tinkerbell/web_python/ray/worker.py
2020-12-04 18:01:27 -08:00

88 lines
2.5 KiB
Python

"""This file includes the Worker class which sits on the client side.
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import inspect
import logging
from typing import List
from typing import Tuple
import cloudpickle
import ray.webpb as webpb
import ray.web as webcall
from ray.common import convert_to_arg
from ray.common import ClientObjectRef
from ray.common import ClientActorClass
from ray.common import ClientRemoteFunc
logger = logging.getLogger(__name__)
class Worker:
def __init__(self):
pass
def get(self, ids):
to_get = []
single = False
if isinstance(ids, list):
to_get = [x.id for x in ids]
elif isinstance(ids, ClientObjectRef):
to_get = [ids.id]
single = True
else:
raise Exception("Can't get something that's not a "
"list of IDs or just an ID: %s" % type(ids))
out = [self._get(x) for x in to_get]
if single:
out = out[0]
return out
def _get(self, id: bytes):
data = webcall.get(id)
if not data["valid"]:
raise Exception(
"Client GetObject returned invalid data: id invalid?")
return webpb.loads(data["data"])
def put(self, vals):
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val):
id = webcall.put(val)
return ClientObjectRef(id)
def remote(self, function_or_class, *args, **kwargs):
# TODO(barakmich): Arguments to ray.remote
# get captured here.
if inspect.isfunction(function_or_class):
return ClientRemoteFunc(function_or_class)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, instance, kind, *args, **kwargs):
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg)
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
return_id = webcall.schedule(task.toJsonable())
return ClientObjectRef(return_id)
def close(self):
pass