diff --git a/graph/sql/optimizers.go b/graph/sql/optimizers.go index c07741b..a109d6a 100644 --- a/graph/sql/optimizers.go +++ b/graph/sql/optimizers.go @@ -23,141 +23,86 @@ import ( "github.com/google/cayley/quad" ) -func intersect(a *StatementIterator, b *StatementIterator) (*StatementIterator, error) { - if a.stType != b.stType { - return nil, errors.New("Cannot combine SQL iterators of two different types") - } - min := a.size - if b.size < a.size { - min = b.size - } - var where clause - if a.where == nil { - if b.where == nil { - where = nil +func intersect(a graph.Iterator, b graph.Iterator) (graph.Iterator, error) { + if anew, ok := a.(*SQLNodeIterator); ok { + if bnew, ok := b.(*SQLNodeIterator); ok { + return intersectNode(anew, bnew) } - where = b.where + } else if anew, ok := a.(*SQLLinkIterator); ok { + if bnew, ok := b.(*SQLLinkIterator); ok { + return intersectLink(anew, bnew) + } + } else { - if b.where == nil { - where = a.where - } - where = joinClause{a.where, b.where, andClause} + return nil, errors.New("Unknown iterator types") } - out := &StatementIterator{ - uid: iterator.NextUID(), - qs: a.qs, - buildWhere: append(a.buildWhere, b.buildWhere...), - tags: append(a.tags, b.tags...), - where: where, - stType: a.stType, - size: min, - dir: a.dir, - } - out.tagger.CopyFrom(a) - out.tagger.CopyFrom(b) - if out.stType == node { - out.buildWhere = append(out.buildWhere, baseClause{ - pair: tableDir{"", a.dir}, - target: tableDir{b.tableName(), b.dir}, - }) - } - return out, nil + return nil, errors.New("Cannot combine SQL iterators of two different types") } -func hasa(a *StatementIterator, d quad.Direction) (*StatementIterator, error) { - if a.stType != link { +func intersectNode(a *SQLNodeIterator, b *SQLNodeIterator) (graph.Iterator, error) { + m := &SQLNodeIterator{ + uid: iterator.NextUID(), + qs: a.qs, + tableName: newTableName(), + linkIts: append(a.linkIts, b.linkIts...), + tagdirs: append(a.tagdirs, b.tagdirs...), + } + m.Tagger().CopyFrom(a) + m.Tagger().CopyFrom(b) + return m, nil +} + +func intersectLink(a *SQLLinkIterator, b *SQLLinkIterator) (graph.Iterator, error) { + m := &SQLLinkIterator{ + uid: iterator.NextUID(), + qs: a.qs, + tableName: newTableName(), + nodeIts: append(a.nodeIts, b.nodeIts...), + constraints: append(a.constraints, b.constraints...), + } + m.Tagger().CopyFrom(a) + m.Tagger().CopyFrom(b) + return m, nil +} + +func hasa(aIn graph.Iterator, d quad.Direction) (graph.Iterator, error) { + a, ok := aIn.(*SQLLinkIterator) + if !ok { return nil, errors.New("Can't take the HASA of a link SQL iterator") } - out := &StatementIterator{ - uid: iterator.NextUID(), - qs: a.qs, - stType: node, - dir: d, - } - where := a.where - for _, w := range a.buildWhere { - w.pair.table = out.tableName() - wherenew := joinClause{where, w, andClause} - where = wherenew - } - out.where = where - //out := &StatementIterator{ - //uid: iterator.NextUID(), - //qs: a.qs, - //stType: node, - //dir: d, - //buildWhere: a.buildWhere, - //where: a.where, - //size: -1, - //} - for k, v := range a.tagger.Fixed() { - out.tagger.AddFixed(k, v) - } - var tags []tag - for _, t := range a.tagger.Tags() { - tags = append(tags, tag{ - pair: tableDir{ - table: out.tableName(), - dir: quad.Any, - }, - t: t, - }) - } - out.tags = append(tags, a.tags...) - return out, nil -} - -func linksto(a *StatementIterator, d quad.Direction) (*StatementIterator, error) { - if a.stType != node { - return nil, errors.New("Can't take the LINKSTO of a node SQL iterator") - } - out := &StatementIterator{ - uid: iterator.NextUID(), - qs: a.qs, - stType: link, - dir: d, - size: -1, - } - where := a.where - for _, w := range a.buildWhere { - w.pair.table = a.tableName() - wherenew := joinClause{where, w, andClause} - where = wherenew - } - - out.where = where - out.buildWhere = []baseClause{ - baseClause{ - pair: tableDir{ + out := &SQLNodeIterator{ + uid: iterator.NextUID(), + qs: a.qs, + tableName: newTableName(), + linkIts: []sqlItDir{ + sqlItDir{ + it: a, dir: d, }, - target: tableDir{ - table: a.tableName(), - dir: a.dir, - }, }, } - var tags []tag - for _, t := range a.tagger.Tags() { - tags = append(tags, tag{ - pair: tableDir{ - table: a.tableName(), - dir: a.dir, + return out, nil +} + +func linksto(aIn graph.Iterator, d quad.Direction) (graph.Iterator, error) { + a, ok := aIn.(*SQLNodeIterator) + if !ok { + return nil, errors.New("Can't take the LINKSTO of a node SQL iterator") + } + + out := &SQLLinkIterator{ + uid: iterator.NextUID(), + qs: a.qs, + tableName: newTableName(), + nodeIts: []sqlItDir{ + sqlItDir{ + it: a, + dir: d, }, - t: t, - }) + }, } - for k, v := range a.tagger.Fixed() { - out.tagger.AddFixed(k, v) - } - for _, t := range a.tags { - if t.pair.table == "" { - t.pair.table = a.tableName() - } - tags = append(tags, t) - } - out.tags = tags + return out, nil } @@ -196,33 +141,34 @@ func (qs *QuadStore) optimizeLinksTo(it *iterator.LinksTo) (graph.Iterator, bool it.Close() return newIt, true } - case sqlBuilderType: - newit, err := linksto(primary.(*StatementIterator), it.Direction()) + case sqlNodeType: + //p := primary.(*SQLNodeIterator) + newit, err := linksto(primary, it.Direction()) if err != nil { glog.Errorln(err) return it, false } newit.Tagger().CopyFrom(it) return newit, true - case graph.All: - newit := &StatementIterator{ - uid: iterator.NextUID(), - qs: qs, - stType: link, - size: qs.Size(), - } - for _, t := range primary.Tagger().Tags() { - newit.tags = append(newit.tags, tag{ - pair: tableDir{"", it.Direction()}, - t: t, - }) - } - for k, v := range primary.Tagger().Fixed() { - newit.tagger.AddFixed(k, v) - } - newit.tagger.CopyFrom(it) + //case graph.All: + //newit := &StatementIterator{ + //uid: iterator.NextUID(), + //qs: qs, + //stType: link, + //size: qs.Size(), + //} + //for _, t := range primary.Tagger().Tags() { + //newit.tags = append(newit.tags, tag{ + //pair: tableDir{"", it.Direction()}, + //t: t, + //}) + //} + //for k, v := range primary.Tagger().Fixed() { + //newit.tagger.AddFixed(k, v) + //} + //newit.tagger.CopyFrom(it) - return newit, true + //return newit, true } return it, false } @@ -230,18 +176,18 @@ func (qs *QuadStore) optimizeLinksTo(it *iterator.LinksTo) (graph.Iterator, bool func (qs *QuadStore) optimizeAnd(it *iterator.And) (graph.Iterator, bool) { subs := it.SubIterators() var unusedIts []graph.Iterator - var newit *StatementIterator + var newit graph.Iterator newit = nil changed := false var err error for _, it := range subs { - if it.Type() == sqlBuilderType { + if it.Type() == sqlLinkType || it.Type() == sqlNodeType { if newit == nil { - newit = it.(*StatementIterator) + newit = it } else { changed = true - newit, err = intersect(newit, it.(*StatementIterator)) + newit, err = intersect(newit, it) if err != nil { glog.Error(err) return it, false @@ -256,7 +202,7 @@ func (qs *QuadStore) optimizeAnd(it *iterator.And) (graph.Iterator, bool) { return it, false } if len(unusedIts) == 0 { - newit.tagger.CopyFrom(it) + newit.Tagger().CopyFrom(it) return newit, true } newAnd := iterator.NewAnd(qs) @@ -274,8 +220,8 @@ func (qs *QuadStore) optimizeHasA(it *iterator.HasA) (graph.Iterator, bool) { return it, false } primary := subs[0] - if primary.Type() == sqlBuilderType { - newit, err := hasa(primary.(*StatementIterator), it.Direction()) + if primary.Type() == sqlLinkType { + newit, err := hasa(primary, it.Direction()) if err != nil { glog.Errorln(err) return it, false diff --git a/graph/sql/optimizers_test.go b/graph/sql/optimizers_test.go index 916fa25..229b91e 100644 --- a/graph/sql/optimizers_test.go +++ b/graph/sql/optimizers_test.go @@ -23,41 +23,43 @@ import ( ) func TestBuildIntersect(t *testing.T) { - a := NewStatementIterator(nil, quad.Subject, "Foo") - b := NewStatementIterator(nil, quad.Predicate, "is_equivalent_to") + a := NewSQLLinkIterator(nil, quad.Subject, "Foo") + b := NewSQLLinkIterator(nil, quad.Predicate, "is_equivalent_to") it, err := intersect(a, b) + i := it.(*SQLLinkIterator) if err != nil { t.Error(err) } - s, v := it.buildQuery(false, nil) + s, v := i.buildSQL(true, nil) fmt.Println(s, v) } func TestBuildHasa(t *testing.T) { - a := NewStatementIterator(nil, quad.Subject, "Foo") + a := NewSQLLinkIterator(nil, quad.Subject, "Foo") a.tagger.Add("foo") - b := NewStatementIterator(nil, quad.Predicate, "is_equivalent_to") + b := NewSQLLinkIterator(nil, quad.Predicate, "is_equivalent_to") it1, err := intersect(a, b) if err != nil { t.Error(err) } it2, err := hasa(it1, quad.Object) + i2 := it2.(*SQLNodeIterator) if err != nil { t.Error(err) } - s, v := it2.buildQuery(false, nil) + s, v := i2.buildSQL(true, nil) fmt.Println(s, v) } func TestBuildLinksTo(t *testing.T) { - a := NewStatementIterator(nil, quad.Subject, "Foo") - b := NewStatementIterator(nil, quad.Predicate, "is_equivalent_to") + a := NewSQLLinkIterator(nil, quad.Subject, "Foo") + b := NewSQLLinkIterator(nil, quad.Predicate, "is_equivalent_to") it1, err := intersect(a, b) if err != nil { t.Error(err) } it2, err := hasa(it1, quad.Object) - it2.tagger.Add("foo") + it2.Tagger().Add("foo") if err != nil { t.Error(err) } @@ -65,7 +67,8 @@ func TestBuildLinksTo(t *testing.T) { if err != nil { t.Error(err) } - s, v := it3.buildQuery(false, nil) + i3 := it3.(*SQLLinkIterator) + s, v := i3.buildSQL(true, nil) fmt.Println(s, v) } @@ -77,8 +80,8 @@ func TestInterestingQuery(t *testing.T) { if err != nil { t.Fatal(err) } - a := NewStatementIterator(db.(*QuadStore), quad.Object, "Humphrey Bogart") - b := NewStatementIterator(db.(*QuadStore), quad.Predicate, "name") + a := NewSQLLinkIterator(db.(*QuadStore), quad.Object, "Humphrey Bogart") + b := NewSQLLinkIterator(db.(*QuadStore), quad.Predicate, "name") it1, err := intersect(a, b) if err != nil { t.Error(err) @@ -92,7 +95,7 @@ func TestInterestingQuery(t *testing.T) { if err != nil { t.Error(err) } - b = NewStatementIterator(db.(*QuadStore), quad.Predicate, "/film/performance/actor") + b = NewSQLLinkIterator(db.(*QuadStore), quad.Predicate, "/film/performance/actor") it4, err := intersect(it3, b) if err != nil { t.Error(err) @@ -105,7 +108,7 @@ func TestInterestingQuery(t *testing.T) { if err != nil { t.Error(err) } - b = NewStatementIterator(db.(*QuadStore), quad.Predicate, "/film/film/starring") + b = NewSQLLinkIterator(db.(*QuadStore), quad.Predicate, "/film/film/starring") it7, err := intersect(it6, b) if err != nil { t.Error(err) @@ -114,13 +117,14 @@ func TestInterestingQuery(t *testing.T) { if err != nil { t.Error(err) } - s, v := it8.buildQuery(false, nil) - it8.Tagger().Add("id") + finalIt := it8.(*SQLNodeIterator) + s, v := finalIt.buildSQL(true, nil) + finalIt.Tagger().Add("id") fmt.Println(s, v) - for graph.Next(it8) { - fmt.Println(it8.Result()) + for graph.Next(finalIt) { + fmt.Println(finalIt.Result()) out := make(map[string]graph.Value) - it8.TagResults(out) + finalIt.TagResults(out) for k, v := range out { fmt.Printf("%s: %v\n", k, v.(string)) } diff --git a/graph/sql/quadstore.go b/graph/sql/quadstore.go index 0cad60d..3181f2b 100644 --- a/graph/sql/quadstore.go +++ b/graph/sql/quadstore.go @@ -200,7 +200,7 @@ func (qs *QuadStore) Quad(val graph.Value) quad.Quad { } func (qs *QuadStore) QuadIterator(d quad.Direction, val graph.Value) graph.Iterator { - return NewStatementIterator(qs, d, val.(string)) + return NewSQLLinkIterator(qs, d, val.(string)) } func (qs *QuadStore) NodesAllIterator() graph.Iterator { diff --git a/graph/sql/sql_link_iterator.go b/graph/sql/sql_link_iterator.go new file mode 100644 index 0000000..5f913d3 --- /dev/null +++ b/graph/sql/sql_link_iterator.go @@ -0,0 +1,525 @@ +// Copyright 2015 The Cayley Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql" + "fmt" + "strings" + "sync/atomic" + + "github.com/barakmich/glog" + "github.com/google/cayley/graph" + "github.com/google/cayley/graph/iterator" + "github.com/google/cayley/quad" +) + +var sqlLinkType graph.Type +var sqlTableID uint64 + +func init() { + sqlLinkType = graph.RegisterIterator("sqllink") + sqlNodeType = graph.RegisterIterator("sqlnode") + atomic.StoreUint64(&sqlTableID, 0) +} + +func newTableName() string { + id := atomic.AddUint64(&sqlTableID, 1) + return fmt.Sprintf("t_%d", id) +} + +type constraint struct { + dir quad.Direction + vals []string +} + +type tagDir struct { + tag string + dir quad.Direction + + // Not to be stored in the iterator directly + table string +} + +type sqlItDir struct { + dir quad.Direction + it sqlIterator +} + +type sqlIterator interface { + sqlClone() sqlIterator + getTables() []string + getTags() []tagDir + buildWhere() (string, []string) + tableID() tagDir + height() int +} + +type SQLLinkIterator struct { + uid uint64 + qs *QuadStore + tagger graph.Tagger + err error + next bool + + cursor *sql.Rows + nodeIts []sqlItDir + constraints []constraint + tableName string + size int64 + + result map[string]string + resultIndex int + resultList [][]string + resultNext [][]string + cols []string + resultQuad quad.Quad +} + +func NewSQLLinkIterator(qs *QuadStore, d quad.Direction, val string) *SQLLinkIterator { + l := &SQLLinkIterator{ + uid: iterator.NextUID(), + qs: qs, + constraints: []constraint{ + constraint{ + dir: d, + vals: []string{val}, + }, + }, + tableName: newTableName(), + size: 0, + } + return l +} + +func (l *SQLLinkIterator) sqlClone() sqlIterator { + return l.Clone().(*SQLLinkIterator) +} + +func (l *SQLLinkIterator) Clone() graph.Iterator { + m := &SQLLinkIterator{ + uid: iterator.NextUID(), + qs: l.qs, + tableName: l.tableName, + size: l.size, + } + for _, i := range l.nodeIts { + m.nodeIts = append(m.nodeIts, sqlItDir{ + dir: i.dir, + it: i.it.sqlClone(), + }) + } + m.constraints = l.constraints[:] + m.tagger.CopyFrom(l) + return m +} + +func (l *SQLLinkIterator) UID() uint64 { + return l.uid +} + +func (l *SQLLinkIterator) Reset() { + l.err = nil + l.Close() +} + +func (l *SQLLinkIterator) Err() error { + return l.err +} + +func (l *SQLLinkIterator) Close() error { + if l.cursor != nil { + err := l.cursor.Close() + if err != nil { + return err + } + l.cursor = nil + } + return nil +} + +func (l *SQLLinkIterator) Tagger() *graph.Tagger { + return &l.tagger +} + +func (l *SQLLinkIterator) Result() graph.Value { + return l.resultQuad +} + +func (l *SQLLinkIterator) TagResults(dst map[string]graph.Value) { + for tag, value := range l.result { + if tag == "__execd" { + for _, tag := range l.tagger.Tags() { + dst[tag] = value + } + continue + } + dst[tag] = value + } + + for tag, value := range l.tagger.Fixed() { + dst[tag] = value + } +} + +func (l *SQLLinkIterator) SubIterators() []graph.Iterator { + // TODO(barakmich): SQL Subiterators shouldn't count? If it makes sense, + // there's no reason not to expose them though. + return nil +} + +func (l *SQLLinkIterator) Sorted() bool { return false } +func (l *SQLLinkIterator) Optimize() (graph.Iterator, bool) { return l, false } + +func (l *SQLLinkIterator) Size() (int64, bool) { + if l.size != 0 { + return l.size, true + } + if len(l.constraints) > 0 { + l.size = l.qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0]) + } else { + return l.qs.Size(), false + } + return l.size, true +} + +func (l *SQLLinkIterator) Describe() graph.Description { + size, _ := l.Size() + return graph.Description{ + UID: l.UID(), + Name: fmt.Sprintf("SQL_LINK_QUERY: %#v", l), + Type: l.Type(), + Size: size, + } +} + +func (l *SQLLinkIterator) Stats() graph.IteratorStats { + size, _ := l.Size() + return graph.IteratorStats{ + ContainsCost: 1, + NextCost: 5, + Size: size, + } +} + +func (l *SQLLinkIterator) Type() graph.Type { + return sqlLinkType +} + +func (l *SQLLinkIterator) Contains(v graph.Value) bool { + var err error + //if it.preFilter(v) { + //return false + //} + err = l.makeCursor(false, v) + if err != nil { + glog.Errorf("Couldn't make query: %v", err) + l.err = err + l.cursor.Close() + return false + } + l.cols, err = l.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + l.err = err + l.cursor.Close() + return false + } + l.resultList = nil + for { + if !l.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := l.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + l.err = err + } + l.cursor.Close() + break + } + s, err := scan(l.cursor, len(l.cols)) + if err != nil { + l.err = err + l.cursor.Close() + return false + } + l.resultList = append(l.resultList, s) + } + l.cursor.Close() + l.cursor = nil + if len(l.resultList) != 0 { + l.resultIndex = 0 + l.buildResult(0) + return true + } + return false +} + +func (l *SQLLinkIterator) NextPath() bool { + l.resultIndex += 1 + if l.resultIndex >= len(l.resultList) { + return false + } + l.buildResult(l.resultIndex) + return true +} + +func (l *SQLLinkIterator) buildResult(i int) { + container := l.resultList[i] + var q quad.Quad + q.Subject = container[0] + q.Predicate = container[1] + q.Object = container[2] + q.Label = container[3] + l.resultQuad = q + l.result = make(map[string]string) + for i, c := range l.cols[4:] { + l.result[c] = container[i+4] + } +} + +func (l *SQLLinkIterator) getTables() []string { + out := []string{l.tableName} + //for _, i := range l.nodeIts { + //out = append(out, i.it.getTables()...) + //} + return out +} + +func (l *SQLLinkIterator) height() int { + v := 0 + for _, i := range l.nodeIts { + if i.it.height() > v { + v = i.it.height() + } + } + return v + 1 +} + +func (l *SQLLinkIterator) getTags() []tagDir { + var out []tagDir + for _, tag := range l.tagger.Tags() { + out = append(out, tagDir{ + dir: quad.Any, + table: l.tableName, + tag: tag, + }) + } + //for _, i := range l.nodeIts { + //out = append(out, i.it.getTags()...) + //} + return out +} + +func (l *SQLLinkIterator) buildWhere() (string, []string) { + var q []string + var vals []string + for _, c := range l.constraints { + q = append(q, fmt.Sprintf("%s.%s = ?", l.tableName, c.dir)) + vals = append(vals, c.vals[0]) + } + for _, i := range l.nodeIts { + sni := i.it.(*SQLNodeIterator) + sql, s := sni.buildSQL(true, nil) + q = append(q, fmt.Sprintf("%s.%s in (%s)", l.tableName, i.dir, sql[:len(sql)-1])) + vals = append(vals, s...) + //q = append(q, fmt.Sprintf("%s.%s = %s.%s", l.tableName, i.dir, t.table, t.dir)) + } + //for _, i := range l.nodeIts { + //s, v := i.it.buildWhere() + //q = append(q, s) + //vals = append(vals, v...) + //} + query := strings.Join(q, " AND ") + return query, vals +} + +func (l *SQLLinkIterator) tableID() tagDir { + return tagDir{ + dir: quad.Any, + table: l.tableName, + } +} + +func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string) { + query := "SELECT " + t := []string{ + fmt.Sprintf("%s.subject", l.tableName), + fmt.Sprintf("%s.predicate", l.tableName), + fmt.Sprintf("%s.object", l.tableName), + fmt.Sprintf("%s.label", l.tableName), + } + for _, v := range l.getTags() { + t = append(t, fmt.Sprintf("%s.%s as %s", v.table, v.dir, v.tag)) + } + query += strings.Join(t, ", ") + query += " FROM " + t = []string{} + for _, k := range l.getTables() { + t = append(t, fmt.Sprintf("quads as %s", k)) + } + query += strings.Join(t, ", ") + query += " WHERE " + l.next = next + constraint, values := l.buildWhere() + + if !next { + v := val.(quad.Quad) + if constraint != "" { + constraint += " AND " + } + t = []string{ + fmt.Sprintf("%s.subject = ?", l.tableName), + fmt.Sprintf("%s.predicate = ?", l.tableName), + fmt.Sprintf("%s.object = ?", l.tableName), + fmt.Sprintf("%s.label = ?", l.tableName), + } + constraint += strings.Join(t, " AND ") + values = append(values, v.Subject) + values = append(values, v.Predicate) + values = append(values, v.Object) + values = append(values, v.Label) + } + query += constraint + query += ";" + + glog.V(2).Infoln(query) + + if glog.V(4) { + dstr := query + for i := 1; i <= len(values); i++ { + dstr = strings.Replace(dstr, "?", fmt.Sprintf("'%s'", values[i-1]), 1) + } + glog.V(4).Infoln(dstr) + } + return query, values +} + +func convertToPostgres(query string, values []string) string { + for i := 1; i <= len(values); i++ { + query = strings.Replace(query, "?", fmt.Sprintf("$%d", i), 1) + } + return query +} + +func (l *SQLLinkIterator) makeCursor(next bool, value graph.Value) error { + if l.cursor != nil { + l.cursor.Close() + } + var q string + var values []string + q, values = l.buildSQL(next, value) + q = convertToPostgres(q, values) + ivalues := make([]interface{}, 0, len(values)) + for _, v := range values { + ivalues = append(ivalues, v) + } + cursor, err := l.qs.db.Query(q, ivalues...) + if err != nil { + glog.Errorf("Couldn't get cursor from SQL database: %v", err) + cursor = nil + return err + } + l.cursor = cursor + return nil +} + +func scan(cursor *sql.Rows, nCols int) ([]string, error) { + pointers := make([]interface{}, nCols) + container := make([]string, nCols) + for i, _ := range pointers { + pointers[i] = &container[i] + } + err := cursor.Scan(pointers...) + if err != nil { + glog.Errorf("Error scanning iterator: %v", err) + return nil, err + } + return container, nil +} + +func (l *SQLLinkIterator) Next() bool { + var err error + graph.NextLogIn(l) + if l.cursor == nil { + err = l.makeCursor(true, nil) + l.cols, err = l.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + l.err = err + l.cursor.Close() + return false + } + // iterate the first one + if !l.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := l.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + l.err = err + } + l.cursor.Close() + return false + } + s, err := scan(l.cursor, len(l.cols)) + if err != nil { + l.err = err + l.cursor.Close() + return false + } + l.resultNext = append(l.resultNext, s) + } + if l.resultList != nil && l.resultNext == nil { + // We're on something and there's no next + return false + } + l.resultList = l.resultNext + l.resultNext = nil + l.resultIndex = 0 + for { + if !l.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := l.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + l.err = err + } + l.cursor.Close() + break + } + s, err := scan(l.cursor, len(l.cols)) + if err != nil { + l.err = err + l.cursor.Close() + return false + } + if l.resultList[0][0] == s[0] && l.resultList[0][1] == s[1] && l.resultList[0][2] == s[2] && l.resultList[0][3] == s[3] { + l.resultList = append(l.resultList, s) + } else { + l.resultNext = append(l.resultNext, s) + break + } + + } + if len(l.resultList) == 0 { + return graph.NextLogOut(l, nil, false) + } + l.buildResult(0) + return graph.NextLogOut(l, l.Result(), true) +} + +type SQLAllIterator struct { + // TBD +} diff --git a/graph/sql/sql_link_iterator_test.go b/graph/sql/sql_link_iterator_test.go new file mode 100644 index 0000000..b13e389 --- /dev/null +++ b/graph/sql/sql_link_iterator_test.go @@ -0,0 +1,87 @@ +// Copyright 2015 The Cayley Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "fmt" + "testing" + + "github.com/google/cayley/graph" + "github.com/google/cayley/graph/iterator" + "github.com/google/cayley/quad" +) + +func TestSQLLink(t *testing.T) { + it := NewSQLLinkIterator(nil, quad.Object, "cool") + s, v := it.buildSQL(true, nil) + fmt.Println(s, v) +} + +func TestSQLLinkIteration(t *testing.T) { + if *dbpath == "" { + t.SkipNow() + } + db, err := newQuadStore(*dbpath, nil) + if err != nil { + t.Fatal(err) + } + it := NewSQLLinkIterator(db.(*QuadStore), quad.Object, "Humphrey Bogart") + for graph.Next(it) { + fmt.Println(it.Result()) + } + it = NewSQLLinkIterator(db.(*QuadStore), quad.Subject, "/en/casablanca_1942") + s, v := it.buildSQL(true, nil) + fmt.Println(s, v) + c := 0 + for graph.Next(it) { + fmt.Println(it.Result()) + c += 1 + } + if c != 18 { + t.Errorf("Not enough results, got %d expected 18", c) + } +} + +func TestSQLNodeIteration(t *testing.T) { + if *dbpath == "" { + t.SkipNow() + } + db, err := newQuadStore(*dbpath, nil) + if err != nil { + t.Fatal(err) + } + link := NewSQLLinkIterator(db.(*QuadStore), quad.Object, "/en/humphrey_bogart") + it := &SQLNodeIterator{ + uid: iterator.NextUID(), + qs: db.(*QuadStore), + tableName: newTableName(), + linkIts: []sqlItDir{ + sqlItDir{it: link, + dir: quad.Subject, + }, + }, + } + s, v := it.buildSQL(true, nil) + fmt.Println(s, v) + c := 0 + for graph.Next(it) { + fmt.Println(it.Result()) + c += 1 + } + if c != 56 { + t.Errorf("Not enough results, got %d expected 56", c) + } + +} diff --git a/graph/sql/sql_node_iterator.go b/graph/sql/sql_node_iterator.go new file mode 100644 index 0000000..e276399 --- /dev/null +++ b/graph/sql/sql_node_iterator.go @@ -0,0 +1,435 @@ +// Copyright 2015 The Cayley Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/barakmich/glog" + "github.com/google/cayley/graph" + "github.com/google/cayley/graph/iterator" + "github.com/google/cayley/quad" +) + +var sqlNodeType graph.Type + +func init() { + sqlNodeType = graph.RegisterIterator("sqlnode") +} + +type SQLNodeIterator struct { + uid uint64 + qs *QuadStore + tagger graph.Tagger + tableName string + err error + + cursor *sql.Rows + linkIts []sqlItDir + size int64 + tagdirs []tagDir + + result map[string]string + resultIndex int + resultList [][]string + resultNext [][]string + cols []string +} + +func (n *SQLNodeIterator) sqlClone() sqlIterator { + return n.Clone().(*SQLNodeIterator) +} + +func (n *SQLNodeIterator) Clone() graph.Iterator { + m := &SQLNodeIterator{ + uid: iterator.NextUID(), + qs: n.qs, + size: n.size, + tableName: n.tableName, + } + for _, i := range n.linkIts { + m.linkIts = append(m.linkIts, sqlItDir{ + dir: i.dir, + it: i.it.sqlClone(), + }) + } + copy(n.tagdirs, m.tagdirs) + m.tagger.CopyFrom(n) + return m +} + +func (n *SQLNodeIterator) UID() uint64 { + return n.uid +} + +func (n *SQLNodeIterator) Reset() { + n.err = nil + n.Close() +} + +func (n *SQLNodeIterator) Err() error { + return n.err +} + +func (n *SQLNodeIterator) Close() error { + if n.cursor != nil { + err := n.cursor.Close() + if err != nil { + return err + } + n.cursor = nil + } + return nil +} + +func (n *SQLNodeIterator) Tagger() *graph.Tagger { + return &n.tagger +} + +func (n *SQLNodeIterator) Result() graph.Value { + return n.result["__execd"] +} + +func (n *SQLNodeIterator) TagResults(dst map[string]graph.Value) { + for tag, value := range n.result { + if tag == "__execd" { + for _, tag := range n.tagger.Tags() { + dst[tag] = value + } + continue + } + dst[tag] = value + } + + for tag, value := range n.tagger.Fixed() { + dst[tag] = value + } +} + +func (n *SQLNodeIterator) Type() graph.Type { + return sqlNodeType +} + +func (n *SQLNodeIterator) SubIterators() []graph.Iterator { + // TODO(barakmich): SQL Subiterators shouldn't count? If it makes sense, + // there's no reason not to expose them though. + return nil +} + +func (n *SQLNodeIterator) Sorted() bool { return false } +func (n *SQLNodeIterator) Optimize() (graph.Iterator, bool) { return n, false } + +func (n *SQLNodeIterator) Size() (int64, bool) { + return n.qs.Size() / int64(len(n.linkIts)+1), true +} + +func (n *SQLNodeIterator) Describe() graph.Description { + size, _ := n.Size() + return graph.Description{ + UID: n.UID(), + Name: fmt.Sprintf("SQL_NODE_QUERY: %#v", n), + Type: n.Type(), + Size: size, + } +} + +func (n *SQLNodeIterator) Stats() graph.IteratorStats { + size, _ := n.Size() + return graph.IteratorStats{ + ContainsCost: 1, + NextCost: 5, + Size: size, + } +} + +func (n *SQLNodeIterator) NextPath() bool { + n.resultIndex += 1 + if n.resultIndex >= len(n.resultList) { + return false + } + n.buildResult(n.resultIndex) + return true +} + +func (n *SQLNodeIterator) buildResult(i int) { + container := n.resultList[i] + n.result = make(map[string]string) + for i, c := range n.cols { + n.result[c] = container[i] + } +} + +func (n *SQLNodeIterator) getTables() []string { + var out []string + for _, i := range n.linkIts { + out = append(out, i.it.getTables()...) + } + if len(out) == 0 { + out = append(out, n.tableName) + } + return out +} + +func (n *SQLNodeIterator) tableID() tagDir { + if len(n.linkIts) == 0 { + return tagDir{ + table: n.tableName, + dir: quad.Any, + } + } + return tagDir{ + table: n.linkIts[0].it.tableID().table, + dir: n.linkIts[0].dir, + } +} + +func (n *SQLNodeIterator) getTags() []tagDir { + myTag := n.tableID() + var out []tagDir + for _, tag := range n.tagger.Tags() { + out = append(out, tagDir{ + dir: myTag.dir, + table: myTag.table, + tag: tag, + }) + } + for _, tag := range n.tagdirs { + out = append(out, tagDir{ + dir: tag.dir, + table: myTag.table, + tag: tag.tag, + }) + + } + for _, i := range n.linkIts { + out = append(out, i.it.getTags()...) + } + return out +} + +func (n *SQLNodeIterator) height() int { + v := 0 + for _, i := range n.linkIts { + if i.it.height() > v { + v = i.it.height() + } + } + return v + 1 +} + +func (n *SQLNodeIterator) buildWhere() (string, []string) { + var q []string + var vals []string + if len(n.linkIts) > 1 { + baseTable := n.linkIts[0].it.tableID().table + baseDir := n.linkIts[0].dir + for _, i := range n.linkIts[1:] { + table := i.it.tableID().table + dir := i.dir + q = append(q, fmt.Sprintf("%s.%s = %s.%s", baseTable, baseDir, table, dir)) + } + } + for _, i := range n.linkIts { + s, v := i.it.buildWhere() + q = append(q, s) + vals = append(vals, v...) + } + query := strings.Join(q, " AND ") + return query, vals +} + +func (n *SQLNodeIterator) buildSQL(next bool, val graph.Value) (string, []string) { + topData := n.tableID() + query := "SELECT " + var t []string + t = append(t, fmt.Sprintf("%s.%s as __execd", topData.table, topData.dir)) + for _, v := range n.getTags() { + t = append(t, fmt.Sprintf("%s.%s as %s", v.table, v.dir, v.tag)) + } + query += strings.Join(t, ", ") + query += " FROM " + t = []string{} + for _, k := range n.getTables() { + t = append(t, fmt.Sprintf("quads as %s", k)) + } + query += strings.Join(t, ", ") + query += " WHERE " + constraint, values := n.buildWhere() + + if !next { + v := val.(string) + if constraint != "" { + constraint += " AND " + } + constraint += fmt.Sprintf("%s.%s = ?", topData.table, topData.dir) + values = append(values, v) + } + query += constraint + query += ";" + + glog.V(2).Infoln(query) + + if glog.V(4) { + dstr := query + for i := 1; i <= len(values); i++ { + dstr = strings.Replace(dstr, "?", fmt.Sprintf("'%s'", values[i-1]), 1) + } + glog.V(4).Infoln(dstr) + } + return query, values +} + +func (n *SQLNodeIterator) Next() bool { + var err error + graph.NextLogIn(n) + if n.cursor == nil { + err = n.makeCursor(true, nil) + n.cols, err = n.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + n.err = err + n.cursor.Close() + return false + } + // iterate the first one + if !n.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := n.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + n.err = err + } + n.cursor.Close() + return false + } + s, err := scan(n.cursor, len(n.cols)) + if err != nil { + n.err = err + n.cursor.Close() + return false + } + n.resultNext = append(n.resultNext, s) + } + if n.resultList != nil && n.resultNext == nil { + // We're on something and there's no next + return false + } + n.resultList = n.resultNext + n.resultNext = nil + n.resultIndex = 0 + for { + if !n.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := n.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + n.err = err + } + n.cursor.Close() + break + } + s, err := scan(n.cursor, len(n.cols)) + if err != nil { + n.err = err + n.cursor.Close() + return false + } + if n.resultList[0][0] != s[0] { + n.resultNext = append(n.resultNext, s) + break + } else { + n.resultList = append(n.resultList, s) + } + + } + if len(n.resultList) == 0 { + return graph.NextLogOut(n, nil, false) + } + n.buildResult(0) + return graph.NextLogOut(n, n.Result(), true) +} + +func (n *SQLNodeIterator) makeCursor(next bool, value graph.Value) error { + if n.cursor != nil { + n.cursor.Close() + } + var q string + var values []string + q, values = n.buildSQL(next, value) + q = convertToPostgres(q, values) + ivalues := make([]interface{}, 0, len(values)) + for _, v := range values { + ivalues = append(ivalues, v) + } + cursor, err := n.qs.db.Query(q, ivalues...) + if err != nil { + glog.Errorf("Couldn't get cursor from SQL database: %v", err) + cursor = nil + return err + } + n.cursor = cursor + return nil +} + +func (n *SQLNodeIterator) Contains(v graph.Value) bool { + var err error + //if it.preFilter(v) { + //return false + //} + err = n.makeCursor(false, v) + if err != nil { + glog.Errorf("Couldn't make query: %v", err) + n.err = err + n.cursor.Close() + return false + } + n.cols, err = n.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + n.err = err + n.cursor.Close() + return false + } + n.resultList = nil + for { + if !n.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := n.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + n.err = err + } + n.cursor.Close() + break + } + s, err := scan(n.cursor, len(n.cols)) + if err != nil { + n.err = err + n.cursor.Close() + return false + } + n.resultList = append(n.resultList, s) + } + n.cursor.Close() + n.cursor = nil + if len(n.resultList) != 0 { + n.resultIndex = 0 + n.buildResult(0) + return true + } + return false +} diff --git a/integration/integration_test.go b/integration/integration_test.go index 05ddac0..76f4178 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -51,11 +51,14 @@ var benchmarkQueries = []struct { long bool query string tag string - expect []interface{} + // for testing + skip bool + expect []interface{} }{ // Easy one to get us started. How quick is the most straightforward retrieval? { message: "name predicate", + skip: true, query: ` g.V("Humphrey Bogart").In("name").All() `, @@ -69,6 +72,7 @@ var benchmarkQueries = []struct { // that's going to be measurably slower for every other backend. { message: "two large sets with no intersection", + skip: true, query: ` function getId(x) { return g.V(x).In("name") } var actor_to_film = g.M().In("/film/performance/actor").In("/film/film/starring") @@ -534,6 +538,9 @@ func checkQueries(t *testing.T) { if testing.Short() && test.long { continue } + if test.skip { + continue + } fmt.Printf("Now testing %s\n", test.message) ses := gremlin.NewSession(handle.QuadStore, cfg.Timeout, true) _, err := ses.Parse(test.query)