hermetically seal everything, bring it up to date

This commit is contained in:
Barak Michener 2020-12-04 13:00:54 -08:00
parent efd3d65269
commit 03beb6970c
12 changed files with 723 additions and 77 deletions

17
python/client.py Normal file
View file

@ -0,0 +1,17 @@
from ray import ray
ray.connect("localhost:50050")
@ray.remote
def plus2(x):
return x + 2
print(ray.get(plus2.remote(4)))
@ray.remote
def fact(x):
if x <= 0:
return 1
return x * ray.get(fact.remote(x - 1))
print(ray.get(fact.remote(20)))

107
python/ray/__init__.py Normal file
View file

@ -0,0 +1,107 @@
from ray.api import ClientAPI
from ray.api import APIImpl
from typing import Optional, List, Tuple
from contextlib import contextmanager
import logging
logger = logging.getLogger(__name__)
# _client_api has to be external to the API stub, below.
# Otherwise, ray.remote() that contains ray.remote()
# contains a reference to the RayAPIStub, therefore a
# reference to the _client_api, and then tries to pickle
# the thing.
_client_api: Optional[APIImpl] = None
_server_api: Optional[APIImpl] = None
_is_server: bool = False
@contextmanager
def stash_api_for_tests(in_test: bool):
global _is_server
is_server = _is_server
if in_test:
_is_server = True
yield _server_api
if in_test:
_is_server = is_server
def _set_client_api(val: Optional[APIImpl]):
global _client_api
global _is_server
if _client_api is not None:
raise Exception("Trying to set more than one client API")
_client_api = val
_is_server = False
def _set_server_api(val: Optional[APIImpl]):
global _server_api
global _is_server
if _server_api is not None:
raise Exception("Trying to set more than one server API")
_server_api = val
_is_server = True
def reset_api():
global _client_api
global _server_api
global _is_server
_client_api = None
_server_api = None
_is_server = False
def _get_client_api() -> APIImpl:
global _client_api
global _server_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
raise Exception("CoreRayAPI not supported")
return api
class RayAPIStub:
def connect(self,
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
from ray.worker import Worker
_client_worker = Worker(
conn_str, secure=secure, metadata=metadata, stub=stub)
_set_client_api(ClientAPI(_client_worker))
def disconnect(self):
global _client_api
if _client_api is not None:
_client_api.close()
_client_api = None
def __getattr__(self, key: str):
global _get_client_api
api = _get_client_api()
return getattr(api, key)
ray = RayAPIStub()
# Someday we might add methods in this module so that someone who
# tries to `import ray_client as ray` -- as a module, instead of
# `from ray_client import ray` -- as the API stub
# still gets expected functionality. This is the way the ray package
# worked in the past.
#
# This really calls for PEP 562: https://www.python.org/dev/peps/pep-0562/
# But until Python 3.6 is EOL, here we are.

77
python/ray/api.py Normal file
View file

@ -0,0 +1,77 @@
# This file defines an interface and client-side API stub
# for referring either to the core Ray API or the same interface
# from the Ray client.
#
# In tandem with __init__.py, we want to expose an API that's
# close to `python/ray/__init__.py` but with more than one implementation.
# The stubs in __init__ should call into a well-defined interface.
# Only the core Ray API implementation should actually `import ray`
# (and thus import all the raylet worker C bindings and such).
# But to make sure that we're matching these calls, we define this API.
from abc import ABC
from abc import abstractmethod
class APIImpl(ABC):
@abstractmethod
def get(self, *args, **kwargs):
pass
@abstractmethod
def put(self, *args, **kwargs):
pass
@abstractmethod
def wait(self, *args, **kwargs):
pass
@abstractmethod
def remote(self, *args, **kwargs):
pass
@abstractmethod
def call_remote(self, instance, kind: int, *args, **kwargs):
pass
@abstractmethod
def get_actor_from_object(self, id):
pass
@abstractmethod
def close(self, *args, **kwargs):
pass
class ClientAPI(APIImpl):
def __init__(self, worker):
self.worker = worker
def get(self, *args, **kwargs):
return self.worker.get(*args, **kwargs)
def put(self, *args, **kwargs):
return self.worker.put(*args, **kwargs)
def wait(self, *args, **kwargs):
return self.worker.wait(*args, **kwargs)
def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs)
def call_remote(self, f, kind, *args, **kwargs):
return self.worker.call_remote(f, kind, *args, **kwargs)
def get_actor_from_object(self, id):
raise Exception("Calling get_actor_from_object on the client side")
def close(self, *args, **kwargs):
return self.worker.close()
def __getattr__(self, key: str):
if not key.startswith("_"):
raise NotImplementedError(
"Not available in Ray client: `ray.{}`. This method is only "
"available within Ray remote functions and is not yet "
"implemented in the client API.".format(key))
return self.__getattribute__(key)

