"""This file includes the Worker class which sits on the client side. It implements the Ray API functions that are forwarded through grpc calls to the server. """ import inspect import logging from typing import List from typing import Tuple import cloudpickle import grpc import rpc.ray_client_pb2 as ray_client_pb2 import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc from ray.common import convert_to_arg from ray.common import ClientObjectRef from ray.common import ClientActorClass from ray.common import ClientRemoteFunc logger = logging.getLogger(__name__) class Worker: def __init__(self, conn_str: str = "", secure: bool = False, metadata: List[Tuple[str, str]] = None, stub=None): """Initializes the worker side grpc client. Args: stub: custom grpc stub. secure: whether to use SSL secure channel or not. metadata: additional metadata passed in the grpc request headers. """ self.metadata = metadata if stub is None: if secure: credentials = grpc.ssl_channel_credentials() self.channel = grpc.secure_channel(conn_str, credentials) else: self.channel = grpc.insecure_channel(conn_str) self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) else: self.server = stub def get(self, ids): to_get = [] single = False if isinstance(ids, list): to_get = [x.id for x in ids] elif isinstance(ids, ClientObjectRef): to_get = [ids.id] single = True else: raise Exception("Can't get something that's not a " "list of IDs or just an ID: %s" % type(ids)) out = [self._get(x) for x in to_get] if single: out = out[0] return out def _get(self, id: bytes): req = ray_client_pb2.GetRequest(id=id) data = self.server.GetObject(req, metadata=self.metadata) if not data.valid: raise Exception( "Client GetObject returned invalid data: id invalid?") return cloudpickle.loads(data.data) def put(self, vals): to_put = [] single = False if isinstance(vals, list): to_put = vals else: single = True to_put.append(vals) out = [self._put(x) for x in to_put] if single: out = out[0] return out def _put(self, val): data = cloudpickle.dumps(val) req = ray_client_pb2.PutRequest(data=data) resp = self.server.PutObject(req, metadata=self.metadata) return ClientObjectRef(resp.id) def wait(self, object_refs: List[ClientObjectRef], *, num_returns: int = 1, timeout: float = None ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: assert isinstance(object_refs, list) for ref in object_refs: assert isinstance(ref, ClientObjectRef) data = { "object_refs": [ cloudpickle.dumps(object_ref) for object_ref in object_refs ], "num_returns": num_returns, "timeout": timeout if timeout else -1 } req = ray_client_pb2.WaitRequest(**data) resp = self.server.WaitObject(req, metadata=self.metadata) if not resp.valid: # TODO(ameer): improve error/exceptions messages. raise Exception("Client Wait request failed. Reference invalid?") client_ready_object_ids = [ ClientObjectRef(id) for id in resp.ready_object_ids ] client_remaining_object_ids = [ ClientObjectRef(id) for id in resp.remaining_object_ids ] return (client_ready_object_ids, client_remaining_object_ids) def remote(self, function_or_class, *args, **kwargs): # TODO(barakmich): Arguments to ray.remote # get captured here. if inspect.isfunction(function_or_class): return ClientRemoteFunc(function_or_class) elif inspect.isclass(function_or_class): return ClientActorClass(function_or_class) else: raise TypeError("The @ray.remote decorator must be applied to " "either a function or to a class.") def call_remote(self, instance, kind, *args, **kwargs): task = instance._prepare_client_task() for arg in args: pb_arg = convert_to_arg(arg) task.args.append(pb_arg) logging.debug("Scheduling %s" % task) ticket = self.server.Schedule(task, metadata=self.metadata) return ClientObjectRef(ticket.return_id) def close(self): self.channel.close()