diff --git a/query/gremlin/session.go b/query/gremlin/session.go index 9e329c2..d1c84b5 100644 --- a/query/gremlin/session.go +++ b/query/gremlin/session.go @@ -94,8 +94,11 @@ func (s *Session) SendResult(r *Result) bool { if s.limit >= 0 && s.limit == s.count { return false } + s.envLock.Lock() + kill := s.kill + s.envLock.Unlock() select { - case <-s.kill: + case <-kill: return false default: } @@ -112,7 +115,6 @@ func (s *Session) SendResult(r *Result) bool { } func (s *Session) runUnsafe(input interface{}) (otto.Value, error) { - s.kill = make(chan struct{}) defer func() { if r := recover(); r != nil { if r == ErrKillTimeout { @@ -126,25 +128,41 @@ func (s *Session) runUnsafe(input interface{}) (otto.Value, error) { // Use buffered chan to prevent blocking. s.env.Interrupt = make(chan func(), 1) + ready := make(chan struct{}) + done := make(chan struct{}) if s.timeout >= 0 { go func() { time.Sleep(s.timeout) - close(s.kill) - s.envLock.Lock() - defer s.envLock.Unlock() - if s.env != nil { - s.env.Interrupt <- func() { - panic(ErrKillTimeout) + <-ready + select { + case <-done: + return + default: + close(s.kill) + s.envLock.Lock() + defer s.envLock.Unlock() + s.kill = nil + if s.env != nil { + s.env.Interrupt <- func() { + panic(ErrKillTimeout) + } + s.env = s.emptyEnv } - s.env = s.emptyEnv + return } }() } s.envLock.Lock() env := s.env + if s.kill == nil { + s.kill = make(chan struct{}) + } s.envLock.Unlock() - return env.Run(input) + close(ready) + out, err := env.Run(input) + close(done) + return out, err } func (s *Session) ExecInput(input string, out chan interface{}, limit int) { @@ -255,8 +273,11 @@ func (s *Session) GetJson() ([]interface{}, error) { if s.err != nil { return nil, s.err } + s.envLock.Lock() + kill := s.kill + s.envLock.Unlock() select { - case <-s.kill: + case <-kill: return nil, ErrKillTimeout default: return s.dataOutput, nil