cayley/graph/sql/sql_link_iterator.go

509 lines
11 KiB
Go

// Copyright 2015 The Cayley Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"database/sql"
"fmt"
"strings"
"sync/atomic"
"github.com/barakmich/glog"
"github.com/google/cayley/graph"
"github.com/google/cayley/graph/iterator"
"github.com/google/cayley/quad"
)
var sqlLinkType graph.Type
var sqlTableID uint64
func init() {
sqlLinkType = graph.RegisterIterator("sqllink")
atomic.StoreUint64(&sqlTableID, 0)
}
func newTableName() string {
id := atomic.AddUint64(&sqlTableID, 1)
return fmt.Sprintf("t_%d", id)
}
type constraint struct {
dir quad.Direction
vals []string
}
type tagDir struct {
tag string
dir quad.Direction
// Not to be stored in the iterator directly
table string
}
type sqlItDir struct {
dir quad.Direction
it sqlIterator
}
type sqlIterator interface {
sqlClone() sqlIterator
getTables() []string
getTags() []tagDir
buildWhere() (string, []string)
tableID() tagDir
}
type SQLLinkIterator struct {
uid uint64
qs *QuadStore
tagger graph.Tagger
err error
cursor *sql.Rows
nodeIts []sqlItDir
constraints []constraint
tableName string
size int64
result map[string]string
resultIndex int
resultList [][]string
resultNext [][]string
cols []string
resultQuad quad.Quad
}
func NewSQLLinkIterator(qs *QuadStore, d quad.Direction, val string) *SQLLinkIterator {
l := &SQLLinkIterator{
uid: iterator.NextUID(),
qs: qs,
constraints: []constraint{
constraint{
dir: d,
vals: []string{val},
},
},
tableName: newTableName(),
size: 0,
}
return l
}
func (l *SQLLinkIterator) sqlClone() sqlIterator {
return l.Clone().(*SQLLinkIterator)
}
func (l *SQLLinkIterator) Clone() graph.Iterator {
m := &SQLLinkIterator{
uid: iterator.NextUID(),
qs: l.qs,
tableName: l.tableName,
size: l.size,
constraints: make([]constraint, 0, len(l.constraints)),
}
for _, i := range l.nodeIts {
m.nodeIts = append(m.nodeIts, sqlItDir{
dir: i.dir,
it: i.it.sqlClone(),
})
}
copy(m.constraints, l.constraints)
m.tagger.CopyFrom(l)
return m
}
func (l *SQLLinkIterator) UID() uint64 {
return l.uid
}
func (l *SQLLinkIterator) Reset() {
l.err = nil
l.Close()
}
func (l *SQLLinkIterator) Err() error {
return l.err
}
func (l *SQLLinkIterator) Close() error {
if l.cursor != nil {
err := l.cursor.Close()
if err != nil {
return err
}
l.cursor = nil
}
return nil
}
func (l *SQLLinkIterator) Tagger() *graph.Tagger {
return &l.tagger
}
func (l *SQLLinkIterator) Result() graph.Value {
return l.resultQuad
}
func (l *SQLLinkIterator) TagResults(dst map[string]graph.Value) {
for tag, value := range l.result {
if tag == "__execd" {
for _, tag := range l.tagger.Tags() {
dst[tag] = value
}
continue
}
dst[tag] = value
}
for tag, value := range l.tagger.Fixed() {
dst[tag] = value
}
}
func (l *SQLLinkIterator) SubIterators() []graph.Iterator {
// TODO(barakmich): SQL Subiterators shouldn't count? If it makes sense,
// there's no reason not to expose them though.
return nil
}
func (l *SQLLinkIterator) Sorted() bool { return false }
func (l *SQLLinkIterator) Optimize() (graph.Iterator, bool) { return l, false }
func (l *SQLLinkIterator) Size() (int64, bool) {
if l.size != 0 {
return l.size, true
}
if len(l.constraints) > 0 {
l.size = l.qs.sizeForIterator(false, l.constraints[0].dir, l.constraints[0].vals[0])
} else {
return l.qs.Size(), false
}
return l.size, true
}
func (l *SQLLinkIterator) Describe() graph.Description {
size, _ := l.Size()
return graph.Description{
UID: l.UID(),
Name: fmt.Sprintf("SQL_LINK_QUERY: %#v", l),
Type: l.Type(),
Size: size,
}
}
func (l *SQLLinkIterator) Stats() graph.IteratorStats {
size, _ := l.Size()
return graph.IteratorStats{
ContainsCost: 1,
NextCost: 5,
Size: size,
}
}
func (l *SQLLinkIterator) Type() graph.Type {
return sqlLinkType
}
func (l *SQLLinkIterator) Contains(v graph.Value) bool {
var err error
//if it.preFilter(v) {
//return false
//}
err = l.makeCursor(false, v)
if err != nil {
glog.Errorf("Couldn't make query: %v", err)
l.err = err
l.cursor.Close()
return false
}
l.cols, err = l.cursor.Columns()
if err != nil {
glog.Errorf("Couldn't get columns")
l.err = err
l.cursor.Close()
return false
}
l.resultList = nil
for {
if !l.cursor.Next() {
glog.V(4).Infoln("sql: No next")
err := l.cursor.Err()
if err != nil {
glog.Errorf("Cursor error in SQL: %v", err)
l.err = err
}
l.cursor.Close()
break
}
s, err := scan(l.cursor, len(l.cols))
if err != nil {
l.err = err
l.cursor.Close()
return false
}
l.resultList = append(l.resultList, s)
}
l.cursor.Close()
l.cursor = nil
if len(l.resultList) != 0 {
l.resultIndex = 0
l.buildResult(0)
return true
}
return false
}
func (l *SQLLinkIterator) NextPath() bool {
l.resultIndex += 1
if l.resultIndex >= len(l.resultList) {
return false
}
l.buildResult(l.resultIndex)
return true
}
func (l *SQLLinkIterator) buildResult(i int) {
container := l.resultList[i]
var q quad.Quad
q.Subject = container[0]
q.Predicate = container[1]
q.Object = container[2]
q.Label = container[3]
l.resultQuad = q
l.result = make(map[string]string)
for i, c := range l.cols[4:] {
l.result[c] = container[i+4]
}
}
func (l *SQLLinkIterator) getTables() []string {
out := []string{l.tableName}
for _, i := range l.nodeIts {
out = append(out, i.it.getTables()...)
}
return out
}
func (l *SQLLinkIterator) getTags() []tagDir {
var out []tagDir
for _, tag := range l.tagger.Tags() {
out = append(out, tagDir{
dir: quad.Any,
table: l.tableName,
tag: tag,
})
}
for _, i := range l.nodeIts {
out = append(out, i.it.getTags()...)
}
return out
}
func (l *SQLLinkIterator) buildWhere() (string, []string) {
var q []string
var vals []string
for _, c := range l.constraints {
q = append(q, fmt.Sprintf("%s.%s = ?", l.tableName, c.dir))
vals = append(vals, c.vals[0])
}
for _, i := range l.nodeIts {
t := i.it.tableID()
q = append(q, fmt.Sprintf("%s.%s = %s.%s", l.tableName, i.dir, t.table, t.dir))
}
for _, i := range l.nodeIts {
s, v := i.it.buildWhere()
q = append(q, s)
vals = append(vals, v...)
}
query := strings.Join(q, " AND ")
return query, vals
}
func (l *SQLLinkIterator) tableID() tagDir {
return tagDir{
dir: quad.Any,
table: l.tableName,
}
}
func (l *SQLLinkIterator) buildSQL(next bool, val graph.Value) (string, []string) {
query := "SELECT "
t := []string{
fmt.Sprintf("%s.subject", l.tableName),
fmt.Sprintf("%s.predicate", l.tableName),
fmt.Sprintf("%s.object", l.tableName),
fmt.Sprintf("%s.label", l.tableName),
}
for _, v := range l.getTags() {
t = append(t, fmt.Sprintf("%s.%s as %s", v.table, v.dir, v.tag))
}
query += strings.Join(t, ", ")
query += " FROM "
t = []string{}
for _, k := range l.getTables() {
t = append(t, fmt.Sprintf("quads as %s", k))
}
query += strings.Join(t, ", ")
query += " WHERE "
constraint, values := l.buildWhere()
if !next {
v := val.(quad.Quad)
if constraint != "" {
constraint += " AND "
}
t = []string{
fmt.Sprintf("%s.subject = ?", l.tableName),
fmt.Sprintf("%s.predicate = ?", l.tableName),
fmt.Sprintf("%s.object = ?", l.tableName),
fmt.Sprintf("%s.label = ?", l.tableName),
}
constraint += strings.Join(t, " AND ")
values = append(values, v.Subject)
values = append(values, v.Predicate)
values = append(values, v.Object)
values = append(values, v.Label)
}
query += constraint
query += ";"
glog.V(2).Infoln(query)
if glog.V(4) {
dstr := query
for i := 1; i <= len(values); i++ {
dstr = strings.Replace(dstr, "?", fmt.Sprintf("'%s'", values[i-1]), 1)
}
glog.V(4).Infoln(dstr)
}
return query, values
}
func convertToPostgres(query string, values []string) string {
for i := 1; i <= len(values); i++ {
query = strings.Replace(query, "?", fmt.Sprintf("$%d", i), 1)
}
return query
}
func (l *SQLLinkIterator) makeCursor(next bool, value graph.Value) error {
if l.cursor != nil {
l.cursor.Close()
}
var q string
var values []string
q, values = l.buildSQL(next, value)
q = convertToPostgres(q, values)
ivalues := make([]interface{}, 0, len(values))
for _, v := range values {
ivalues = append(ivalues, v)
}
cursor, err := l.qs.db.Query(q, ivalues...)
if err != nil {
glog.Errorf("Couldn't get cursor from SQL database: %v", err)
cursor = nil
return err
}
l.cursor = cursor
return nil
}
func scan(cursor *sql.Rows, nCols int) ([]string, error) {
pointers := make([]interface{}, nCols)
container := make([]string, nCols)
for i, _ := range pointers {
pointers[i] = &container[i]
}
err := cursor.Scan(pointers...)
if err != nil {
glog.Errorf("Error scanning iterator: %v", err)
return nil, err
}
return container, nil
}
func (l *SQLLinkIterator) Next() bool {
var err error
graph.NextLogIn(l)
if l.cursor == nil {
err = l.makeCursor(true, nil)
l.cols, err = l.cursor.Columns()
if err != nil {
glog.Errorf("Couldn't get columns")
l.err = err
l.cursor.Close()
return false
}
// iterate the first one
if !l.cursor.Next() {
glog.V(4).Infoln("sql: No next")
err := l.cursor.Err()
if err != nil {
glog.Errorf("Cursor error in SQL: %v", err)
l.err = err
}
l.cursor.Close()
return false
}
s, err := scan(l.cursor, len(l.cols))
if err != nil {
l.err = err
l.cursor.Close()
return false
}
l.resultNext = append(l.resultNext, s)
}
if l.resultList != nil && l.resultNext == nil {
// We're on something and there's no next
return false
}
l.resultList = l.resultNext
l.resultNext = nil
l.resultIndex = 0
for {
if !l.cursor.Next() {
glog.V(4).Infoln("sql: No next")
err := l.cursor.Err()
if err != nil {
glog.Errorf("Cursor error in SQL: %v", err)
l.err = err
}
l.cursor.Close()
break
}
s, err := scan(l.cursor, len(l.cols))
if err != nil {
l.err = err
l.cursor.Close()
return false
}
if l.resultList[0][0] == s[0] && l.resultList[0][1] == s[1] && l.resultList[0][2] == s[2] && l.resultList[0][3] == s[3] {
l.resultList = append(l.resultList, s)
} else {
l.resultNext = append(l.resultNext, s)
break
}
}
if len(l.resultList) == 0 {
return graph.NextLogOut(l, nil, false)
}
l.buildResult(0)
return graph.NextLogOut(l, l.Result(), true)
}
type SQLAllIterator struct {
// TBD
}