From 3e02bb2b714cbe14c6c12040658a12ee388df3ba Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Wed, 29 Jul 2015 15:56:15 -0400 Subject: [PATCH] refactor to SQL builder iterators and standard iterator wrapper --- graph/iterator.go | 12 +- graph/sql/optimizers.go | 74 ++++----- graph/sql/sql_iterator.go | 341 +++++++++++++++++++++++++++++++++++++++++ graph/sql/sql_link_iterator.go | 333 +++++----------------------------------- graph/sql/sql_node_iterator.go | 286 ++++------------------------------ 5 files changed, 453 insertions(+), 593 deletions(-) create mode 100644 graph/sql/sql_iterator.go diff --git a/graph/iterator.go b/graph/iterator.go index 0f3b76e..c2d46d9 100644 --- a/graph/iterator.go +++ b/graph/iterator.go @@ -67,8 +67,10 @@ func (t *Tagger) Fixed() map[string]Value { } func (t *Tagger) CopyFrom(src Iterator) { - st := src.Tagger() + t.CopyFromTagger(src.Tagger()) +} +func (t *Tagger) CopyFromTagger(st *Tagger) { t.tags = append(t.tags, st.tags...) if t.fixedTags == nil { @@ -331,16 +333,16 @@ func DumpStats(it Iterator) StatsContainer { func ContainsLogIn(it Iterator, val Value) { if glog.V(4) { - glog.V(4).Infof("%s %d CHECK CONTAINS %d", strings.ToUpper(it.Type().String()), it.UID(), val) + glog.V(4).Infof("%s %d CHECK CONTAINS %v", strings.ToUpper(it.Type().String()), it.UID(), val) } } func ContainsLogOut(it Iterator, val Value, good bool) bool { if glog.V(4) { if good { - glog.V(4).Infof("%s %d CHECK CONTAINS %d GOOD", strings.ToUpper(it.Type().String()), it.UID(), val) + glog.V(4).Infof("%s %d CHECK CONTAINS %v GOOD", strings.ToUpper(it.Type().String()), it.UID(), val) } else { - glog.V(4).Infof("%s %d CHECK CONTAINS %d BAD", strings.ToUpper(it.Type().String()), it.UID(), val) + glog.V(4).Infof("%s %d CHECK CONTAINS %v BAD", strings.ToUpper(it.Type().String()), it.UID(), val) } } return good @@ -355,7 +357,7 @@ func NextLogIn(it Iterator) { func NextLogOut(it Iterator, val Value, ok bool) bool { if glog.V(4) { if ok { - glog.V(4).Infof("%s %d NEXT IS %d", strings.ToUpper(it.Type().String()), it.UID(), val) + glog.V(4).Infof("%s %d NEXT IS %v", strings.ToUpper(it.Type().String()), it.UID(), val) } else { glog.V(4).Infof("%s %d NEXT DONE", strings.ToUpper(it.Type().String()), it.UID()) } diff --git a/graph/sql/optimizers.go b/graph/sql/optimizers.go index 386270c..debf7a8 100644 --- a/graph/sql/optimizers.go +++ b/graph/sql/optimizers.go @@ -23,14 +23,14 @@ import ( "github.com/google/cayley/quad" ) -func intersect(a graph.Iterator, b graph.Iterator) (graph.Iterator, error) { +func intersect(a sqlIterator, b sqlIterator, qs *QuadStore) (*SQLIterator, error) { if anew, ok := a.(*SQLNodeIterator); ok { if bnew, ok := b.(*SQLNodeIterator); ok { - return intersectNode(anew, bnew) + return intersectNode(anew, bnew, qs) } } else if anew, ok := a.(*SQLLinkIterator); ok { if bnew, ok := b.(*SQLLinkIterator); ok { - return intersectLink(anew, bnew) + return intersectLink(anew, bnew, qs) } } else { @@ -39,41 +39,37 @@ func intersect(a graph.Iterator, b graph.Iterator) (graph.Iterator, error) { return nil, errors.New("Cannot combine SQL iterators of two different types") } -func intersectNode(a *SQLNodeIterator, b *SQLNodeIterator) (graph.Iterator, error) { +func intersectNode(a *SQLNodeIterator, b *SQLNodeIterator, qs *QuadStore) (*SQLIterator, error) { m := &SQLNodeIterator{ - uid: iterator.NextUID(), - qs: a.qs, tableName: newTableName(), linkIts: append(a.linkIts, b.linkIts...), } - m.Tagger().CopyFrom(a) - m.Tagger().CopyFrom(b) - return m, nil + m.Tagger().CopyFromTagger(a.Tagger()) + m.Tagger().CopyFromTagger(b.Tagger()) + it := NewSQLIterator(qs, m) + return it, nil } -func intersectLink(a *SQLLinkIterator, b *SQLLinkIterator) (graph.Iterator, error) { +func intersectLink(a *SQLLinkIterator, b *SQLLinkIterator, qs *QuadStore) (*SQLIterator, error) { m := &SQLLinkIterator{ - uid: iterator.NextUID(), - qs: a.qs, tableName: newTableName(), nodeIts: append(a.nodeIts, b.nodeIts...), constraints: append(a.constraints, b.constraints...), tagdirs: append(a.tagdirs, b.tagdirs...), } - m.Tagger().CopyFrom(a) - m.Tagger().CopyFrom(b) - return m, nil + m.Tagger().CopyFromTagger(a.Tagger()) + m.Tagger().CopyFromTagger(b.Tagger()) + it := NewSQLIterator(qs, m) + return it, nil } -func hasa(aIn graph.Iterator, d quad.Direction) (graph.Iterator, error) { +func hasa(aIn sqlIterator, d quad.Direction, qs *QuadStore) (graph.Iterator, error) { a, ok := aIn.(*SQLLinkIterator) if !ok { return nil, errors.New("Can't take the HASA of a link SQL iterator") } out := &SQLNodeIterator{ - uid: iterator.NextUID(), - qs: a.qs, tableName: newTableName(), linkIts: []sqlItDir{ sqlItDir{ @@ -82,18 +78,17 @@ func hasa(aIn graph.Iterator, d quad.Direction) (graph.Iterator, error) { }, }, } - return out, nil + it := NewSQLIterator(qs, out) + return it, nil } -func linksto(aIn graph.Iterator, d quad.Direction) (graph.Iterator, error) { +func linksto(aIn sqlIterator, d quad.Direction, qs *QuadStore) (graph.Iterator, error) { a, ok := aIn.(*SQLNodeIterator) if !ok { return nil, errors.New("Can't take the LINKSTO of a node SQL iterator") } out := &SQLLinkIterator{ - uid: iterator.NextUID(), - qs: a.qs, tableName: newTableName(), nodeIts: []sqlItDir{ sqlItDir{ @@ -102,8 +97,8 @@ func linksto(aIn graph.Iterator, d quad.Direction) (graph.Iterator, error) { }, }, } - - return out, nil + it := NewSQLIterator(qs, out) + return it, nil } func (qs *QuadStore) OptimizeIterator(it graph.Iterator) (graph.Iterator, bool) { @@ -141,9 +136,9 @@ func (qs *QuadStore) optimizeLinksTo(it *iterator.LinksTo) (graph.Iterator, bool it.Close() return newIt, true } - case sqlNodeType: - //p := primary.(*SQLNodeIterator) - newit, err := linksto(primary, it.Direction()) + case sqlType: + p := primary.(*SQLIterator) + newit, err := linksto(p.sql, it.Direction(), qs) if err != nil { glog.Errorln(err) return it, false @@ -151,22 +146,20 @@ func (qs *QuadStore) optimizeLinksTo(it *iterator.LinksTo) (graph.Iterator, bool newit.Tagger().CopyFrom(it) return newit, true case graph.All: - newit := &SQLLinkIterator{ - uid: iterator.NextUID(), - qs: qs, + linkit := &SQLLinkIterator{ size: qs.Size(), } for _, t := range primary.Tagger().Tags() { - newit.tagdirs = append(newit.tagdirs, tagDir{ + linkit.tagdirs = append(linkit.tagdirs, tagDir{ dir: it.Direction(), tag: t, }) } for k, v := range primary.Tagger().Fixed() { - newit.tagger.AddFixed(k, v) + linkit.tagger.AddFixed(k, v) } - newit.tagger.CopyFrom(it) - + linkit.tagger.CopyFrom(it) + newit := NewSQLIterator(qs, linkit) return newit, true } return it, false @@ -175,18 +168,18 @@ func (qs *QuadStore) optimizeLinksTo(it *iterator.LinksTo) (graph.Iterator, bool func (qs *QuadStore) optimizeAnd(it *iterator.And) (graph.Iterator, bool) { subs := it.SubIterators() var unusedIts []graph.Iterator - var newit graph.Iterator + var newit *SQLIterator newit = nil changed := false var err error for _, it := range subs { - if it.Type() == sqlLinkType || it.Type() == sqlNodeType { + if it.Type() == sqlType { if newit == nil { - newit = it + newit = it.(*SQLIterator) } else { changed = true - newit, err = intersect(newit, it) + newit, err = intersect(newit.sql, it.(*SQLIterator).sql, qs) if err != nil { glog.Error(err) return it, false @@ -219,8 +212,9 @@ func (qs *QuadStore) optimizeHasA(it *iterator.HasA) (graph.Iterator, bool) { return it, false } primary := subs[0] - if primary.Type() == sqlLinkType { - newit, err := hasa(primary, it.Direction()) + if primary.Type() == sqlType { + p := primary.(*SQLIterator) + newit, err := hasa(p.sql, it.Direction(), qs) if err != nil { glog.Errorln(err) return it, false diff --git a/graph/sql/sql_iterator.go b/graph/sql/sql_iterator.go new file mode 100644 index 0000000..74ca0c2 --- /dev/null +++ b/graph/sql/sql_iterator.go @@ -0,0 +1,341 @@ +// Copyright 2015 The Cayley Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/barakmich/glog" + "github.com/google/cayley/graph" + "github.com/google/cayley/graph/iterator" + "github.com/google/cayley/quad" +) + +var sqlType graph.Type + +func init() { + sqlType = graph.RegisterIterator("sql") +} + +type SQLIterator struct { + uid uint64 + qs *QuadStore + cursor *sql.Rows + err error + + sql sqlIterator + + result map[string]string + resultIndex int + resultList [][]string + resultNext [][]string + cols []string +} + +func (it *SQLIterator) Clone() graph.Iterator { + m := &SQLIterator{ + uid: iterator.NextUID(), + qs: it.qs, + sql: it.sql.sqlClone(), + } + return m +} + +func (it *SQLIterator) UID() uint64 { + return it.uid +} + +func (it *SQLIterator) Reset() { + it.err = nil + it.Close() +} + +func (it *SQLIterator) Err() error { + return it.err +} + +func (it *SQLIterator) Close() error { + if it.cursor != nil { + err := it.cursor.Close() + if err != nil { + return err + } + it.cursor = nil + } + return nil +} + +func (it *SQLIterator) Tagger() *graph.Tagger { + return it.sql.Tagger() +} + +func (it *SQLIterator) Result() graph.Value { + return it.sql.Result() +} + +func (it *SQLIterator) TagResults(dst map[string]graph.Value) { + for tag, value := range it.result { + if tag == "__execd" { + for _, tag := range it.Tagger().Tags() { + dst[tag] = value + } + continue + } + dst[tag] = value + } + + for tag, value := range it.Tagger().Fixed() { + dst[tag] = value + } +} + +func (it *SQLIterator) Type() graph.Type { + return sqlType +} + +func (it *SQLIterator) SubIterators() []graph.Iterator { + return nil +} + +func (it *SQLIterator) Sorted() bool { return false } +func (it *SQLIterator) Optimize() (graph.Iterator, bool) { return it, false } + +func (it *SQLIterator) Size() (int64, bool) { + return it.sql.Size(it.qs) +} + +func (it *SQLIterator) Describe() graph.Description { + size, _ := it.Size() + return graph.Description{ + UID: it.UID(), + Name: it.sql.Describe(), + Type: it.Type(), + Size: size, + } +} + +func (it *SQLIterator) Stats() graph.IteratorStats { + size, _ := it.Size() + return graph.IteratorStats{ + ContainsCost: 1, + NextCost: 5, + Size: size, + } +} + +func (it *SQLIterator) NextPath() bool { + it.resultIndex += 1 + if it.resultIndex >= len(it.resultList) { + return false + } + it.buildResult(it.resultIndex) + return true +} + +func (it *SQLIterator) Next() bool { + var err error + graph.NextLogIn(it) + if it.cursor == nil { + err = it.makeCursor(true, nil) + it.cols, err = it.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + it.err = err + it.cursor.Close() + return false + } + // iterate the first one + if !it.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := it.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + it.err = err + } + it.cursor.Close() + return false + } + s, err := scan(it.cursor, len(it.cols)) + if err != nil { + it.err = err + it.cursor.Close() + return false + } + it.resultNext = append(it.resultNext, s) + } + if it.resultList != nil && it.resultNext == nil { + // We're on something and there's no next + return false + } + it.resultList = it.resultNext + it.resultNext = nil + it.resultIndex = 0 + for { + if !it.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := it.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + it.err = err + } + it.cursor.Close() + break + } + s, err := scan(it.cursor, len(it.cols)) + if err != nil { + it.err = err + it.cursor.Close() + return false + } + + if it.sql.sameTopResult(it.resultList[0], s) { + it.resultList = append(it.resultList, s) + } else { + it.resultNext = append(it.resultNext, s) + break + } + } + + if len(it.resultList) == 0 { + return graph.NextLogOut(it, nil, false) + } + it.buildResult(0) + return graph.NextLogOut(it, it.Result(), true) +} + +func (it *SQLIterator) Contains(v graph.Value) bool { + var err error + if ok, res := it.sql.quickContains(v); ok { + return res + } + err = it.makeCursor(false, v) + if err != nil { + glog.Errorf("Couldn't make query: %v", err) + it.err = err + it.cursor.Close() + return false + } + it.cols, err = it.cursor.Columns() + if err != nil { + glog.Errorf("Couldn't get columns") + it.err = err + it.cursor.Close() + return false + } + it.resultList = nil + for { + if !it.cursor.Next() { + glog.V(4).Infoln("sql: No next") + err := it.cursor.Err() + if err != nil { + glog.Errorf("Cursor error in SQL: %v", err) + it.err = err + } + it.cursor.Close() + break + } + s, err := scan(it.cursor, len(it.cols)) + if err != nil { + it.err = err + it.cursor.Close() + return false + } + it.resultList = append(it.resultList, s) + } + it.cursor.Close() + it.cursor = nil + if len(it.resultList) != 0 { + it.resultIndex = 0 + it.buildResult(0) + return true + } + return false +} + +func scan(cursor *sql.Rows, nCols int) ([]string, error) { + pointers := make([]interface{}, nCols) + container := make([]string, nCols) + for i, _ := range pointers { + pointers[i] = &container[i] + } + err := cursor.Scan(pointers...) + if err != nil { + glog.Errorf("Error scanning iterator: %v", err) + return nil, err + } + return container, nil +} + +func (it *SQLIterator) buildResult(i int) { + it.result = it.sql.buildResult(it.resultList[i], it.cols) +} + +func (it *SQLIterator) makeCursor(next bool, value graph.Value) error { + if it.cursor != nil { + it.cursor.Close() + } + var q string + var values []string + q, values = it.sql.buildSQL(next, value) + q = convertToPostgres(q, values) + ivalues := make([]interface{}, 0, len(values)) + for _, v := range values { + ivalues = append(ivalues, v) + } + cursor, err := it.qs.db.Query(q, ivalues...) + if err != nil { + glog.Errorf("Couldn't get cursor from SQL database: %v", err) + cursor = nil + return err + } + it.cursor = cursor + return nil +} + +func convertToPostgres(query string, values []string) string { + for i := 1; i <= len(values); i++ { + query = strings.Replace(query, "?", fmt.Sprintf("$%d", i), 1) + } + return query +} + +func NewSQLLinkIterator(qs *QuadStore, d quad.Direction, val string) *SQLIterator { + l := &SQLIterator{ + uid: iterator.NextUID(), + qs: qs, + sql: &SQLLinkIterator{ + constraints: []constraint{ + constraint{ + dir: d, + vals: []string{val}, + }, + }, + tableName: newTableName(), + size: 0, + }, + } + return l +} + +func NewSQLIterator(qs *QuadStore, sql sqlIterator) *SQLIterator { + l := &SQLIterator{ + uid: iterator.NextUID(), + qs: qs, + sql: sql, + } + return l +} diff --git a/graph/sql/sql_link_iterator.go b/graph/sql/sql_link_iterator.go index 1d71a16..8e7a805 100644 --- a/graph/sql/sql_link_iterator.go +++ b/graph/sql/sql_link_iterator.go @@ -15,22 +15,18 @@ package sql import ( - "database/sql" "fmt" "strings" "sync/atomic" "github.com/barakmich/glog" "github.com/google/cayley/graph" - "github.com/google/cayley/graph/iterator" "github.com/google/cayley/quad" ) -var sqlLinkType graph.Type var sqlTableID uint64 func init() { - sqlLinkType = graph.RegisterIterator("sqllink") atomic.StoreUint64(&sqlTableID, 0) } @@ -73,59 +69,39 @@ type sqlItDir struct { } type sqlIterator interface { - buildSQL(next bool, val graph.Value) (string, []string) sqlClone() sqlIterator + + buildSQL(next bool, val graph.Value) (string, []string) getTables() []tableDef getTags() []tagDir buildWhere() (string, []string) tableID() tagDir + + quickContains(graph.Value) (ok bool, result bool) + buildResult(result []string, cols []string) map[string]string + sameTopResult(target []string, test []string) bool + + Result() graph.Value + Size(*QuadStore) (int64, bool) + Describe() string + Type() sqlQueryType + Tagger() *graph.Tagger } type SQLLinkIterator struct { - uid uint64 - qs *QuadStore tagger graph.Tagger - err error - cursor *sql.Rows nodeIts []sqlItDir constraints []constraint tableName string size int64 tagdirs []tagDir - result map[string]string - resultIndex int - resultList [][]string - resultNext [][]string - cols []string - resultQuad quad.Quad -} - -func NewSQLLinkIterator(qs *QuadStore, d quad.Direction, val string) *SQLLinkIterator { - l := &SQLLinkIterator{ - uid: iterator.NextUID(), - qs: qs, - constraints: []constraint{ - constraint{ - dir: d, - vals: []string{val}, - }, - }, - tableName: newTableName(), - size: 0, - } - return l + resultQuad quad.Quad } func (l *SQLLinkIterator) sqlClone() sqlIterator { - return l.Clone().(*SQLLinkIterator) -} - -func (l *SQLLinkIterator) Clone() graph.Iterator { m := &SQLLinkIterator{ - uid: iterator.NextUID(), - qs: l.qs, tableName: l.tableName, size: l.size, constraints: make([]constraint, len(l.constraints)), @@ -139,34 +115,10 @@ func (l *SQLLinkIterator) Clone() graph.Iterator { } copy(m.constraints, l.constraints) copy(m.tagdirs, l.tagdirs) - m.tagger.CopyFrom(l) + m.tagger.CopyFromTagger(l.Tagger()) return m } -func (l *SQLLinkIterator) UID() uint64 { - return l.uid -} - -func (l *SQLLinkIterator) Reset() { - l.err = nil - l.Close() -} - -func (l *SQLLinkIterator) Err() error { - return l.err -} - -func (l *SQLLinkIterator) Close() error { - if l.cursor != nil { - err := l.cursor.Close() - if err != nil { - return err - } - l.cursor = nil - } - return nil -} - func (l *SQLLinkIterator) Tagger() *graph.Tagger { return &l.tagger } @@ -175,70 +127,30 @@ func (l *SQLLinkIterator) Result() graph.Value { return l.resultQuad } -func (l *SQLLinkIterator) TagResults(dst map[string]graph.Value) { - for tag, value := range l.result { - if tag == "__execd" { - for _, tag := range l.tagger.Tags() { - dst[tag] = value - } - continue - } - dst[tag] = value - } - - for tag, value := range l.tagger.Fixed() { - dst[tag] = value - } -} - -func (l *SQLLinkIterator) SubIterators() []graph.Iterator { - // TODO(barakmich): SQL Subiterators shouldn't count? If it makes sense, - // there's no reason not to expose them though. - return nil -} - -func (l *SQLLinkIterator) Sorted() bool { return false } -func (l *SQLLinkIterator) Optimize() (graph.Iterator, bool) { return l, false } - -func (l *SQLLinkIterator) Size() (int64, bool) { +func (l *SQLLinkIterator) Size(qs *QuadStore) (int64, bool) { if l.size != 0 { return l.size, true } if len(l.constraints) > 0 { - l.size = l.qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0]) + l.size = qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0]) } else if len(l.nodeIts) > 1 { - subsize, _ := l.nodeIts[0].it.(*SQLNodeIterator).Size() + subsize, _ := l.nodeIts[0].it.(*SQLNodeIterator).Size(qs) return subsize * 20, false } else { - return l.qs.Size(), false + return qs.Size(), false } return l.size, true } -func (l *SQLLinkIterator) Describe() graph.Description { - size, _ := l.Size() - return graph.Description{ - UID: l.UID(), - Name: fmt.Sprintf("SQL_LINK_QUERY: %#v", l), - Type: l.Type(), - Size: size, - } +func (l *SQLLinkIterator) Describe() string { + return fmt.Sprintf("SQL_LINK_QUERY: %#v", l) } -func (l *SQLLinkIterator) Stats() graph.IteratorStats { - size, _ := l.Size() - return graph.IteratorStats{ - ContainsCost: 1, - NextCost: 5, - Size: size, - } +func (l *SQLLinkIterator) Type() sqlQueryType { + return link } -func (l *SQLLinkIterator) Type() graph.Type { - return sqlLinkType -} - -func (l *SQLLinkIterator) preFilter(v graph.Value) bool { +func (l *SQLLinkIterator) quickContains(v graph.Value) (bool, bool) { for _, c := range l.constraints { none := true desired := v.(quad.Quad).Get(c.dir) @@ -249,85 +161,27 @@ func (l *SQLLinkIterator) preFilter(v graph.Value) bool { } } if none { - return true + return true, false } } - return false -} - -func (l *SQLLinkIterator) Contains(v graph.Value) bool { - var err error - if l.preFilter(v) { - return false - } if len(l.nodeIts) == 0 { - return true + return true, true } - err = l.makeCursor(false, v) - if err != nil { - glog.Errorf("Couldn't make query: %v", err) - l.err = err - l.cursor.Close() - return false - } - l.cols, err = l.cursor.Columns() - if err != nil { - glog.Errorf("Couldn't get columns") - l.err = err - l.cursor.Close() - return false - } - l.resultList = nil - for { - if !l.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := l.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - l.err = err - } - l.cursor.Close() - break - } - s, err := scan(l.cursor, len(l.cols)) - if err != nil { - l.err = err - l.cursor.Close() - return false - } - l.resultList = append(l.resultList, s) - } - l.cursor.Close() - l.cursor = nil - if len(l.resultList) != 0 { - l.resultIndex = 0 - l.buildResult(0) - return true - } - return false + return false, false } -func (l *SQLLinkIterator) NextPath() bool { - l.resultIndex += 1 - if l.resultIndex >= len(l.resultList) { - return false - } - l.buildResult(l.resultIndex) - return true -} - -func (l *SQLLinkIterator) buildResult(i int) { - container := l.resultList[i] +func (l *SQLLinkIterator) buildResult(result []string, cols []string) map[string]string { var q quad.Quad - q.Subject = container[0] - q.Predicate = container[1] - q.Object = container[2] - q.Label = container[3] + q.Subject = result[0] + q.Predicate = result[1] + q.Object = result[2] + q.Label = result[3] l.resultQuad = q - l.result = make(map[string]string) - for i, c := range l.cols[4:] { - l.result[c] = container[i+4] + m := make(map[string]string) + for i, c := range cols[4:] { + m[c] = result[i+4] } + return m } func (l *SQLLinkIterator) getTables() []tableDef { @@ -448,119 +302,6 @@ func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string return query, values } -func convertToPostgres(query string, values []string) string { - for i := 1; i <= len(values); i++ { - query = strings.Replace(query, "?", fmt.Sprintf("$%d", i), 1) - } - return query -} - -func (l *SQLLinkIterator) makeCursor(next bool, value graph.Value) error { - if l.cursor != nil { - l.cursor.Close() - } - var q string - var values []string - q, values = l.buildSQL(next, value) - q = convertToPostgres(q, values) - ivalues := make([]interface{}, 0, len(values)) - for _, v := range values { - ivalues = append(ivalues, v) - } - cursor, err := l.qs.db.Query(q, ivalues...) - if err != nil { - glog.Errorf("Couldn't get cursor from SQL database: %v", err) - cursor = nil - return err - } - l.cursor = cursor - return nil -} - -func scan(cursor *sql.Rows, nCols int) ([]string, error) { - pointers := make([]interface{}, nCols) - container := make([]string, nCols) - for i, _ := range pointers { - pointers[i] = &container[i] - } - err := cursor.Scan(pointers...) - if err != nil { - glog.Errorf("Error scanning iterator: %v", err) - return nil, err - } - return container, nil -} - -func (l *SQLLinkIterator) Next() bool { - var err error - graph.NextLogIn(l) - if l.cursor == nil { - err = l.makeCursor(true, nil) - l.cols, err = l.cursor.Columns() - if err != nil { - glog.Errorf("Couldn't get columns") - l.err = err - l.cursor.Close() - return false - } - // iterate the first one - if !l.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := l.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - l.err = err - } - l.cursor.Close() - return false - } - s, err := scan(l.cursor, len(l.cols)) - if err != nil { - l.err = err - l.cursor.Close() - return false - } - l.resultNext = append(l.resultNext, s) - } - if l.resultList != nil && l.resultNext == nil { - // We're on something and there's no next - return false - } - l.resultList = l.resultNext - l.resultNext = nil - l.resultIndex = 0 - for { - if !l.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := l.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - l.err = err - } - l.cursor.Close() - break - } - s, err := scan(l.cursor, len(l.cols)) - if err != nil { - l.err = err - l.cursor.Close() - return false - } - if l.resultList[0][0] == s[0] && l.resultList[0][1] == s[1] && l.resultList[0][2] == s[2] && l.resultList[0][3] == s[3] { - l.resultList = append(l.resultList, s) - } else { - l.resultNext = append(l.resultNext, s) - break - } - - } - if len(l.resultList) == 0 { - return graph.NextLogOut(l, nil, false) - } - l.buildResult(0) - return graph.NextLogOut(l, l.Result(), true) -} - -type SQLAllIterator struct { - // TBD +func (l *SQLLinkIterator) sameTopResult(target []string, test []string) bool { + return target[0] == test[0] && target[1] == test[1] && target[2] == test[2] && target[3] == test[3] } diff --git a/graph/sql/sql_node_iterator.go b/graph/sql/sql_node_iterator.go index 58efc9b..90e80e8 100644 --- a/graph/sql/sql_node_iterator.go +++ b/graph/sql/sql_node_iterator.go @@ -15,22 +15,25 @@ package sql import ( - "database/sql" "fmt" "strings" "sync/atomic" "github.com/barakmich/glog" "github.com/google/cayley/graph" - "github.com/google/cayley/graph/iterator" "github.com/google/cayley/quad" ) -var sqlNodeType graph.Type var sqlNodeTableID uint64 +type sqlQueryType int + +const ( + node sqlQueryType = iota + link +) + func init() { - sqlNodeType = graph.RegisterIterator("sqlnode") atomic.StoreUint64(&sqlNodeTableID, 0) } @@ -40,34 +43,20 @@ func newNodeTableName() string { } type SQLNodeIterator struct { - uid uint64 - qs *QuadStore - tagger graph.Tagger tableName string - err error - cursor *sql.Rows linkIts []sqlItDir nodetables []string size int64 + tagger graph.Tagger - result map[string]string - resultIndex int - resultList [][]string - resultNext [][]string - cols []string + result string } func (n *SQLNodeIterator) sqlClone() sqlIterator { - return n.Clone().(*SQLNodeIterator) -} - -func (n *SQLNodeIterator) Clone() graph.Iterator { m := &SQLNodeIterator{ - uid: iterator.NextUID(), - qs: n.qs, - size: n.size, tableName: n.tableName, + size: n.size, } for _, i := range n.linkIts { m.linkIts = append(m.linkIts, sqlItDir{ @@ -75,109 +64,39 @@ func (n *SQLNodeIterator) Clone() graph.Iterator { it: i.it.sqlClone(), }) } - m.tagger.CopyFrom(n) + m.tagger.CopyFromTagger(n.Tagger()) return m } -func (n *SQLNodeIterator) UID() uint64 { - return n.uid -} - -func (n *SQLNodeIterator) Reset() { - n.err = nil - n.Close() -} - -func (n *SQLNodeIterator) Err() error { - return n.err -} - -func (n *SQLNodeIterator) Close() error { - if n.cursor != nil { - err := n.cursor.Close() - if err != nil { - return err - } - n.cursor = nil - } - return nil -} - func (n *SQLNodeIterator) Tagger() *graph.Tagger { return &n.tagger } func (n *SQLNodeIterator) Result() graph.Value { - return n.result["__execd"] + return n.result } -func (n *SQLNodeIterator) TagResults(dst map[string]graph.Value) { - for tag, value := range n.result { - if tag == "__execd" { - for _, tag := range n.tagger.Tags() { - dst[tag] = value - } - continue +func (n *SQLNodeIterator) Type() sqlQueryType { + return node +} + +func (n *SQLNodeIterator) Size(qs *QuadStore) (int64, bool) { + return qs.Size() / int64(len(n.linkIts)+1), true +} + +func (n *SQLNodeIterator) Describe() string { + return fmt.Sprintf("SQL_NODE_QUERY: %#v", n) +} + +func (n *SQLNodeIterator) buildResult(result []string, cols []string) map[string]string { + m := make(map[string]string) + for i, c := range cols { + if c == "__execd" { + n.result = result[i] } - dst[tag] = value - } - - for tag, value := range n.tagger.Fixed() { - dst[tag] = value - } -} - -func (n *SQLNodeIterator) Type() graph.Type { - return sqlNodeType -} - -func (n *SQLNodeIterator) SubIterators() []graph.Iterator { - // TODO(barakmich): SQL Subiterators shouldn't count? If it makes sense, - // there's no reason not to expose them though. - return nil -} - -func (n *SQLNodeIterator) Sorted() bool { return false } -func (n *SQLNodeIterator) Optimize() (graph.Iterator, bool) { return n, false } - -func (n *SQLNodeIterator) Size() (int64, bool) { - return n.qs.Size() / int64(len(n.linkIts)+1), true -} - -func (n *SQLNodeIterator) Describe() graph.Description { - size, _ := n.Size() - return graph.Description{ - UID: n.UID(), - Name: fmt.Sprintf("SQL_NODE_QUERY: %#v", n), - Type: n.Type(), - Size: size, - } -} - -func (n *SQLNodeIterator) Stats() graph.IteratorStats { - size, _ := n.Size() - return graph.IteratorStats{ - ContainsCost: 1, - NextCost: 5, - Size: size, - } -} - -func (n *SQLNodeIterator) NextPath() bool { - n.resultIndex += 1 - if n.resultIndex >= len(n.resultList) { - return false - } - n.buildResult(n.resultIndex) - return true -} - -func (n *SQLNodeIterator) buildResult(i int) { - container := n.resultList[i] - n.result = make(map[string]string) - for i, c := range n.cols { - n.result[c] = container[i] + m[c] = result[i] } + return m } func (n *SQLNodeIterator) makeNodeTableNames() { @@ -215,7 +134,6 @@ func (n *SQLNodeIterator) buildSubqueries() []tableDef { // separate SQL iterators to build a similar tree as we're doing here, and // have a single graph.Iterator 'caddy' structure around it. subNode := &SQLNodeIterator{ - uid: iterator.NextUID(), tableName: newTableName(), linkIts: []sqlItDir{it}, } @@ -351,144 +269,8 @@ func (n *SQLNodeIterator) buildSQL(next bool, val graph.Value) (string, []string return query, values } -func (n *SQLNodeIterator) Next() bool { - var err error - graph.NextLogIn(n) - if n.cursor == nil { - err = n.makeCursor(true, nil) - n.cols, err = n.cursor.Columns() - if err != nil { - glog.Errorf("Couldn't get columns") - n.err = err - n.cursor.Close() - return false - } - // iterate the first one - if !n.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := n.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - n.err = err - } - n.cursor.Close() - return false - } - s, err := scan(n.cursor, len(n.cols)) - if err != nil { - n.err = err - n.cursor.Close() - return false - } - n.resultNext = append(n.resultNext, s) - } - if n.resultList != nil && n.resultNext == nil { - // We're on something and there's no next - return false - } - n.resultList = n.resultNext - n.resultNext = nil - n.resultIndex = 0 - for { - if !n.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := n.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - n.err = err - } - n.cursor.Close() - break - } - s, err := scan(n.cursor, len(n.cols)) - if err != nil { - n.err = err - n.cursor.Close() - return false - } - if n.resultList[0][0] != s[0] { - n.resultNext = append(n.resultNext, s) - break - } else { - n.resultList = append(n.resultList, s) - } - - } - if len(n.resultList) == 0 { - return graph.NextLogOut(n, nil, false) - } - n.buildResult(0) - return graph.NextLogOut(n, n.Result(), true) +func (n *SQLNodeIterator) sameTopResult(target []string, test []string) bool { + return target[0] == test[0] } -func (n *SQLNodeIterator) makeCursor(next bool, value graph.Value) error { - if n.cursor != nil { - n.cursor.Close() - } - var q string - var values []string - q, values = n.buildSQL(next, value) - q = convertToPostgres(q, values) - ivalues := make([]interface{}, 0, len(values)) - for _, v := range values { - ivalues = append(ivalues, v) - } - cursor, err := n.qs.db.Query(q, ivalues...) - if err != nil { - glog.Errorf("Couldn't get cursor from SQL database: %v", err) - glog.Errorf("Query: %v", q) - cursor = nil - return err - } - n.cursor = cursor - return nil -} - -func (n *SQLNodeIterator) Contains(v graph.Value) bool { - var err error - //if it.preFilter(v) { - //return false - //} - err = n.makeCursor(false, v) - if err != nil { - glog.Errorf("Couldn't make query: %v", err) - n.err = err - n.cursor.Close() - return false - } - n.cols, err = n.cursor.Columns() - if err != nil { - glog.Errorf("Couldn't get columns") - n.err = err - n.cursor.Close() - return false - } - n.resultList = nil - for { - if !n.cursor.Next() { - glog.V(4).Infoln("sql: No next") - err := n.cursor.Err() - if err != nil { - glog.Errorf("Cursor error in SQL: %v", err) - n.err = err - } - n.cursor.Close() - break - } - s, err := scan(n.cursor, len(n.cols)) - if err != nil { - n.err = err - n.cursor.Close() - return false - } - n.resultList = append(n.resultList, s) - } - n.cursor.Close() - n.cursor = nil - if len(n.resultList) != 0 { - n.resultIndex = 0 - n.buildResult(0) - return true - } - return false -} +func (n *SQLNodeIterator) quickContains(_ graph.Value) (bool, bool) { return false, false }