From fe9ca5ffcc17f3c0737069da089182a853ad793f Mon Sep 17 00:00:00 2001 From: kortschak Date: Sat, 2 Aug 2014 22:01:02 +0930 Subject: [PATCH] Fix data race in gremlin timeout handling Fixes issue #95. --- query/gremlin/finals.go | 28 +++++++++++++++++++++------- query/gremlin/session.go | 24 ++++++++++++++++++------ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/query/gremlin/finals.go b/query/gremlin/finals.go index 022a394..a16e661 100644 --- a/query/gremlin/finals.go +++ b/query/gremlin/finals.go @@ -148,8 +148,10 @@ func runIteratorToArray(it graph.Iterator, ses *Session, limit int) []map[string count := 0 it, _ = it.Optimize() for { - if ses.doHalt { + select { + case <-ses.kill: return nil + default: } _, ok := graph.Next(it) if !ok { @@ -163,8 +165,10 @@ func runIteratorToArray(it graph.Iterator, ses *Session, limit int) []map[string break } for it.NextResult() == true { - if ses.doHalt { + select { + case <-ses.kill: return nil + default: } tags := make(map[string]graph.Value) it.TagResults(tags) @@ -184,8 +188,10 @@ func runIteratorToArrayNoTags(it graph.Iterator, ses *Session, limit int) []stri count := 0 it, _ = it.Optimize() for { - if ses.doHalt { + select { + case <-ses.kill: return nil + default: } val, ok := graph.Next(it) if !ok { @@ -205,8 +211,10 @@ func runIteratorWithCallback(it graph.Iterator, ses *Session, callback otto.Valu count := 0 it, _ = it.Optimize() for { - if ses.doHalt { + select { + case <-ses.kill: return + default: } _, ok := graph.Next(it) if !ok { @@ -221,8 +229,10 @@ func runIteratorWithCallback(it graph.Iterator, ses *Session, callback otto.Valu break } for it.NextResult() == true { - if ses.doHalt { + select { + case <-ses.kill: return + default: } tags := make(map[string]graph.Value) it.TagResults(tags) @@ -246,8 +256,10 @@ func runIteratorOnSession(it graph.Iterator, ses *Session) { glog.V(2).Infoln(it.DebugString(0)) for { // TODO(barakmich): Better halting. - if ses.doHalt { + select { + case <-ses.kill: return + default: } _, ok := graph.Next(it) if !ok { @@ -260,8 +272,10 @@ func runIteratorOnSession(it graph.Iterator, ses *Session) { break } for it.NextResult() == true { - if ses.doHalt { + select { + case <-ses.kill: return + default: } tags := make(map[string]graph.Value) it.TagResults(tags) diff --git a/query/gremlin/session.go b/query/gremlin/session.go index 00532b8..e100b43 100644 --- a/query/gremlin/session.go +++ b/query/gremlin/session.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "sort" + "sync" "time" "github.com/robertkrimen/otto" @@ -30,6 +31,7 @@ type Session struct { ts graph.TripleStore currentChannel chan interface{} env *otto.Otto + envLock sync.Mutex debug bool limit int count int @@ -38,7 +40,7 @@ type Session struct { queryShape map[string]interface{} err error script *otto.Script - doHalt bool + kill chan struct{} timeoutSec time.Duration emptyEnv *otto.Otto } @@ -95,8 +97,10 @@ func (s *Session) SendResult(result *GremlinResult) bool { if s.limit >= 0 && s.limit == s.count { return false } - if s.doHalt { + select { + case <-s.kill: return false + default: } if s.currentChannel != nil { s.currentChannel <- result @@ -113,7 +117,7 @@ func (s *Session) SendResult(result *GremlinResult) bool { var halt = errors.New("Query Timeout") func (s *Session) runUnsafe(input interface{}) (otto.Value, error) { - s.doHalt = false + s.kill = make(chan struct{}) defer func() { if caught := recover(); caught != nil { if caught == halt { @@ -129,7 +133,9 @@ func (s *Session) runUnsafe(input interface{}) (otto.Value, error) { if s.timeoutSec != -1 { go func() { time.Sleep(s.timeoutSec * time.Second) // Stop after two seconds - s.doHalt = true + close(s.kill) + s.envLock.Lock() + defer s.envLock.Unlock() if s.env != nil { s.env.Interrupt <- func() { panic(halt) @@ -139,6 +145,8 @@ func (s *Session) runUnsafe(input interface{}) (otto.Value, error) { }() } + s.envLock.Lock() + defer s.envLock.Unlock() return s.env.Run(input) // Here be dragons (risky code) } @@ -166,7 +174,9 @@ func (s *Session) ExecInput(input string, out chan interface{}, limit int) { } s.currentChannel = nil s.script = nil + s.envLock.Lock() s.env = s.emptyEnv + s.envLock.Unlock() return } @@ -256,10 +266,12 @@ func (ses *Session) GetJson() ([]interface{}, error) { if ses.err != nil { return nil, ses.err } - if ses.doHalt { + select { + case <-ses.kill: return nil, halt + default: + return ses.dataOutput, nil } - return ses.dataOutput, nil } func (ses *Session) ClearJson() {