import cloudpickle import ray import logging from concurrent import futures import grpc import proto.task_pb2 import proto.task_pb2_grpc import threading import time id_printer = 1 def generate_id(): global id_printer out = str(id_printer).encode() id_printer += 1 return out class ObjectStore: def __init__(self): self.data = {} self.lock = threading.Lock() def get(self, id): with self.lock: return self.data[id] def put(self, data): with self.lock: id = generate_id() self.data[id] = data return id class TaskServicer(proto.task_pb2_grpc.TaskServerServicer): def __init__(self, objects, executor): self.objects = objects self.executor = executor def GetObject(self, request, context=None): data = self.objects.get(request.id) if data is None: return proto.task_pb2.GetResponse(valid=False) return proto.task_pb2.GetResponse(valid=True, data=data) def PutObject(self, request, context=None): id = self.objects.put(request.data) return proto.task_pb2.PutResponse(id=id) def Schedule(self, task, context=None): return_val = self.executor.execute(task, context) return proto.task_pb2.TaskTicket(return_id=return_val) class SimpleExecutor: def __init__(self, object_store): self.objects = object_store def execute(self, task, context): #print("Executing task", task.name) self.context = context func_data = self.objects.get(task.payload_id) if func_data is None: raise Exception("WTF") f = cloudpickle.loads(func_data) args = [] for a in task.args: if a.local == proto.task_pb2.Arg.Locality.INTERNED: data = a.data else: data = self.objects.get(a.reference_id) realarg = cloudpickle.loads(data) args.append(realarg) out = f(*args) id = self.objects.put(cloudpickle.dumps(out)) self.context = None return id def set_task_servicer(self, servicer): ray._global_worker = ray.Worker(stub=servicer) def serve(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) object_store = ObjectStore() executor = SimpleExecutor(object_store) task_servicer = TaskServicer(object_store, executor) executor.set_task_servicer(task_servicer) proto.task_pb2_grpc.add_TaskServerServicer_to_server( task_servicer, server) server.add_insecure_port('[::]:50051') server.start() try: while True: time.sleep(1000) except KeyboardInterrupt: server.stop(0) if __name__ == '__main__': logging.basicConfig() serve()