small python ray
This commit is contained in:
commit
e02e80708a
15 changed files with 1261 additions and 0 deletions
104
python/dumb_raylet.py
Normal file
104
python/dumb_raylet.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import cloudpickle
|
||||
import ray
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import proto.task_pb2
|
||||
import proto.task_pb2_grpc
|
||||
import threading
|
||||
import time
|
||||
|
||||
id_printer = 1
|
||||
|
||||
|
||||
def generate_id():
|
||||
global id_printer
|
||||
out = str(id_printer).encode()
|
||||
id_printer += 1
|
||||
return out
|
||||
|
||||
|
||||
class ObjectStore:
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get(self, id):
|
||||
with self.lock:
|
||||
return self.data[id]
|
||||
|
||||
def put(self, data):
|
||||
with self.lock:
|
||||
id = generate_id()
|
||||
self.data[id] = data
|
||||
return id
|
||||
|
||||
|
||||
class TaskServicer(proto.task_pb2_grpc.TaskServerServicer):
|
||||
def __init__(self, objects, executor):
|
||||
self.objects = objects
|
||||
self.executor = executor
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
data = self.objects.get(request.id)
|
||||
if data is None:
|
||||
return proto.task_pb2.GetResponse(valid=False)
|
||||
return proto.task_pb2.GetResponse(valid=True, data=data)
|
||||
|
||||
def PutObject(self, request, context=None):
|
||||
id = self.objects.put(request.data)
|
||||
return proto.task_pb2.PutResponse(id=id)
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
return_val = self.executor.execute(task, context)
|
||||
return proto.task_pb2.TaskTicket(return_id=return_val)
|
||||
|
||||
|
||||
class SimpleExecutor:
|
||||
def __init__(self, object_store):
|
||||
self.objects = object_store
|
||||
|
||||
def execute(self, task, context):
|
||||
#print("Executing task", task.name)
|
||||
self.context = context
|
||||
func_data = self.objects.get(task.payload_id)
|
||||
if func_data is None:
|
||||
raise Exception("WTF")
|
||||
f = cloudpickle.loads(func_data)
|
||||
args = []
|
||||
for a in task.args:
|
||||
if a.local == proto.task_pb2.Arg.Locality.INTERNED:
|
||||
data = a.data
|
||||
else:
|
||||
data = self.objects.get(a.reference_id)
|
||||
realarg = cloudpickle.loads(data)
|
||||
args.append(realarg)
|
||||
out = f(*args)
|
||||
id = self.objects.put(cloudpickle.dumps(out))
|
||||
self.context = None
|
||||
return id
|
||||
|
||||
def set_task_servicer(self, servicer):
|
||||
ray._global_worker = ray.Worker(stub=servicer)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
object_store = ObjectStore()
|
||||
executor = SimpleExecutor(object_store)
|
||||
task_servicer = TaskServicer(object_store, executor)
|
||||
executor.set_task_servicer(task_servicer)
|
||||
proto.task_pb2_grpc.add_TaskServerServicer_to_server(
|
||||
task_servicer, server)
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig()
|
||||
serve()
|
||||
Loading…
Add table
Add a link
Reference in a new issue