diff --git a/python/client.py b/python/client.py index e80cb46..933aff1 100644 --- a/python/client.py +++ b/python/client.py @@ -1,6 +1,7 @@ from ray import ray +import sys -ray.connect("localhost:50050") +ray.connect(sys.argv[1]) @ray.remote def plus2(x): @@ -14,4 +15,17 @@ def fact(x): return 1 return x * ray.get(fact.remote(x - 1)) -print(ray.get(fact.remote(20))) +#print(ray.get(fact.remote(20))) + +@ray.remote +def sleeper(x): + import time + time.sleep(1) + return x * 2 + +holder = [] +for i in range(20): + holder.append(sleeper.remote(i)) + +print([ray.get(x) for x in holder]) + diff --git a/python/ray_worker/__init__.py b/python/ray_worker/__init__.py index f751afc..180e2b5 100644 --- a/python/ray_worker/__init__.py +++ b/python/ray_worker/__init__.py @@ -10,6 +10,7 @@ from ray import ray from ray.common import ClientObjectRef + class Worker: def __init__(self, conn_str): self.channel = grpc.insecure_channel(conn_str) @@ -37,9 +38,9 @@ class Worker: continue args = self.decode_args(task) func = self.get(task.payload_id) - #self.pool.submit(self.run_and_return, func, args, work.ticket) - t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket)) - t.start() + self.pool.submit(self.run_and_return, func, args, work.ticket) + #t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket)) + #t.start() def run_and_return(self, func, args, ticket): @@ -50,6 +51,7 @@ class Worker: complete_data = out_data, finished_ticket = ticket, )) + print("Finished Work") # def get(self, id_bytes): # data = self.server.GetObject(ray_client_pb2.GetRequest( diff --git a/raylet_grpc.go b/raylet_grpc.go index a4e0b20..95f39a7 100644 --- a/raylet_grpc.go +++ b/raylet_grpc.go @@ -59,6 +59,7 @@ func (r *Raylet) Workstream(conn WorkstreamConnection) error { workChan: make(chan *ray_rpc.Work), clientConn: conn, pool: r.Workers, + max: 3, } r.Workers.Register(worker) err := worker.Run() diff --git a/web/raylet.js b/web/raylet.js index f36a5de..cd9dfe1 100644 --- a/web/raylet.js +++ b/web/raylet.js @@ -17,25 +17,40 @@ languagePluginLoader.then(() => { wsprotocol = "wss:" } var wspath = wsprotocol + "//" + window.location.host + "/api/ws" - var c = new WebSocket(wspath) - c.onmessage = function(msg) { - var workText = workTerms[Math.floor(Math.random() * workTerms.length)]; - $("#output").append("
" + workText + "...
") - pyodide.globals.torun = msg.data - pyodide.runPythonAsync("exec_work(torun)").then((res) => { - $("#output").append("Did work! 👏
") - c.send(res) - }) - } - c.onopen = function() { - $("#status").text("Status: connected!") - c.send(JSON.stringify({ - status: 2, - error_msg: "WebsocketWorker" - })) - } - c.onclose = function() { - $("#status").text("Status: disconnected") - } -}) -}) + function connect() { + var c = new WebSocket(wspath) + + c.onopen = function() { + $("#status").text("Status: connected!") + c.send(JSON.stringify({ + status: 2, + error_msg: "WebsocketWorker" + })) + } + + c.onmessage = function(msg) { + var workText = workTerms[Math.floor(Math.random() * workTerms.length)]; + $("#output").append("" + workText + "...
") + pyodide.globals.torun = msg.data + pyodide.runPythonAsync("exec_work(torun)").then((res) => { + $("#output").append("Did work! 👏
") + c.send(res) + }) + }; + + c.onclose = function(e) { + $("#status").text("Status: disconnected. reconnecting...") + console.log('Socket is closed. Reconnect will be attempted in 1 second.', e.reason); + setTimeout(function() { + connect(); + }, 500); + }; + + c.onerror = function(err) { + console.error('Socket encountered error: ', err.message, 'Closing socket'); + c.close(); + }; + }; + + connect(); +}) }) diff --git a/worker.go b/worker.go index 6c90d5f..fc35fb6 100644 --- a/worker.go +++ b/worker.go @@ -20,10 +20,12 @@ type SimpleWorker struct { workChan chan *ray_rpc.Work clientConn WorkstreamConnection pool WorkerPool + max int + curr int } func (s *SimpleWorker) Schedulable() bool { - return true + return s.curr < s.max } func (s *SimpleWorker) AssignWork(work *ray_rpc.Work) error { @@ -48,6 +50,7 @@ func (w *SimpleWorker) Run() error { go func() { for work := range w.workChan { zap.S().Debug("Sending work") + w.curr++ err = w.clientConn.Send(work) if err != nil { zap.S().Error("Error sending:", err) @@ -61,6 +64,7 @@ func (w *SimpleWorker) Run() error { zap.S().Error("Error on channel:", err) return err } + w.curr-- err = w.pool.Finish(result) if err != nil { zap.S().Error("Error finishing:", err) diff --git a/worker_pool.go b/worker_pool.go index 71cb137..816545b 100644 --- a/worker_pool.go +++ b/worker_pool.go @@ -21,6 +21,7 @@ type SimpleRRWorkerPool struct { workers []Worker store ObjectStore offset int + pending []chan bool } func NewRoundRobinWorkerPool(obj ObjectStore) *SimpleRRWorkerPool { @@ -57,16 +58,27 @@ func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error { wp.offset = 0 } if wp.offset == origOffset && !done { - return errors.New("No workers schedulable") + c := make(chan bool) + wp.pending = append(wp.pending, c) + wp.Unlock() + <-c + wp.Lock() } } return nil } func (wp *SimpleRRWorkerPool) Finish(status *ray_rpc.WorkStatus) error { + wp.Lock() + defer wp.Unlock() if status.Status != ray_rpc.COMPLETE { panic("todo: Only call Finish on successfully completed work") } + if len(wp.pending) != 0 { + c := wp.pending[0] + wp.pending = wp.pending[1:] + close(c) + } id := deserializeObjectID(status.FinishedTicket.ReturnId) return wp.store.PutObject(&Object{id, status.CompleteData}) }