Optimize by collapsing trees into single SQL queries

This commit is contained in:
Barak Michener 2015-07-22 19:11:59 -04:00
parent 185e236f15
commit 621acae945
7 changed files with 501 additions and 50 deletions

View file

@ -46,6 +46,7 @@ func (td tableDir) String() string {
type clause interface {
toSQL() (string, []string)
getTables() map[string]bool
size() int
}
type baseClause struct {
@ -65,6 +66,8 @@ func (b baseClause) toSQL() (string, []string) {
return fmt.Sprintf("%s = ?", b.pair), []string{b.strTarget[0]}
}
func (b baseClause) size() int { return 1 }
func (b baseClause) getTables() map[string]bool {
out := make(map[string]bool)
if b.pair.table != "" {
@ -82,7 +85,27 @@ type joinClause struct {
op clauseOp
}
func (jc joinClause) size() int {
size := 0
if jc.left != nil {
size += jc.left.size()
}
if jc.right != nil {
size += jc.right.size()
}
return size
}
func (jc joinClause) toSQL() (string, []string) {
if jc.left == nil {
if jc.right == nil {
return "", []string{}
}
return jc.right.toSQL()
}
if jc.right == nil {
return jc.left.toSQL()
}
l, lstr := jc.left.toSQL()
r, rstr := jc.right.toSQL()
lstr = append(lstr, rstr...)
@ -93,13 +116,20 @@ func (jc joinClause) toSQL() (string, []string) {
case orClause:
op = "OR"
}
return fmt.Sprint("(%s %s %s)", l, op, r), lstr
return fmt.Sprintf("(%s %s %s)", l, op, r), lstr
}
func (jc joinClause) getTables() map[string]bool {
m := jc.left.getTables()
for k, _ := range jc.right.getTables() {
m[k] = true
var m map[string]bool
if jc.left != nil {
m = jc.left.getTables()
} else {
m = make(map[string]bool)
}
if jc.right != nil {
for k, _ := range jc.right.getTables() {
m[k] = true
}
}
return m
}
@ -166,8 +196,14 @@ func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, [
t = []string{fmt.Sprintf("%s.%s as __execd", it.tableName(), it.dir)}
}
for _, v := range it.tags {
if v.pair.table == "" {
v.pair.table = it.tableName()
}
t = append(t, fmt.Sprintf("%s as %s", v.pair, v.t))
}
for _, v := range it.tagger.Tags() {
t = append(t, fmt.Sprintf("%s as %s", tableDir{it.tableName(), it.dir}, v))
}
str += strings.Join(t, ", ")
str += " FROM "
t = []string{fmt.Sprintf("quads as %s", it.tableName())}
@ -180,7 +216,7 @@ func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, [
str += " WHERE "
var values []string
var s string
if it.stType != node {
if len(it.buildWhere) != 0 {
s, values = it.canonicalizeWhere()
}
if it.where != nil {
@ -191,28 +227,31 @@ func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, [
s += where
values = append(values, v2...)
}
str += s
if contains {
if s != "" {
s += " AND "
}
if it.stType == link {
q := v.(quad.Quad)
str += " AND "
t = []string{
fmt.Sprintf("%s.subject = ?", it.tableName()),
fmt.Sprintf("%s.predicate = ?", it.tableName()),
fmt.Sprintf("%s.object = ?", it.tableName()),
fmt.Sprintf("%s.label = ?", it.tableName()),
}
str += " " + strings.Join(t, " AND ") + " "
s += " " + strings.Join(t, " AND ") + " "
values = append(values, q.Subject)
values = append(values, q.Predicate)
values = append(values, q.Object)
values = append(values, q.Label)
} else {
str += fmt.Sprintf(" AND %s.%s = ? ", it.tableName(), it.dir)
s += fmt.Sprintf("%s.%s = ? ", it.tableName(), it.dir)
values = append(values, v.(string))
}
}
str += s
if it.stType == node {
str += " ORDER BY __execd "
}
@ -220,6 +259,14 @@ func (it *StatementIterator) buildQuery(contains bool, v graph.Value) (string, [
for i := 1; i <= len(values); i++ {
str = strings.Replace(str, "?", fmt.Sprintf("$%d", i), 1)
}
glog.V(2).Infoln(str)
if glog.V(4) {
dstr := str
for i := 1; i <= len(values); i++ {
dstr = strings.Replace(dstr, fmt.Sprintf("$%d", i), fmt.Sprintf("'%s'", values[i-1]), 1)
}
glog.V(4).Infoln(dstr)
}
return str, values
}
@ -254,6 +301,7 @@ func (it *StatementIterator) Clone() graph.Iterator {
where: it.where,
stType: it.stType,
size: it.size,
dir: it.dir,
}
copy(it.tags, m.tags)
m.tagger.CopyFrom(it)
@ -364,6 +412,12 @@ func (it *StatementIterator) Contains(v graph.Value) bool {
ivalues = append(ivalues, v)
}
it.cursor, err = it.qs.db.Query(q, ivalues...)
if err != nil {
glog.Errorf("Couldn't make query: %v", err)
it.err = err
it.cursor.Close()
return false
}
it.cols, err = it.cursor.Columns()
if err != nil {
glog.Errorf("Couldn't get columns")
@ -414,10 +468,17 @@ func (it *StatementIterator) Size() (int64, bool) {
return it.size, true
}
if it.stType == node {
return it.qs.Size(), true
if it.where == nil {
return it.qs.Size() / int64(len(it.buildWhere)+1), true
}
return it.qs.Size() / int64(it.where.size()+len(it.buildWhere)+1), true
}
b := it.buildWhere[0]
it.size = it.qs.sizeForIterator(false, b.pair.dir, b.strTarget[0])
if len(b.strTarget) > 0 {
it.size = it.qs.sizeForIterator(false, b.pair.dir, b.strTarget[0])
} else {
return it.qs.Size(), false
}
return it.size, true
}
@ -425,7 +486,7 @@ func (it *StatementIterator) Describe() graph.Description {
size, _ := it.Size()
return graph.Description{
UID: it.UID(),
Name: "SQL_QUERY",
Name: fmt.Sprintf("SQL_QUERY: %#v", it),
Type: it.Type(),
Size: size,
}
@ -451,7 +512,7 @@ func (it *StatementIterator) makeCursor() {
}
cursor, err := it.qs.db.Query(q, ivalues...)
if err != nil {
glog.Errorln("Couldn't get cursor from SQL database: %v", err)
glog.Errorf("Couldn't get cursor from SQL database: %v", err)
cursor = nil
}
it.cursor = cursor
@ -542,7 +603,7 @@ func (it *StatementIterator) Next() bool {
return graph.NextLogOut(it, nil, false)
}
it.buildResult(0)
return graph.NextLogOut(it, it.result, true)
return graph.NextLogOut(it, it.Result(), true)
}
func (it *StatementIterator) scan() ([]string, error) {