small python ray

This commit is contained in:
Barak Michener 2020-09-12 15:57:57 -07:00
commit e02e80708a
15 changed files with 1261 additions and 0 deletions

25
python/BUILD Normal file
View file

@ -0,0 +1,25 @@
load("@py_deps//:requirements.bzl", "requirement")
load("//bazel:tools.bzl", "copy_to_workspace")
filegroup(
name = "all_py_proto",
srcs = [
"//proto:task_proto_grpc_py",
],
)
copy_to_workspace(
name = "proto_folder",
srcs = [":all_py_proto"],
dstdir = "python/proto",
)
py_binary(
name = "main",
srcs = ["main.py"],
data = [ ":proto_folder" ],
deps = [
requirement("cloudpickle"),
],
)

104
python/dumb_raylet.py Normal file
View file

@ -0,0 +1,104 @@
import cloudpickle
import ray
import logging
from concurrent import futures
import grpc
import proto.task_pb2
import proto.task_pb2_grpc
import threading
import time
id_printer = 1
def generate_id():
global id_printer
out = str(id_printer).encode()
id_printer += 1
return out
class ObjectStore:
def __init__(self):
self.data = {}
self.lock = threading.Lock()
def get(self, id):
with self.lock:
return self.data[id]
def put(self, data):
with self.lock:
id = generate_id()
self.data[id] = data
return id
class TaskServicer(proto.task_pb2_grpc.TaskServerServicer):
def __init__(self, objects, executor):
self.objects = objects
self.executor = executor
def GetObject(self, request, context=None):
data = self.objects.get(request.id)
if data is None:
return proto.task_pb2.GetResponse(valid=False)
return proto.task_pb2.GetResponse(valid=True, data=data)
def PutObject(self, request, context=None):
id = self.objects.put(request.data)
return proto.task_pb2.PutResponse(id=id)
def Schedule(self, task, context=None):
return_val = self.executor.execute(task, context)
return proto.task_pb2.TaskTicket(return_id=return_val)
class SimpleExecutor:
def __init__(self, object_store):
self.objects = object_store
def execute(self, task, context):
#print("Executing task", task.name)
self.context = context
func_data = self.objects.get(task.payload_id)
if func_data is None:
raise Exception("WTF")
f = cloudpickle.loads(func_data)
args = []
for a in task.args:
if a.local == proto.task_pb2.Arg.Locality.INTERNED:
data = a.data
else:
data = self.objects.get(a.reference_id)
realarg = cloudpickle.loads(data)
args.append(realarg)
out = f(*args)
id = self.objects.put(cloudpickle.dumps(out))
self.context = None
return id
def set_task_servicer(self, servicer):
ray._global_worker = ray.Worker(stub=servicer)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
object_store = ObjectStore()
executor = SimpleExecutor(object_store)
task_servicer = TaskServicer(object_store, executor)
executor.set_task_servicer(task_servicer)
proto.task_pb2_grpc.add_TaskServerServicer_to_server(
task_servicer, server)
server.add_insecure_port('[::]:50051')
server.start()
try:
while True:
time.sleep(1000)
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
logging.basicConfig()
serve()

34
python/example_client.py Normal file
View file

@ -0,0 +1,34 @@
import cloudpickle
import logging
from concurrent import futures
import grpc
import ray
ray.connect("localhost:50051")
@ray.remote
def fact(x):
out = 1
for i in range(1, x+1):
out = i * out
return out
@ray.remote
def fib(x):
if x <= 1:
return 1
return ray.get(fib.remote(x - 1)) + ray.get(fib.remote(x - 2))
def run():
# out = fact.remote(5)
# out2 = fact.remote(out)
# print("Got fact: ", ray.get(out2))
fib_out = fib.remote(6)
# print("Fib out:", ray.get(fib_out))
if __name__ == "__main__":
run()

96
python/example_test.py Normal file
View file

