package main import ( "errors" "fmt" "sync" "github.com/barakmich/go_raylet/ray_rpc" "go.uber.org/zap" ) type WorkstreamConnection = ray_rpc.RayletWorkerConnection_WorkstreamServer type WorkerPool interface { ray_rpc.RayletWorkerConnectionServer Schedule(*ray_rpc.Work) error Close() error Finish(*ray_rpc.WorkStatus) error Deregister(interface{}) error } type SimpleRRWorkerPool struct { sync.Mutex workers []*SimpleWorker store ObjectStore offset int } type SimpleWorker struct { workChan chan *ray_rpc.Work clientConn WorkstreamConnection pool WorkerPool } func NewRoundRobinWorkerPool(obj ObjectStore) *SimpleRRWorkerPool { return &SimpleRRWorkerPool{ store: obj, } } func (wp *SimpleRRWorkerPool) Workstream(conn WorkstreamConnection) error { wp.Lock() worker := &SimpleWorker{ workChan: make(chan *ray_rpc.Work), clientConn: conn, pool: wp, } wp.workers = append(wp.workers, worker) wp.Unlock() err := worker.Main() wp.Deregister(worker) return err } func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error { wp.Lock() defer wp.Unlock() if len(wp.workers) == 0 { return errors.New("No workers available, try again later") } zap.S().Info("Sending work to worker", wp.offset) wp.workers[wp.offset].workChan <- work wp.offset++ if wp.offset == len(wp.workers) { wp.offset = 0 } return nil } func (wp *SimpleRRWorkerPool) Finish(status *ray_rpc.WorkStatus) error { if status.Status != ray_rpc.COMPLETE { panic("todo: Only call Finish on successfully completed work") } id := deserializeObjectID(status.FinishedTicket.ReturnId) return wp.store.PutObject(&Object{id, status.CompleteData}) } func (wp *SimpleRRWorkerPool) Close() error { wp.Lock() defer wp.Unlock() for _, w := range wp.workers { close(w.workChan) } return nil } func (wp *SimpleRRWorkerPool) Deregister(ptr interface{}) error { wp.Lock() defer wp.Unlock() fmt.Println("Deregistering worker") worker := ptr.(*SimpleWorker) found := false for i, w := range wp.workers { if w == worker { wp.workers = append(wp.workers[:i], wp.workers[i+1:]...) if wp.offset == len(wp.workers) { wp.offset = 0 } close(worker.workChan) found = true } } if !found { panic("Trying to deregister a worker that was never created") } return nil } func (w *SimpleWorker) Main() error { sentinel, err := w.clientConn.Recv() if err != nil { return err } if sentinel.Status != ray_rpc.READY { return errors.New("Sent wrong sentinel? Closing...") } fmt.Println("New worker:", sentinel.ErrorMsg) go func() { for work := range w.workChan { fmt.Println("sending work") err = w.clientConn.Send(work) if err != nil { fmt.Println("Error sending:", err) return } } }() for { result, err := w.clientConn.Recv() if err != nil { fmt.Println("Error on channel:", err) return err } err = w.pool.Finish(result) if err != nil { fmt.Println("Error finishing:", err) return err } } return nil }