hermetically seal everything, bring it up to date
This commit is contained in:
parent
efd3d65269
commit
03beb6970c
12 changed files with 723 additions and 77 deletions
17
python/client.py
Normal file
17
python/client.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from ray import ray
|
||||
|
||||
ray.connect("localhost:50050")
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
|
||||
print(ray.get(plus2.remote(4)))
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
if x <= 0:
|
||||
return 1
|
||||
return x * ray.get(fact.remote(x - 1))
|
||||
|
||||
print(ray.get(fact.remote(20)))
|
||||
107
python/ray/__init__.py
Normal file
107
python/ray/__init__.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from ray.api import ClientAPI
|
||||
from ray.api import APIImpl
|
||||
from typing import Optional, List, Tuple
|
||||
from contextlib import contextmanager
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# _client_api has to be external to the API stub, below.
|
||||
# Otherwise, ray.remote() that contains ray.remote()
|
||||
# contains a reference to the RayAPIStub, therefore a
|
||||
# reference to the _client_api, and then tries to pickle
|
||||
# the thing.
|
||||
_client_api: Optional[APIImpl] = None
|
||||
|
||||
_server_api: Optional[APIImpl] = None
|
||||
|
||||
_is_server: bool = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_api_for_tests(in_test: bool):
|
||||
global _is_server
|
||||
is_server = _is_server
|
||||
if in_test:
|
||||
_is_server = True
|
||||
yield _server_api
|
||||
if in_test:
|
||||
_is_server = is_server
|
||||
|
||||
|
||||
def _set_client_api(val: Optional[APIImpl]):
|
||||
global _client_api
|
||||
global _is_server
|
||||
if _client_api is not None:
|
||||
raise Exception("Trying to set more than one client API")
|
||||
_client_api = val
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _set_server_api(val: Optional[APIImpl]):
|
||||
global _server_api
|
||||
global _is_server
|
||||
if _server_api is not None:
|
||||
raise Exception("Trying to set more than one server API")
|
||||
_server_api = val
|
||||
_is_server = True
|
||||
|
||||
|
||||
def reset_api():
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
_client_api = None
|
||||
_server_api = None
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _get_client_api() -> APIImpl:
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
api = None
|
||||
if _is_server:
|
||||
api = _server_api
|
||||
else:
|
||||
api = _client_api
|
||||
if api is None:
|
||||
# We're inside a raylet worker
|
||||
raise Exception("CoreRayAPI not supported")
|
||||
return api
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
from ray.worker import Worker
|
||||
_client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||
_set_client_api(ClientAPI(_client_worker))
|
||||
|
||||
def disconnect(self):
|
||||
global _client_api
|
||||
if _client_api is not None:
|
||||
_client_api.close()
|
||||
_client_api = None
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
global _get_client_api
|
||||
api = _get_client_api()
|
||||
return getattr(api, key)
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
||||
# Someday we might add methods in this module so that someone who
|
||||
# tries to `import ray_client as ray` -- as a module, instead of
|
||||
# `from ray_client import ray` -- as the API stub
|
||||
# still gets expected functionality. This is the way the ray package
|
||||
# worked in the past.
|
||||
#
|
||||
# This really calls for PEP 562: https://www.python.org/dev/peps/pep-0562/
|
||||
# But until Python 3.6 is EOL, here we are.
|
||||
77
python/ray/api.py
Normal file
77
python/ray/api.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
# This file defines an interface and client-side API stub
|
||||
# for referring either to the core Ray API or the same interface
|
||||
# from the Ray client.
|
||||
#
|
||||
# In tandem with __init__.py, we want to expose an API that's
|
||||
# close to `python/ray/__init__.py` but with more than one implementation.
|
||||
# The stubs in __init__ should call into a well-defined interface.
|
||||
# Only the core Ray API implementation should actually `import ray`
|
||||
# (and thus import all the raylet worker C bindings and such).
|
||||
# But to make sure that we're matching these calls, we define this API.
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class APIImpl(ABC):
|
||||
@abstractmethod
|
||||
def get(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, instance, kind: int, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_actor_from_object(self, id):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self.worker.get(*args, **kwargs)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, f, kind, *args, **kwargs):
|
||||
return self.worker.call_remote(f, kind, *args, **kwargs)
|
||||
|
||||
def get_actor_from_object(self, id):
|
||||
raise Exception("Calling get_actor_from_object on the client side")
|
||||
|
||||
def close(self, *args, **kwargs):
|
||||
return self.worker.close()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if not key.startswith("_"):
|
||||
raise NotImplementedError(
|
||||
"Not available in Ray client: `ray.{}`. This method is only "
|
||||
"available within Ray remote functions and is not yet "
|
||||
"implemented in the client API.".format(key))
|
||||
return self.__getattribute__(key)
|
||||
90
python/ray/client_app.py
Normal file
90
python/ray/client_app.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
from ray import ray
|
||||
from typing import Tuple
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
|
||||
@ray.remote
|
||||
class HelloActor:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def say_hello(self, whom: str) -> Tuple[str, int]:
|
||||
self.count += 1
|
||||
return ("Hello " + whom, self.count)
|
||||
|
||||
|
||||
actor = HelloActor.remote()
|
||||
s, count = ray.get(actor.say_hello.remote("you"))
|
||||
print(s, count)
|
||||
assert s == "Hello you"
|
||||
assert count == 1
|
||||
s, count = ray.get(actor.say_hello.remote("world"))
|
||||
print(s, count)
|
||||
assert s == "Hello world"
|
||||
assert count == 2
|
||||
|
||||
|
||||
@ray.remote
|
||||
def plus2(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
print(x, type(fact))
|
||||
if x <= 0:
|
||||
return 1
|
||||
# This hits the "nested tasks" issue
|
||||
# https://github.com/ray-project/ray/issues/3644
|
||||
# So we're on the right track!
|
||||
return ray.get(fact.remote(x - 1)) * x
|
||||
|
||||
|
||||
@ray.remote
|
||||
def get_nodes():
|
||||
return ray.nodes() # Can access the full Ray API in remote methods.
|
||||
|
||||
|
||||
print("Cluster nodes", ray.get(get_nodes.remote()))
|
||||
print(ray.nodes())
|
||||
|
||||
objectref = ray.put("hello world")
|
||||
|
||||
# `ClientObjectRef(...)`
|
||||
print(objectref)
|
||||
|
||||
# `hello world`
|
||||
print(ray.get(objectref))
|
||||
|
||||
ref2 = plus2.remote(234)
|
||||
# `ClientObjectRef(...)`
|
||||
print(ref2)
|
||||
# `236`
|
||||
print(ray.get(ref2))
|
||||
|
||||
ref3 = fact.remote(20)
|
||||
# `ClientObjectRef(...)`
|
||||
print(ref3)
|
||||
# `2432902008176640000`
|
||||
print(ray.get(ref3))
|
||||
|
||||
# Reuse the cached ClientRemoteFunc object
|
||||
ref4 = fact.remote(5)
|
||||
# `120`
|
||||
print(ray.get(ref4))
|
||||
|
||||
ref5 = fact.remote(10)
|
||||
|
||||
print([ref2, ref3, ref4, ref5])
|
||||
# should return ref2, ref3, ref4
|
||||
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||
print(res)
|
||||
assert [ref2, ref3, ref4] == res[0]
|
||||
assert [ref5] == res[1]
|
||||
|
||||
# should return ref2, ref3, ref4, ref5
|
||||
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||
print(res)
|
||||
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||
assert [] == res[1]
|
||||
205
python/ray/common.py
Normal file
205
python/ray/common.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
import rpc.ray_client_pb2 as ray_client_pb2
|
||||
from ray import ray
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
import cloudpickle
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
type(self).__name__,
|
||||
self.id.hex(),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
pass
|
||||
|
||||
|
||||
class ClientActorNameRef(ClientBaseRef):
|
||||
pass
|
||||
|
||||
|
||||
class ClientRemoteFunc:
|
||||
def __init__(self, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
|
||||
**kwargs)
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._raylet_remote is None:
|
||||
self._raylet_remote = ray.remote(self._func)
|
||||
return self._raylet_remote
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self._func)
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
return task
|
||||
|
||||
|
||||
class ClientActorClass:
|
||||
def __init__(self, actor_cls):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_cls": self.actor_cls,
|
||||
"_name": self._name,
|
||||
"_ref": self._ref,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_cls = state["actor_cls"]
|
||||
self._name = state["_name"]
|
||||
self._ref = state["_ref"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
|
||||
**kwargs)
|
||||
return ClientActorHandle(ClientActorNameRef(ref.id), self)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def __getattr__(self, key):
|
||||
raise NotImplementedError("static methods")
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.id
|
||||
return task
|
||||
|
||||
|
||||
class ClientActorHandle:
|
||||
def __init__(self, actor_id: ClientActorNameRef,
|
||||
actor_class: ClientActorClass):
|
||||
self.actor_id = actor_id
|
||||
self.actor_class = actor_class
|
||||
self._real_actor_handle = None
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._real_actor_handle is None:
|
||||
self._real_actor_handle = ray.get_actor_from_object(self.actor_id)
|
||||
return self._real_actor_handle
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_id": self.actor_id,
|
||||
"actor_class": self.actor_class,
|
||||
"_real_actor_handle": self._real_actor_handle,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_id = state["actor_id"]
|
||||
self.actor_class = state["actor_class"]
|
||||
self._real_actor_handle = state["_real_actor_handle"]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ClientRemoteMethod(self, key)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientActorHandle(%s, %s, %s)" % (
|
||||
self.actor_id, self.actor_class, self._real_actor_handle)
|
||||
|
||||
|
||||
class ClientRemoteMethod:
|
||||
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
|
||||
self.actor_handle = actor_handle
|
||||
self.method_name = method_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote method cannot be called directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||
self.method_name)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_handle": self.actor_handle,
|
||||
"method_name": self.method_name,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_handle = state["actor_handle"]
|
||||
self.method_name = state["method_name"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
|
||||
**kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||
self.actor_handle.actor_id)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = self.method_name
|
||||
task.payload_id = self.actor_handle.actor_id.id
|
||||
return task
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return ClientObjectRef(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
|
||||
|
||||
def convert_to_arg(val):
|
||||
out = ray_client_pb2.Arg()
|
||||
if isinstance(val, ClientObjectRef):
|
||||
out.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||
out.reference_id = val.id
|
||||
else:
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = cloudpickle.dumps(val)
|
||||
return out
|
||||
142
python/ray/worker.py
Normal file
142
python/ray/worker.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
"""This file includes the Worker class which sits on the client side.
|
||||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import cloudpickle
|
||||
import grpc
|
||||
|
||||
import rpc.ray_client_pb2 as ray_client_pb2
|
||||
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.common import convert_to_arg
|
||||
from ray.common import ClientObjectRef
|
||||
from ray.common import ClientActorClass
|
||||
from ray.common import ClientRemoteFunc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self,
|
||||
conn_str: str = "",
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
"""Initializes the worker side grpc client.
|
||||
|
||||
Args:
|
||||
stub: custom grpc stub.
|
||||
secure: whether to use SSL secure channel or not.
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
if stub is None:
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
else:
|
||||
self.server = stub
|
||||
|
||||
def get(self, ids):
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(ids, list):
|
||||
to_get = [x.id for x in ids]
|
||||
elif isinstance(ids, ClientObjectRef):
|
||||
to_get = [ids.id]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a "
|
||||
"list of IDs or just an ID: %s" % type(ids))
|
||||
out = [self._get(x) for x in to_get]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, id: bytes):
|
||||
req = ray_client_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
if not data.valid:
|
||||
raise Exception(
|
||||
"Client GetObject returned invalid data: id invalid?")
|
||||
return cloudpickle.loads(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
*,
|
||||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||
assert isinstance(object_refs, list)
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
data = {
|
||||
"object_refs": [
|
||||
cloudpickle.dumps(object_ref) for object_ref in object_refs
|
||||
],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||
if not resp.valid:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
client_ready_object_ids = [
|
||||
ClientObjectRef(id) for id in resp.ready_object_ids
|
||||
]
|
||||
client_remaining_object_ids = [
|
||||
ClientObjectRef(id) for id in resp.remaining_object_ids
|
||||
]
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, function_or_class, *args, **kwargs):
|
||||
# TODO(barakmich): Arguments to ray.remote
|
||||
# get captured here.
|
||||
if inspect.isfunction(function_or_class):
|
||||
return ClientRemoteFunc(function_or_class)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def call_remote(self, instance, kind, *args, **kwargs):
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
task.args.append(pb_arg)
|
||||
logging.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
74
python/ray_worker/__init__.py
Normal file
74
python/ray_worker/__init__.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
import rpc.ray_client_pb2 as ray_client_pb2
|
||||
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
import queue
|
||||
import grpc
|
||||
import cloudpickle
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from ray import ray
|
||||
from ray.common import ClientObjectRef
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str):
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
||||
self.send_queue = None
|
||||
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
def begin(self):
|
||||
self.send_queue = queue.Queue()
|
||||
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
|
||||
start = ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.READY,
|
||||
error_msg = "python_worker",
|
||||
)
|
||||
self.send_queue.put(start)
|
||||
for work in work_stream:
|
||||
print("Got work")
|
||||
task = work.task
|
||||
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
||||
send_queue.put(ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
||||
error_msg = "unimplemented",
|
||||
))
|
||||
continue
|
||||
args = self.decode_args(task)
|
||||
func = self.get(task.payload_id)
|
||||
#self.pool.submit(self.run_and_return, func, args, work.ticket)
|
||||
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
|
||||
t.start()
|
||||
|
||||
|
||||
def run_and_return(self, func, args, ticket):
|
||||
res = func(*args)
|
||||
out_data = cloudpickle.dumps(res)
|
||||
self.send_queue.put(ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
||||
complete_data = out_data,
|
||||
finished_ticket = ticket,
|
||||
))
|
||||
|
||||
# def get(self, id_bytes):
|
||||
# data = self.server.GetObject(ray_client_pb2.GetRequest(
|
||||
# id = id_bytes,
|
||||
# ))
|
||||
# return cloudpickle.loads(data.data)
|
||||
def get(self, id_bytes):
|
||||
return ray.get(ClientObjectRef(id_bytes))
|
||||
|
||||
def decode_args(self, task):
|
||||
out = []
|
||||
for arg in task.args:
|
||||
t = self.convert_from_arg(arg)
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
def convert_from_arg(self, pb):
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return self.get(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
|
|
@ -1 +1,7 @@
|
|||
#!/bin/sh
|
||||
python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto
|
||||
SEDCMD="sed"
|
||||
if [ -n "`which gsed`" ]; then
|
||||
SEDCMD="gsed"
|
||||
fi
|
||||
$SEDCMD -i 's/^import ray_client_pb2/import rpc.ray_client_pb2/g' *.py
|
||||
|
|
|
|||
|
|
@ -1,81 +1,8 @@
|
|||
# from rpc import ray_client_pb2
|
||||
# from rpc import ray_client_pb2_grpc
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
import queue
|
||||
from ray_worker import Worker
|
||||
import sys
|
||||
import grpc
|
||||
import cloudpickle
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from ray.experimental.client import ray
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str):
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
||||
self.send_queue = None
|
||||
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
def begin(self):
|
||||
self.send_queue = queue.Queue()
|
||||
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
|
||||
start = ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.READY,
|
||||
error_msg = "python_worker",
|
||||
)
|
||||
self.send_queue.put(start)
|
||||
for work in work_stream:
|
||||
print("Got work")
|
||||
task = work.task
|
||||
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
||||
send_queue.put(ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
||||
error_msg = "unimplemented",
|
||||
))
|
||||
continue
|
||||
args = self.decode_args(task)
|
||||
func = self.get(task.payload_id)
|
||||
#self.pool.submit(self.run_and_return, func, args, work.ticket)
|
||||
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
|
||||
t.start()
|
||||
|
||||
|
||||
def run_and_return(self, func, args, ticket):
|
||||
res = func(*args)
|
||||
out_data = cloudpickle.dumps(res)
|
||||
self.send_queue.put(ray_client_pb2.WorkStatus(
|
||||
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
||||
complete_data = out_data,
|
||||
finished_ticket = ticket,
|
||||
))
|
||||
|
||||
# def get(self, id_bytes):
|
||||
# data = self.server.GetObject(ray_client_pb2.GetRequest(
|
||||
# id = id_bytes,
|
||||
# ))
|
||||
# return cloudpickle.loads(data.data)
|
||||
def get(self, id_bytes):
|
||||
return ray.get(ClientObjectRef(id_bytes))
|
||||
|
||||
def decode_args(self, task):
|
||||
out = []
|
||||
for arg in task.args:
|
||||
t = self.convert_from_arg(arg)
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
def convert_from_arg(self, pb):
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return self.get(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
from ray import ray
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ func (r *Raylet) Workstream(conn WorkstreamConnection) error {
|
|||
worker := &SimpleWorker{
|
||||
workChan: make(chan *ray_rpc.Work),
|
||||
clientConn: conn,
|
||||
pool: wp,
|
||||
pool: r.Workers,
|
||||
}
|
||||
r.Workers.Register(worker)
|
||||
err := worker.Run()
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ func (s *SimpleWorker) AssignWork(work *ray_rpc.Work) error {
|
|||
}
|
||||
|
||||
func (s *SimpleWorker) Close() error {
|
||||
close(w.workChan)
|
||||
close(s.workChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ func (wp *SimpleRRWorkerPool) Register(worker Worker) error {
|
|||
wp.Lock()
|
||||
defer wp.Unlock()
|
||||
wp.workers = append(wp.workers, worker)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue