hermetically seal everything, bring it up to date
This commit is contained in:
parent
efd3d65269
commit
03beb6970c
12 changed files with 723 additions and 77 deletions
17
python/client.py
Normal file
17
python/client.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
from ray import ray
|
||||||
|
|
||||||
|
ray.connect("localhost:50050")
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def plus2(x):
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
print(ray.get(plus2.remote(4)))
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def fact(x):
|
||||||
|
if x <= 0:
|
||||||
|
return 1
|
||||||
|
return x * ray.get(fact.remote(x - 1))
|
||||||
|
|
||||||
|
print(ray.get(fact.remote(20)))
|
||||||
107
python/ray/__init__.py
Normal file
107
python/ray/__init__.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
from ray.api import ClientAPI
|
||||||
|
from ray.api import APIImpl
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# _client_api has to be external to the API stub, below.
|
||||||
|
# Otherwise, ray.remote() that contains ray.remote()
|
||||||
|
# contains a reference to the RayAPIStub, therefore a
|
||||||
|
# reference to the _client_api, and then tries to pickle
|
||||||
|
# the thing.
|
||||||
|
_client_api: Optional[APIImpl] = None
|
||||||
|
|
||||||
|
_server_api: Optional[APIImpl] = None
|
||||||
|
|
||||||
|
_is_server: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def stash_api_for_tests(in_test: bool):
|
||||||
|
global _is_server
|
||||||
|
is_server = _is_server
|
||||||
|
if in_test:
|
||||||
|
_is_server = True
|
||||||
|
yield _server_api
|
||||||
|
if in_test:
|
||||||
|
_is_server = is_server
|
||||||
|
|
||||||
|
|
||||||
|
def _set_client_api(val: Optional[APIImpl]):
|
||||||
|
global _client_api
|
||||||
|
global _is_server
|
||||||
|
if _client_api is not None:
|
||||||
|
raise Exception("Trying to set more than one client API")
|
||||||
|
_client_api = val
|
||||||
|
_is_server = False
|
||||||
|
|
||||||
|
|
||||||
|
def _set_server_api(val: Optional[APIImpl]):
|
||||||
|
global _server_api
|
||||||
|
global _is_server
|
||||||
|
if _server_api is not None:
|
||||||
|
raise Exception("Trying to set more than one server API")
|
||||||
|
_server_api = val
|
||||||
|
_is_server = True
|
||||||
|
|
||||||
|
|
||||||
|
def reset_api():
|
||||||
|
global _client_api
|
||||||
|
global _server_api
|
||||||
|
global _is_server
|
||||||
|
_client_api = None
|
||||||
|
_server_api = None
|
||||||
|
_is_server = False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_api() -> APIImpl:
|
||||||
|
global _client_api
|
||||||
|
global _server_api
|
||||||
|
global _is_server
|
||||||
|
api = None
|
||||||
|
if _is_server:
|
||||||
|
api = _server_api
|
||||||
|
else:
|
||||||
|
api = _client_api
|
||||||
|
if api is None:
|
||||||
|
# We're inside a raylet worker
|
||||||
|
raise Exception("CoreRayAPI not supported")
|
||||||
|
return api
|
||||||
|
|
||||||
|
|
||||||
|
class RayAPIStub:
|
||||||
|
def connect(self,
|
||||||
|
conn_str: str,
|
||||||
|
secure: bool = False,
|
||||||
|
metadata: List[Tuple[str, str]] = None,
|
||||||
|
stub=None):
|
||||||
|
from ray.worker import Worker
|
||||||
|
_client_worker = Worker(
|
||||||
|
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||||
|
_set_client_api(ClientAPI(_client_worker))
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
global _client_api
|
||||||
|
if _client_api is not None:
|
||||||
|
_client_api.close()
|
||||||
|
_client_api = None
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
global _get_client_api
|
||||||
|
api = _get_client_api()
|
||||||
|
return getattr(api, key)
|
||||||
|
|
||||||
|
|
||||||
|
ray = RayAPIStub()
|
||||||
|
|
||||||
|
# Someday we might add methods in this module so that someone who
|
||||||
|
# tries to `import ray_client as ray` -- as a module, instead of
|
||||||
|
# `from ray_client import ray` -- as the API stub
|
||||||
|
# still gets expected functionality. This is the way the ray package
|
||||||
|
# worked in the past.
|
||||||
|
#
|
||||||
|
# This really calls for PEP 562: https://www.python.org/dev/peps/pep-0562/
|
||||||
|
# But until Python 3.6 is EOL, here we are.
|
||||||
77
python/ray/api.py
Normal file
77
python/ray/api.py
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
# This file defines an interface and client-side API stub
|
||||||
|
# for referring either to the core Ray API or the same interface
|
||||||
|
# from the Ray client.
|
||||||
|
#
|
||||||
|
# In tandem with __init__.py, we want to expose an API that's
|
||||||
|
# close to `python/ray/__init__.py` but with more than one implementation.
|
||||||
|
# The stubs in __init__ should call into a well-defined interface.
|
||||||
|
# Only the core Ray API implementation should actually `import ray`
|
||||||
|
# (and thus import all the raylet worker C bindings and such).
|
||||||
|
# But to make sure that we're matching these calls, we define this API.
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class APIImpl(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def put(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remote(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def call_remote(self, instance, kind: int, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_actor_from_object(self, id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClientAPI(APIImpl):
|
||||||
|
def __init__(self, worker):
|
||||||
|
self.worker = worker
|
||||||
|
|
||||||
|
def get(self, *args, **kwargs):
|
||||||
|
return self.worker.get(*args, **kwargs)
|
||||||
|
|
||||||
|
def put(self, *args, **kwargs):
|
||||||
|
return self.worker.put(*args, **kwargs)
|
||||||
|
|
||||||
|
def wait(self, *args, **kwargs):
|
||||||
|
return self.worker.wait(*args, **kwargs)
|
||||||
|
|
||||||
|
def remote(self, *args, **kwargs):
|
||||||
|
return self.worker.remote(*args, **kwargs)
|
||||||
|
|
||||||
|
def call_remote(self, f, kind, *args, **kwargs):
|
||||||
|
return self.worker.call_remote(f, kind, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_actor_from_object(self, id):
|
||||||
|
raise Exception("Calling get_actor_from_object on the client side")
|
||||||
|
|
||||||
|
def close(self, *args, **kwargs):
|
||||||
|
return self.worker.close()
|
||||||
|
|
||||||
|
def __getattr__(self, key: str):
|
||||||
|
if not key.startswith("_"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not available in Ray client: `ray.{}`. This method is only "
|
||||||
|
"available within Ray remote functions and is not yet "
|
||||||
|
"implemented in the client API.".format(key))
|
||||||
|
return self.__getattribute__(key)
|
||||||
90
python/ray/client_app.py
Normal file
90
python/ray/client_app.py
Normal file
|
|
@ -0,0 +1,90 @@
|
||||||
|
from ray import ray
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
ray.connect("localhost:50051")
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class HelloActor:
|
||||||
|
def __init__(self):
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def say_hello(self, whom: str) -> Tuple[str, int]:
|
||||||
|
self.count += 1
|
||||||
|
return ("Hello " + whom, self.count)
|
||||||
|
|
||||||
|
|
||||||
|
actor = HelloActor.remote()
|
||||||
|
s, count = ray.get(actor.say_hello.remote("you"))
|
||||||
|
print(s, count)
|
||||||
|
assert s == "Hello you"
|
||||||
|
assert count == 1
|
||||||
|
s, count = ray.get(actor.say_hello.remote("world"))
|
||||||
|
print(s, count)
|
||||||
|
assert s == "Hello world"
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def plus2(x):
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def fact(x):
|
||||||
|
print(x, type(fact))
|
||||||
|
if x <= 0:
|
||||||
|
return 1
|
||||||
|
# This hits the "nested tasks" issue
|
||||||
|
# https://github.com/ray-project/ray/issues/3644
|
||||||
|
# So we're on the right track!
|
||||||
|
return ray.get(fact.remote(x - 1)) * x
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
def get_nodes():
|
||||||
|
return ray.nodes() # Can access the full Ray API in remote methods.
|
||||||
|
|
||||||
|
|
||||||
|
print("Cluster nodes", ray.get(get_nodes.remote()))
|
||||||
|
print(ray.nodes())
|
||||||
|
|
||||||
|
objectref = ray.put("hello world")
|
||||||
|
|
||||||
|
# `ClientObjectRef(...)`
|
||||||
|
print(objectref)
|
||||||
|
|
||||||
|
# `hello world`
|
||||||
|
print(ray.get(objectref))
|
||||||
|
|
||||||
|
ref2 = plus2.remote(234)
|
||||||
|
# `ClientObjectRef(...)`
|
||||||
|
print(ref2)
|
||||||
|
# `236`
|
||||||
|
print(ray.get(ref2))
|
||||||
|
|
||||||
|
ref3 = fact.remote(20)
|
||||||
|
# `ClientObjectRef(...)`
|
||||||
|
print(ref3)
|
||||||
|
# `2432902008176640000`
|
||||||
|
print(ray.get(ref3))
|
||||||
|
|
||||||
|
# Reuse the cached ClientRemoteFunc object
|
||||||
|
ref4 = fact.remote(5)
|
||||||
|
# `120`
|
||||||
|
print(ray.get(ref4))
|
||||||
|
|
||||||
|
ref5 = fact.remote(10)
|
||||||
|
|
||||||
|
print([ref2, ref3, ref4, ref5])
|
||||||
|
# should return ref2, ref3, ref4
|
||||||
|
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
||||||
|
print(res)
|
||||||
|
assert [ref2, ref3, ref4] == res[0]
|
||||||
|
assert [ref5] == res[1]
|
||||||
|
|
||||||
|
# should return ref2, ref3, ref4, ref5
|
||||||
|
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
||||||
|
print(res)
|
||||||
|
assert [ref2, ref3, ref4, ref5] == res[0]
|
||||||
|
assert [] == res[1]
|
||||||
205
python/ray/common.py
Normal file
205
python/ray/common.py
Normal file
|
|
@ -0,0 +1,205 @@
|
||||||
|
import rpc.ray_client_pb2 as ray_client_pb2
|
||||||
|
from ray import ray
|
||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
|
|
||||||
|
class ClientBaseRef:
|
||||||
|
def __init__(self, id):
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s(%s)" % (
|
||||||
|
type(self).__name__,
|
||||||
|
self.id.hex(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.id == other.id
|
||||||
|
|
||||||
|
def binary(self):
|
||||||
|
return self.id
|
||||||
|
|
||||||
|
|
||||||
|
class ClientObjectRef(ClientBaseRef):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClientActorNameRef(ClientBaseRef):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRemoteFunc:
|
||||||
|
def __init__(self, f):
|
||||||
|
self._func = f
|
||||||
|
self._name = f.__name__
|
||||||
|
self.id = None
|
||||||
|
self._ref = None
|
||||||
|
self._raylet_remote = None
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise TypeError(f"Remote function cannot be called directly. "
|
||||||
|
"Use {self._name}.remote method instead")
|
||||||
|
|
||||||
|
def remote(self, *args, **kwargs):
|
||||||
|
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def _get_ray_remote_impl(self):
|
||||||
|
if self._raylet_remote is None:
|
||||||
|
self._raylet_remote = ray.remote(self._func)
|
||||||
|
return self._raylet_remote
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
|
||||||
|
|
||||||
|
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||||
|
if self._ref is None:
|
||||||
|
self._ref = ray.put(self._func)
|
||||||
|
task = ray_client_pb2.ClientTask()
|
||||||
|
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||||
|
task.name = self._name
|
||||||
|
task.payload_id = self._ref.id
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
class ClientActorClass:
|
||||||
|
def __init__(self, actor_cls):
|
||||||
|
self.actor_cls = actor_cls
|
||||||
|
self._name = actor_cls.__name__
|
||||||
|
self._ref = None
|
||||||
|
self._raylet_remote = None
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||||
|
"Use {self._name}.remote() instead")
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = {
|
||||||
|
"actor_cls": self.actor_cls,
|
||||||
|
"_name": self._name,
|
||||||
|
"_ref": self._ref,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict) -> None:
|
||||||
|
self.actor_cls = state["actor_cls"]
|
||||||
|
self._name = state["_name"]
|
||||||
|
self._ref = state["_ref"]
|
||||||
|
|
||||||
|
def remote(self, *args, **kwargs):
|
||||||
|
# Actually instantiate the actor
|
||||||
|
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
|
||||||
|
**kwargs)
|
||||||
|
return ClientActorHandle(ClientActorNameRef(ref.id), self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
raise NotImplementedError("static methods")
|
||||||
|
|
||||||
|
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||||
|
if self._ref is None:
|
||||||
|
self._ref = ray.put(self.actor_cls)
|
||||||
|
task = ray_client_pb2.ClientTask()
|
||||||
|
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||||
|
task.name = self._name
|
||||||
|
task.payload_id = self._ref.id
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
class ClientActorHandle:
|
||||||
|
def __init__(self, actor_id: ClientActorNameRef,
|
||||||
|
actor_class: ClientActorClass):
|
||||||
|
self.actor_id = actor_id
|
||||||
|
self.actor_class = actor_class
|
||||||
|
self._real_actor_handle = None
|
||||||
|
|
||||||
|
def _get_ray_remote_impl(self):
|
||||||
|
if self._real_actor_handle is None:
|
||||||
|
self._real_actor_handle = ray.get_actor_from_object(self.actor_id)
|
||||||
|
return self._real_actor_handle
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = {
|
||||||
|
"actor_id": self.actor_id,
|
||||||
|
"actor_class": self.actor_class,
|
||||||
|
"_real_actor_handle": self._real_actor_handle,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict) -> None:
|
||||||
|
self.actor_id = state["actor_id"]
|
||||||
|
self.actor_class = state["actor_class"]
|
||||||
|
self._real_actor_handle = state["_real_actor_handle"]
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
return ClientRemoteMethod(self, key)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "ClientActorHandle(%s, %s, %s)" % (
|
||||||
|
self.actor_id, self.actor_class, self._real_actor_handle)
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRemoteMethod:
|
||||||
|
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
|
||||||
|
self.actor_handle = actor_handle
|
||||||
|
self.method_name = method_name
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
raise TypeError(f"Remote method cannot be called directly. "
|
||||||
|
"Use {self._name}.remote() instead")
|
||||||
|
|
||||||
|
def _get_ray_remote_impl(self):
|
||||||
|
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||||
|
self.method_name)
|
||||||
|
|
||||||
|
def __getstate__(self) -> Dict:
|
||||||
|
state = {
|
||||||
|
"actor_handle": self.actor_handle,
|
||||||
|
"method_name": self.method_name,
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict) -> None:
|
||||||
|
self.actor_handle = state["actor_handle"]
|
||||||
|
self.method_name = state["method_name"]
|
||||||
|
|
||||||
|
def remote(self, *args, **kwargs):
|
||||||
|
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||||
|
self.method_name)
|
||||||
|
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||||
|
self.actor_handle.actor_id)
|
||||||
|
|
||||||
|
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||||
|
task = ray_client_pb2.ClientTask()
|
||||||
|
task.type = ray_client_pb2.ClientTask.METHOD
|
||||||
|
task.name = self.method_name
|
||||||
|
task.payload_id = self.actor_handle.actor_id.id
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def convert_from_arg(pb) -> Any:
|
||||||
|
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||||
|
return ClientObjectRef(pb.reference_id)
|
||||||
|
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||||
|
return cloudpickle.loads(pb.data)
|
||||||
|
|
||||||
|
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_arg(val):
|
||||||
|
out = ray_client_pb2.Arg()
|
||||||
|
if isinstance(val, ClientObjectRef):
|
||||||
|
out.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||||
|
out.reference_id = val.id
|
||||||
|
else:
|
||||||
|
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||||
|
out.data = cloudpickle.dumps(val)
|
||||||
|
return out
|
||||||
142
python/ray/worker.py
Normal file
142
python/ray/worker.py
Normal file
|
|
@ -0,0 +1,142 @@
|
||||||
|
"""This file includes the Worker class which sits on the client side.
|
||||||
|
It implements the Ray API functions that are forwarded through grpc calls
|
||||||
|
to the server.
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
import rpc.ray_client_pb2 as ray_client_pb2
|
||||||
|
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
|
from ray.common import convert_to_arg
|
||||||
|
from ray.common import ClientObjectRef
|
||||||
|
from ray.common import ClientActorClass
|
||||||
|
from ray.common import ClientRemoteFunc
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self,
|
||||||
|
conn_str: str = "",
|
||||||
|
secure: bool = False,
|
||||||
|
metadata: List[Tuple[str, str]] = None,
|
||||||
|
stub=None):
|
||||||
|
"""Initializes the worker side grpc client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stub: custom grpc stub.
|
||||||
|
secure: whether to use SSL secure channel or not.
|
||||||
|
metadata: additional metadata passed in the grpc request headers.
|
||||||
|
"""
|
||||||
|
self.metadata = metadata
|
||||||
|
if stub is None:
|
||||||
|
if secure:
|
||||||
|
credentials = grpc.ssl_channel_credentials()
|
||||||
|
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||||
|
else:
|
||||||
|
self.channel = grpc.insecure_channel(conn_str)
|
||||||
|
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||||
|
else:
|
||||||
|
self.server = stub
|
||||||
|
|
||||||
|
def get(self, ids):
|
||||||
|
to_get = []
|
||||||
|
single = False
|
||||||
|
if isinstance(ids, list):
|
||||||
|
to_get = [x.id for x in ids]
|
||||||
|
elif isinstance(ids, ClientObjectRef):
|
||||||
|
to_get = [ids.id]
|
||||||
|
single = True
|
||||||
|
else:
|
||||||
|
raise Exception("Can't get something that's not a "
|
||||||
|
"list of IDs or just an ID: %s" % type(ids))
|
||||||
|
out = [self._get(x) for x in to_get]
|
||||||
|
if single:
|
||||||
|
out = out[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _get(self, id: bytes):
|
||||||
|
req = ray_client_pb2.GetRequest(id=id)
|
||||||
|
data = self.server.GetObject(req, metadata=self.metadata)
|
||||||
|
if not data.valid:
|
||||||
|
raise Exception(
|
||||||
|
"Client GetObject returned invalid data: id invalid?")
|
||||||
|
return cloudpickle.loads(data.data)
|
||||||
|
|
||||||
|
def put(self, vals):
|
||||||
|
to_put = []
|
||||||
|
single = False
|
||||||
|
if isinstance(vals, list):
|
||||||
|
to_put = vals
|
||||||
|
else:
|
||||||
|
single = True
|
||||||
|
to_put.append(vals)
|
||||||
|
|
||||||
|
out = [self._put(x) for x in to_put]
|
||||||
|
if single:
|
||||||
|
out = out[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _put(self, val):
|
||||||
|
data = cloudpickle.dumps(val)
|
||||||
|
req = ray_client_pb2.PutRequest(data=data)
|
||||||
|
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||||
|
return ClientObjectRef(resp.id)
|
||||||
|
|
||||||
|
def wait(self,
|
||||||
|
object_refs: List[ClientObjectRef],
|
||||||
|
*,
|
||||||
|
num_returns: int = 1,
|
||||||
|
timeout: float = None
|
||||||
|
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||||
|
assert isinstance(object_refs, list)
|
||||||
|
for ref in object_refs:
|
||||||
|
assert isinstance(ref, ClientObjectRef)
|
||||||
|
data = {
|
||||||
|
"object_refs": [
|
||||||
|
cloudpickle.dumps(object_ref) for object_ref in object_refs
|
||||||
|
],
|
||||||
|
"num_returns": num_returns,
|
||||||
|
"timeout": timeout if timeout else -1
|
||||||
|
}
|
||||||
|
req = ray_client_pb2.WaitRequest(**data)
|
||||||
|
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||||
|
if not resp.valid:
|
||||||
|
# TODO(ameer): improve error/exceptions messages.
|
||||||
|
raise Exception("Client Wait request failed. Reference invalid?")
|
||||||
|
client_ready_object_ids = [
|
||||||
|
ClientObjectRef(id) for id in resp.ready_object_ids
|
||||||
|
]
|
||||||
|
client_remaining_object_ids = [
|
||||||
|
ClientObjectRef(id) for id in resp.remaining_object_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return (client_ready_object_ids, client_remaining_object_ids)
|
||||||
|
|
||||||
|
def remote(self, function_or_class, *args, **kwargs):
|
||||||
|
# TODO(barakmich): Arguments to ray.remote
|
||||||
|
# get captured here.
|
||||||
|
if inspect.isfunction(function_or_class):
|
||||||
|
return ClientRemoteFunc(function_or_class)
|
||||||
|
elif inspect.isclass(function_or_class):
|
||||||
|
return ClientActorClass(function_or_class)
|
||||||
|
else:
|
||||||
|
raise TypeError("The @ray.remote decorator must be applied to "
|
||||||
|
"either a function or to a class.")
|
||||||
|
|
||||||
|
def call_remote(self, instance, kind, *args, **kwargs):
|
||||||
|
task = instance._prepare_client_task()
|
||||||
|
for arg in args:
|
||||||
|
pb_arg = convert_to_arg(arg)
|
||||||
|
task.args.append(pb_arg)
|
||||||
|
logging.debug("Scheduling %s" % task)
|
||||||
|
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||||
|
return ClientObjectRef(ticket.return_id)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.channel.close()
|
||||||
74
python/ray_worker/__init__.py
Normal file
74
python/ray_worker/__init__.py
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
import rpc.ray_client_pb2 as ray_client_pb2
|
||||||
|
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import grpc
|
||||||
|
import cloudpickle
|
||||||
|
import threading
|
||||||
|
import concurrent.futures
|
||||||
|
from ray import ray
|
||||||
|
from ray.common import ClientObjectRef
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self, conn_str):
|
||||||
|
self.channel = grpc.insecure_channel(conn_str)
|
||||||
|
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
||||||
|
self.send_queue = None
|
||||||
|
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||||
|
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
self.send_queue = queue.Queue()
|
||||||
|
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
|
||||||
|
start = ray_client_pb2.WorkStatus(
|
||||||
|
status = ray_client_pb2.WorkStatus.StatusCode.READY,
|
||||||
|
error_msg = "python_worker",
|
||||||
|
)
|
||||||
|
self.send_queue.put(start)
|
||||||
|
for work in work_stream:
|
||||||
|
print("Got work")
|
||||||
|
task = work.task
|
||||||
|
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
||||||
|
send_queue.put(ray_client_pb2.WorkStatus(
|
||||||
|
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
||||||
|
error_msg = "unimplemented",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
args = self.decode_args(task)
|
||||||
|
func = self.get(task.payload_id)
|
||||||
|
#self.pool.submit(self.run_and_return, func, args, work.ticket)
|
||||||
|
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
def run_and_return(self, func, args, ticket):
|
||||||
|
res = func(*args)
|
||||||
|
out_data = cloudpickle.dumps(res)
|
||||||
|
self.send_queue.put(ray_client_pb2.WorkStatus(
|
||||||
|
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
||||||
|
complete_data = out_data,
|
||||||
|
finished_ticket = ticket,
|
||||||
|
))
|
||||||
|
|
||||||
|
# def get(self, id_bytes):
|
||||||
|
# data = self.server.GetObject(ray_client_pb2.GetRequest(
|
||||||
|
# id = id_bytes,
|
||||||
|
# ))
|
||||||
|
# return cloudpickle.loads(data.data)
|
||||||
|
def get(self, id_bytes):
|
||||||
|
return ray.get(ClientObjectRef(id_bytes))
|
||||||
|
|
||||||
|
def decode_args(self, task):
|
||||||
|
out = []
|
||||||
|
for arg in task.args:
|
||||||
|
t = self.convert_from_arg(arg)
|
||||||
|
out.append(t)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def convert_from_arg(self, pb):
|
||||||
|
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||||
|
return self.get(pb.reference_id)
|
||||||
|
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||||
|
return cloudpickle.loads(pb.data)
|
||||||
|
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||||
|
|
@ -1 +1,7 @@
|
||||||
|
#!/bin/sh
|
||||||
python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto
|
python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto
|
||||||
|
SEDCMD="sed"
|
||||||
|
if [ -n "`which gsed`" ]; then
|
||||||
|
SEDCMD="gsed"
|
||||||
|
fi
|
||||||
|
$SEDCMD -i 's/^import ray_client_pb2/import rpc.ray_client_pb2/g' *.py
|
||||||
|
|
|
||||||
|
|
@ -1,81 +1,8 @@
|
||||||
# from rpc import ray_client_pb2
|
# from rpc import ray_client_pb2
|
||||||
# from rpc import ray_client_pb2_grpc
|
# from rpc import ray_client_pb2_grpc
|
||||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
from ray_worker import Worker
|
||||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
|
||||||
|
|
||||||
import queue
|
|
||||||
import sys
|
import sys
|
||||||
import grpc
|
from ray import ray
|
||||||
import cloudpickle
|
|
||||||
import threading
|
|
||||||
import concurrent.futures
|
|
||||||
from ray.experimental.client import ray
|
|
||||||
from ray.experimental.client.common import ClientObjectRef
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
|
||||||
def __init__(self, conn_str):
|
|
||||||
self.channel = grpc.insecure_channel(conn_str)
|
|
||||||
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
|
||||||
self.send_queue = None
|
|
||||||
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
|
||||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
|
||||||
|
|
||||||
def begin(self):
|
|
||||||
self.send_queue = queue.Queue()
|
|
||||||
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
|
|
||||||
start = ray_client_pb2.WorkStatus(
|
|
||||||
status = ray_client_pb2.WorkStatus.StatusCode.READY,
|
|
||||||
error_msg = "python_worker",
|
|
||||||
)
|
|
||||||
self.send_queue.put(start)
|
|
||||||
for work in work_stream:
|
|
||||||
print("Got work")
|
|
||||||
task = work.task
|
|
||||||
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
|
||||||
send_queue.put(ray_client_pb2.WorkStatus(
|
|
||||||
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
|
||||||
error_msg = "unimplemented",
|
|
||||||
))
|
|
||||||
continue
|
|
||||||
args = self.decode_args(task)
|
|
||||||
func = self.get(task.payload_id)
|
|
||||||
#self.pool.submit(self.run_and_return, func, args, work.ticket)
|
|
||||||
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
|
|
||||||
def run_and_return(self, func, args, ticket):
|
|
||||||
res = func(*args)
|
|
||||||
out_data = cloudpickle.dumps(res)
|
|
||||||
self.send_queue.put(ray_client_pb2.WorkStatus(
|
|
||||||
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
|
||||||
complete_data = out_data,
|
|
||||||
finished_ticket = ticket,
|
|
||||||
))
|
|
||||||
|
|
||||||
# def get(self, id_bytes):
|
|
||||||
# data = self.server.GetObject(ray_client_pb2.GetRequest(
|
|
||||||
# id = id_bytes,
|
|
||||||
# ))
|
|
||||||
# return cloudpickle.loads(data.data)
|
|
||||||
def get(self, id_bytes):
|
|
||||||
return ray.get(ClientObjectRef(id_bytes))
|
|
||||||
|
|
||||||
def decode_args(self, task):
|
|
||||||
out = []
|
|
||||||
for arg in task.args:
|
|
||||||
t = self.convert_from_arg(arg)
|
|
||||||
out.append(t)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def convert_from_arg(self, pb):
|
|
||||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
|
||||||
return self.get(pb.reference_id)
|
|
||||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
|
||||||
return cloudpickle.loads(pb.data)
|
|
||||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ func (r *Raylet) Workstream(conn WorkstreamConnection) error {
|
||||||
worker := &SimpleWorker{
|
worker := &SimpleWorker{
|
||||||
workChan: make(chan *ray_rpc.Work),
|
workChan: make(chan *ray_rpc.Work),
|
||||||
clientConn: conn,
|
clientConn: conn,
|
||||||
pool: wp,
|
pool: r.Workers,
|
||||||
}
|
}
|
||||||
r.Workers.Register(worker)
|
r.Workers.Register(worker)
|
||||||
err := worker.Run()
|
err := worker.Run()
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ func (s *SimpleWorker) AssignWork(work *ray_rpc.Work) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SimpleWorker) Close() error {
|
func (s *SimpleWorker) Close() error {
|
||||||
close(w.workChan)
|
close(s.workChan)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ func (wp *SimpleRRWorkerPool) Register(worker Worker) error {
|
||||||
wp.Lock()
|
wp.Lock()
|
||||||
defer wp.Unlock()
|
defer wp.Unlock()
|
||||||
wp.workers = append(wp.workers, worker)
|
wp.workers = append(wp.workers, worker)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {
|
func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue