implement through workers

This commit is contained in:
Barak Michener 2020-12-04 04:36:37 +00:00
parent 39385bc8b2
commit 46524832de
13 changed files with 2213 additions and 91 deletions

1
python/rpc/mk_proto.sh Executable file
View file

@ -0,0 +1 @@
python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,226 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import ray_client_pb2 as ray__client__pb2
class RayletDriverStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.GetObject = channel.unary_unary(
'/ray.rpc.RayletDriver/GetObject',
request_serializer=ray__client__pb2.GetRequest.SerializeToString,
response_deserializer=ray__client__pb2.GetResponse.FromString,
)
self.PutObject = channel.unary_unary(
'/ray.rpc.RayletDriver/PutObject',
request_serializer=ray__client__pb2.PutRequest.SerializeToString,
response_deserializer=ray__client__pb2.PutResponse.FromString,
)
self.WaitObject = channel.unary_unary(
'/ray.rpc.RayletDriver/WaitObject',
request_serializer=ray__client__pb2.WaitRequest.SerializeToString,
response_deserializer=ray__client__pb2.WaitResponse.FromString,
)
self.Schedule = channel.unary_unary(
'/ray.rpc.RayletDriver/Schedule',
request_serializer=ray__client__pb2.ClientTask.SerializeToString,
response_deserializer=ray__client__pb2.ClientTaskTicket.FromString,
)
class RayletDriverServicer(object):
"""Missing associated documentation comment in .proto file."""
def GetObject(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def PutObject(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def WaitObject(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Schedule(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_RayletDriverServicer_to_server(servicer, server):
rpc_method_handlers = {
'GetObject': grpc.unary_unary_rpc_method_handler(
servicer.GetObject,
request_deserializer=ray__client__pb2.GetRequest.FromString,
response_serializer=ray__client__pb2.GetResponse.SerializeToString,
),
'PutObject': grpc.unary_unary_rpc_method_handler(
servicer.PutObject,
request_deserializer=ray__client__pb2.PutRequest.FromString,
response_serializer=ray__client__pb2.PutResponse.SerializeToString,
),
'WaitObject': grpc.unary_unary_rpc_method_handler(
servicer.WaitObject,
request_deserializer=ray__client__pb2.WaitRequest.FromString,
response_serializer=ray__client__pb2.WaitResponse.SerializeToString,
),
'Schedule': grpc.unary_unary_rpc_method_handler(
servicer.Schedule,
request_deserializer=ray__client__pb2.ClientTask.FromString,
response_serializer=ray__client__pb2.ClientTaskTicket.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'ray.rpc.RayletDriver', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class RayletDriver(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def GetObject(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/GetObject',
ray__client__pb2.GetRequest.SerializeToString,
ray__client__pb2.GetResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def PutObject(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/PutObject',
ray__client__pb2.PutRequest.SerializeToString,
ray__client__pb2.PutResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def WaitObject(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/WaitObject',
ray__client__pb2.WaitRequest.SerializeToString,
ray__client__pb2.WaitResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Schedule(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/Schedule',
ray__client__pb2.ClientTask.SerializeToString,
ray__client__pb2.ClientTaskTicket.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
class RayletWorkerConnectionStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Workstream = channel.stream_stream(
'/ray.rpc.RayletWorkerConnection/Workstream',
request_serializer=ray__client__pb2.WorkStatus.SerializeToString,
response_deserializer=ray__client__pb2.Work.FromString,
)
class RayletWorkerConnectionServicer(object):
"""Missing associated documentation comment in .proto file."""
def Workstream(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_RayletWorkerConnectionServicer_to_server(servicer, server):
rpc_method_handlers = {
'Workstream': grpc.stream_stream_rpc_method_handler(
servicer.Workstream,
request_deserializer=ray__client__pb2.WorkStatus.FromString,
response_serializer=ray__client__pb2.Work.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'ray.rpc.RayletWorkerConnection', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class RayletWorkerConnection(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Workstream(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/ray.rpc.RayletWorkerConnection/Workstream',
ray__client__pb2.WorkStatus.SerializeToString,
ray__client__pb2.Work.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

67
python/worker.py Normal file
View file

@ -0,0 +1,67 @@
from rpc import ray_client_pb2
from rpc import ray_client_pb2_grpc
import queue
class Worker:
def __init__(self, conn_str):
self.channel = grpc.insecure_channel(conn_str)
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
def begin(self):
send_queue = queue.SimpleQueue()
work_stream = self.worker_stub.Workstream(iter(send_queue.get, None))
start = ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.READY
)
send_queue.put(start)
for work in work_stream:
task = work.task
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
error_msg = "unimplemented",
))
continue
args = self.decode_args(task)
func_data = self.get(task.payload_id)
func = cloudpickle.loads(func_data)
res = func(*args)
out_data = cloudpickle.dumps(res)
send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
complete_data = out_data,
finished_ticket = work.ticket,
))
def get(self, id_bytes):
data = self.server.GetObject(ray_client_pb2.GetRequest(
id = id_bytes,
))
return data.data
def decode_args(self, task):
out = []
for arg in arg_list:
t = self.convert_from_arg(arg)
out.append(t)
return out
def convert_from_arg(self, pb):
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return self.get(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
raise Exception("convert_from_arg: Uncovered locality enum")
def main():
worker = Worker(os.args[1])
worker.begin()
print("Shutting down...")
worker.channel.close()
if __name__ == "__main__":
main()