rayclient/python/dumb_raylet.py
2020-09-12 15:57:57 -07:00

104 lines
2.8 KiB
Python

import cloudpickle
import ray
import logging
from concurrent import futures
import grpc
import proto.task_pb2
import proto.task_pb2_grpc
import threading
import time
id_printer = 1
def generate_id():
global id_printer
out = str(id_printer).encode()
id_printer += 1
return out
class ObjectStore:
def __init__(self):
self.data = {}
self.lock = threading.Lock()
def get(self, id):
with self.lock:
return self.data[id]
def put(self, data):
with self.lock:
id = generate_id()
self.data[id] = data
return id
class TaskServicer(proto.task_pb2_grpc.TaskServerServicer):
def __init__(self, objects, executor):
self.objects = objects
self.executor = executor
def GetObject(self, request, context=None):
data = self.objects.get(request.id)
if data is None:
return proto.task_pb2.GetResponse(valid=False)
return proto.task_pb2.GetResponse(valid=True, data=data)
def PutObject(self, request, context=None):
id = self.objects.put(request.data)
return proto.task_pb2.PutResponse(id=id)
def Schedule(self, task, context=None):
return_val = self.executor.execute(task, context)
return proto.task_pb2.TaskTicket(return_id=return_val)
class SimpleExecutor:
def __init__(self, object_store):
self.objects = object_store
def execute(self, task, context):
#print("Executing task", task.name)
self.context = context
func_data = self.objects.get(task.payload_id)
if func_data is None:
raise Exception("WTF")
f = cloudpickle.loads(func_data)
args = []
for a in task.args:
if a.local == proto.task_pb2.Arg.Locality.INTERNED:
data = a.data
else:
data = self.objects.get(a.reference_id)
realarg = cloudpickle.loads(data)
args.append(realarg)
out = f(*args)
id = self.objects.put(cloudpickle.dumps(out))
self.context = None
return id
def set_task_servicer(self, servicer):
ray._global_worker = ray.Worker(stub=servicer)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
object_store = ObjectStore()
executor = SimpleExecutor(object_store)
task_servicer = TaskServicer(object_store, executor)
executor.set_task_servicer(task_servicer)
proto.task_pb2_grpc.add_TaskServerServicer_to_server(
task_servicer, server)
server.add_insecure_port('[::]:50051')
server.start()
try:
while True:
time.sleep(1000)
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
logging.basicConfig()
serve()