From 03beb6970c2117c00d36f27427f77a76b3dcc519 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Fri, 4 Dec 2020 13:00:54 -0800 Subject: [PATCH] hermetically seal everything, bring it up to date --- python/client.py | 17 ++++ python/ray/__init__.py | 107 ++++++++++++++++++++++ python/ray/api.py | 77 ++++++++++++++++ python/ray/client_app.py | 90 +++++++++++++++++++ python/ray/common.py | 205 ++++++++++++++++++++++++++++++++++++++++++ python/ray/worker.py | 142 +++++++++++++++++++++++++++++ python/ray_worker/__init__.py | 74 +++++++++++++++ python/rpc/mk_proto.sh | 6 ++ python/worker.py | 77 +--------------- raylet_grpc.go | 2 +- worker.go | 2 +- worker_pool.go | 1 + 12 files changed, 723 insertions(+), 77 deletions(-) create mode 100644 python/client.py create mode 100644 python/ray/__init__.py create mode 100644 python/ray/api.py create mode 100644 python/ray/client_app.py create mode 100644 python/ray/common.py create mode 100644 python/ray/worker.py create mode 100644 python/ray_worker/__init__.py diff --git a/python/client.py b/python/client.py new file mode 100644 index 0000000..e80cb46 --- /dev/null +++ b/python/client.py @@ -0,0 +1,17 @@ +from ray import ray + +ray.connect("localhost:50050") + +@ray.remote +def plus2(x): + return x + 2 + +print(ray.get(plus2.remote(4))) + +@ray.remote +def fact(x): + if x <= 0: + return 1 + return x * ray.get(fact.remote(x - 1)) + +print(ray.get(fact.remote(20))) diff --git a/python/ray/__init__.py b/python/ray/__init__.py new file mode 100644 index 0000000..f088770 --- /dev/null +++ b/python/ray/__init__.py @@ -0,0 +1,107 @@ +from ray.api import ClientAPI +from ray.api import APIImpl +from typing import Optional, List, Tuple +from contextlib import contextmanager + +import logging + +logger = logging.getLogger(__name__) + +# _client_api has to be external to the API stub, below. +# Otherwise, ray.remote() that contains ray.remote() +# contains a reference to the RayAPIStub, therefore a +# reference to the _client_api, and then tries to pickle +# the thing. +_client_api: Optional[APIImpl] = None + +_server_api: Optional[APIImpl] = None + +_is_server: bool = False + + +@contextmanager +def stash_api_for_tests(in_test: bool): + global _is_server + is_server = _is_server + if in_test: + _is_server = True + yield _server_api + if in_test: + _is_server = is_server + + +def _set_client_api(val: Optional[APIImpl]): + global _client_api + global _is_server + if _client_api is not None: + raise Exception("Trying to set more than one client API") + _client_api = val + _is_server = False + + +def _set_server_api(val: Optional[APIImpl]): + global _server_api + global _is_server + if _server_api is not None: + raise Exception("Trying to set more than one server API") + _server_api = val + _is_server = True + + +def reset_api(): + global _client_api + global _server_api + global _is_server + _client_api = None + _server_api = None + _is_server = False + + +def _get_client_api() -> APIImpl: + global _client_api + global _server_api + global _is_server + api = None + if _is_server: + api = _server_api + else: + api = _client_api + if api is None: + # We're inside a raylet worker + raise Exception("CoreRayAPI not supported") + return api + + +class RayAPIStub: + def connect(self, + conn_str: str, + secure: bool = False, + metadata: List[Tuple[str, str]] = None, + stub=None): + from ray.worker import Worker + _client_worker = Worker( + conn_str, secure=secure, metadata=metadata, stub=stub) + _set_client_api(ClientAPI(_client_worker)) + + def disconnect(self): + global _client_api + if _client_api is not None: + _client_api.close() + _client_api = None + + def __getattr__(self, key: str): + global _get_client_api + api = _get_client_api() + return getattr(api, key) + + +ray = RayAPIStub() + +# Someday we might add methods in this module so that someone who +# tries to `import ray_client as ray` -- as a module, instead of +# `from ray_client import ray` -- as the API stub +# still gets expected functionality. This is the way the ray package +# worked in the past. +# +# This really calls for PEP 562: https://www.python.org/dev/peps/pep-0562/ +# But until Python 3.6 is EOL, here we are. diff --git a/python/ray/api.py b/python/ray/api.py new file mode 100644 index 0000000..d5ade03 --- /dev/null +++ b/python/ray/api.py @@ -0,0 +1,77 @@ +# This file defines an interface and client-side API stub +# for referring either to the core Ray API or the same interface +# from the Ray client. +# +# In tandem with __init__.py, we want to expose an API that's +# close to `python/ray/__init__.py` but with more than one implementation. +# The stubs in __init__ should call into a well-defined interface. +# Only the core Ray API implementation should actually `import ray` +# (and thus import all the raylet worker C bindings and such). +# But to make sure that we're matching these calls, we define this API. + +from abc import ABC +from abc import abstractmethod + + +class APIImpl(ABC): + @abstractmethod + def get(self, *args, **kwargs): + pass + + @abstractmethod + def put(self, *args, **kwargs): + pass + + @abstractmethod + def wait(self, *args, **kwargs): + pass + + @abstractmethod + def remote(self, *args, **kwargs): + pass + + @abstractmethod + def call_remote(self, instance, kind: int, *args, **kwargs): + pass + + @abstractmethod + def get_actor_from_object(self, id): + pass + + @abstractmethod + def close(self, *args, **kwargs): + pass + + +class ClientAPI(APIImpl): + def __init__(self, worker): + self.worker = worker + + def get(self, *args, **kwargs): + return self.worker.get(*args, **kwargs) + + def put(self, *args, **kwargs): + return self.worker.put(*args, **kwargs) + + def wait(self, *args, **kwargs): + return self.worker.wait(*args, **kwargs) + + def remote(self, *args, **kwargs): + return self.worker.remote(*args, **kwargs) + + def call_remote(self, f, kind, *args, **kwargs): + return self.worker.call_remote(f, kind, *args, **kwargs) + + def get_actor_from_object(self, id): + raise Exception("Calling get_actor_from_object on the client side") + + def close(self, *args, **kwargs): + return self.worker.close() + + def __getattr__(self, key: str): + if not key.startswith("_"): + raise NotImplementedError( + "Not available in Ray client: `ray.{}`. This method is only " + "available within Ray remote functions and is not yet " + "implemented in the client API.".format(key)) + return self.__getattribute__(key) diff --git a/python/ray/client_app.py b/python/ray/client_app.py new file mode 100644 index 0000000..3b99f81 --- /dev/null +++ b/python/ray/client_app.py @@ -0,0 +1,90 @@ +from ray import ray +from typing import Tuple + +ray.connect("localhost:50051") + + +@ray.remote +class HelloActor: + def __init__(self): + self.count = 0 + + def say_hello(self, whom: str) -> Tuple[str, int]: + self.count += 1 + return ("Hello " + whom, self.count) + + +actor = HelloActor.remote() +s, count = ray.get(actor.say_hello.remote("you")) +print(s, count) +assert s == "Hello you" +assert count == 1 +s, count = ray.get(actor.say_hello.remote("world")) +print(s, count) +assert s == "Hello world" +assert count == 2 + + +@ray.remote +def plus2(x): + return x + 2 + + +@ray.remote +def fact(x): + print(x, type(fact)) + if x <= 0: + return 1 + # This hits the "nested tasks" issue + # https://github.com/ray-project/ray/issues/3644 + # So we're on the right track! + return ray.get(fact.remote(x - 1)) * x + + +@ray.remote +def get_nodes(): + return ray.nodes() # Can access the full Ray API in remote methods. + + +print("Cluster nodes", ray.get(get_nodes.remote())) +print(ray.nodes()) + +objectref = ray.put("hello world") + +# `ClientObjectRef(...)` +print(objectref) + +# `hello world` +print(ray.get(objectref)) + +ref2 = plus2.remote(234) +# `ClientObjectRef(...)` +print(ref2) +# `236` +print(ray.get(ref2)) + +ref3 = fact.remote(20) +# `ClientObjectRef(...)` +print(ref3) +# `2432902008176640000` +print(ray.get(ref3)) + +# Reuse the cached ClientRemoteFunc object +ref4 = fact.remote(5) +# `120` +print(ray.get(ref4)) + +ref5 = fact.remote(10) + +print([ref2, ref3, ref4, ref5]) +# should return ref2, ref3, ref4 +res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3) +print(res) +assert [ref2, ref3, ref4] == res[0] +assert [ref5] == res[1] + +# should return ref2, ref3, ref4, ref5 +res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4) +print(res) +assert [ref2, ref3, ref4, ref5] == res[0] +assert [] == res[1] diff --git a/python/ray/common.py b/python/ray/common.py new file mode 100644 index 0000000..9aefbae --- /dev/null +++ b/python/ray/common.py @@ -0,0 +1,205 @@ +import rpc.ray_client_pb2 as ray_client_pb2 +from ray import ray +from typing import Any +from typing import Dict +import cloudpickle + + +class ClientBaseRef: + def __init__(self, id): + self.id = id + + def __repr__(self): + return "%s(%s)" % ( + type(self).__name__, + self.id.hex(), + ) + + def __eq__(self, other): + return self.id == other.id + + def binary(self): + return self.id + + +class ClientObjectRef(ClientBaseRef): + pass + + +class ClientActorNameRef(ClientBaseRef): + pass + + +class ClientRemoteFunc: + def __init__(self, f): + self._func = f + self._name = f.__name__ + self.id = None + self._ref = None + self._raylet_remote = None + + def __call__(self, *args, **kwargs): + raise TypeError(f"Remote function cannot be called directly. " + "Use {self._name}.remote method instead") + + def remote(self, *args, **kwargs): + return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args, + **kwargs) + + def _get_ray_remote_impl(self): + if self._raylet_remote is None: + self._raylet_remote = ray.remote(self._func) + return self._raylet_remote + + def __repr__(self): + return "ClientRemoteFunc(%s, %s)" % (self._name, self.id) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + if self._ref is None: + self._ref = ray.put(self._func) + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.FUNCTION + task.name = self._name + task.payload_id = self._ref.id + return task + + +class ClientActorClass: + def __init__(self, actor_cls): + self.actor_cls = actor_cls + self._name = actor_cls.__name__ + self._ref = None + self._raylet_remote = None + + def __call__(self, *args, **kwargs): + raise TypeError(f"Remote actor cannot be instantiated directly. " + "Use {self._name}.remote() instead") + + def __getstate__(self) -> Dict: + state = { + "actor_cls": self.actor_cls, + "_name": self._name, + "_ref": self._ref, + } + return state + + def __setstate__(self, state: Dict) -> None: + self.actor_cls = state["actor_cls"] + self._name = state["_name"] + self._ref = state["_ref"] + + def remote(self, *args, **kwargs): + # Actually instantiate the actor + ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args, + **kwargs) + return ClientActorHandle(ClientActorNameRef(ref.id), self) + + def __repr__(self): + return "ClientRemoteActor(%s, %s)" % (self._name, self._ref) + + def __getattr__(self, key): + raise NotImplementedError("static methods") + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + if self._ref is None: + self._ref = ray.put(self.actor_cls) + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.ACTOR + task.name = self._name + task.payload_id = self._ref.id + return task + + +class ClientActorHandle: + def __init__(self, actor_id: ClientActorNameRef, + actor_class: ClientActorClass): + self.actor_id = actor_id + self.actor_class = actor_class + self._real_actor_handle = None + + def _get_ray_remote_impl(self): + if self._real_actor_handle is None: + self._real_actor_handle = ray.get_actor_from_object(self.actor_id) + return self._real_actor_handle + + def __getstate__(self) -> Dict: + state = { + "actor_id": self.actor_id, + "actor_class": self.actor_class, + "_real_actor_handle": self._real_actor_handle, + } + return state + + def __setstate__(self, state: Dict) -> None: + self.actor_id = state["actor_id"] + self.actor_class = state["actor_class"] + self._real_actor_handle = state["_real_actor_handle"] + + def __getattr__(self, key): + return ClientRemoteMethod(self, key) + + def __repr__(self): + return "ClientActorHandle(%s, %s, %s)" % ( + self.actor_id, self.actor_class, self._real_actor_handle) + + +class ClientRemoteMethod: + def __init__(self, actor_handle: ClientActorHandle, method_name: str): + self.actor_handle = actor_handle + self.method_name = method_name + + def __call__(self, *args, **kwargs): + raise TypeError(f"Remote method cannot be called directly. " + "Use {self._name}.remote() instead") + + def _get_ray_remote_impl(self): + return getattr(self.actor_handle._get_ray_remote_impl(), + self.method_name) + + def __getstate__(self) -> Dict: + state = { + "actor_handle": self.actor_handle, + "method_name": self.method_name, + } + return state + + def __setstate__(self, state: Dict) -> None: + self.actor_handle = state["actor_handle"] + self.method_name = state["method_name"] + + def remote(self, *args, **kwargs): + return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args, + **kwargs) + + def __repr__(self): + name = "%s.%s" % (self.actor_handle.actor_class._name, + self.method_name) + return "ClientRemoteMethod(%s, %s)" % (name, + self.actor_handle.actor_id) + + def _prepare_client_task(self) -> ray_client_pb2.ClientTask: + task = ray_client_pb2.ClientTask() + task.type = ray_client_pb2.ClientTask.METHOD + task.name = self.method_name + task.payload_id = self.actor_handle.actor_id.id + return task + + +def convert_from_arg(pb) -> Any: + if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: + return ClientObjectRef(pb.reference_id) + elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: + return cloudpickle.loads(pb.data) + + raise Exception("convert_from_arg: Uncovered locality enum") + + +def convert_to_arg(val): + out = ray_client_pb2.Arg() + if isinstance(val, ClientObjectRef): + out.local = ray_client_pb2.Arg.Locality.REFERENCE + out.reference_id = val.id + else: + out.local = ray_client_pb2.Arg.Locality.INTERNED + out.data = cloudpickle.dumps(val) + return out diff --git a/python/ray/worker.py b/python/ray/worker.py new file mode 100644 index 0000000..dd5b84b --- /dev/null +++ b/python/ray/worker.py @@ -0,0 +1,142 @@ +"""This file includes the Worker class which sits on the client side. +It implements the Ray API functions that are forwarded through grpc calls +to the server. +""" +import inspect +import logging +from typing import List +from typing import Tuple + +import cloudpickle +import grpc + +import rpc.ray_client_pb2 as ray_client_pb2 +import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc +from ray.common import convert_to_arg +from ray.common import ClientObjectRef +from ray.common import ClientActorClass +from ray.common import ClientRemoteFunc + +logger = logging.getLogger(__name__) + + +class Worker: + def __init__(self, + conn_str: str = "", + secure: bool = False, + metadata: List[Tuple[str, str]] = None, + stub=None): + """Initializes the worker side grpc client. + + Args: + stub: custom grpc stub. + secure: whether to use SSL secure channel or not. + metadata: additional metadata passed in the grpc request headers. + """ + self.metadata = metadata + if stub is None: + if secure: + credentials = grpc.ssl_channel_credentials() + self.channel = grpc.secure_channel(conn_str, credentials) + else: + self.channel = grpc.insecure_channel(conn_str) + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + else: + self.server = stub + + def get(self, ids): + to_get = [] + single = False + if isinstance(ids, list): + to_get = [x.id for x in ids] + elif isinstance(ids, ClientObjectRef): + to_get = [ids.id] + single = True + else: + raise Exception("Can't get something that's not a " + "list of IDs or just an ID: %s" % type(ids)) + out = [self._get(x) for x in to_get] + if single: + out = out[0] + return out + + def _get(self, id: bytes): + req = ray_client_pb2.GetRequest(id=id) + data = self.server.GetObject(req, metadata=self.metadata) + if not data.valid: + raise Exception( + "Client GetObject returned invalid data: id invalid?") + return cloudpickle.loads(data.data) + + def put(self, vals): + to_put = [] + single = False + if isinstance(vals, list): + to_put = vals + else: + single = True + to_put.append(vals) + + out = [self._put(x) for x in to_put] + if single: + out = out[0] + return out + + def _put(self, val): + data = cloudpickle.dumps(val) + req = ray_client_pb2.PutRequest(data=data) + resp = self.server.PutObject(req, metadata=self.metadata) + return ClientObjectRef(resp.id) + + def wait(self, + object_refs: List[ClientObjectRef], + *, + num_returns: int = 1, + timeout: float = None + ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]: + assert isinstance(object_refs, list) + for ref in object_refs: + assert isinstance(ref, ClientObjectRef) + data = { + "object_refs": [ + cloudpickle.dumps(object_ref) for object_ref in object_refs + ], + "num_returns": num_returns, + "timeout": timeout if timeout else -1 + } + req = ray_client_pb2.WaitRequest(**data) + resp = self.server.WaitObject(req, metadata=self.metadata) + if not resp.valid: + # TODO(ameer): improve error/exceptions messages. + raise Exception("Client Wait request failed. Reference invalid?") + client_ready_object_ids = [ + ClientObjectRef(id) for id in resp.ready_object_ids + ] + client_remaining_object_ids = [ + ClientObjectRef(id) for id in resp.remaining_object_ids + ] + + return (client_ready_object_ids, client_remaining_object_ids) + + def remote(self, function_or_class, *args, **kwargs): + # TODO(barakmich): Arguments to ray.remote + # get captured here. + if inspect.isfunction(function_or_class): + return ClientRemoteFunc(function_or_class) + elif inspect.isclass(function_or_class): + return ClientActorClass(function_or_class) + else: + raise TypeError("The @ray.remote decorator must be applied to " + "either a function or to a class.") + + def call_remote(self, instance, kind, *args, **kwargs): + task = instance._prepare_client_task() + for arg in args: + pb_arg = convert_to_arg(arg) + task.args.append(pb_arg) + logging.debug("Scheduling %s" % task) + ticket = self.server.Schedule(task, metadata=self.metadata) + return ClientObjectRef(ticket.return_id) + + def close(self): + self.channel.close() diff --git a/python/ray_worker/__init__.py b/python/ray_worker/__init__.py new file mode 100644 index 0000000..f751afc --- /dev/null +++ b/python/ray_worker/__init__.py @@ -0,0 +1,74 @@ +import rpc.ray_client_pb2 as ray_client_pb2 +import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc + +import queue +import grpc +import cloudpickle +import threading +import concurrent.futures +from ray import ray +from ray.common import ClientObjectRef + + +class Worker: + def __init__(self, conn_str): + self.channel = grpc.insecure_channel(conn_str) + self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel) + self.send_queue = None + self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3) + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + + def begin(self): + self.send_queue = queue.Queue() + work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None)) + start = ray_client_pb2.WorkStatus( + status = ray_client_pb2.WorkStatus.StatusCode.READY, + error_msg = "python_worker", + ) + self.send_queue.put(start) + for work in work_stream: + print("Got work") + task = work.task + if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION: + send_queue.put(ray_client_pb2.WorkStatus( + status = ray_client_pb2.WorkStatus.StatusCode.ERROR, + error_msg = "unimplemented", + )) + continue + args = self.decode_args(task) + func = self.get(task.payload_id) + #self.pool.submit(self.run_and_return, func, args, work.ticket) + t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket)) + t.start() + + + def run_and_return(self, func, args, ticket): + res = func(*args) + out_data = cloudpickle.dumps(res) + self.send_queue.put(ray_client_pb2.WorkStatus( + status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE, + complete_data = out_data, + finished_ticket = ticket, + )) + + # def get(self, id_bytes): + # data = self.server.GetObject(ray_client_pb2.GetRequest( + # id = id_bytes, + # )) + # return cloudpickle.loads(data.data) + def get(self, id_bytes): + return ray.get(ClientObjectRef(id_bytes)) + + def decode_args(self, task): + out = [] + for arg in task.args: + t = self.convert_from_arg(arg) + out.append(t) + return out + + def convert_from_arg(self, pb): + if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: + return self.get(pb.reference_id) + elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: + return cloudpickle.loads(pb.data) + raise Exception("convert_from_arg: Uncovered locality enum") diff --git a/python/rpc/mk_proto.sh b/python/rpc/mk_proto.sh index 557869b..0213298 100755 --- a/python/rpc/mk_proto.sh +++ b/python/rpc/mk_proto.sh @@ -1 +1,7 @@ +#!/bin/sh python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto +SEDCMD="sed" +if [ -n "`which gsed`" ]; then + SEDCMD="gsed" +fi +$SEDCMD -i 's/^import ray_client_pb2/import rpc.ray_client_pb2/g' *.py diff --git a/python/worker.py b/python/worker.py index b466bb0..27463ce 100644 --- a/python/worker.py +++ b/python/worker.py @@ -1,81 +1,8 @@ # from rpc import ray_client_pb2 # from rpc import ray_client_pb2_grpc -import ray.core.generated.ray_client_pb2 as ray_client_pb2 -import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc - -import queue +from ray_worker import Worker import sys -import grpc -import cloudpickle -import threading -import concurrent.futures -from ray.experimental.client import ray -from ray.experimental.client.common import ClientObjectRef - - - -class Worker: - def __init__(self, conn_str): - self.channel = grpc.insecure_channel(conn_str) - self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel) - self.send_queue = None - self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3) - self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) - - def begin(self): - self.send_queue = queue.Queue() - work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None)) - start = ray_client_pb2.WorkStatus( - status = ray_client_pb2.WorkStatus.StatusCode.READY, - error_msg = "python_worker", - ) - self.send_queue.put(start) - for work in work_stream: - print("Got work") - task = work.task - if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION: - send_queue.put(ray_client_pb2.WorkStatus( - status = ray_client_pb2.WorkStatus.StatusCode.ERROR, - error_msg = "unimplemented", - )) - continue - args = self.decode_args(task) - func = self.get(task.payload_id) - #self.pool.submit(self.run_and_return, func, args, work.ticket) - t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket)) - t.start() - - - def run_and_return(self, func, args, ticket): - res = func(*args) - out_data = cloudpickle.dumps(res) - self.send_queue.put(ray_client_pb2.WorkStatus( - status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE, - complete_data = out_data, - finished_ticket = ticket, - )) - - # def get(self, id_bytes): - # data = self.server.GetObject(ray_client_pb2.GetRequest( - # id = id_bytes, - # )) - # return cloudpickle.loads(data.data) - def get(self, id_bytes): - return ray.get(ClientObjectRef(id_bytes)) - - def decode_args(self, task): - out = [] - for arg in task.args: - t = self.convert_from_arg(arg) - out.append(t) - return out - - def convert_from_arg(self, pb): - if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: - return self.get(pb.reference_id) - elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: - return cloudpickle.loads(pb.data) - raise Exception("convert_from_arg: Uncovered locality enum") +from ray import ray def main(): diff --git a/raylet_grpc.go b/raylet_grpc.go index d0a71a1..a4e0b20 100644 --- a/raylet_grpc.go +++ b/raylet_grpc.go @@ -58,7 +58,7 @@ func (r *Raylet) Workstream(conn WorkstreamConnection) error { worker := &SimpleWorker{ workChan: make(chan *ray_rpc.Work), clientConn: conn, - pool: wp, + pool: r.Workers, } r.Workers.Register(worker) err := worker.Run() diff --git a/worker.go b/worker.go index 139ad37..b1af58a 100644 --- a/worker.go +++ b/worker.go @@ -27,7 +27,7 @@ func (s *SimpleWorker) AssignWork(work *ray_rpc.Work) error { } func (s *SimpleWorker) Close() error { - close(w.workChan) + close(s.workChan) return nil } diff --git a/worker_pool.go b/worker_pool.go index 1dcb568..ed1da4a 100644 --- a/worker_pool.go +++ b/worker_pool.go @@ -34,6 +34,7 @@ func (wp *SimpleRRWorkerPool) Register(worker Worker) error { wp.Lock() defer wp.Unlock() wp.workers = append(wp.workers, worker) + return nil } func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {