hermetically seal everything, bring it up to date
This commit is contained in:
parent
efd3d65269
commit
03beb6970c
12 changed files with 723 additions and 77 deletions
74
python/ray_worker/__init__.py
Normal file
74
python/ray_worker/__init__.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
import rpc.ray_client_pb2 as ray_client_pb2
|
||||
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
import queue
|
||||
import grpc
|
||||
import cloudpickle
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from ray import ray
|
||||
from ray.common import ClientObjectRef
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str):
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
||||
self.send_queue = None
|
||||
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
def begin(self):
|
||||
self.send_queue = queue.Queue()
|
||||
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
|
||||
start = ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.READY,
|
||||
error_msg = "python_worker",
|
||||
)
|
||||
self.send_queue.put(start)
|
||||
for work in work_stream:
|
||||
print("Got work")
|
||||
task = work.task
|
||||
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
||||
send_queue.put(ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
||||
error_msg = "unimplemented",
|
||||
))
|
||||
continue
|
||||
args = self.decode_args(task)
|
||||
func = self.get(task.payload_id)
|
||||
#self.pool.submit(self.run_and_return, func, args, work.ticket)
|
||||
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
|
||||
t.start()
|
||||
|
||||
|
||||
def run_and_return(self, func, args, ticket):
|
||||
res = func(*args)
|
||||
out_data = cloudpickle.dumps(res)
|
||||
self.send_queue.put(ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
||||
complete_data = out_data,
|
||||
finished_ticket = ticket,
|
||||
))
|
||||
|
||||
# def get(self, id_bytes):
|
||||
# data = self.server.GetObject(ray_client_pb2.GetRequest(
|
||||
# id = id_bytes,
|
||||
# ))
|
||||
# return cloudpickle.loads(data.data)
|
||||
def get(self, id_bytes):
|
||||
return ray.get(ClientObjectRef(id_bytes))
|
||||
|
||||
def decode_args(self, task):
|
||||
out = []
|
||||
for arg in task.args:
|
||||
t = self.convert_from_arg(arg)
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
def convert_from_arg(self, pb):
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return self.get(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
Loading…
Add table
Add a link
Reference in a new issue