67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
from rpc import ray_client_pb2
|
|
from rpc import ray_client_pb2_grpc
|
|
import queue
|
|
|
|
|
|
class Worker:
|
|
def __init__(self, conn_str):
|
|
self.channel = grpc.insecure_channel(conn_str)
|
|
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
|
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
|
|
|
def begin(self):
|
|
send_queue = queue.SimpleQueue()
|
|
work_stream = self.worker_stub.Workstream(iter(send_queue.get, None))
|
|
start = ray_client_pb2.WorkStatus(
|
|
status = ray_client_pb2.WorkStatus.StatusCode.READY
|
|
)
|
|
send_queue.put(start)
|
|
for work in work_stream:
|
|
task = work.task
|
|
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
|
send_queue.put(ray_client_pb2.WorkStatus(
|
|
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
|
error_msg = "unimplemented",
|
|
))
|
|
continue
|
|
args = self.decode_args(task)
|
|
func_data = self.get(task.payload_id)
|
|
func = cloudpickle.loads(func_data)
|
|
res = func(*args)
|
|
out_data = cloudpickle.dumps(res)
|
|
send_queue.put(ray_client_pb2.WorkStatus(
|
|
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
|
complete_data = out_data,
|
|
finished_ticket = work.ticket,
|
|
))
|
|
|
|
def get(self, id_bytes):
|
|
data = self.server.GetObject(ray_client_pb2.GetRequest(
|
|
id = id_bytes,
|
|
))
|
|
return data.data
|
|
|
|
def decode_args(self, task):
|
|
out = []
|
|
for arg in arg_list:
|
|
t = self.convert_from_arg(arg)
|
|
out.append(t)
|
|
return out
|
|
|
|
def convert_from_arg(self, pb):
|
|
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
|
return self.get(pb.reference_id)
|
|
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
|
return cloudpickle.loads(pb.data)
|
|
raise Exception("convert_from_arg: Uncovered locality enum")
|
|
|
|
|
|
def main():
|
|
worker = Worker(os.args[1])
|
|
worker.begin()
|
|
print("Shutting down...")
|
|
worker.channel.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|