Fix data race in gremlin timeout handling

Fixes issue #95.
This commit is contained in:
kortschak 2014-08-02 22:01:02 +09:30
parent cb177aa390
commit fe9ca5ffcc
2 changed files with 39 additions and 13 deletions

View file

@ -148,8 +148,10 @@ func runIteratorToArray(it graph.Iterator, ses *Session, limit int) []map[string
count := 0 count := 0
it, _ = it.Optimize() it, _ = it.Optimize()
for { for {
if ses.doHalt { select {
case <-ses.kill:
return nil return nil
default:
} }
_, ok := graph.Next(it) _, ok := graph.Next(it)
if !ok { if !ok {
@ -163,8 +165,10 @@ func runIteratorToArray(it graph.Iterator, ses *Session, limit int) []map[string
break break
} }
for it.NextResult() == true { for it.NextResult() == true {
if ses.doHalt { select {
case <-ses.kill:
return nil return nil
default:
} }
tags := make(map[string]graph.Value) tags := make(map[string]graph.Value)
it.TagResults(tags) it.TagResults(tags)
@ -184,8 +188,10 @@ func runIteratorToArrayNoTags(it graph.Iterator, ses *Session, limit int) []stri
count := 0 count := 0
it, _ = it.Optimize() it, _ = it.Optimize()
for { for {
if ses.doHalt { select {
case <-ses.kill:
return nil return nil
default:
} }
val, ok := graph.Next(it) val, ok := graph.Next(it)
if !ok { if !ok {
@ -205,8 +211,10 @@ func runIteratorWithCallback(it graph.Iterator, ses *Session, callback otto.Valu
count := 0 count := 0
it, _ = it.Optimize() it, _ = it.Optimize()
for { for {
if ses.doHalt { select {
case <-ses.kill:
return return
default:
} }
_, ok := graph.Next(it) _, ok := graph.Next(it)
if !ok { if !ok {
@ -221,8 +229,10 @@ func runIteratorWithCallback(it graph.Iterator, ses *Session, callback otto.Valu
break break
} }
for it.NextResult() == true { for it.NextResult() == true {
if ses.doHalt { select {
case <-ses.kill:
return return
default:
} }
tags := make(map[string]graph.Value) tags := make(map[string]graph.Value)
it.TagResults(tags) it.TagResults(tags)
@ -246,8 +256,10 @@ func runIteratorOnSession(it graph.Iterator, ses *Session) {
glog.V(2).Infoln(it.DebugString(0)) glog.V(2).Infoln(it.DebugString(0))
for { for {
// TODO(barakmich): Better halting. // TODO(barakmich): Better halting.
if ses.doHalt { select {
case <-ses.kill:
return return
default:
} }
_, ok := graph.Next(it) _, ok := graph.Next(it)
if !ok { if !ok {
@ -260,8 +272,10 @@ func runIteratorOnSession(it graph.Iterator, ses *Session) {
break break
} }
for it.NextResult() == true { for it.NextResult() == true {
if ses.doHalt { select {
case <-ses.kill:
return return
default:
} }
tags := make(map[string]graph.Value) tags := make(map[string]graph.Value)
it.TagResults(tags) it.TagResults(tags)

View file

@ -18,6 +18,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sort" "sort"
"sync"
"time" "time"
"github.com/robertkrimen/otto" "github.com/robertkrimen/otto"
@ -30,6 +31,7 @@ type Session struct {
ts graph.TripleStore ts graph.TripleStore
currentChannel chan interface{} currentChannel chan interface{}
env *otto.Otto env *otto.Otto
envLock sync.Mutex
debug bool debug bool
limit int limit int
count int count int
@ -38,7 +40,7 @@ type Session struct {
queryShape map[string]interface{} queryShape map[string]interface{}
err error err error
script *otto.Script script *otto.Script
doHalt bool kill chan struct{}
timeoutSec time.Duration timeoutSec time.Duration
emptyEnv *otto.Otto emptyEnv *otto.Otto
} }
@ -95,8 +97,10 @@ func (s *Session) SendResult(result *GremlinResult) bool {
if s.limit >= 0 && s.limit == s.count { if s.limit >= 0 && s.limit == s.count {
return false return false
} }
if s.doHalt { select {
case <-s.kill:
return false return false
default:
} }
if s.currentChannel != nil { if s.currentChannel != nil {
s.currentChannel <- result s.currentChannel <- result
@ -113,7 +117,7 @@ func (s *Session) SendResult(result *GremlinResult) bool {
var halt = errors.New("Query Timeout") var halt = errors.New("Query Timeout")
func (s *Session) runUnsafe(input interface{}) (otto.Value, error) { func (s *Session) runUnsafe(input interface{}) (otto.Value, error) {
s.doHalt = false s.kill = make(chan struct{})
defer func() { defer func() {
if caught := recover(); caught != nil { if caught := recover(); caught != nil {
if caught == halt { if caught == halt {
@ -129,7 +133,9 @@ func (s *Session) runUnsafe(input interface{}) (otto.Value, error) {
if s.timeoutSec != -1 { if s.timeoutSec != -1 {
go func() { go func() {
time.Sleep(s.timeoutSec * time.Second) // Stop after two seconds time.Sleep(s.timeoutSec * time.Second) // Stop after two seconds
s.doHalt = true close(s.kill)
s.envLock.Lock()
defer s.envLock.Unlock()
if s.env != nil { if s.env != nil {
s.env.Interrupt <- func() { s.env.Interrupt <- func() {
panic(halt) panic(halt)
@ -139,6 +145,8 @@ func (s *Session) runUnsafe(input interface{}) (otto.Value, error) {
}() }()
} }
s.envLock.Lock()
defer s.envLock.Unlock()
return s.env.Run(input) // Here be dragons (risky code) return s.env.Run(input) // Here be dragons (risky code)
} }
@ -166,7 +174,9 @@ func (s *Session) ExecInput(input string, out chan interface{}, limit int) {
} }
s.currentChannel = nil s.currentChannel = nil
s.script = nil s.script = nil
s.envLock.Lock()
s.env = s.emptyEnv s.env = s.emptyEnv
s.envLock.Unlock()
return return
} }
@ -256,10 +266,12 @@ func (ses *Session) GetJson() ([]interface{}, error) {
if ses.err != nil { if ses.err != nil {
return nil, ses.err return nil, ses.err
} }
if ses.doHalt { select {
case <-ses.kill:
return nil, halt return nil, halt
default:
return ses.dataOutput, nil
} }
return ses.dataOutput, nil
} }
func (ses *Session) ClearJson() { func (ses *Session) ClearJson() {