@ -0,0 +1,96 @@
import ray
import numpy as np
import time
import dumb_raylet
#ray.connect("localhost:50051")
object_store = dumb_raylet.ObjectStore()
executor = dumb_raylet.SimpleExecutor(object_store)
task_servicer = dumb_raylet.TaskServicer(object_store, executor)
executor.set_task_servicer(task_servicer)
def test_timing():
@ray.remote
def empty_function():
pass
@ray.remote
def trivial_function():
return 1
# Measure the time required to submit a remote task to the scheduler.
elapsed_times = []
for _ in range(1000):
start_time = time.time()
empty_function.remote()
end_time = time.time()
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
average_elapsed_time = sum(elapsed_times) / 1000
print("Time required to submit an empty function call:")
print(" Average: {}".format(average_elapsed_time))
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.00038.
# Measure the time required to submit a remote task to the scheduler
# (where the remote task returns one value).
elapsed_times = []
for _ in range(1000):
start_time = time.time()
trivial_function.remote()
end_time = time.time()
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
average_elapsed_time = sum(elapsed_times) / 1000
print("Time required to submit a trivial function call:")
print(" Average: {}".format(average_elapsed_time))
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# Measure the time required to submit a remote task to the scheduler
# and get the result.
elapsed_times = []
for _ in range(1000):
start_time = time.time()
x = trivial_function.remote()
ray.get(x)
end_time = time.time()
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
average_elapsed_time = sum(elapsed_times) / 1000
print("Time required to submit a trivial function call and get the "
"result:")
print(" Average: {}".format(average_elapsed_time))
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.0013.
# Measure the time required to do do a put.
elapsed_times = []
for _ in range(1000):
start_time = time.time()
ray.put(1)
end_time = time.time()
elapsed_times.append(end_time - start_time)
elapsed_times = np.sort(elapsed_times)
average_elapsed_time = sum(elapsed_times) / 1000
print("Time required to put an int:")
print(" Average: {}".format(average_elapsed_time))
print(" 90th percentile: {}".format(elapsed_times[900]))
print(" 99th percentile: {}".format(elapsed_times[990]))
print(" worst: {}".format(elapsed_times[999]))
# average_elapsed_time should be about 0.00087.
def run():
test_timing()
if __name__ == "__main__":
run()

11
python/main.py Normal file
View file

@ -0,0 +1,11 @@
import cloudpickle
import proto.task_pb2_grpc
def main():
a = cloudpickle.dumps("hello world")
print(a)
if __name__ == "__main__":
main()

477
python/proto/task_pb2.py Executable file
View file

@ -0,0 +1,477 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: proto/task.proto
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='proto/task.proto',
package='',
syntax='proto3',
serialized_options=None,
serialized_pb=b'\n\x10proto/task.proto\"\xa5\x01\n\x03\x41rg\x12#\n\x05local\x18\x01 \x01(\x0e\x32\r.Arg.LocalityR\x05local\x12!\n\x0creference_id\x18\x02 \x01(\x0cR\x0breferenceId\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\x0cR\x04\x64\x61ta\x12\x19\n\x04type\x18\x04 \x01(\x0e\x32\x05.TypeR\x04type\"\'\n\x08Locality\x12\x0c\n\x08INTERNED\x10\x00\x12\r\n\tREFERENCE\x10\x01\"S\n\x04Task\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n\npayload_id\x18\x02 \x01(\x0cR\tpayloadId\x12\x18\n\x04\x61rgs\x18\x03 \x03(\x0b\x32\x04.ArgR\x04\x61rgs\")\n\nTaskTicket\x12\x1b\n\treturn_id\x18\x01 \x01(\x0cR\x08returnId\" \n\nPutRequest\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\"\x1d\n\x0bPutResponse\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"\x1c\n\nGetRequest\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"7\n\x0bGetResponse\x12\x14\n\x05valid\x18\x01 \x01(\x08R\x05valid\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta*\x13\n\x04Type\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x32\x82\x01\n\nTaskServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12 \n\x08Schedule\x12\x05.Task\x1a\x0b.TaskTicket\"\x00\x32\x85\x01\n\nWorkServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12#\n\x07\x45xecute\x12\x0b.TaskTicket\x1a\x05.Task\"\x00(\x01\x30\x01\x62\x06proto3'
)
_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='DEFAULT', index=0, number=0,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=468,
serialized_end=487,
)
_sym_db.RegisterEnumDescriptor(_TYPE)
Type = enum_type_wrapper.EnumTypeWrapper(_TYPE)
DEFAULT = 0
_ARG_LOCALITY = _descriptor.EnumDescriptor(
name='Locality',
full_name='Arg.Locality',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INTERNED', index=0, number=0,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='REFERENCE', index=1, number=1,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=147,
serialized_end=186,
)
_sym_db.RegisterEnumDescriptor(_ARG_LOCALITY)
_ARG = _descriptor.Descriptor(
name='Arg',
full_name='Arg',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='local', full_name='Arg.local', index=0,
number=1, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='local', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='reference_id', full_name='Arg.reference_id', index=1,
number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='referenceId', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='data', full_name='Arg.data', index=2,
number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='data', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='type', full_name='Arg.type', index=3,
number=4, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='type', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
_ARG_LOCALITY,
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=21,
serialized_end=186,
)
_TASK = _descriptor.Descriptor(
name='Task',
full_name='Task',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='Task.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='name', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='payload_id', full_name='Task.payload_id', index=1,
number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='payloadId', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='args', full_name='Task.args', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='args', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=188,
serialized_end=271,
)
_TASKTICKET = _descriptor.Descriptor(
name='TaskTicket',
full_name='TaskTicket',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='return_id', full_name='TaskTicket.return_id', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='returnId', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=273,
serialized_end=314,
)
_PUTREQUEST = _descriptor.Descriptor(
name='PutRequest',
full_name='PutRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data', full_name='PutRequest.data', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='data', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=316,
serialized_end=348,
)
_PUTRESPONSE = _descriptor.Descriptor(
name='PutResponse',
full_name='PutResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='id', full_name='PutResponse.id', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='id', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=350,
serialized_end=379,
)
_GETREQUEST = _descriptor.Descriptor(
name='GetRequest',
full_name='GetRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='id', full_name='GetRequest.id', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='id', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=381,
serialized_end=409,
)
_GETRESPONSE = _descriptor.Descriptor(
name='GetResponse',
full_name='GetResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='valid', full_name='GetResponse.valid', index=0,
number=1, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='valid', file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='data', full_name='GetResponse.data', index=1,
number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, json_name='data', file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=411,
serialized_end=466,
)
_ARG.fields_by_name['local'].enum_type = _ARG_LOCALITY
_ARG.fields_by_name['type'].enum_type = _TYPE
_ARG_LOCALITY.containing_type = _ARG
_TASK.fields_by_name['args'].message_type = _ARG
DESCRIPTOR.message_types_by_name['Arg'] = _ARG
DESCRIPTOR.message_types_by_name['Task'] = _TASK
DESCRIPTOR.message_types_by_name['TaskTicket'] = _TASKTICKET
DESCRIPTOR.message_types_by_name['PutRequest'] = _PUTREQUEST
DESCRIPTOR.message_types_by_name['PutResponse'] = _PUTRESPONSE
DESCRIPTOR.message_types_by_name['GetRequest'] = _GETREQUEST
DESCRIPTOR.message_types_by_name['GetResponse'] = _GETRESPONSE
DESCRIPTOR.enum_types_by_name['Type'] = _TYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Arg = _reflection.GeneratedProtocolMessageType('Arg', (_message.Message,), {
'DESCRIPTOR' : _ARG,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:Arg)
})
_sym_db.RegisterMessage(Arg)
Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), {
'DESCRIPTOR' : _TASK,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:Task)
})
_sym_db.RegisterMessage(Task)
TaskTicket = _reflection.GeneratedProtocolMessageType('TaskTicket', (_message.Message,), {
'DESCRIPTOR' : _TASKTICKET,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:TaskTicket)
})
_sym_db.RegisterMessage(TaskTicket)
PutRequest = _reflection.GeneratedProtocolMessageType('PutRequest', (_message.Message,), {
'DESCRIPTOR' : _PUTREQUEST,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:PutRequest)
})
_sym_db.RegisterMessage(PutRequest)
PutResponse = _reflection.GeneratedProtocolMessageType('PutResponse', (_message.Message,), {
'DESCRIPTOR' : _PUTRESPONSE,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:PutResponse)
})
_sym_db.RegisterMessage(PutResponse)
GetRequest = _reflection.GeneratedProtocolMessageType('GetRequest', (_message.Message,), {
'DESCRIPTOR' : _GETREQUEST,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:GetRequest)
})
_sym_db.RegisterMessage(GetRequest)
GetResponse = _reflection.GeneratedProtocolMessageType('GetResponse', (_message.Message,), {
'DESCRIPTOR' : _GETRESPONSE,
'__module__' : 'proto.task_pb2'
# @@protoc_insertion_point(class_scope:GetResponse)
})
_sym_db.RegisterMessage(GetResponse)
_TASKSERVER = _descriptor.ServiceDescriptor(
name='TaskServer',
full_name='TaskServer',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=490,
serialized_end=620,
methods=[
_descriptor.MethodDescriptor(
name='GetObject',
full_name='TaskServer.GetObject',
index=0,
containing_service=None,
input_type=_GETREQUEST,
output_type=_GETRESPONSE,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='PutObject',
full_name='TaskServer.PutObject',
index=1,
containing_service=None,
input_type=_PUTREQUEST,
output_type=_PUTRESPONSE,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='Schedule',
full_name='TaskServer.Schedule',
index=2,
containing_service=None,
input_type=_TASK,
output_type=_TASKTICKET,
serialized_options=None,
),
])
_sym_db.RegisterServiceDescriptor(_TASKSERVER)
DESCRIPTOR.services_by_name['TaskServer'] = _TASKSERVER
_WORKSERVER = _descriptor.ServiceDescriptor(
name='WorkServer',
full_name='WorkServer',
file=DESCRIPTOR,
index=1,
serialized_options=None,
serialized_start=623,
serialized_end=756,
methods=[
_descriptor.MethodDescriptor(
name='GetObject',
full_name='WorkServer.GetObject',
index=0,
containing_service=None,
input_type=_GETREQUEST,
output_type=_GETRESPONSE,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='PutObject',
full_name='WorkServer.PutObject',
index=1,
containing_service=None,
input_type=_PUTREQUEST,
output_type=_PUTRESPONSE,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='Execute',
full_name='WorkServer.Execute',
index=2,
containing_service=None,
input_type=_TASKTICKET,
output_type=_TASK,
serialized_options=None,
),
])
_sym_db.RegisterServiceDescriptor(_WORKSERVER)
DESCRIPTOR.services_by_name['WorkServer'] = _WORKSERVER
# @@protoc_insertion_point(module_scope)

156
python/proto/task_pb2_grpc.py Executable file
View file

@ -0,0 +1,156 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
import grpc
from proto import task_pb2 as proto_dot_task__pb2
class TaskServerStub(object):
# missing associated documentation comment in .proto file
pass
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.GetObject = channel.unary_unary(
'/TaskServer/GetObject',
request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString,
response_deserializer=proto_dot_task__pb2.GetResponse.FromString,
)
self.PutObject = channel.unary_unary(
'/TaskServer/PutObject',
request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString,
response_deserializer=proto_dot_task__pb2.PutResponse.FromString,
)
self.Schedule = channel.unary_unary(
'/TaskServer/Schedule',
request_serializer=proto_dot_task__pb2.Task.SerializeToString,
response_deserializer=proto_dot_task__pb2.TaskTicket.FromString,
)
class TaskServerServicer(object):
# missing associated documentation comment in .proto file
pass
def GetObject(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def PutObject(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Schedule(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TaskServerServicer_to_server(servicer, server):
rpc_method_handlers = {
'GetObject': grpc.unary_unary_rpc_method_handler(
servicer.GetObject,
request_deserializer=proto_dot_task__pb2.GetRequest.FromString,
response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString,
),
'PutObject': grpc.unary_unary_rpc_method_handler(
servicer.PutObject,
request_deserializer=proto_dot_task__pb2.PutRequest.FromString,
response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString,
),
'Schedule': grpc.unary_unary_rpc_method_handler(
servicer.Schedule,
request_deserializer=proto_dot_task__pb2.Task.FromString,
response_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'TaskServer', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
class WorkServerStub(object):
# missing associated documentation comment in .proto file
pass
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.GetObject = channel.unary_unary(
'/WorkServer/GetObject',
request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString,
response_deserializer=proto_dot_task__pb2.GetResponse.FromString,
)
self.PutObject = channel.unary_unary(
'/WorkServer/PutObject',
request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString,
response_deserializer=proto_dot_task__pb2.PutResponse.FromString,
)
self.Execute = channel.stream_stream(
'/WorkServer/Execute',
request_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString,
response_deserializer=proto_dot_task__pb2.Task.FromString,
)
class WorkServerServicer(object):
# missing associated documentation comment in .proto file
pass
def GetObject(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def PutObject(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Execute(self, request_iterator, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_WorkServerServicer_to_server(servicer, server):
rpc_method_handlers = {
'GetObject': grpc.unary_unary_rpc_method_handler(
servicer.GetObject,
request_deserializer=proto_dot_task__pb2.GetRequest.FromString,
response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString,
),
'PutObject': grpc.unary_unary_rpc_method_handler(
servicer.PutObject,
request_deserializer=proto_dot_task__pb2.PutRequest.FromString,
response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString,
),
'Execute': grpc.stream_stream_rpc_method_handler(
servicer.Execute,
request_deserializer=proto_dot_task__pb2.TaskTicket.FromString,
response_serializer=proto_dot_task__pb2.Task.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'WorkServer', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))

163
python/ray.py Normal file
View file

@ -0,0 +1,163 @@
import cloudpickle
import grpc
import proto.task_pb2
import proto.task_pb2_grpc
import uuid
class ObjectID:
def __init__(self, id):
self.id = id
def __repr__(self):
return "ObjectID(%s)" % self.id.decode()
worker_registry = {}
def set_global_worker(worker):
global _global_worker
_global_worker = worker
def register_worker(worker):
id = uuid.uuid4()
worker_registry[id] = worker
return id
def get_worker_registry(id):
out = worker_registry.get(id)
if out is None:
return _global_worker
return out
class Worker:
def __init__(self, conn_str="", stub=None):
if stub is None:
self.channel = grpc.insecure_channel(conn_str)
self.server = proto.task_pb2_grpc.TaskServerStub(self.channel)
else:
self.server = stub
self.uuid = register_worker(self)
def get(self, ids):
to_get = []
single = False
if isinstance(ids, list):
to_get = [x.id for x in ids]
elif isinstance(ids, ObjectID):
to_get = [ids.id]
single = True
else:
raise Exception("Can't get something that's not a list of IDs or just an ID")
out = [self._get(x) for x in to_get]
if single:
out = out[0]
return out
def _get(self, id: bytes):
req = proto.task_pb2.GetRequest(id=id)
data = self.server.GetObject(req)
return cloudpickle.loads(data.data)
def put(self, vals):
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val):
data = cloudpickle.dumps(val)
#print("val: %s\ndata: %s"%(val, data))
req = proto.task_pb2.PutRequest(data=data)
resp = self.server.PutObject(req)
return ObjectID(resp.id)
def remote(self, func):
return RemoteFunc(self, func)
def schedule(self, task):
return self.server.Schedule(task)
def close(self):
self.channel.close()
class RemoteFunc:
def __init__(self, worker, f):
self._func = f
self._name = f.__name__
self.id = None
self._worker_id = worker.uuid
def __call__(self, *args, **kwargs):
raise Exception("Matching the old API")
def remote(self, *args):
if self.id is None:
self._push_func()
t = proto.task_pb2.Task()
t.name = self._name
t.payload_id = self.id.id
for a in args:
arg = proto.task_pb2.Arg()
if isinstance(a, ObjectID):
arg.local = proto.task_pb2.Arg.Locality.REFERENCE
arg.reference_id = a.id
else:
arg.local = proto.task_pb2.Arg.Locality.INTERNED
arg.data = cloudpickle.dumps(a)
t.args.append(arg)
worker = get_worker_registry(self._worker_id)
ticket = worker.schedule(t)
return ObjectID(ticket.return_id)
def _push_func(self):
worker = get_worker_registry(self._worker_id)
self.id = worker.put(self._func)
def __repr__(self):
return "RemoteFunc(%s, %s)" % (self._name, self.id)
_global_worker = None
def connect(*args, **kwargs):
global _global_worker
if _global_worker is not None:
raise Exception("Can't connect a second global worker")
_global_worker = Worker(*args, **kwargs)
def get(*args, **kwargs):
global _global_worker
if _global_worker is None:
raise Exception("Need a connection before calling")
return _global_worker.get(*args, **kwargs)
def put(*args, **kwargs):
global _global_worker
if _global_worker is None:
raise Exception("Need a connection before calling")
return _global_worker.put(*args, **kwargs)
def remote(*args, **kwargs):
global _global_worker
if _global_worker is None:
raise Exception("Need a connection before calling")
return _global_worker.remote(*args, **kwargs)

5
python/requirements.txt Normal file
View file

@ -0,0 +1,5 @@
cloudpickle
futures
grpcio
google
protobuf