# from rpc import ray_client_pb2 # from rpc import ray_client_pb2_grpc import ray.core.generated.ray_client_pb2 as ray_client_pb2 import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc import queue import sys import grpc import cloudpickle import threading import concurrent.futures from ray.experimental.client import ray from ray.experimental.client.common import ClientObjectRef class Worker: def __init__(self, conn_str): self.channel = grpc.insecure_channel(conn_str) self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel) self.send_queue = None self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3) self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) def begin(self): self.send_queue = queue.Queue() work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None)) start = ray_client_pb2.WorkStatus( status = ray_client_pb2.WorkStatus.StatusCode.READY, error_msg = "python_worker", ) self.send_queue.put(start) for work in work_stream: print("Got work") task = work.task if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION: send_queue.put(ray_client_pb2.WorkStatus( status = ray_client_pb2.WorkStatus.StatusCode.ERROR, error_msg = "unimplemented", )) continue args = self.decode_args(task) func = self.get(task.payload_id) #self.pool.submit(self.run_and_return, func, args, work.ticket) t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket)) t.start() def run_and_return(self, func, args, ticket): res = func(*args) out_data = cloudpickle.dumps(res) self.send_queue.put(ray_client_pb2.WorkStatus( status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE, complete_data = out_data, finished_ticket = ticket, )) # def get(self, id_bytes): # data = self.server.GetObject(ray_client_pb2.GetRequest( # id = id_bytes, # )) # return cloudpickle.loads(data.data) def get(self, id_bytes): return ray.get(ClientObjectRef(id_bytes)) def decode_args(self, task): out = [] for arg in task.args: t = self.convert_from_arg(arg) out.append(t) return out def convert_from_arg(self, pb): if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: return self.get(pb.reference_id) elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: return cloudpickle.loads(pb.data) raise Exception("convert_from_arg: Uncovered locality enum") def main(): worker = Worker(sys.argv[1]) ray.connect(sys.argv[1], stub=worker.server) worker.begin() print("Shutting down...") worker.channel.close() if __name__ == "__main__": main()