From 185e236f158794e2bcfe77ef852edbe8cdd11402 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Wed, 15 Jul 2015 18:29:55 -0400 Subject: [PATCH] attempt to build more interesting SQL queries Subcommits: fix old iterator, and flesh out new builder iterator fix contains for builder iterator Working replacement iterator --- graph/sql/builder_iterator.go | 582 +++++++++++++++++++++++++++++++++++++ graph/sql/builder_iterator_test.go | 110 +++++++ graph/sql/iterator.go | 7 +- graph/sql/quadstore.go | 2 +- integration/integration_test.go | 15 +- 5 files changed, 712 insertions(+), 4 deletions(-) create mode 100644 graph/sql/builder_iterator.go create mode 100644 graph/sql/builder_iterator_test.go diff --git a/graph/sql/builder_iterator.go b/graph/sql/builder_iterator.go new file mode 100644 index 0000000..f470845 --- /dev/null +++ b/graph/sql/builder_iterator.go @@ -0,0 +1,582 @@ +// Copyright 2015 The Cayley Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/barakmich/glog" + "github.com/google/cayley/graph" + "github.com/google/cayley/graph/iterator" + "github.com/google/cayley/quad" +) + +var sqlBuilderType graph.Type + +func init() { + sqlBuilderType = graph.RegisterIterator("sqlbuilder") +} + +type tableDir struct { + table string + dir quad.Direction +} + +func (td tableDir) String() string { + if td.table != "" { + return fmt.Sprintf("%s.%s", td.table, td.dir) + } + return "ERR" +} + +type clause interface { + toSQL() (string, []string) + getTables() map[string]bool +} + +type baseClause struct { + pair tableDir + strTarget []string + target tableDir +} + +func (b baseClause) toSQL() (string, []string) { + if len(b.strTarget) > 1 { + // TODO(barakmich): Sets of things, IN clause + return "", []string{} + } + if len(b.strTarget) == 0 { + return fmt.Sprintf("%s = %s", b.pair, b.target), []string{} + } + return fmt.Sprintf("%s = ?", b.pair), []string{b.strTarget[0]} +} + +func (b baseClause) getTables() map[string]bool { + out := make(map[string]bool) + if b.pair.table != "" { + out[b.pair.table] = true + } + if b.target.table != "" { + out[b.target.table] = true + } + return out +} + +type joinClause struct { + left clause + right clause + op clauseOp +} + +func (jc joinClause) toSQL() (string, []string) { + l, lstr := jc.left.toSQL() + r, rstr := jc.right.toSQL() + lstr = append(lstr, rstr...) + var op string + switch jc.op { + case andClause: + op = "AND" + case orClause: + op = "OR" + } + return fmt.Sprint("(%s %s %s)", l, op, r), lstr +} + +func (jc joinClause) getTables() map[string]bool { + m := jc.left.getTables() + for k, _ := range jc.right.getTables() { + m[k] = true + } + return m +} + +type tag struct { + pair tableDir + t string +} + +type statementType int + +const ( + node statementType = iota + link +) + +type clauseOp int + +const ( + andClause clauseOp = iota + orClause +) + +func (it *StatementIterator) canonicalizeWhere() (string, []string) { + var out []string + var values []string + for _, b := range it.buildWhere { + b.pair.table = it.tableName() + s, v := b.toSQL() + values = append(values, v...) + out = append(out, s) + } + return strings.Join(out, " AND "), values +} + +func (it *StatementIterator) getTables() map[string]bool { + m := make(map[string]bool) + if it.where != nil { + m = it.where.getTables() + } + for _, t := range it.tags { + if t.pair.table != "" { + m[t.pair.table] = true + } + } + return m +} + +func (it *StatementIterator) tableName() string { + return fmt.Sprintf("t_%d", it.uid) +} + +func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, []string) { + str := "SELECT " + var t []string + if it.stType == link { + t = []string{ + fmt.Sprintf("%s.subject", it.tableName()), + fmt.Sprintf("%s.predicate", it.tableName()), + fmt.Sprintf("%s.object", it.tableName()), + fmt.Sprintf("%s.label", it.tableName()), + } + } else { + t = []string{fmt.Sprintf("%s.%s as __execd", it.tableName(), it.dir)} + } + for _, v := range it.tags { + t = append(t, fmt.Sprintf("%s as %s", v.pair, v.t)) + } + str += strings.Join(t, ", ") + str += " FROM " + t = []string{fmt.Sprintf("quads as %s", it.tableName())} + for k, _ := range it.getTables() { + if k != it.tableName() { + t = append(t, fmt.Sprintf("quads as %s", k)) + } + } + str += strings.Join(t, ", ") + str += " WHERE " + var values []string + var s string + if it.stType != node { + s, values = it.canonicalizeWhere() + } + if it.where != nil { + if s != "" { + s += " AND " + } + where, v2 := it.where.toSQL() + s += where + values = append(values, v2...) + } + str += s + if contains { + if it.stType == link { + q := v.(quad.Quad) + str += " AND " + t = []string{ + fmt.Sprintf("%s.subject = ?", it.tableName()), + fmt.Sprintf("%s.predicate = ?", it.tableName()), + fmt.Sprintf("%s.object = ?", it.tableName()), + fmt.Sprintf("%s.label = ?", it.tableName()), + } + str += " " + strings.Join(t, " AND ") + " " + values = append(values, q.Subject) + values = append(values, q.Predicate) + values = append(values, q.Object) + values = append(values, q.Label) + } else { + str += fmt.Sprintf(" AND %s.%s = ? ", it.tableName(), it.dir) + values = append(values, v.(string)) + } + + } + if it.stType == node { + str += " ORDER BY __execd " + } + str += ";" + for i := 1; i <= len(values); i++ { + str = strings.Replace(str, "?", fmt.Sprintf("$%d", i), 1) + } + return str, values +} + +type StatementIterator struct { + uid uint64 + qs *QuadStore + + // Only for links + buildWhere []baseClause + + where clause + tagger graph.Tagger + tags []tag + err error + cursor *sql.Rows + stType statementType + dir quad.Direction + result map[string]string + resultIndex int + resultList [][]string + resultNext [][]string + cols []string + resultQuad quad.Quad + size int64 +} + +func (it *StatementIterator) Clone() graph.Iterator { + m := &StatementIterator{ + uid: iterator.NextUID(), + qs: it.qs, + buildWhere: it.buildWhere, + where: it.where, + stType: it.stType, + size: it.size, + } + copy(it.tags, m.tags) + m.tagger.CopyFrom(it) + return m +} + +func NewStatementIterator(qs *QuadStore, d quad.Direction, val string) *StatementIterator { + it := &StatementIterator{ + uid: iterator.NextUID(), + qs: qs, + buildWhere: []baseClause{ + baseClause{ + pair: tableDir{"", d}, + strTarget: []string{val}, + }, + }, + stType: link, + size: -1, + } + return it +} + +func (it *StatementIterator) UID() uint64 { + return it.uid +} + +func (it *StatementIterator) Reset() { + it.err = nil + it.Close() +} + +func (it *StatementIterator) Err() error { + return it.err +} + +func (it *StatementIterator) Close() error { + if it.cursor != nil { + err := it.cursor.Close() + if err != nil { + return err + } + it.cursor = nil + } + return nil +} + +func (it *StatementIterator) Tagger() *graph.Tagger { + return &it.tagger +} + +func (it *StatementIterator) Result() graph.Value { + if it.stType == node { + return it.result["__execd"] + } + return it.resultQuad +} + +func (it *StatementIterator) TagResults(dst map[string]graph.Value) { + for tag, value := range it.result { + if tag == "__execd" { + for _, tag := range it.tagger.Tags() { + dst[tag] = value + } + continue + } + dst[tag] = value + } + + for tag, value := range it.tagger.Fixed() { + dst[tag] = value + } +} + +func (it *StatementIterator) Type() graph.Type { + return sqlBuilderType +} + +func (it *StatementIterator) preFilter(v graph.Value) bool { + if it.stType == link { + q := v.(quad.Quad) + for _, b := range it.buildWhere { + if len(b.strTarget) == 0 { + continue + } + canFilter := true + for _, s := range b.strTarget { + if q.Get(b.pair.dir) == s { + canFilter = false + break + } + } + if canFilter { + return true + } + } + } + return false +} + +func (it *StatementIterator) Contains(v graph.Value) bool { + var err error + if it.preFilter(v) { + return false + } + q, values := it.buildQuery(true, v) + ivalues := make([]interface{}, 0, len(values)) + for _, v := range values { + ivalues = append(ivalues, v) + } + it.cursor, err = it.qs.db.Query(q, ivalues...) + it.cols, err = it.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + it.err = err + it.cursor.Close() + return false + } + it.resultList = nil + for { + if !it.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := it.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + it.err = err + } + it.cursor.Close() + break + } + s, err := it.scan() + if err != nil { + it.err = err + it.cursor.Close() + return false + } + it.resultList = append(it.resultList, s) + } + it.cursor.Close() + it.cursor = nil + if len(it.resultList) != 0 { + it.resultIndex = 0 + it.buildResult(0) + return true + } + return false +} + +func (it *StatementIterator) SubIterators() []graph.Iterator { + return nil +} + +func (it *StatementIterator) Sorted() bool { return false } +func (it *StatementIterator) Optimize() (graph.Iterator, bool) { return it, false } + +func (it *StatementIterator) Size() (int64, bool) { + + if it.size != -1 { + return it.size, true + } + if it.stType == node { + return it.qs.Size(), true + } + b := it.buildWhere[0] + it.size = it.qs.sizeForIterator(false, b.pair.dir, b.strTarget[0]) + return it.size, true +} + +func (it *StatementIterator) Describe() graph.Description { + size, _ := it.Size() + return graph.Description{ + UID: it.UID(), + Name: "SQL_QUERY", + Type: it.Type(), + Size: size, + } +} + +func (it *StatementIterator) Stats() graph.IteratorStats { + size, _ := it.Size() + return graph.IteratorStats{ + ContainsCost: 1, + NextCost: 5, + Size: size, + } +} + +func (it *StatementIterator) makeCursor() { + if it.cursor != nil { + it.cursor.Close() + } + q, values := it.buildQuery(false, nil) + ivalues := make([]interface{}, 0, len(values)) + for _, v := range values { + ivalues = append(ivalues, v) + } + cursor, err := it.qs.db.Query(q, ivalues...) + if err != nil { + glog.Errorln("Couldn't get cursor from SQL database: %v", err) + cursor = nil + } + it.cursor = cursor +} + +func (it *StatementIterator) NextPath() bool { + it.resultIndex += 1 + if it.resultIndex >= len(it.resultList) { + return false + } + it.buildResult(it.resultIndex) + return true +} + +func (it *StatementIterator) Next() bool { + var err error + graph.NextLogIn(it) + if it.cursor == nil { + it.makeCursor() + it.cols, err = it.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + it.err = err + it.cursor.Close() + return false + } + // iterate the first one + if !it.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := it.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + it.err = err + } + it.cursor.Close() + return false + } + s, err := it.scan() + if err != nil { + it.err = err + it.cursor.Close() + return false + } + it.resultNext = append(it.resultNext, s) + } + if it.resultList != nil && it.resultNext == nil { + // We're on something and there's no next + return false + } + it.resultList = it.resultNext + it.resultNext = nil + it.resultIndex = 0 + for { + if !it.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := it.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + it.err = err + } + it.cursor.Close() + break + } + s, err := it.scan() + if err != nil { + it.err = err + it.cursor.Close() + return false + } + if it.stType == node { + if it.resultList[0][0] != s[0] { + it.resultNext = append(it.resultNext, s) + break + } else { + it.resultList = append(it.resultList, s) + } + } else { + if it.resultList[0][0] == s[0] && it.resultList[0][1] == s[1] && it.resultList[0][2] == s[2] && it.resultList[0][3] == s[3] { + it.resultList = append(it.resultList, s) + } else { + it.resultNext = append(it.resultNext, s) + break + } + } + + } + if len(it.resultList) == 0 { + return graph.NextLogOut(it, nil, false) + } + it.buildResult(0) + return graph.NextLogOut(it, it.result, true) +} + +func (it *StatementIterator) scan() ([]string, error) { + pointers := make([]interface{}, len(it.cols)) + container := make([]string, len(it.cols)) + for i, _ := range pointers { + pointers[i] = &container[i] + } + err := it.cursor.Scan(pointers...) + if err != nil { + glog.Errorf("Error scanning iterator: %v", err) + it.err = err + return nil, err + } + return container, nil +} + +func (it *StatementIterator) buildResult(i int) { + container := it.resultList[i] + if it.stType == node { + it.result = make(map[string]string) + for i, c := range it.cols { + it.result[c] = container[i] + } + return + } + var q quad.Quad + q.Subject = container[0] + q.Predicate = container[1] + q.Object = container[2] + q.Label = container[3] + it.resultQuad = q + it.result = make(map[string]string) + for i, c := range it.cols[4:] { + it.result[c] = container[i+4] + } +} diff --git a/graph/sql/builder_iterator_test.go b/graph/sql/builder_iterator_test.go new file mode 100644 index 0000000..cbb960d --- /dev/null +++ b/graph/sql/builder_iterator_test.go @@ -0,0 +1,110 @@ +// Copyright 2015 The Cayley Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "flag" + "fmt" + "testing" + + "github.com/google/cayley/graph" + "github.com/google/cayley/graph/iterator" + "github.com/google/cayley/quad" +) + +var dbpath = flag.String("dbpath", "", "Path to running DB") + +func TestSimpleSQL(t *testing.T) { + it := NewStatementIterator(nil, quad.Object, "cool") + s, v := it.buildQuery(false, nil) + fmt.Println(s, v) +} + +// Functional tests + +func TestQuadIteration(t *testing.T) { + if *dbpath == "" { + t.SkipNow() + } + db, err := newQuadStore(*dbpath, nil) + if err != nil { + t.Fatal(err) + } + it := NewStatementIterator(db.(*QuadStore), quad.Object, "Humphrey Bogart") + for graph.Next(it) { + fmt.Println(it.Result()) + } + it = NewStatementIterator(db.(*QuadStore), quad.Subject, "/en/casablanca_1942") + s, v := it.buildQuery(false, nil) + fmt.Println(s, v) + c := 0 + for graph.Next(it) { + fmt.Println(it.Result()) + c += 1 + } + if c != 18 { + t.Errorf("Not enough results, got %d expected 18") + } +} + +func TestNodeIteration(t *testing.T) { + if *dbpath == "" { + t.SkipNow() + } + db, err := newQuadStore(*dbpath, nil) + if err != nil { + t.Fatal(err) + } + it := &StatementIterator{ + uid: iterator.NextUID(), + qs: db.(*QuadStore), + stType: node, + dir: quad.Object, + tags: []tag{ + tag{ + pair: tableDir{ + table: "t_4", + dir: quad.Subject, + }, + t: "x", + }, + }, + where: baseClause{ + pair: tableDir{ + table: "t_4", + dir: quad.Subject, + }, + strTarget: []string{"/en/casablanca_1942"}, + }, + } + s, v := it.buildQuery(false, nil) + it.Tagger().Add("id") + fmt.Println(s, v) + for graph.Next(it) { + fmt.Println(it.Result()) + out := make(map[string]graph.Value) + it.TagResults(out) + for k, v := range out { + fmt.Printf("%s: %v\n", k, v.(string)) + } + } + contains := it.Contains("Casablanca") + s, v = it.buildQuery(true, "Casablanca") + fmt.Println(s, v) + it.Tagger().Add("id") + if !contains { + t.Error("Didn't contain Casablanca") + } +} diff --git a/graph/sql/iterator.go b/graph/sql/iterator.go index a0750ef..1482eaa 100644 --- a/graph/sql/iterator.go +++ b/graph/sql/iterator.go @@ -199,10 +199,12 @@ func (it *Iterator) Next() bool { func (it *Iterator) Contains(v graph.Value) bool { graph.ContainsLogIn(it, v) if it.isAll { + it.result = v return graph.ContainsLogOut(it, v, true) } q := v.(quad.Quad) if q.Get(it.dir) == it.val.(string) { + it.result = v return graph.ContainsLogOut(it, v, true) } return graph.ContainsLogOut(it, v, false) @@ -217,6 +219,9 @@ func (it *Iterator) Size() (int64, bool) { } func (it *Iterator) Result() graph.Value { + if it.result == nil { + glog.Fatalln("result was nil", it) + } return it.result } @@ -239,7 +244,7 @@ func (it *Iterator) Type() graph.Type { return sqlType } -func (it *Iterator) Sorted() bool { return true } +func (it *Iterator) Sorted() bool { return false } func (it *Iterator) Optimize() (graph.Iterator, bool) { return it, false } func (it *Iterator) Describe() graph.Description { diff --git a/graph/sql/quadstore.go b/graph/sql/quadstore.go index c52e942..26d1ff8 100644 --- a/graph/sql/quadstore.go +++ b/graph/sql/quadstore.go @@ -200,7 +200,7 @@ func (qs *QuadStore) Quad(val graph.Value) quad.Quad { } func (qs *QuadStore) QuadIterator(d quad.Direction, val graph.Value) graph.Iterator { - return NewIterator(qs, d, val) + return NewStatementIterator(qs, d, val.(string)) } func (qs *QuadStore) NodesAllIterator() graph.Iterator { diff --git a/integration/integration_test.go b/integration/integration_test.go index 4b5b726..05ddac0 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -37,12 +37,14 @@ import ( _ "github.com/google/cayley/graph/leveldb" _ "github.com/google/cayley/graph/memstore" _ "github.com/google/cayley/graph/mongo" + _ "github.com/google/cayley/graph/sql" // Load writer registry _ "github.com/google/cayley/writer" ) var backend = flag.String("backend", "memstore", "Which backend to test. Loads test data to /tmp if not present.") +var backendPath = flag.String("backend_path", "", "Path to the chosen backend. Will have sane testing defaults if not specified") var benchmarkQueries = []struct { message string @@ -422,6 +424,7 @@ var ( ) func prepare(t testing.TB) { + var remote bool cfg.DatabaseType = *backend switch *backend { case "memstore": @@ -436,14 +439,21 @@ func prepare(t testing.TB) { cfg.DatabaseOptions = map[string]interface{}{ "database_name": "cayley_test", // provide a default test database } + remote = true + case "sql": + cfg.DatabasePath = "postgres://localhost/cayley_test" + remote = true default: t.Fatalf("Untestable backend store %s", *backend) } + if *backendPath != "" { + cfg.DatabasePath = *backendPath + } var err error create.Do(func() { needsLoad := true - if graph.IsPersistent(cfg.DatabaseType) { + if graph.IsPersistent(cfg.DatabaseType) && !remote { if _, err := os.Stat(cfg.DatabasePath); os.IsNotExist(err) { err = db.Init(cfg) if err != nil { @@ -459,7 +469,7 @@ func prepare(t testing.TB) { t.Fatalf("Failed to open %q: %v", cfg.DatabasePath, err) } - if needsLoad { + if needsLoad && !remote { err = internal.Load(handle.QuadWriter, cfg, "../data/30kmoviedata.nq.gz", "cquad") if err != nil { t.Fatalf("Failed to load %q: %v", cfg.DatabasePath, err) @@ -524,6 +534,7 @@ func checkQueries(t *testing.T) { if testing.Short() && test.long { continue } + fmt.Printf("Now testing %s\n", test.message) ses := gremlin.NewSession(handle.QuadStore, cfg.Timeout, true) _, err := ses.Parse(test.query) if err != nil {