90
python/ray/client_app.py Normal file
View file

@ -0,0 +1,90 @@
from ray import ray
from typing import Tuple
ray.connect("localhost:50051")
@ray.remote
class HelloActor:
def __init__(self):
self.count = 0
def say_hello(self, whom: str) -> Tuple[str, int]:
self.count += 1
return ("Hello " + whom, self.count)
actor = HelloActor.remote()
s, count = ray.get(actor.say_hello.remote("you"))
print(s, count)
assert s == "Hello you"
assert count == 1
s, count = ray.get(actor.say_hello.remote("world"))
print(s, count)
assert s == "Hello world"
assert count == 2
@ray.remote
def plus2(x):
return x + 2
@ray.remote
def fact(x):
print(x, type(fact))
if x <= 0:
return 1
# This hits the "nested tasks" issue
# https://github.com/ray-project/ray/issues/3644
# So we're on the right track!
return ray.get(fact.remote(x - 1)) * x
@ray.remote
def get_nodes():
return ray.nodes() # Can access the full Ray API in remote methods.
print("Cluster nodes", ray.get(get_nodes.remote()))
print(ray.nodes())
objectref = ray.put("hello world")
# `ClientObjectRef(...)`
print(objectref)
# `hello world`
print(ray.get(objectref))
ref2 = plus2.remote(234)
# `ClientObjectRef(...)`
print(ref2)
# `236`
print(ray.get(ref2))
ref3 = fact.remote(20)
# `ClientObjectRef(...)`
print(ref3)
# `2432902008176640000`
print(ray.get(ref3))
# Reuse the cached ClientRemoteFunc object
ref4 = fact.remote(5)
# `120`
print(ray.get(ref4))
ref5 = fact.remote(10)
print([ref2, ref3, ref4, ref5])
# should return ref2, ref3, ref4
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
print(res)
assert [ref2, ref3, ref4] == res[0]
assert [ref5] == res[1]
# should return ref2, ref3, ref4, ref5
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
print(res)
assert [ref2, ref3, ref4, ref5] == res[0]
assert [] == res[1]

205
python/ray/common.py Normal file
View file

