88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
"""This file includes the Worker class which sits on the client side.
|
|
It implements the Ray API functions that are forwarded through grpc calls
|
|
to the server.
|
|
"""
|
|
import inspect
|
|
import logging
|
|
from typing import List
|
|
from typing import Tuple
|
|
|
|
import cloudpickle
|
|
|
|
import ray.webpb as webpb
|
|
import ray.web as webcall
|
|
from ray.common import convert_to_arg
|
|
from ray.common import ClientObjectRef
|
|
from ray.common import ClientActorClass
|
|
from ray.common import ClientRemoteFunc
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Worker:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get(self, ids):
|
|
to_get = []
|
|
single = False
|
|
if isinstance(ids, list):
|
|
to_get = [x.id for x in ids]
|
|
elif isinstance(ids, ClientObjectRef):
|
|
to_get = [ids.id]
|
|
single = True
|
|
else:
|
|
raise Exception("Can't get something that's not a "
|
|
"list of IDs or just an ID: %s" % type(ids))
|
|
out = [self._get(x) for x in to_get]
|
|
if single:
|
|
out = out[0]
|
|
return out
|
|
|
|
def _get(self, id: bytes):
|
|
data = webcall.get(id)
|
|
if not data["valid"]:
|
|
raise Exception(
|
|
"Client GetObject returned invalid data: id invalid?")
|
|
return webpb.loads(data["data"])
|
|
|
|
def put(self, vals):
|
|
to_put = []
|
|
single = False
|
|
if isinstance(vals, list):
|
|
to_put = vals
|
|
else:
|
|
single = True
|
|
to_put.append(vals)
|
|
|
|
out = [self._put(x) for x in to_put]
|
|
if single:
|
|
out = out[0]
|
|
return out
|
|
|
|
def _put(self, val):
|
|
id = webcall.put(val)
|
|
return ClientObjectRef(id)
|
|
|
|
def remote(self, function_or_class, *args, **kwargs):
|
|
# TODO(barakmich): Arguments to ray.remote
|
|
# get captured here.
|
|
if inspect.isfunction(function_or_class):
|
|
return ClientRemoteFunc(function_or_class)
|
|
elif inspect.isclass(function_or_class):
|
|
return ClientActorClass(function_or_class)
|
|
else:
|
|
raise TypeError("The @ray.remote decorator must be applied to "
|
|
"either a function or to a class.")
|
|
|
|
def call_remote(self, instance, kind, *args, **kwargs):
|
|
task = instance._prepare_client_task()
|
|
for arg in args:
|
|
pb_arg = convert_to_arg(arg)
|
|
task.args.append(pb_arg)
|
|
logging.debug("Scheduling %s" % task)
|
|
return_id = webcall.schedule(task.toJsonable())
|
|
return ClientObjectRef(return_id)
|
|
|
|
def close(self):
|
|
pass
|