cayley/graph/sql/builder_iterator.go

643 lines
13 KiB
Go

// Copyright 2015 The Cayley Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"database/sql"
"fmt"
"strings"
"github.com/barakmich/glog"
"github.com/google/cayley/graph"
"github.com/google/cayley/graph/iterator"
"github.com/google/cayley/quad"
)
var sqlBuilderType graph.Type
func init() {
sqlBuilderType = graph.RegisterIterator("sqlbuilder")
}
type tableDir struct {
table string
dir quad.Direction
}
func (td tableDir) String() string {
if td.table != "" {
return fmt.Sprintf("%s.%s", td.table, td.dir)
}
return "ERR"
}
type clause interface {
toSQL() (string, []string)
getTables() map[string]bool
size() int
}
type baseClause struct {
pair tableDir
strTarget []string
target tableDir
}
func (b baseClause) toSQL() (string, []string) {
if len(b.strTarget) > 1 {
// TODO(barakmich): Sets of things, IN clause
return "", []string{}
}
if len(b.strTarget) == 0 {
return fmt.Sprintf("%s = %s", b.pair, b.target), []string{}
}
return fmt.Sprintf("%s = ?", b.pair), []string{b.strTarget[0]}
}
func (b baseClause) size() int { return 1 }
func (b baseClause) getTables() map[string]bool {
out := make(map[string]bool)
if b.pair.table != "" {
out[b.pair.table] = true
}
if b.target.table != "" {
out[b.target.table] = true
}
return out
}
type joinClause struct {
left clause
right clause
op clauseOp
}
func (jc joinClause) size() int {
size := 0
if jc.left != nil {
size += jc.left.size()
}
if jc.right != nil {
size += jc.right.size()
}
return size
}
func (jc joinClause) toSQL() (string, []string) {
if jc.left == nil {
if jc.right == nil {
return "", []string{}
}
return jc.right.toSQL()
}
if jc.right == nil {
return jc.left.toSQL()
}
l, lstr := jc.left.toSQL()
r, rstr := jc.right.toSQL()
lstr = append(lstr, rstr...)
var op string
switch jc.op {
case andClause:
op = "AND"
case orClause:
op = "OR"
}
return fmt.Sprintf("(%s %s %s)", l, op, r), lstr
}
func (jc joinClause) getTables() map[string]bool {
var m map[string]bool
if jc.left != nil {
m = jc.left.getTables()
} else {
m = make(map[string]bool)
}
if jc.right != nil {
for k, _ := range jc.right.getTables() {
m[k] = true
}
}
return m
}
type tag struct {
pair tableDir
t string
}
type statementType int
const (
node statementType = iota
link
)
type clauseOp int
const (
andClause clauseOp = iota
orClause
)
func (it *StatementIterator) canonicalizeWhere() (string, []string) {
var out []string
var values []string
for _, b := range it.buildWhere {
b.pair.table = it.tableName()
s, v := b.toSQL()
values = append(values, v...)
out = append(out, s)
}
return strings.Join(out, " AND "), values
}
func (it *StatementIterator) getTables() map[string]bool {
m := make(map[string]bool)
if it.where != nil {
m = it.where.getTables()
}
for _, t := range it.tags {
if t.pair.table != "" {
m[t.pair.table] = true
}
}
return m
}
func (it *StatementIterator) tableName() string {
return fmt.Sprintf("t_%d", it.uid)
}
func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, []string) {
str := "SELECT "
var t []string
if it.stType == link {
t = []string{
fmt.Sprintf("%s.subject", it.tableName()),
fmt.Sprintf("%s.predicate", it.tableName()),
fmt.Sprintf("%s.object", it.tableName()),
fmt.Sprintf("%s.label", it.tableName()),
}
} else {
t = []string{fmt.Sprintf("%s.%s as __execd", it.tableName(), it.dir)}
}
for _, v := range it.tags {
if v.pair.table == "" {
v.pair.table = it.tableName()
}
t = append(t, fmt.Sprintf("%s as %s", v.pair, v.t))
}
for _, v := range it.tagger.Tags() {
t = append(t, fmt.Sprintf("%s as %s", tableDir{it.tableName(), it.dir}, v))
}
str += strings.Join(t, ", ")
str += " FROM "
t = []string{fmt.Sprintf("quads as %s", it.tableName())}
for k, _ := range it.getTables() {
if k != it.tableName() {
t = append(t, fmt.Sprintf("quads as %s", k))
}
}
str += strings.Join(t, ", ")
str += " WHERE "
var values []string
var s string
if len(it.buildWhere) != 0 {
s, values = it.canonicalizeWhere()
}
if it.where != nil {
if s != "" {
s += " AND "
}
where, v2 := it.where.toSQL()
s += where
values = append(values, v2...)
}
if contains {
if s != "" {
s += " AND "
}
if it.stType == link {
q := v.(quad.Quad)
t = []string{
fmt.Sprintf("%s.subject = ?", it.tableName()),
fmt.Sprintf("%s.predicate = ?", it.tableName()),
fmt.Sprintf("%s.object = ?", it.tableName()),
fmt.Sprintf("%s.label = ?", it.tableName()),
}
s += " " + strings.Join(t, " AND ") + " "
values = append(values, q.Subject)
values = append(values, q.Predicate)
values = append(values, q.Object)
values = append(values, q.Label)
} else {
s += fmt.Sprintf("%s.%s = ? ", it.tableName(), it.dir)
values = append(values, v.(string))
}
}
str += s
if it.stType == node {
str += " ORDER BY __execd "
}
str += ";"
for i := 1; i <= len(values); i++ {
str = strings.Replace(str, "?", fmt.Sprintf("$%d", i), 1)
}
glog.V(2).Infoln(str)
if glog.V(4) {
dstr := str
for i := 1; i <= len(values); i++ {
dstr = strings.Replace(dstr, fmt.Sprintf("$%d", i), fmt.Sprintf("'%s'", values[i-1]), 1)
}
glog.V(4).Infoln(dstr)
}
return str, values
}
type StatementIterator struct {
uid uint64
qs *QuadStore
// Only for links
buildWhere []baseClause
where clause
tagger graph.Tagger
tags []tag
err error
cursor *sql.Rows
stType statementType
dir quad.Direction
result map[string]string
resultIndex int
resultList [][]string
resultNext [][]string
cols []string
resultQuad quad.Quad
size int64
}
func (it *StatementIterator) Clone() graph.Iterator {
m := &StatementIterator{
uid: iterator.NextUID(),
qs: it.qs,
buildWhere: it.buildWhere,
where: it.where,
stType: it.stType,
size: it.size,
dir: it.dir,
}
copy(it.tags, m.tags)
m.tagger.CopyFrom(it)
return m
}
func NewStatementIterator(qs *QuadStore, d quad.Direction, val string) *StatementIterator {
it := &StatementIterator{
uid: iterator.NextUID(),
qs: qs,
buildWhere: []baseClause{
baseClause{
pair: tableDir{"", d},
strTarget: []string{val},
},
},
stType: link,
size: -1,
}
return it
}
func (it *StatementIterator) UID() uint64 {
return it.uid
}
func (it *StatementIterator) Reset() {
it.err = nil
it.Close()
}
func (it *StatementIterator) Err() error {
return it.err
}
func (it *StatementIterator) Close() error {
if it.cursor != nil {
err := it.cursor.Close()
if err != nil {
return err
}
it.cursor = nil
}
return nil
}
func (it *StatementIterator) Tagger() *graph.Tagger {
return &it.tagger
}
func (it *StatementIterator) Result() graph.Value {
if it.stType == node {
return it.result["__execd"]
}
return it.resultQuad
}
func (it *StatementIterator) TagResults(dst map[string]graph.Value) {
for tag, value := range it.result {
if tag == "__execd" {
for _, tag := range it.tagger.Tags() {
dst[tag] = value
}
continue
}
dst[tag] = value
}
for tag, value := range it.tagger.Fixed() {
dst[tag] = value
}
}
func (it *StatementIterator) Type() graph.Type {
return sqlBuilderType
}
func (it *StatementIterator) preFilter(v graph.Value) bool {
if it.stType == link {
q := v.(quad.Quad)
for _, b := range it.buildWhere {
if len(b.strTarget) == 0 {
continue
}
canFilter := true
for _, s := range b.strTarget {
if q.Get(b.pair.dir) == s {
canFilter = false
break
}
}
if canFilter {
return true
}
}
}
return false
}
func (it *StatementIterator) Contains(v graph.Value) bool {
var err error
if it.preFilter(v) {
return false
}
q, values := it.buildQuery(true, v)
ivalues := make([]interface{}, 0, len(values))
for _, v := range values {
ivalues = append(ivalues, v)
}
it.cursor, err = it.qs.db.Query(q, ivalues...)
if err != nil {
glog.Errorf("Couldn't make query: %v", err)
it.err = err
it.cursor.Close()
return false
}
it.cols, err = it.cursor.Columns()
if err != nil {
glog.Errorf("Couldn't get columns")
it.err = err
it.cursor.Close()
return false
}
it.resultList = nil
for {
if !it.cursor.Next() {
glog.V(4).Infoln("sql: No next")
err := it.cursor.Err()
if err != nil {
glog.Errorf("Cursor error in SQL: %v", err)
it.err = err
}
it.cursor.Close()
break
}
s, err := it.scan()
if err != nil {
it.err = err
it.cursor.Close()
return false
}
it.resultList = append(it.resultList, s)
}
it.cursor.Close()
it.cursor = nil
if len(it.resultList) != 0 {
it.resultIndex = 0
it.buildResult(0)
return true
}
return false
}
func (it *StatementIterator) SubIterators() []graph.Iterator {
return nil
}
func (it *StatementIterator) Sorted() bool { return false }
func (it *StatementIterator) Optimize() (graph.Iterator, bool) { return it, false }
func (it *StatementIterator) Size() (int64, bool) {
if it.size != -1 {
return it.size, true
}
if it.stType == node {
if it.where == nil {
return it.qs.Size() / int64(len(it.buildWhere)+1), true
}
return it.qs.Size() / int64(it.where.size()+len(it.buildWhere)+1), true
}
b := it.buildWhere[0]
if len(b.strTarget) > 0 {
it.size = it.qs.sizeForIterator(false, b.pair.dir, b.strTarget[0])
} else {
return it.qs.Size(), false
}
return it.size, true
}
func (it *StatementIterator) Describe() graph.Description {
size, _ := it.Size()
return graph.Description{
UID: it.UID(),
Name: fmt.Sprintf("SQL_QUERY: %#v", it),
Type: it.Type(),
Size: size,
}
}
func (it *StatementIterator) Stats() graph.IteratorStats {
size, _ := it.Size()
return graph.IteratorStats{
ContainsCost: 1,
NextCost: 5,
Size: size,
}
}
func (it *StatementIterator) makeCursor() {
if it.cursor != nil {
it.cursor.Close()
}
q, values := it.buildQuery(false, nil)
ivalues := make([]interface{}, 0, len(values))
for _, v := range values {
ivalues = append(ivalues, v)
}
cursor, err := it.qs.db.Query(q, ivalues...)
if err != nil {
glog.Errorf("Couldn't get cursor from SQL database: %v", err)
cursor = nil
}
it.cursor = cursor
}
func (it *StatementIterator) NextPath() bool {
it.resultIndex += 1
if it.resultIndex >= len(it.resultList) {
return false
}
it.buildResult(it.resultIndex)
return true
}
func (it *StatementIterator) Next() bool {
var err error
graph.NextLogIn(it)
if it.cursor == nil {
it.makeCursor()
it.cols, err = it.cursor.Columns()
if err != nil {
glog.Errorf("Couldn't get columns")
it.err = err
it.cursor.Close()
return false
}
// iterate the first one
if !it.cursor.Next() {
glog.V(4).Infoln("sql: No next")
err := it.cursor.Err()
if err != nil {
glog.Errorf("Cursor error in SQL: %v", err)
it.err = err
}
it.cursor.Close()
return false
}
s, err := it.scan()
if err != nil {
it.err = err
it.cursor.Close()
return false
}
it.resultNext = append(it.resultNext, s)
}
if it.resultList != nil && it.resultNext == nil {
// We're on something and there's no next
return false
}
it.resultList = it.resultNext
it.resultNext = nil
it.resultIndex = 0
for {
if !it.cursor.Next() {
glog.V(4).Infoln("sql: No next")
err := it.cursor.Err()
if err != nil {
glog.Errorf("Cursor error in SQL: %v", err)
it.err = err
}
it.cursor.Close()
break
}
s, err := it.scan()
if err != nil {
it.err = err
it.cursor.Close()
return false
}
if it.stType == node {
if it.resultList[0][0] != s[0] {
it.resultNext = append(it.resultNext, s)
break
} else {
it.resultList = append(it.resultList, s)
}
} else {
if it.resultList[0][0] == s[0] && it.resultList[0][1] == s[1] && it.resultList[0][2] == s[2] && it.resultList[0][3] == s[3] {
it.resultList = append(it.resultList, s)
} else {
it.resultNext = append(it.resultNext, s)
break
}
}
}
if len(it.resultList) == 0 {
return graph.NextLogOut(it, nil, false)
}
it.buildResult(0)
return graph.NextLogOut(it, it.Result(), true)
}
func (it *StatementIterator) scan() ([]string, error) {
pointers := make([]interface{}, len(it.cols))
container := make([]string, len(it.cols))
for i, _ := range pointers {
pointers[i] = &container[i]
}
err := it.cursor.Scan(pointers...)
if err != nil {
glog.Errorf("Error scanning iterator: %v", err)
it.err = err
return nil, err
}
return container, nil
}
func (it *StatementIterator) buildResult(i int) {
container := it.resultList[i]
if it.stType == node {
it.result = make(map[string]string)
for i, c := range it.cols {
it.result[c] = container[i]
}
return
}
var q quad.Quad
q.Subject = container[0]
q.Predicate = container[1]
q.Object = container[2]
q.Label = container[3]
it.resultQuad = q
it.result = make(map[string]string)
for i, c := range it.cols[4:] {
it.result[c] = container[i+4]
}
}