working worker

This commit is contained in:
Barak Michener 2020-12-04 20:01:46 +00:00
parent 46524832de
commit 51f92278ac
9 changed files with 194 additions and 61 deletions

View file

@ -1,22 +1,37 @@
from rpc import ray_client_pb2
from rpc import ray_client_pb2_grpc
# from rpc import ray_client_pb2
# from rpc import ray_client_pb2_grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import queue
import sys
import grpc
import cloudpickle
import threading
import concurrent.futures
from ray.experimental.client import ray
from ray.experimental.client.common import ClientObjectRef
class Worker:
def __init__(self, conn_str):
self.channel = grpc.insecure_channel(conn_str)
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
self.send_queue = None
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
def begin(self):
send_queue = queue.SimpleQueue()
work_stream = self.worker_stub.Workstream(iter(send_queue.get, None))
self.send_queue = queue.Queue()
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
start = ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.READY
status = ray_client_pb2.WorkStatus.StatusCode.READY,
error_msg = "python_worker",
)
send_queue.put(start)
self.send_queue.put(start)
for work in work_stream:
print("Got work")
task = work.task
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
send_queue.put(ray_client_pb2.WorkStatus(
@ -25,25 +40,32 @@ class Worker:
))
continue
args = self.decode_args(task)
func_data = self.get(task.payload_id)
func = cloudpickle.loads(func_data)
res = func(*args)
out_data = cloudpickle.dumps(res)
send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
complete_data = out_data,
finished_ticket = work.ticket,
))
func = self.get(task.payload_id)
#self.pool.submit(self.run_and_return, func, args, work.ticket)
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
t.start()
def get(self, id_bytes):
data = self.server.GetObject(ray_client_pb2.GetRequest(
id = id_bytes,
def run_and_return(self, func, args, ticket):
res = func(*args)
out_data = cloudpickle.dumps(res)
self.send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
complete_data = out_data,
finished_ticket = ticket,
))
return data.data
# def get(self, id_bytes):
# data = self.server.GetObject(ray_client_pb2.GetRequest(
# id = id_bytes,
# ))
# return cloudpickle.loads(data.data)
def get(self, id_bytes):
return ray.get(ClientObjectRef(id_bytes))
def decode_args(self, task):
out = []
for arg in arg_list:
for arg in task.args:
t = self.convert_from_arg(arg)
out.append(t)
return out
@ -57,7 +79,8 @@ class Worker:
def main():
worker = Worker(os.args[1])
worker = Worker(sys.argv[1])
ray.connect(sys.argv[1], stub=worker.server)
worker.begin()
print("Shutting down...")
worker.channel.close()