small python ray
This commit is contained in:
commit
e02e80708a
15 changed files with 1261 additions and 0 deletions
25
python/BUILD
Normal file
25
python/BUILD
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
load("@py_deps//:requirements.bzl", "requirement")
|
||||
load("//bazel:tools.bzl", "copy_to_workspace")
|
||||
|
||||
filegroup(
|
||||
name = "all_py_proto",
|
||||
srcs = [
|
||||
"//proto:task_proto_grpc_py",
|
||||
],
|
||||
)
|
||||
|
||||
copy_to_workspace(
|
||||
name = "proto_folder",
|
||||
srcs = [":all_py_proto"],
|
||||
dstdir = "python/proto",
|
||||
)
|
||||
|
||||
|
||||
py_binary(
|
||||
name = "main",
|
||||
srcs = ["main.py"],
|
||||
data = [ ":proto_folder" ],
|
||||
deps = [
|
||||
requirement("cloudpickle"),
|
||||
],
|
||||
)
|
||||
104
python/dumb_raylet.py
Normal file
104
python/dumb_raylet.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import cloudpickle
|
||||
import ray
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import proto.task_pb2
|
||||
import proto.task_pb2_grpc
|
||||
import threading
|
||||
import time
|
||||
|
||||
id_printer = 1
|
||||
|
||||
|
||||
def generate_id():
|
||||
global id_printer
|
||||
out = str(id_printer).encode()
|
||||
id_printer += 1
|
||||
return out
|
||||
|
||||
|
||||
class ObjectStore:
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get(self, id):
|
||||
with self.lock:
|
||||
return self.data[id]
|
||||
|
||||
def put(self, data):
|
||||
with self.lock:
|
||||
id = generate_id()
|
||||
self.data[id] = data
|
||||
return id
|
||||
|
||||
|
||||
class TaskServicer(proto.task_pb2_grpc.TaskServerServicer):
|
||||
def __init__(self, objects, executor):
|
||||
self.objects = objects
|
||||
self.executor = executor
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
data = self.objects.get(request.id)
|
||||
if data is None:
|
||||
return proto.task_pb2.GetResponse(valid=False)
|
||||
return proto.task_pb2.GetResponse(valid=True, data=data)
|
||||
|
||||
def PutObject(self, request, context=None):
|
||||
id = self.objects.put(request.data)
|
||||
return proto.task_pb2.PutResponse(id=id)
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
return_val = self.executor.execute(task, context)
|
||||
return proto.task_pb2.TaskTicket(return_id=return_val)
|
||||
|
||||
|
||||
class SimpleExecutor:
|
||||
def __init__(self, object_store):
|
||||
self.objects = object_store
|
||||
|
||||
def execute(self, task, context):
|
||||
#print("Executing task", task.name)
|
||||
self.context = context
|
||||
func_data = self.objects.get(task.payload_id)
|
||||
if func_data is None:
|
||||
raise Exception("WTF")
|
||||
f = cloudpickle.loads(func_data)
|
||||
args = []
|
||||
for a in task.args:
|
||||
if a.local == proto.task_pb2.Arg.Locality.INTERNED:
|
||||
data = a.data
|
||||
else:
|
||||
data = self.objects.get(a.reference_id)
|
||||
realarg = cloudpickle.loads(data)
|
||||
args.append(realarg)
|
||||
out = f(*args)
|
||||
id = self.objects.put(cloudpickle.dumps(out))
|
||||
self.context = None
|
||||
return id
|
||||
|
||||
def set_task_servicer(self, servicer):
|
||||
ray._global_worker = ray.Worker(stub=servicer)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
object_store = ObjectStore()
|
||||
executor = SimpleExecutor(object_store)
|
||||
task_servicer = TaskServicer(object_store, executor)
|
||||
executor.set_task_servicer(task_servicer)
|
||||
proto.task_pb2_grpc.add_TaskServerServicer_to_server(
|
||||
task_servicer, server)
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig()
|
||||
serve()
|
||||
34
python/example_client.py
Normal file
34
python/example_client.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import cloudpickle
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import ray
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
out = 1
|
||||
for i in range(1, x+1):
|
||||
out = i * out
|
||||
return out
|
||||
|
||||
|
||||
@ray.remote
|
||||
def fib(x):
|
||||
if x <= 1:
|
||||
return 1
|
||||
return ray.get(fib.remote(x - 1)) + ray.get(fib.remote(x - 2))
|
||||
|
||||
|
||||
def run():
|
||||
# out = fact.remote(5)
|
||||
# out2 = fact.remote(out)
|
||||
# print("Got fact: ", ray.get(out2))
|
||||
fib_out = fib.remote(6)
|
||||
# print("Fib out:", ray.get(fib_out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
96
python/example_test.py
Normal file
96
python/example_test.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import ray
|
||||
import numpy as np
|
||||
import time
|
||||
import dumb_raylet
|
||||
|
||||
#ray.connect("localhost:50051")
|
||||
object_store = dumb_raylet.ObjectStore()
|
||||
executor = dumb_raylet.SimpleExecutor(object_store)
|
||||
task_servicer = dumb_raylet.TaskServicer(object_store, executor)
|
||||
executor.set_task_servicer(task_servicer)
|
||||
|
||||
|
||||
def test_timing():
|
||||
@ray.remote
|
||||
def empty_function():
|
||||
pass
|
||||
|
||||
@ray.remote
|
||||
def trivial_function():
|
||||
return 1
|
||||
|
||||
# Measure the time required to submit a remote task to the scheduler.
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
empty_function.remote()
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to submit an empty function call:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
# average_elapsed_time should be about 0.00038.
|
||||
|
||||
# Measure the time required to submit a remote task to the scheduler
|
||||
# (where the remote task returns one value).
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
trivial_function.remote()
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to submit a trivial function call:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
|
||||
# Measure the time required to submit a remote task to the scheduler
|
||||
# and get the result.
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
x = trivial_function.remote()
|
||||
ray.get(x)
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to submit a trivial function call and get the "
|
||||
"result:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
# average_elapsed_time should be about 0.0013.
|
||||
|
||||
# Measure the time required to do do a put.
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
ray.put(1)
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to put an int:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
# average_elapsed_time should be about 0.00087.
|
||||
|
||||
|
||||
|
||||
def run():
|
||||
test_timing()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
11
python/main.py
Normal file
11
python/main.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import cloudpickle
|
||||
import proto.task_pb2_grpc
|
||||
|
||||
|
||||
def main():
|
||||
a = cloudpickle.dumps("hello world")
|
||||
print(a)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
477
python/proto/task_pb2.py
Executable file
477
python/proto/task_pb2.py
Executable file
|
|
@ -0,0 +1,477 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: proto/task.proto
|
||||
|
||||
from google.protobuf.internal import enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from google.protobuf import reflection as _reflection
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||
name='proto/task.proto',
|
||||
package='',
|
||||
syntax='proto3',
|
||||
serialized_options=None,
|
||||
serialized_pb=b'\n\x10proto/task.proto\"\xa5\x01\n\x03\x41rg\x12#\n\x05local\x18\x01 \x01(\x0e\x32\r.Arg.LocalityR\x05local\x12!\n\x0creference_id\x18\x02 \x01(\x0cR\x0breferenceId\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\x0cR\x04\x64\x61ta\x12\x19\n\x04type\x18\x04 \x01(\x0e\x32\x05.TypeR\x04type\"\'\n\x08Locality\x12\x0c\n\x08INTERNED\x10\x00\x12\r\n\tREFERENCE\x10\x01\"S\n\x04Task\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n\npayload_id\x18\x02 \x01(\x0cR\tpayloadId\x12\x18\n\x04\x61rgs\x18\x03 \x03(\x0b\x32\x04.ArgR\x04\x61rgs\")\n\nTaskTicket\x12\x1b\n\treturn_id\x18\x01 \x01(\x0cR\x08returnId\" \n\nPutRequest\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\"\x1d\n\x0bPutResponse\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"\x1c\n\nGetRequest\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"7\n\x0bGetResponse\x12\x14\n\x05valid\x18\x01 \x01(\x08R\x05valid\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta*\x13\n\x04Type\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x32\x82\x01\n\nTaskServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12 \n\x08Schedule\x12\x05.Task\x1a\x0b.TaskTicket\"\x00\x32\x85\x01\n\nWorkServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12#\n\x07\x45xecute\x12\x0b.TaskTicket\x1a\x05.Task\"\x00(\x01\x30\x01\x62\x06proto3'
|
||||
)
|
||||
|
||||
_TYPE = _descriptor.EnumDescriptor(
|
||||
name='Type',
|
||||
full_name='Type',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
values=[
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='DEFAULT', index=0, number=0,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
],
|
||||
containing_type=None,
|
||||
serialized_options=None,
|
||||
serialized_start=468,
|
||||
serialized_end=487,
|
||||
)
|
||||
_sym_db.RegisterEnumDescriptor(_TYPE)
|
||||
|
||||
Type = enum_type_wrapper.EnumTypeWrapper(_TYPE)
|
||||
DEFAULT = 0
|
||||
|
||||
|
||||
_ARG_LOCALITY = _descriptor.EnumDescriptor(
|
||||
name='Locality',
|
||||
full_name='Arg.Locality',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
values=[
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='INTERNED', index=0, number=0,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='REFERENCE', index=1, number=1,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
],
|
||||
containing_type=None,
|
||||
serialized_options=None,
|
||||
serialized_start=147,
|
||||
serialized_end=186,
|
||||
)
|
||||
_sym_db.RegisterEnumDescriptor(_ARG_LOCALITY)
|
||||
|
||||
|
||||
_ARG = _descriptor.Descriptor(
|
||||
name='Arg',
|
||||
full_name='Arg',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='local', full_name='Arg.local', index=0,
|
||||
number=1, type=14, cpp_type=8, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='local', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='reference_id', full_name='Arg.reference_id', index=1,
|
||||
number=2, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='referenceId', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='Arg.data', index=2,
|
||||
number=3, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='data', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='type', full_name='Arg.type', index=3,
|
||||
number=4, type=14, cpp_type=8, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='type', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
_ARG_LOCALITY,
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=21,
|
||||
serialized_end=186,
|
||||
)
|
||||
|
||||
|
||||
_TASK = _descriptor.Descriptor(
|
||||
name='Task',
|
||||
full_name='Task',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='name', full_name='Task.name', index=0,
|
||||
number=1, type=9, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='name', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='payload_id', full_name='Task.payload_id', index=1,
|
||||
number=2, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='payloadId', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='args', full_name='Task.args', index=2,
|
||||
number=3, type=11, cpp_type=10, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='args', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=188,
|
||||
serialized_end=271,
|
||||
)
|
||||
|
||||
|
||||
_TASKTICKET = _descriptor.Descriptor(
|
||||
name='TaskTicket',
|
||||
full_name='TaskTicket',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='return_id', full_name='TaskTicket.return_id', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='returnId', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=273,
|
||||
serialized_end=314,
|
||||
)
|
||||
|
||||
|
||||
_PUTREQUEST = _descriptor.Descriptor(
|
||||
name='PutRequest',
|
||||
full_name='PutRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='PutRequest.data', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='data', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=316,
|
||||
serialized_end=348,
|
||||
)
|
||||
|
||||
|
||||
_PUTRESPONSE = _descriptor.Descriptor(
|
||||
name='PutResponse',
|
||||
full_name='PutResponse',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='id', full_name='PutResponse.id', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='id', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=350,
|
||||
serialized_end=379,
|
||||
)
|
||||
|
||||
|
||||
_GETREQUEST = _descriptor.Descriptor(
|
||||
name='GetRequest',
|
||||
full_name='GetRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='id', full_name='GetRequest.id', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='id', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=381,
|
||||
serialized_end=409,
|
||||
)
|
||||
|
||||
|
||||
_GETRESPONSE = _descriptor.Descriptor(
|
||||
name='GetResponse',
|
||||
full_name='GetResponse',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='valid', full_name='GetResponse.valid', index=0,
|
||||
number=1, type=8, cpp_type=7, label=1,
|
||||
has_default_value=False, default_value=False,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='valid', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='GetResponse.data', index=1,
|
||||
number=2, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='data', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=411,
|
||||
serialized_end=466,
|
||||
)
|
||||
|
||||
_ARG.fields_by_name['local'].enum_type = _ARG_LOCALITY
|
||||
_ARG.fields_by_name['type'].enum_type = _TYPE
|
||||
_ARG_LOCALITY.containing_type = _ARG
|
||||
_TASK.fields_by_name['args'].message_type = _ARG
|
||||
DESCRIPTOR.message_types_by_name['Arg'] = _ARG
|
||||
DESCRIPTOR.message_types_by_name['Task'] = _TASK
|
||||
DESCRIPTOR.message_types_by_name['TaskTicket'] = _TASKTICKET
|
||||
DESCRIPTOR.message_types_by_name['PutRequest'] = _PUTREQUEST
|
||||
DESCRIPTOR.message_types_by_name['PutResponse'] = _PUTRESPONSE
|
||||
DESCRIPTOR.message_types_by_name['GetRequest'] = _GETREQUEST
|
||||
DESCRIPTOR.message_types_by_name['GetResponse'] = _GETRESPONSE
|
||||
DESCRIPTOR.enum_types_by_name['Type'] = _TYPE
|
||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||
|
||||
Arg = _reflection.GeneratedProtocolMessageType('Arg', (_message.Message,), {
|
||||
'DESCRIPTOR' : _ARG,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:Arg)
|
||||
})
|
||||
_sym_db.RegisterMessage(Arg)
|
||||
|
||||
Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TASK,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:Task)
|
||||
})
|
||||
_sym_db.RegisterMessage(Task)
|
||||
|
||||
TaskTicket = _reflection.GeneratedProtocolMessageType('TaskTicket', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TASKTICKET,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:TaskTicket)
|
||||
})
|
||||
_sym_db.RegisterMessage(TaskTicket)
|
||||
|
||||
PutRequest = _reflection.GeneratedProtocolMessageType('PutRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _PUTREQUEST,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:PutRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(PutRequest)
|
||||
|
||||
PutResponse = _reflection.GeneratedProtocolMessageType('PutResponse', (_message.Message,), {
|
||||
'DESCRIPTOR' : _PUTRESPONSE,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:PutResponse)
|
||||
})
|
||||
_sym_db.RegisterMessage(PutResponse)
|
||||
|
||||
GetRequest = _reflection.GeneratedProtocolMessageType('GetRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _GETREQUEST,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:GetRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(GetRequest)
|
||||
|
||||
GetResponse = _reflection.GeneratedProtocolMessageType('GetResponse', (_message.Message,), {
|
||||
'DESCRIPTOR' : _GETRESPONSE,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:GetResponse)
|
||||
})
|
||||
_sym_db.RegisterMessage(GetResponse)
|
||||
|
||||
|
||||
|
||||
_TASKSERVER = _descriptor.ServiceDescriptor(
|
||||
name='TaskServer',
|
||||
full_name='TaskServer',
|
||||
file=DESCRIPTOR,
|
||||
index=0,
|
||||
serialized_options=None,
|
||||
serialized_start=490,
|
||||
serialized_end=620,
|
||||
methods=[
|
||||
_descriptor.MethodDescriptor(
|
||||
name='GetObject',
|
||||
full_name='TaskServer.GetObject',
|
||||
index=0,
|
||||
containing_service=None,
|
||||
input_type=_GETREQUEST,
|
||||
output_type=_GETRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='PutObject',
|
||||
full_name='TaskServer.PutObject',
|
||||
index=1,
|
||||
containing_service=None,
|
||||
input_type=_PUTREQUEST,
|
||||
output_type=_PUTRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Schedule',
|
||||
full_name='TaskServer.Schedule',
|
||||
index=2,
|
||||
containing_service=None,
|
||||
input_type=_TASK,
|
||||
output_type=_TASKTICKET,
|
||||
serialized_options=None,
|
||||
),
|
||||
])
|
||||
_sym_db.RegisterServiceDescriptor(_TASKSERVER)
|
||||
|
||||
DESCRIPTOR.services_by_name['TaskServer'] = _TASKSERVER
|
||||
|
||||
|
||||
_WORKSERVER = _descriptor.ServiceDescriptor(
|
||||
name='WorkServer',
|
||||
full_name='WorkServer',
|
||||
file=DESCRIPTOR,
|
||||
index=1,
|
||||
serialized_options=None,
|
||||
serialized_start=623,
|
||||
serialized_end=756,
|
||||
methods=[
|
||||
_descriptor.MethodDescriptor(
|
||||
name='GetObject',
|
||||
full_name='WorkServer.GetObject',
|
||||
index=0,
|
||||
containing_service=None,
|
||||
input_type=_GETREQUEST,
|
||||
output_type=_GETRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='PutObject',
|
||||
full_name='WorkServer.PutObject',
|
||||
index=1,
|
||||
containing_service=None,
|
||||
input_type=_PUTREQUEST,
|
||||
output_type=_PUTRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Execute',
|
||||
full_name='WorkServer.Execute',
|
||||
index=2,
|
||||
containing_service=None,
|
||||
input_type=_TASKTICKET,
|
||||
output_type=_TASK,
|
||||
serialized_options=None,
|
||||
),
|
||||
])
|
||||
_sym_db.RegisterServiceDescriptor(_WORKSERVER)
|
||||
|
||||
DESCRIPTOR.services_by_name['WorkServer'] = _WORKSERVER
|
||||
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
156
python/proto/task_pb2_grpc.py
Executable file
156
python/proto/task_pb2_grpc.py
Executable file
|
|
@ -0,0 +1,156 @@
|
|||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
import grpc
|
||||
|
||||
from proto import task_pb2 as proto_dot_task__pb2
|
||||
|
||||
|
||||
class TaskServerStub(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.GetObject = channel.unary_unary(
|
||||
'/TaskServer/GetObject',
|
||||
request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.GetResponse.FromString,
|
||||
)
|
||||
self.PutObject = channel.unary_unary(
|
||||
'/TaskServer/PutObject',
|
||||
request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.PutResponse.FromString,
|
||||
)
|
||||
self.Schedule = channel.unary_unary(
|
||||
'/TaskServer/Schedule',
|
||||
request_serializer=proto_dot_task__pb2.Task.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.TaskTicket.FromString,
|
||||
)
|
||||
|
||||
|
||||
class TaskServerServicer(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def GetObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PutObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Schedule(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_TaskServerServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'GetObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetObject,
|
||||
request_deserializer=proto_dot_task__pb2.GetRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString,
|
||||
),
|
||||
'PutObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.PutObject,
|
||||
request_deserializer=proto_dot_task__pb2.PutRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString,
|
||||
),
|
||||
'Schedule': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Schedule,
|
||||
request_deserializer=proto_dot_task__pb2.Task.FromString,
|
||||
response_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'TaskServer', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
class WorkServerStub(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.GetObject = channel.unary_unary(
|
||||
'/WorkServer/GetObject',
|
||||
request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.GetResponse.FromString,
|
||||
)
|
||||
self.PutObject = channel.unary_unary(
|
||||
'/WorkServer/PutObject',
|
||||
request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.PutResponse.FromString,
|
||||
)
|
||||
self.Execute = channel.stream_stream(
|
||||
'/WorkServer/Execute',
|
||||
request_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.Task.FromString,
|
||||
)
|
||||
|
||||
|
||||
class WorkServerServicer(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def GetObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PutObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Execute(self, request_iterator, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_WorkServerServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'GetObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetObject,
|
||||
request_deserializer=proto_dot_task__pb2.GetRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString,
|
||||
),
|
||||
'PutObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.PutObject,
|
||||
request_deserializer=proto_dot_task__pb2.PutRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString,
|
||||
),
|
||||
'Execute': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.Execute,
|
||||
request_deserializer=proto_dot_task__pb2.TaskTicket.FromString,
|
||||
response_serializer=proto_dot_task__pb2.Task.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'WorkServer', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
163
python/ray.py
Normal file
163
python/ray.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import cloudpickle
|
||||
import grpc
|
||||
import proto.task_pb2
|
||||
import proto.task_pb2_grpc
|
||||
import uuid
|
||||
|
||||
|
||||
class ObjectID:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "ObjectID(%s)" % self.id.decode()
|
||||
|
||||
|
||||
worker_registry = {}
|
||||
|
||||
|
||||
def set_global_worker(worker):
|
||||
global _global_worker
|
||||
_global_worker = worker
|
||||
|
||||
|
||||
def register_worker(worker):
|
||||
id = uuid.uuid4()
|
||||
worker_registry[id] = worker
|
||||
return id
|
||||
|
||||
|
||||
def get_worker_registry(id):
|
||||
out = worker_registry.get(id)
|
||||
if out is None:
|
||||
return _global_worker
|
||||
return out
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str="", stub=None):
|
||||
if stub is None:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = proto.task_pb2_grpc.TaskServerStub(self.channel)
|
||||
else:
|
||||
self.server = stub
|
||||
self.uuid = register_worker(self)
|
||||
|
||||
def get(self, ids):
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(ids, list):
|
||||
to_get = [x.id for x in ids]
|
||||
elif isinstance(ids, ObjectID):
|
||||
to_get = [ids.id]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a list of IDs or just an ID")
|
||||
|
||||
out = [self._get(x) for x in to_get]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, id: bytes):
|
||||
req = proto.task_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req)
|
||||
return cloudpickle.loads(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
#print("val: %s\ndata: %s"%(val, data))
|
||||
req = proto.task_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req)
|
||||
return ObjectID(resp.id)
|
||||
|
||||
def remote(self, func):
|
||||
return RemoteFunc(self, func)
|
||||
|
||||
def schedule(self, task):
|
||||
return self.server.Schedule(task)
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
|
||||
|
||||
class RemoteFunc:
|
||||
def __init__(self, worker, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
self._worker_id = worker.uuid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Matching the old API")
|
||||
|
||||
def remote(self, *args):
|
||||
if self.id is None:
|
||||
self._push_func()
|
||||
t = proto.task_pb2.Task()
|
||||
t.name = self._name
|
||||
t.payload_id = self.id.id
|
||||
for a in args:
|
||||
arg = proto.task_pb2.Arg()
|
||||
if isinstance(a, ObjectID):
|
||||
arg.local = proto.task_pb2.Arg.Locality.REFERENCE
|
||||
arg.reference_id = a.id
|
||||
else:
|
||||
arg.local = proto.task_pb2.Arg.Locality.INTERNED
|
||||
arg.data = cloudpickle.dumps(a)
|
||||
t.args.append(arg)
|
||||
worker = get_worker_registry(self._worker_id)
|
||||
ticket = worker.schedule(t)
|
||||
return ObjectID(ticket.return_id)
|
||||
|
||||
def _push_func(self):
|
||||
worker = get_worker_registry(self._worker_id)
|
||||
self.id = worker.put(self._func)
|
||||
|
||||
def __repr__(self):
|
||||
return "RemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
|
||||
_global_worker = None
|
||||
|
||||
|
||||
def connect(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is not None:
|
||||
raise Exception("Can't connect a second global worker")
|
||||
_global_worker = Worker(*args, **kwargs)
|
||||
|
||||
|
||||
def get(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is None:
|
||||
raise Exception("Need a connection before calling")
|
||||
return _global_worker.get(*args, **kwargs)
|
||||
|
||||
|
||||
def put(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is None:
|
||||
raise Exception("Need a connection before calling")
|
||||
return _global_worker.put(*args, **kwargs)
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is None:
|
||||
raise Exception("Need a connection before calling")
|
||||
return _global_worker.remote(*args, **kwargs)
|
||||
5
python/requirements.txt
Normal file
5
python/requirements.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
cloudpickle
|
||||
futures
|
||||
grpcio
|
||||
google
|
||||
protobuf
|
||||
Loading…
Add table
Add a link
Reference in a new issue