@ -0,0 +1,205 @@
import rpc.ray_client_pb2 as ray_client_pb2
from ray import ray
from typing import Any
from typing import Dict
import cloudpickle
class ClientBaseRef:
def __init__(self, id):
self.id = id
def __repr__(self):
return "%s(%s)" % (
type(self).__name__,
self.id.hex(),
)
def __eq__(self, other):
return self.id == other.id
def binary(self):
return self.id
class ClientObjectRef(ClientBaseRef):
pass
class ClientActorNameRef(ClientBaseRef):
pass
class ClientRemoteFunc:
def __init__(self, f):
self._func = f
self._name = f.__name__
self.id = None
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.FUNCTION, *args,
**kwargs)
def _get_ray_remote_impl(self):
if self._raylet_remote is None:
self._raylet_remote = ray.remote(self._func)
return self._raylet_remote
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self.id)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self._func)
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.id
return task
class ClientActorClass:
def __init__(self, actor_cls):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
self._ref = None
self._raylet_remote = None
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead")
def __getstate__(self) -> Dict:
state = {
"actor_cls": self.actor_cls,
"_name": self._name,
"_ref": self._ref,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs):
# Actually instantiate the actor
ref = ray.call_remote(self, ray_client_pb2.ClientTask.ACTOR, *args,
**kwargs)
return ClientActorHandle(ClientActorNameRef(ref.id), self)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key):
raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self.actor_cls)
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.id
return task
class ClientActorHandle:
def __init__(self, actor_id: ClientActorNameRef,
actor_class: ClientActorClass):
self.actor_id = actor_id
self.actor_class = actor_class
self._real_actor_handle = None
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = ray.get_actor_from_object(self.actor_id)
return self._real_actor_handle
def __getstate__(self) -> Dict:
state = {
"actor_id": self.actor_id,
"actor_class": self.actor_class,
"_real_actor_handle": self._real_actor_handle,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_id = state["actor_id"]
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
def __getattr__(self, key):
return ClientRemoteMethod(self, key)
def __repr__(self):
return "ClientActorHandle(%s, %s, %s)" % (
self.actor_id, self.actor_class, self._real_actor_handle)
class ClientRemoteMethod:
def __init__(self, actor_handle: ClientActorHandle, method_name: str):
self.actor_handle = actor_handle
self.method_name = method_name
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead")
def _get_ray_remote_impl(self):
return getattr(self.actor_handle._get_ray_remote_impl(),
self.method_name)
def __getstate__(self) -> Dict:
state = {
"actor_handle": self.actor_handle,
"method_name": self.method_name,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_handle = state["actor_handle"]
self.method_name = state["method_name"]
def remote(self, *args, **kwargs):
return ray.call_remote(self, ray_client_pb2.ClientTask.METHOD, *args,
**kwargs)
def __repr__(self):
name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
return "ClientRemoteMethod(%s, %s)" % (name,
self.actor_handle.actor_id)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_id.id
return task
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return ClientObjectRef(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
raise Exception("convert_from_arg: Uncovered locality enum")
def convert_to_arg(val):
out = ray_client_pb2.Arg()
if isinstance(val, ClientObjectRef):
out.local = ray_client_pb2.Arg.Locality.REFERENCE
out.reference_id = val.id
else:
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = cloudpickle.dumps(val)
return out

142
python/ray/worker.py Normal file
View file

@ -0,0 +1,142 @@
"""This file includes the Worker class which sits on the client side.
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import inspect
import logging
from typing import List
from typing import Tuple
import cloudpickle
import grpc
import rpc.ray_client_pb2 as ray_client_pb2
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.common import convert_to_arg
from ray.common import ClientObjectRef
from ray.common import ClientActorClass
from ray.common import ClientRemoteFunc
logger = logging.getLogger(__name__)
class Worker:
def __init__(self,
conn_str: str = "",
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
"""Initializes the worker side grpc client.
Args:
stub: custom grpc stub.
secure: whether to use SSL secure channel or not.
metadata: additional metadata passed in the grpc request headers.
"""
self.metadata = metadata
if stub is None:
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
else:
self.server = stub
def get(self, ids):
to_get = []
single = False
if isinstance(ids, list):
to_get = [x.id for x in ids]
elif isinstance(ids, ClientObjectRef):
to_get = [ids.id]
single = True
else:
raise Exception("Can't get something that's not a "
"list of IDs or just an ID: %s" % type(ids))
out = [self._get(x) for x in to_get]
if single:
out = out[0]
return out
def _get(self, id: bytes):
req = ray_client_pb2.GetRequest(id=id)
data = self.server.GetObject(req, metadata=self.metadata)
if not data.valid:
raise Exception(
"Client GetObject returned invalid data: id invalid?")
return cloudpickle.loads(data.data)
def put(self, vals):
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val):
data = cloudpickle.dumps(val)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req, metadata=self.metadata)
return ClientObjectRef(resp.id)
def wait(self,
object_refs: List[ClientObjectRef],
*,
num_returns: int = 1,
timeout: float = None
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
assert isinstance(object_refs, list)
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
data = {
"object_refs": [
cloudpickle.dumps(object_ref) for object_ref in object_refs
],
"num_returns": num_returns,
"timeout": timeout if timeout else -1
}
req = ray_client_pb2.WaitRequest(**data)
resp = self.server.WaitObject(req, metadata=self.metadata)
if not resp.valid:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
client_ready_object_ids = [
ClientObjectRef(id) for id in resp.ready_object_ids
]
client_remaining_object_ids = [
ClientObjectRef(id) for id in resp.remaining_object_ids
]
return (client_ready_object_ids, client_remaining_object_ids)
def remote(self, function_or_class, *args, **kwargs):
# TODO(barakmich): Arguments to ray.remote
# get captured here.
if inspect.isfunction(function_or_class):
return ClientRemoteFunc(function_or_class)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, instance, kind, *args, **kwargs):
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef(ticket.return_id)
def close(self):
self.channel.close()

View file

@ -0,0 +1,74 @@
import rpc.ray_client_pb2 as ray_client_pb2
import rpc.ray_client_pb2_grpc as ray_client_pb2_grpc
import queue
import grpc
import cloudpickle
import threading
import concurrent.futures
from ray import ray
from ray.common import ClientObjectRef
class Worker:
def __init__(self, conn_str):
self.channel = grpc.insecure_channel(conn_str)
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
self.send_queue = None
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
def begin(self):
self.send_queue = queue.Queue()
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
start = ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.READY,
error_msg = "python_worker",
)
self.send_queue.put(start)
for work in work_stream:
print("Got work")
task = work.task
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
error_msg = "unimplemented",
))
continue
args = self.decode_args(task)
func = self.get(task.payload_id)
#self.pool.submit(self.run_and_return, func, args, work.ticket)
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
t.start()
def run_and_return(self, func, args, ticket):
res = func(*args)
out_data = cloudpickle.dumps(res)
self.send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
complete_data = out_data,
finished_ticket = ticket,
))
# def get(self, id_bytes):
# data = self.server.GetObject(ray_client_pb2.GetRequest(
# id = id_bytes,
# ))
# return cloudpickle.loads(data.data)
def get(self, id_bytes):
return ray.get(ClientObjectRef(id_bytes))
def decode_args(self, task):
out = []
for arg in task.args:
t = self.convert_from_arg(arg)
out.append(t)
return out
def convert_from_arg(self, pb):
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return self.get(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
raise Exception("convert_from_arg: Uncovered locality enum")

View file

@ -1 +1,7 @@
#!/bin/sh
python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto
SEDCMD="sed"
if [ -n "`which gsed`" ]; then
SEDCMD="gsed"
fi
$SEDCMD -i 's/^import ray_client_pb2/import rpc.ray_client_pb2/g' *.py

View file

@ -1,81 +1,8 @@
# from rpc import ray_client_pb2
# from rpc import ray_client_pb2_grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import queue
from ray_worker import Worker
import sys
import grpc
import cloudpickle
import threading
import concurrent.futures
from ray.experimental.client import ray
from ray.experimental.client.common import ClientObjectRef
class Worker:
def __init__(self, conn_str):
self.channel = grpc.insecure_channel(conn_str)
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
self.send_queue = None
self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=3)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
def begin(self):
self.send_queue = queue.Queue()
work_stream = self.worker_stub.Workstream(iter(self.send_queue.get, None))
start = ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.READY,
error_msg = "python_worker",
)
self.send_queue.put(start)
for work in work_stream:
print("Got work")
task = work.task
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
error_msg = "unimplemented",
))
continue
args = self.decode_args(task)
func = self.get(task.payload_id)
#self.pool.submit(self.run_and_return, func, args, work.ticket)
t = threading.Thread(target=self.run_and_return, args=(func, args, work.ticket))
t.start()
def run_and_return(self, func, args, ticket):
res = func(*args)
out_data = cloudpickle.dumps(res)
self.send_queue.put(ray_client_pb2.WorkStatus(
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
complete_data = out_data,
finished_ticket = ticket,
))
# def get(self, id_bytes):
# data = self.server.GetObject(ray_client_pb2.GetRequest(
# id = id_bytes,
# ))
# return cloudpickle.loads(data.data)
def get(self, id_bytes):
return ray.get(ClientObjectRef(id_bytes))
def decode_args(self, task):
out = []
for arg in task.args:
t = self.convert_from_arg(arg)
out.append(t)
return out
def convert_from_arg(self, pb):
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return self.get(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
raise Exception("convert_from_arg: Uncovered locality enum")
from ray import ray
def main():

View file

@ -58,7 +58,7 @@ func (r *Raylet) Workstream(conn WorkstreamConnection) error {
worker := &SimpleWorker{
workChan: make(chan *ray_rpc.Work),
clientConn: conn,
pool: wp,
pool: r.Workers,
}
r.Workers.Register(worker)
err := worker.Run()

View file

@ -27,7 +27,7 @@ func (s *SimpleWorker) AssignWork(work *ray_rpc.Work) error {
}
func (s *SimpleWorker) Close() error {
close(w.workChan)
close(s.workChan)
return nil
}

View file

@ -34,6 +34,7 @@ func (wp *SimpleRRWorkerPool) Register(worker Worker) error {
wp.Lock()
defer wp.Unlock()
wp.workers = append(wp.workers, worker)
return nil
}
func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {