diff --git a/graph/mongo/indexed_linksto.go b/graph/mongo/indexed_linksto.go index 249ad0b..7406143 100644 --- a/graph/mongo/indexed_linksto.go +++ b/graph/mongo/indexed_linksto.go @@ -15,7 +15,6 @@ package mongo import ( - "github.com/barakmich/glog" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" @@ -52,17 +51,25 @@ func NewLinksTo(qs *QuadStore, it graph.Iterator, collection string, d quad.Dire } } -func (it *LinksTo) buildIteratorFor(d quad.Direction, val graph.Value) *mgo.Iter { - name := it.qs.NameOf(val) - constraint := bson.M{d.String(): name} +func (it *LinksTo) buildConstraint() bson.M { + constraint := bson.M{} for _, set := range it.lset { var s []string for _, v := range set.Values { s = append(s, it.qs.NameOf(v)) } constraint[set.Dir.String()] = bson.M{"$in": s} + if len(s) == 1 { + constraint[set.Dir.String()] = s[0] + } } - glog.V(4).Infof("%#v", constraint) + return constraint +} + +func (it *LinksTo) buildIteratorFor(d quad.Direction, val graph.Value) *mgo.Iter { + name := it.qs.NameOf(val) + constraint := it.buildConstraint() + constraint[d.String()] = name return it.qs.db.C(it.collection).Find(constraint).Iter() } @@ -244,10 +251,17 @@ func (it *LinksTo) Stats() graph.IteratorStats { fanoutFactor := int64(20) checkConstant := int64(1) nextConstant := int64(2) + + size := fanoutFactor * subitStats.Size + csize, _ := it.qs.getSize(it.collection, it.buildConstraint()) + if size > csize { + size = csize + } + return graph.IteratorStats{ NextCost: nextConstant + subitStats.NextCost, ContainsCost: checkConstant + subitStats.ContainsCost, - Size: fanoutFactor * subitStats.Size, + Size: size, Next: it.runstats.Next, Contains: it.runstats.Contains, ContainsNext: it.runstats.ContainsNext, diff --git a/graph/mongo/iterator.go b/graph/mongo/iterator.go index b607236..5362971 100644 --- a/graph/mongo/iterator.go +++ b/graph/mongo/iterator.go @@ -198,7 +198,7 @@ func (it *Iterator) Contains(v graph.Value) bool { func (it *Iterator) Size() (int64, bool) { if it.size == -1 { var err error - it.size, err = it.qs.getSize(it.collection, &it.constraint) + it.size, err = it.qs.getSize(it.collection, it.constraint) if err != nil { it.err = err } diff --git a/graph/mongo/quadstore.go b/graph/mongo/quadstore.go index 7f01b86..78b85a7 100644 --- a/graph/mongo/quadstore.go +++ b/graph/mongo/quadstore.go @@ -380,7 +380,7 @@ func (qs *QuadStore) Type() string { return QuadStoreType } -func (qs *QuadStore) getSize(collection string, constraint *bson.M) (int64, error) { +func (qs *QuadStore) getSize(collection string, constraint bson.M) (int64, error) { var size int var err error bytes, err := bson.Marshal(constraint)