141 lines
2.9 KiB
Go
141 lines
2.9 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/barakmich/go_raylet/ray_rpc"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type WorkstreamConnection = ray_rpc.RayletWorkerConnection_WorkstreamServer
|
|
|
|
type WorkerPool interface {
|
|
ray_rpc.RayletWorkerConnectionServer
|
|
Schedule(*ray_rpc.Work) error
|
|
Close() error
|
|
Finish(*ray_rpc.WorkStatus) error
|
|
Deregister(interface{}) error
|
|
}
|
|
|
|
type SimpleRRWorkerPool struct {
|
|
sync.Mutex
|
|
workers []*SimpleWorker
|
|
store ObjectStore
|
|
offset int
|
|
}
|
|
|
|
type SimpleWorker struct {
|
|
workChan chan *ray_rpc.Work
|
|
clientConn WorkstreamConnection
|
|
pool WorkerPool
|
|
}
|
|
|
|
func NewRoundRobinWorkerPool(obj ObjectStore) *SimpleRRWorkerPool {
|
|
return &SimpleRRWorkerPool{
|
|
store: obj,
|
|
}
|
|
}
|
|
|
|
func (wp *SimpleRRWorkerPool) Workstream(conn WorkstreamConnection) error {
|
|
wp.Lock()
|
|
worker := &SimpleWorker{
|
|
workChan: make(chan *ray_rpc.Work),
|
|
clientConn: conn,
|
|
pool: wp,
|
|
}
|
|
wp.workers = append(wp.workers, worker)
|
|
wp.Unlock()
|
|
err := worker.Main()
|
|
wp.Deregister(worker)
|
|
return err
|
|
}
|
|
|
|
func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {
|
|
wp.Lock()
|
|
defer wp.Unlock()
|
|
if len(wp.workers) == 0 {
|
|
return errors.New("No workers available, try again later")
|
|
}
|
|
zap.S().Info("Sending work to worker", wp.offset)
|
|
wp.workers[wp.offset].workChan <- work
|
|
wp.offset++
|
|
if wp.offset == len(wp.workers) {
|
|
wp.offset = 0
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (wp *SimpleRRWorkerPool) Finish(status *ray_rpc.WorkStatus) error {
|
|
if status.Status != ray_rpc.COMPLETE {
|
|
panic("todo: Only call Finish on successfully completed work")
|
|
}
|
|
id := deserializeObjectID(status.FinishedTicket.ReturnId)
|
|
return wp.store.PutObject(&Object{id, status.CompleteData})
|
|
}
|
|
|
|
func (wp *SimpleRRWorkerPool) Close() error {
|
|
wp.Lock()
|
|
defer wp.Unlock()
|
|
for _, w := range wp.workers {
|
|
close(w.workChan)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (wp *SimpleRRWorkerPool) Deregister(ptr interface{}) error {
|
|
wp.Lock()
|
|
defer wp.Unlock()
|
|
fmt.Println("Deregistering worker")
|
|
worker := ptr.(*SimpleWorker)
|
|
found := false
|
|
for i, w := range wp.workers {
|
|
if w == worker {
|
|
wp.workers = append(wp.workers[:i], wp.workers[i+1:]...)
|
|
if wp.offset == len(wp.workers) {
|
|
wp.offset = 0
|
|
}
|
|
close(worker.workChan)
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
panic("Trying to deregister a worker that was never created")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (w *SimpleWorker) Main() error {
|
|
sentinel, err := w.clientConn.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if sentinel.Status != ray_rpc.READY {
|
|
return errors.New("Sent wrong sentinel? Closing...")
|
|
}
|
|
fmt.Println("New worker:", sentinel.ErrorMsg)
|
|
go func() {
|
|
for work := range w.workChan {
|
|
fmt.Println("sending work")
|
|
err = w.clientConn.Send(work)
|
|
if err != nil {
|
|
fmt.Println("Error sending:", err)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
for {
|
|
result, err := w.clientConn.Recv()
|
|
if err != nil {
|
|
fmt.Println("Error on channel:", err)
|
|
return err
|
|
}
|
|
err = w.pool.Finish(result)
|
|
if err != nil {
|
|
fmt.Println("Error finishing:", err)
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|