104 lines
2.8 KiB
Python
104 lines
2.8 KiB
Python
import cloudpickle
|
|
import ray
|
|
import logging
|
|
from concurrent import futures
|
|
import grpc
|
|
import proto.task_pb2
|
|
import proto.task_pb2_grpc
|
|
import threading
|
|
import time
|
|
|
|
id_printer = 1
|
|
|
|
|
|
def generate_id():
|
|
global id_printer
|
|
out = str(id_printer).encode()
|
|
id_printer += 1
|
|
return out
|
|
|
|
|
|
class ObjectStore:
|
|
def __init__(self):
|
|
self.data = {}
|
|
self.lock = threading.Lock()
|
|
|
|
def get(self, id):
|
|
with self.lock:
|
|
return self.data[id]
|
|
|
|
def put(self, data):
|
|
with self.lock:
|
|
id = generate_id()
|
|
self.data[id] = data
|
|
return id
|
|
|
|
|
|
class TaskServicer(proto.task_pb2_grpc.TaskServerServicer):
|
|
def __init__(self, objects, executor):
|
|
self.objects = objects
|
|
self.executor = executor
|
|
|
|
def GetObject(self, request, context=None):
|
|
data = self.objects.get(request.id)
|
|
if data is None:
|
|
return proto.task_pb2.GetResponse(valid=False)
|
|
return proto.task_pb2.GetResponse(valid=True, data=data)
|
|
|
|
def PutObject(self, request, context=None):
|
|
id = self.objects.put(request.data)
|
|
return proto.task_pb2.PutResponse(id=id)
|
|
|
|
def Schedule(self, task, context=None):
|
|
return_val = self.executor.execute(task, context)
|
|
return proto.task_pb2.TaskTicket(return_id=return_val)
|
|
|
|
|
|
class SimpleExecutor:
|
|
def __init__(self, object_store):
|
|
self.objects = object_store
|
|
|
|
def execute(self, task, context):
|
|
#print("Executing task", task.name)
|
|
self.context = context
|
|
func_data = self.objects.get(task.payload_id)
|
|
if func_data is None:
|
|
raise Exception("WTF")
|
|
f = cloudpickle.loads(func_data)
|
|
args = []
|
|
for a in task.args:
|
|
if a.local == proto.task_pb2.Arg.Locality.INTERNED:
|
|
data = a.data
|
|
else:
|
|
data = self.objects.get(a.reference_id)
|
|
realarg = cloudpickle.loads(data)
|
|
args.append(realarg)
|
|
out = f(*args)
|
|
id = self.objects.put(cloudpickle.dumps(out))
|
|
self.context = None
|
|
return id
|
|
|
|
def set_task_servicer(self, servicer):
|
|
ray._global_worker = ray.Worker(stub=servicer)
|
|
|
|
|
|
def serve():
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
object_store = ObjectStore()
|
|
executor = SimpleExecutor(object_store)
|
|
task_servicer = TaskServicer(object_store, executor)
|
|
executor.set_task_servicer(task_servicer)
|
|
proto.task_pb2_grpc.add_TaskServerServicer_to_server(
|
|
task_servicer, server)
|
|
server.add_insecure_port('[::]:50051')
|
|
server.start()
|
|
try:
|
|
while True:
|
|
time.sleep(1000)
|
|
except KeyboardInterrupt:
|
|
server.stop(0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig()
|
|
serve()
|