From 425292811b0c22d10c2eab7d8f8f615bd3e102a9 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Mon, 27 Jul 2015 16:53:34 -0400 Subject: [PATCH] First reasonably fast integration test --- graph/sql/builder_iterator.go | 643 ------------------------------------ graph/sql/builder_iterator_test.go | 110 ------ graph/sql/optimizers.go | 37 +-- graph/sql/sql_link_iterator.go | 89 ++++- graph/sql/sql_link_iterator_test.go | 3 + graph/sql/sql_node_iterator.go | 153 ++++++--- integration/integration_test.go | 7 +- 7 files changed, 209 insertions(+), 833 deletions(-) delete mode 100644 graph/sql/builder_iterator.go delete mode 100644 graph/sql/builder_iterator_test.go diff --git a/graph/sql/builder_iterator.go b/graph/sql/builder_iterator.go deleted file mode 100644 index 867789d..0000000 --- a/graph/sql/builder_iterator.go +++ /dev/null @@ -1,643 +0,0 @@ -// Copyright 2015 The Cayley Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/barakmich/glog" - "github.com/google/cayley/graph" - "github.com/google/cayley/graph/iterator" - "github.com/google/cayley/quad" -) - -var sqlBuilderType graph.Type - -func init() { - sqlBuilderType = graph.RegisterIterator("sqlbuilder") -} - -type tableDir struct { - table string - dir quad.Direction -} - -func (td tableDir) String() string { - if td.table != "" { - return fmt.Sprintf("%s.%s", td.table, td.dir) - } - return "ERR" -} - -type clause interface { - toSQL() (string, []string) - getTables() map[string]bool - size() int -} - -type baseClause struct { - pair tableDir - strTarget []string - target tableDir -} - -func (b baseClause) toSQL() (string, []string) { - if len(b.strTarget) > 1 { - // TODO(barakmich): Sets of things, IN clause - return "", []string{} - } - if len(b.strTarget) == 0 { - return fmt.Sprintf("%s = %s", b.pair, b.target), []string{} - } - return fmt.Sprintf("%s = ?", b.pair), []string{b.strTarget[0]} -} - -func (b baseClause) size() int { return 1 } - -func (b baseClause) getTables() map[string]bool { - out := make(map[string]bool) - if b.pair.table != "" { - out[b.pair.table] = true - } - if b.target.table != "" { - out[b.target.table] = true - } - return out -} - -type joinClause struct { - left clause - right clause - op clauseOp -} - -func (jc joinClause) size() int { - size := 0 - if jc.left != nil { - size += jc.left.size() - } - if jc.right != nil { - size += jc.right.size() - } - return size -} - -func (jc joinClause) toSQL() (string, []string) { - if jc.left == nil { - if jc.right == nil { - return "", []string{} - } - return jc.right.toSQL() - } - if jc.right == nil { - return jc.left.toSQL() - } - l, lstr := jc.left.toSQL() - r, rstr := jc.right.toSQL() - lstr = append(lstr, rstr...) - var op string - switch jc.op { - case andClause: - op = "AND" - case orClause: - op = "OR" - } - return fmt.Sprintf("(%s %s %s)", l, op, r), lstr -} - -func (jc joinClause) getTables() map[string]bool { - var m map[string]bool - if jc.left != nil { - m = jc.left.getTables() - } else { - m = make(map[string]bool) - } - if jc.right != nil { - for k, _ := range jc.right.getTables() { - m[k] = true - } - } - return m -} - -type tag struct { - pair tableDir - t string -} - -type statementType int - -const ( - node statementType = iota - link -) - -type clauseOp int - -const ( - andClause clauseOp = iota - orClause -) - -func (it *StatementIterator) canonicalizeWhere() (string, []string) { - var out []string - var values []string - for _, b := range it.buildWhere { - b.pair.table = it.tableName() - s, v := b.toSQL() - values = append(values, v...) - out = append(out, s) - } - return strings.Join(out, " AND "), values -} - -func (it *StatementIterator) getTables() map[string]bool { - m := make(map[string]bool) - if it.where != nil { - m = it.where.getTables() - } - for _, t := range it.tags { - if t.pair.table != "" { - m[t.pair.table] = true - } - } - return m -} - -func (it *StatementIterator) tableName() string { - return fmt.Sprintf("t_%d", it.uid) -} - -func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, []string) { - str := "SELECT " - var t []string - if it.stType == link { - t = []string{ - fmt.Sprintf("%s.subject", it.tableName()), - fmt.Sprintf("%s.predicate", it.tableName()), - fmt.Sprintf("%s.object", it.tableName()), - fmt.Sprintf("%s.label", it.tableName()), - } - } else { - t = []string{fmt.Sprintf("%s.%s as __execd", it.tableName(), it.dir)} - } - for _, v := range it.tags { - if v.pair.table == "" { - v.pair.table = it.tableName() - } - t = append(t, fmt.Sprintf("%s as %s", v.pair, v.t)) - } - for _, v := range it.tagger.Tags() { - t = append(t, fmt.Sprintf("%s as %s", tableDir{it.tableName(), it.dir}, v)) - } - str += strings.Join(t, ", ") - str += " FROM " - t = []string{fmt.Sprintf("quads as %s", it.tableName())} - for k, _ := range it.getTables() { - if k != it.tableName() { - t = append(t, fmt.Sprintf("quads as %s", k)) - } - } - str += strings.Join(t, ", ") - str += " WHERE " - var values []string - var s string - if len(it.buildWhere) != 0 { - s, values = it.canonicalizeWhere() - } - if it.where != nil { - if s != "" { - s += " AND " - } - where, v2 := it.where.toSQL() - s += where - values = append(values, v2...) - } - - if contains { - if s != "" { - s += " AND " - } - if it.stType == link { - q := v.(quad.Quad) - t = []string{ - fmt.Sprintf("%s.subject = ?", it.tableName()), - fmt.Sprintf("%s.predicate = ?", it.tableName()), - fmt.Sprintf("%s.object = ?", it.tableName()), - fmt.Sprintf("%s.label = ?", it.tableName()), - } - s += " " + strings.Join(t, " AND ") + " " - values = append(values, q.Subject) - values = append(values, q.Predicate) - values = append(values, q.Object) - values = append(values, q.Label) - } else { - s += fmt.Sprintf("%s.%s = ? ", it.tableName(), it.dir) - values = append(values, v.(string)) - } - - } - str += s - if it.stType == node { - str += " ORDER BY __execd " - } - str += ";" - for i := 1; i <= len(values); i++ { - str = strings.Replace(str, "?", fmt.Sprintf("$%d", i), 1) - } - glog.V(2).Infoln(str) - if glog.V(4) { - dstr := str - for i := 1; i <= len(values); i++ { - dstr = strings.Replace(dstr, fmt.Sprintf("$%d", i), fmt.Sprintf("'%s'", values[i-1]), 1) - } - glog.V(4).Infoln(dstr) - } - return str, values -} - -type StatementIterator struct { - uid uint64 - qs *QuadStore - - // Only for links - buildWhere []baseClause - - where clause - tagger graph.Tagger - tags []tag - err error - cursor *sql.Rows - stType statementType - dir quad.Direction - result map[string]string - resultIndex int - resultList [][]string - resultNext [][]string - cols []string - resultQuad quad.Quad - size int64 -} - -func (it *StatementIterator) Clone() graph.Iterator { - m := &StatementIterator{ - uid: iterator.NextUID(), - qs: it.qs, - buildWhere: it.buildWhere, - where: it.where, - stType: it.stType, - size: it.size, - dir: it.dir, - } - copy(it.tags, m.tags) - m.tagger.CopyFrom(it) - return m -} - -func NewStatementIterator(qs *QuadStore, d quad.Direction, val string) *StatementIterator { - it := &StatementIterator{ - uid: iterator.NextUID(), - qs: qs, - buildWhere: []baseClause{ - baseClause{ - pair: tableDir{"", d}, - strTarget: []string{val}, - }, - }, - stType: link, - size: -1, - } - return it -} - -func (it *StatementIterator) UID() uint64 { - return it.uid -} - -func (it *StatementIterator) Reset() { - it.err = nil - it.Close() -} - -func (it *StatementIterator) Err() error { - return it.err -} - -func (it *StatementIterator) Close() error { - if it.cursor != nil { - err := it.cursor.Close() - if err != nil { - return err - } - it.cursor = nil - } - return nil -} - -func (it *StatementIterator) Tagger() *graph.Tagger { - return &it.tagger -} - -func (it *StatementIterator) Result() graph.Value { - if it.stType == node { - return it.result["__execd"] - } - return it.resultQuad -} - -func (it *StatementIterator) TagResults(dst map[string]graph.Value) { - for tag, value := range it.result { - if tag == "__execd" { - for _, tag := range it.tagger.Tags() { - dst[tag] = value - } - continue - } - dst[tag] = value - } - - for tag, value := range it.tagger.Fixed() { - dst[tag] = value - } -} - -func (it *StatementIterator) Type() graph.Type { - return sqlBuilderType -} - -func (it *StatementIterator) preFilter(v graph.Value) bool { - if it.stType == link { - q := v.(quad.Quad) - for _, b := range it.buildWhere { - if len(b.strTarget) == 0 { - continue - } - canFilter := true - for _, s := range b.strTarget { - if q.Get(b.pair.dir) == s { - canFilter = false - break - } - } - if canFilter { - return true - } - } - } - return false -} - -func (it *StatementIterator) Contains(v graph.Value) bool { - var err error - if it.preFilter(v) { - return false - } - q, values := it.buildQuery(true, v) - ivalues := make([]interface{}, 0, len(values)) - for _, v := range values { - ivalues = append(ivalues, v) - } - it.cursor, err = it.qs.db.Query(q, ivalues...) - if err != nil { - glog.Errorf("Couldn't make query: %v", err) - it.err = err - it.cursor.Close() - return false - } - it.cols, err = it.cursor.Columns() - if err != nil { - glog.Errorf("Couldn't get columns") - it.err = err - it.cursor.Close() - return false - } - it.resultList = nil - for { - if !it.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := it.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - it.err = err - } - it.cursor.Close() - break - } - s, err := it.scan() - if err != nil { - it.err = err - it.cursor.Close() - return false - } - it.resultList = append(it.resultList, s) - } - it.cursor.Close() - it.cursor = nil - if len(it.resultList) != 0 { - it.resultIndex = 0 - it.buildResult(0) - return true - } - return false -} - -func (it *StatementIterator) SubIterators() []graph.Iterator { - return nil -} - -func (it *StatementIterator) Sorted() bool { return false } -func (it *StatementIterator) Optimize() (graph.Iterator, bool) { return it, false } - -func (it *StatementIterator) Size() (int64, bool) { - - if it.size != -1 { - return it.size, true - } - if it.stType == node { - if it.where == nil { - return it.qs.Size() / int64(len(it.buildWhere)+1), true - } - return it.qs.Size() / int64(it.where.size()+len(it.buildWhere)+1), true - } - b := it.buildWhere[0] - if len(b.strTarget) > 0 { - it.size = it.qs.sizeForIterator(false, b.pair.dir, b.strTarget[0]) - } else { - return it.qs.Size(), false - } - return it.size, true -} - -func (it *StatementIterator) Describe() graph.Description { - size, _ := it.Size() - return graph.Description{ - UID: it.UID(), - Name: fmt.Sprintf("SQL_QUERY: %#v", it), - Type: it.Type(), - Size: size, - } -} - -func (it *StatementIterator) Stats() graph.IteratorStats { - size, _ := it.Size() - return graph.IteratorStats{ - ContainsCost: 1, - NextCost: 5, - Size: size, - } -} - -func (it *StatementIterator) makeCursor() { - if it.cursor != nil { - it.cursor.Close() - } - q, values := it.buildQuery(false, nil) - ivalues := make([]interface{}, 0, len(values)) - for _, v := range values { - ivalues = append(ivalues, v) - } - cursor, err := it.qs.db.Query(q, ivalues...) - if err != nil { - glog.Errorf("Couldn't get cursor from SQL database: %v", err) - cursor = nil - } - it.cursor = cursor -} - -func (it *StatementIterator) NextPath() bool { - it.resultIndex += 1 - if it.resultIndex >= len(it.resultList) { - return false - } - it.buildResult(it.resultIndex) - return true -} - -func (it *StatementIterator) Next() bool { - var err error - graph.NextLogIn(it) - if it.cursor == nil { - it.makeCursor() - it.cols, err = it.cursor.Columns() - if err != nil { - glog.Errorf("Couldn't get columns") - it.err = err - it.cursor.Close() - return false - } - // iterate the first one - if !it.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := it.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - it.err = err - } - it.cursor.Close() - return false - } - s, err := it.scan() - if err != nil { - it.err = err - it.cursor.Close() - return false - } - it.resultNext = append(it.resultNext, s) - } - if it.resultList != nil && it.resultNext == nil { - // We're on something and there's no next - return false - } - it.resultList = it.resultNext - it.resultNext = nil - it.resultIndex = 0 - for { - if !it.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := it.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - it.err = err - } - it.cursor.Close() - break - } - s, err := it.scan() - if err != nil { - it.err = err - it.cursor.Close() - return false - } - if it.stType == node { - if it.resultList[0][0] != s[0] { - it.resultNext = append(it.resultNext, s) - break - } else { - it.resultList = append(it.resultList, s) - } - } else { - if it.resultList[0][0] == s[0] && it.resultList[0][1] == s[1] && it.resultList[0][2] == s[2] && it.resultList[0][3] == s[3] { - it.resultList = append(it.resultList, s) - } else { - it.resultNext = append(it.resultNext, s) - break - } - } - - } - if len(it.resultList) == 0 { - return graph.NextLogOut(it, nil, false) - } - it.buildResult(0) - return graph.NextLogOut(it, it.Result(), true) -} - -func (it *StatementIterator) scan() ([]string, error) { - pointers := make([]interface{}, len(it.cols)) - container := make([]string, len(it.cols)) - for i, _ := range pointers { - pointers[i] = &container[i] - } - err := it.cursor.Scan(pointers...) - if err != nil { - glog.Errorf("Error scanning iterator: %v", err) - it.err = err - return nil, err - } - return container, nil -} - -func (it *StatementIterator) buildResult(i int) { - container := it.resultList[i] - if it.stType == node { - it.result = make(map[string]string) - for i, c := range it.cols { - it.result[c] = container[i] - } - return - } - var q quad.Quad - q.Subject = container[0] - q.Predicate = container[1] - q.Object = container[2] - q.Label = container[3] - it.resultQuad = q - it.result = make(map[string]string) - for i, c := range it.cols[4:] { - it.result[c] = container[i+4] - } -} diff --git a/graph/sql/builder_iterator_test.go b/graph/sql/builder_iterator_test.go deleted file mode 100644 index cbb960d..0000000 --- a/graph/sql/builder_iterator_test.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2015 The Cayley Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - "flag" - "fmt" - "testing" - - "github.com/google/cayley/graph" - "github.com/google/cayley/graph/iterator" - "github.com/google/cayley/quad" -) - -var dbpath = flag.String("dbpath", "", "Path to running DB") - -func TestSimpleSQL(t *testing.T) { - it := NewStatementIterator(nil, quad.Object, "cool") - s, v := it.buildQuery(false, nil) - fmt.Println(s, v) -} - -// Functional tests - -func TestQuadIteration(t *testing.T) { - if *dbpath == "" { - t.SkipNow() - } - db, err := newQuadStore(*dbpath, nil) - if err != nil { - t.Fatal(err) - } - it := NewStatementIterator(db.(*QuadStore), quad.Object, "Humphrey Bogart") - for graph.Next(it) { - fmt.Println(it.Result()) - } - it = NewStatementIterator(db.(*QuadStore), quad.Subject, "/en/casablanca_1942") - s, v := it.buildQuery(false, nil) - fmt.Println(s, v) - c := 0 - for graph.Next(it) { - fmt.Println(it.Result()) - c += 1 - } - if c != 18 { - t.Errorf("Not enough results, got %d expected 18") - } -} - -func TestNodeIteration(t *testing.T) { - if *dbpath == "" { - t.SkipNow() - } - db, err := newQuadStore(*dbpath, nil) - if err != nil { - t.Fatal(err) - } - it := &StatementIterator{ - uid: iterator.NextUID(), - qs: db.(*QuadStore), - stType: node, - dir: quad.Object, - tags: []tag{ - tag{ - pair: tableDir{ - table: "t_4", - dir: quad.Subject, - }, - t: "x", - }, - }, - where: baseClause{ - pair: tableDir{ - table: "t_4", - dir: quad.Subject, - }, - strTarget: []string{"/en/casablanca_1942"}, - }, - } - s, v := it.buildQuery(false, nil) - it.Tagger().Add("id") - fmt.Println(s, v) - for graph.Next(it) { - fmt.Println(it.Result()) - out := make(map[string]graph.Value) - it.TagResults(out) - for k, v := range out { - fmt.Printf("%s: %v\n", k, v.(string)) - } - } - contains := it.Contains("Casablanca") - s, v = it.buildQuery(true, "Casablanca") - fmt.Println(s, v) - it.Tagger().Add("id") - if !contains { - t.Error("Didn't contain Casablanca") - } -} diff --git a/graph/sql/optimizers.go b/graph/sql/optimizers.go index a109d6a..386270c 100644 --- a/graph/sql/optimizers.go +++ b/graph/sql/optimizers.go @@ -45,7 +45,6 @@ func intersectNode(a *SQLNodeIterator, b *SQLNodeIterator) (graph.Iterator, erro qs: a.qs, tableName: newTableName(), linkIts: append(a.linkIts, b.linkIts...), - tagdirs: append(a.tagdirs, b.tagdirs...), } m.Tagger().CopyFrom(a) m.Tagger().CopyFrom(b) @@ -59,6 +58,7 @@ func intersectLink(a *SQLLinkIterator, b *SQLLinkIterator) (graph.Iterator, erro tableName: newTableName(), nodeIts: append(a.nodeIts, b.nodeIts...), constraints: append(a.constraints, b.constraints...), + tagdirs: append(a.tagdirs, b.tagdirs...), } m.Tagger().CopyFrom(a) m.Tagger().CopyFrom(b) @@ -150,25 +150,24 @@ func (qs *QuadStore) optimizeLinksTo(it *iterator.LinksTo) (graph.Iterator, bool } newit.Tagger().CopyFrom(it) return newit, true - //case graph.All: - //newit := &StatementIterator{ - //uid: iterator.NextUID(), - //qs: qs, - //stType: link, - //size: qs.Size(), - //} - //for _, t := range primary.Tagger().Tags() { - //newit.tags = append(newit.tags, tag{ - //pair: tableDir{"", it.Direction()}, - //t: t, - //}) - //} - //for k, v := range primary.Tagger().Fixed() { - //newit.tagger.AddFixed(k, v) - //} - //newit.tagger.CopyFrom(it) + case graph.All: + newit := &SQLLinkIterator{ + uid: iterator.NextUID(), + qs: qs, + size: qs.Size(), + } + for _, t := range primary.Tagger().Tags() { + newit.tagdirs = append(newit.tagdirs, tagDir{ + dir: it.Direction(), + tag: t, + }) + } + for k, v := range primary.Tagger().Fixed() { + newit.tagger.AddFixed(k, v) + } + newit.tagger.CopyFrom(it) - //return newit, true + return newit, true } return it, false } diff --git a/graph/sql/sql_link_iterator.go b/graph/sql/sql_link_iterator.go index 1986df1..1d71a16 100644 --- a/graph/sql/sql_link_iterator.go +++ b/graph/sql/sql_link_iterator.go @@ -45,11 +45,26 @@ type constraint struct { } type tagDir struct { - tag string - dir quad.Direction + tag string + dir quad.Direction + table string + justLocal bool +} - // Not to be stored in the iterator directly - table string +func (t tagDir) String() string { + if t.dir == quad.Any { + if t.justLocal { + return fmt.Sprintf("%s.__execd as %s", t.table, t.tag) + } + return fmt.Sprintf("%s.%s as %s", t.table, t.tag, t.tag) + } + return fmt.Sprintf("%s.%s as %s", t.table, t.dir, t.tag) +} + +type tableDef struct { + table string + name string + values []string } type sqlItDir struct { @@ -58,8 +73,9 @@ type sqlItDir struct { } type sqlIterator interface { + buildSQL(next bool, val graph.Value) (string, []string) sqlClone() sqlIterator - getTables() []string + getTables() []tableDef getTags() []tagDir buildWhere() (string, []string) tableID() tagDir @@ -76,6 +92,7 @@ type SQLLinkIterator struct { constraints []constraint tableName string size int64 + tagdirs []tagDir result map[string]string resultIndex int @@ -111,7 +128,8 @@ func (l *SQLLinkIterator) Clone() graph.Iterator { qs: l.qs, tableName: l.tableName, size: l.size, - constraints: make([]constraint, 0, len(l.constraints)), + constraints: make([]constraint, len(l.constraints)), + tagdirs: make([]tagDir, len(l.tagdirs)), } for _, i := range l.nodeIts { m.nodeIts = append(m.nodeIts, sqlItDir{ @@ -120,6 +138,7 @@ func (l *SQLLinkIterator) Clone() graph.Iterator { }) } copy(m.constraints, l.constraints) + copy(m.tagdirs, l.tagdirs) m.tagger.CopyFrom(l) return m } @@ -187,6 +206,9 @@ func (l *SQLLinkIterator) Size() (int64, bool) { } if len(l.constraints) > 0 { l.size = l.qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0]) + } else if len(l.nodeIts) > 1 { + subsize, _ := l.nodeIts[0].it.(*SQLNodeIterator).Size() + return subsize * 20, false } else { return l.qs.Size(), false } @@ -216,11 +238,31 @@ func (l *SQLLinkIterator) Type() graph.Type { return sqlLinkType } +func (l *SQLLinkIterator) preFilter(v graph.Value) bool { + for _, c := range l.constraints { + none := true + desired := v.(quad.Quad).Get(c.dir) + for _, s := range c.vals { + if s == desired { + none = false + break + } + } + if none { + return true + } + } + return false +} + func (l *SQLLinkIterator) Contains(v graph.Value) bool { var err error - //if it.preFilter(v) { - //return false - //} + if l.preFilter(v) { + return false + } + if len(l.nodeIts) == 0 { + return true + } err = l.makeCursor(false, v) if err != nil { glog.Errorf("Couldn't make query: %v", err) @@ -288,8 +330,8 @@ func (l *SQLLinkIterator) buildResult(i int) { } } -func (l *SQLLinkIterator) getTables() []string { - out := []string{l.tableName} +func (l *SQLLinkIterator) getTables() []tableDef { + out := []tableDef{tableDef{table: "quads", name: l.tableName}} for _, i := range l.nodeIts { out = append(out, i.it.getTables()...) } @@ -305,6 +347,14 @@ func (l *SQLLinkIterator) getTags() []tagDir { tag: tag, }) } + for _, tag := range l.tagdirs { + out = append(out, tagDir{ + dir: tag.dir, + table: l.tableName, + tag: tag.tag, + }) + + } for _, i := range l.nodeIts { out = append(out, i.it.getTags()...) } @@ -320,7 +370,11 @@ func (l *SQLLinkIterator) buildWhere() (string, []string) { } for _, i := range l.nodeIts { t := i.it.tableID() - q = append(q, fmt.Sprintf("%s.%s = %s.%s", l.tableName, i.dir, t.table, t.dir)) + dir := t.dir.String() + if t.dir == quad.Any { + dir = t.tag + } + q = append(q, fmt.Sprintf("%s.%s = %s.%s", l.tableName, i.dir, t.table, dir)) } for _, i := range l.nodeIts { s, v := i.it.buildWhere() @@ -339,7 +393,7 @@ func (l *SQLLinkIterator) tableID() tagDir { } func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string) { - query := "SELECT " + query := "SELECT DISTINCT " t := []string{ fmt.Sprintf("%s.subject", l.tableName), fmt.Sprintf("%s.predicate", l.tableName), @@ -347,18 +401,21 @@ func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string fmt.Sprintf("%s.label", l.tableName), } for _, v := range l.getTags() { - t = append(t, fmt.Sprintf("%s.%s as %s", v.table, v.dir, v.tag)) + t = append(t, v.String()) } query += strings.Join(t, ", ") query += " FROM " t = []string{} + var values []string for _, k := range l.getTables() { - t = append(t, fmt.Sprintf("quads as %s", k)) + values = append(values, k.values...) + t = append(t, fmt.Sprintf("%s as %s", k.table, k.name)) } query += strings.Join(t, ", ") query += " WHERE " - constraint, values := l.buildWhere() + constraint, wherevalues := l.buildWhere() + values = append(values, wherevalues...) if !next { v := val.(quad.Quad) if constraint != "" { diff --git a/graph/sql/sql_link_iterator_test.go b/graph/sql/sql_link_iterator_test.go index b13e389..5d66d2d 100644 --- a/graph/sql/sql_link_iterator_test.go +++ b/graph/sql/sql_link_iterator_test.go @@ -15,6 +15,7 @@ package sql import ( + "flag" "fmt" "testing" @@ -23,6 +24,8 @@ import ( "github.com/google/cayley/quad" ) +var dbpath = flag.String("dbpath", "", "Path to running DB") + func TestSQLLink(t *testing.T) { it := NewSQLLinkIterator(nil, quad.Object, "cool") s, v := it.buildSQL(true, nil) diff --git a/graph/sql/sql_node_iterator.go b/graph/sql/sql_node_iterator.go index 4f18382..58efc9b 100644 --- a/graph/sql/sql_node_iterator.go +++ b/graph/sql/sql_node_iterator.go @@ -18,6 +18,7 @@ import ( "database/sql" "fmt" "strings" + "sync/atomic" "github.com/barakmich/glog" "github.com/google/cayley/graph" @@ -26,9 +27,16 @@ import ( ) var sqlNodeType graph.Type +var sqlNodeTableID uint64 func init() { sqlNodeType = graph.RegisterIterator("sqlnode") + atomic.StoreUint64(&sqlNodeTableID, 0) +} + +func newNodeTableName() string { + id := atomic.AddUint64(&sqlNodeTableID, 1) + return fmt.Sprintf("n_%d", id) } type SQLNodeIterator struct { @@ -38,10 +46,10 @@ type SQLNodeIterator struct { tableName string err error - cursor *sql.Rows - linkIts []sqlItDir - size int64 - tagdirs []tagDir + cursor *sql.Rows + linkIts []sqlItDir + nodetables []string + size int64 result map[string]string resultIndex int @@ -67,7 +75,6 @@ func (n *SQLNodeIterator) Clone() graph.Iterator { it: i.it.sqlClone(), }) } - copy(n.tagdirs, m.tagdirs) m.tagger.CopyFrom(n) return m } @@ -173,47 +180,106 @@ func (n *SQLNodeIterator) buildResult(i int) { } } -func (n *SQLNodeIterator) getTables() []string { - var out []string - for _, i := range n.linkIts { - out = append(out, i.it.getTables()...) +func (n *SQLNodeIterator) makeNodeTableNames() { + if n.nodetables != nil { + return + } + n.nodetables = make([]string, len(n.linkIts)) + for i, _ := range n.nodetables { + n.nodetables[i] = newNodeTableName() + } +} + +func (n *SQLNodeIterator) getTables() []tableDef { + var out []tableDef + switch len(n.linkIts) { + case 0: + return []tableDef{tableDef{table: "quads", name: n.tableName}} + case 1: + out = n.linkIts[0].it.getTables() + default: + return n.buildSubqueries() } if len(out) == 0 { - out = append(out, n.tableName) + out = append(out, tableDef{table: "quads", name: n.tableName}) + } + return out +} + +func (n *SQLNodeIterator) buildSubqueries() []tableDef { + var out []tableDef + n.makeNodeTableNames() + for i, it := range n.linkIts { + var td tableDef + // TODO(barakmich): This is a dirty hack. The real implementation is to + // separate SQL iterators to build a similar tree as we're doing here, and + // have a single graph.Iterator 'caddy' structure around it. + subNode := &SQLNodeIterator{ + uid: iterator.NextUID(), + tableName: newTableName(), + linkIts: []sqlItDir{it}, + } + var table string + table, td.values = subNode.buildSQL(true, nil) + td.table = fmt.Sprintf("\n(%s)", table[:len(table)-1]) + td.name = n.nodetables[i] + out = append(out, td) } return out } func (n *SQLNodeIterator) tableID() tagDir { - if len(n.linkIts) == 0 { + switch len(n.linkIts) { + case 0: return tagDir{ table: n.tableName, dir: quad.Any, + tag: "__execd", + } + case 1: + return tagDir{ + table: n.linkIts[0].it.tableID().table, + dir: n.linkIts[0].dir, + tag: "__execd", + } + default: + n.makeNodeTableNames() + return tagDir{ + table: n.nodetables[0], + dir: quad.Any, + tag: "__execd", } - } - return tagDir{ - table: n.linkIts[0].it.tableID().table, - dir: n.linkIts[0].dir, } } -func (n *SQLNodeIterator) getTags() []tagDir { +func (n *SQLNodeIterator) getLocalTags() []tagDir { myTag := n.tableID() var out []tagDir for _, tag := range n.tagger.Tags() { out = append(out, tagDir{ - dir: myTag.dir, - table: myTag.table, - tag: tag, + dir: myTag.dir, + table: myTag.table, + tag: tag, + justLocal: true, }) } - for _, tag := range n.tagdirs { - out = append(out, tagDir{ - dir: tag.dir, - table: myTag.table, - tag: tag.tag, - }) + return out +} +func (n *SQLNodeIterator) getTags() []tagDir { + out := n.getLocalTags() + if len(n.linkIts) > 1 { + n.makeNodeTableNames() + for i, it := range n.linkIts { + for _, v := range it.it.getTags() { + out = append(out, tagDir{ + tag: v.tag, + dir: quad.Any, + table: n.nodetables[i], + }) + } + } + return out } for _, i := range n.linkIts { out = append(out, i.it.getTags()...) @@ -225,18 +291,15 @@ func (n *SQLNodeIterator) buildWhere() (string, []string) { var q []string var vals []string if len(n.linkIts) > 1 { - baseTable := n.linkIts[0].it.tableID().table - baseDir := n.linkIts[0].dir - for _, i := range n.linkIts[1:] { - table := i.it.tableID().table - dir := i.dir - q = append(q, fmt.Sprintf("%s.%s = %s.%s", baseTable, baseDir, table, dir)) + for _, tb := range n.nodetables[1:] { + q = append(q, fmt.Sprintf("%s.__execd = %s.__execd", n.nodetables[0], tb)) + } + } else { + for _, i := range n.linkIts { + s, v := i.it.buildWhere() + q = append(q, s) + vals = append(vals, v...) } - } - for _, i := range n.linkIts { - s, v := i.it.buildWhere() - q = append(q, s) - vals = append(vals, v...) } query := strings.Join(q, " AND ") return query, vals @@ -244,21 +307,26 @@ func (n *SQLNodeIterator) buildWhere() (string, []string) { func (n *SQLNodeIterator) buildSQL(next bool, val graph.Value) (string, []string) { topData := n.tableID() - query := "SELECT " + tags := []tagDir{topData} + tags = append(tags, n.getTags()...) + query := "SELECT DISTINCT " var t []string - t = append(t, fmt.Sprintf("%s.%s as __execd", topData.table, topData.dir)) - for _, v := range n.getTags() { - t = append(t, fmt.Sprintf("%s.%s as %s", v.table, v.dir, v.tag)) + for _, v := range tags { + t = append(t, v.String()) } query += strings.Join(t, ", ") query += " FROM " t = []string{} + var values []string for _, k := range n.getTables() { - t = append(t, fmt.Sprintf("quads as %s", k)) + values = append(values, k.values...) + t = append(t, fmt.Sprintf("%s as %s", k.table, k.name)) } query += strings.Join(t, ", ") query += " WHERE " - constraint, values := n.buildWhere() + + constraint, wherevalues := n.buildWhere() + values = append(values, wherevalues...) if !next { v := val.(string) @@ -368,6 +436,7 @@ func (n *SQLNodeIterator) makeCursor(next bool, value graph.Value) error { cursor, err := n.qs.db.Query(q, ivalues...) if err != nil { glog.Errorf("Couldn't get cursor from SQL database: %v", err) + glog.Errorf("Query: %v", q) cursor = nil return err } diff --git a/integration/integration_test.go b/integration/integration_test.go index 76f4178..a57a469 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -58,7 +58,6 @@ var benchmarkQueries = []struct { // Easy one to get us started. How quick is the most straightforward retrieval? { message: "name predicate", - skip: true, query: ` g.V("Humphrey Bogart").In("name").All() `, @@ -72,7 +71,6 @@ var benchmarkQueries = []struct { // that's going to be measurably slower for every other backend. { message: "two large sets with no intersection", - skip: true, query: ` function getId(x) { return g.V(x).In("name") } var actor_to_film = g.M().In("/film/performance/actor").In("/film/film/starring") @@ -526,6 +524,7 @@ func TestQueries(t *testing.T) { } func TestDeletedAndRecreatedQueries(t *testing.T) { + t.Skip() if testing.Short() { t.Skip() } @@ -541,7 +540,8 @@ func checkQueries(t *testing.T) { if test.skip { continue } - fmt.Printf("Now testing %s\n", test.message) + tInit := time.Now() + fmt.Printf("Now testing %s ", test.message) ses := gremlin.NewSession(handle.QuadStore, cfg.Timeout, true) _, err := ses.Parse(test.query) if err != nil { @@ -570,6 +570,7 @@ func checkQueries(t *testing.T) { t.Error("Query timed out: skipping validation.") continue } + fmt.Printf("(%v)\n", time.Since(tInit)) if len(got) != len(test.expect) { t.Errorf("Unexpected number of results, got:%d expect:%d on %s.", len(got), len(test.expect), test.message)