commit e02e80708ae3e1650f9e1ab3ca52d7374f712b13 Author: Barak Michener Date: Sat Sep 12 15:57:57 2020 -0700 small python ray diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d0787e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/bazel-* +*.pyc diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..552e677 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,78 @@ +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "rules_python", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.0.2/rules_python-0.0.2.tar.gz", + strip_prefix = "rules_python-0.0.2", + sha256 = "b5668cde8bb6e3515057ef465a35ad712214962f0b3a314e551204266c7be90c", +) + +load("@rules_python//python:repositories.bzl", "py_repositories") + +py_repositories() + +load("@rules_python//python:pip.bzl", "pip_repositories") + +pip_repositories() + +http_archive( + name = "rules_proto_grpc", + urls = ["https://github.com/rules-proto-grpc/rules_proto_grpc/archive/1.0.2.tar.gz"], + sha256 = "5f0f2fc0199810c65a2de148a52ba0aff14d631d4e8202f41aff6a9d590a471b", + strip_prefix = "rules_proto_grpc-1.0.2", +) + +load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains", "rules_proto_grpc_repos") +rules_proto_grpc_toolchains() +rules_proto_grpc_repos() + + +load("@rules_proto_grpc//python:repositories.bzl", rules_proto_grpc_python_repos="python_repos") + +rules_proto_grpc_python_repos() + +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + +grpc_deps() + +rules_python_external_version = "0.1.5" + +http_archive( + name = "rules_python_external", + url = "https://github.com/dillon-giacoppo/rules_python_external/archive/v{version}.zip".format(version = rules_python_external_version), + sha256 = "bc655e6d402915944e014c3b2cad23d0a97b83a66cc22f20db09c9f8da2e2789", + strip_prefix = "rules_python_external-{version}".format(version = rules_python_external_version), +) + +load("@rules_python_external//:repositories.bzl", "rules_python_external_dependencies") +rules_python_external_dependencies() + + + +load("@rules_python//python:pip.bzl", "pip_import") +pip_import( + name = "rules_proto_grpc_py3_deps", + #python_interpreter = "python3", + requirements = "@rules_proto_grpc//python:requirements.txt", +) + +load("@rules_proto_grpc_py3_deps//:requirements.bzl", pip3_install="pip_install") +pip3_install() + +load("@rules_python_external//:defs.bzl", "pip_install") +pip_install( + name = "py_deps", + requirements = "//python:requirements.txt", + # (Optional) You can provide a python interpreter (by path): + #python_interpreter = "/usr/bin/python3.8", + # (Optional) Alternatively you can provide an in-build python interpreter, that is available as a Bazel target. + # This overrides `python_interpreter`. + # Note: You need to set up the interpreter target beforehand (not shown here). Please see the `example` folder for further details. + #python_interpreter_target = "@python_interpreter//:python_bin", +) + + + + + + diff --git a/bazel/BUILD b/bazel/BUILD new file mode 100644 index 0000000..e69de29 diff --git a/bazel/tools.bzl b/bazel/tools.bzl new file mode 100644 index 0000000..eb3e96c --- /dev/null +++ b/bazel/tools.bzl @@ -0,0 +1,36 @@ +def copy_to_workspace(name, srcs, dstdir = ""): + if dstdir.startswith("/") or dstdir.startswith("\\"): + fail("Subdirectory must be a relative path: " + dstdir) + src_locations = " ".join(["$(locations %s)" % (src,) for src in srcs]) + native.genrule( + name = name, + srcs = srcs, + outs = [name + ".out"], + # Keep this Bash script equivalent to the batch script below (or take out the batch script) + cmd = r""" + mkdir -p -- {dstdir} + for f in {locations}; do + rm -f -- {dstdir}$${{f##*/}} + cp -f -- "$$f" {dstdir} + done + date > $@ + """.format( + locations = src_locations, + dstdir = "." + ("/" + dstdir.replace("\\", "/")).rstrip("/") + "/", + ), + # Keep this batch script equivalent to the Bash script above (or take out the batch script) + cmd_bat = """ + ( + if not exist {dstdir} mkdir {dstdir} + ) && ( + for %f in ({locations}) do @( + (if exist {dstdir}%~nxf del /f /q {dstdir}%~nxf) && + copy /B /Y %f {dstdir} >NUL + ) + ) && >$@ echo %TIME% + """.replace("\r", "").replace("\n", " ").format( + locations = src_locations, + dstdir = "." + ("\\" + dstdir.replace("/", "\\")).rstrip("\\") + "\\", + ), + local = 1, + ) diff --git a/proto/BUILD b/proto/BUILD new file mode 100644 index 0000000..76fe2d3 --- /dev/null +++ b/proto/BUILD @@ -0,0 +1,18 @@ +package(default_visibility = ["//visibility:public"]) + +load("@rules_proto_grpc//python:defs.bzl", "python_grpc_compile") + +proto_library( + name = "task_proto", + srcs = [ + "task.proto", + ], + deps = [ + "@com_google_protobuf//:any_proto", + ], +) + +python_grpc_compile( + name = "task_proto_grpc_py", + deps = [":task_proto"], +) diff --git a/proto/task.proto b/proto/task.proto new file mode 100644 index 0000000..a3ca2f3 --- /dev/null +++ b/proto/task.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +enum Type { + DEFAULT = 0; +} + +message Arg { + enum Locality { + INTERNED = 0; + REFERENCE = 1; + } + Locality local = 1; + bytes reference_id = 2; + bytes data = 3; + Type type = 4; +} + +message Task { + // Optionally Provided Task Name + string name = 1; + bytes payload_id = 2; + repeated Arg args = 3; +} + +message TaskTicket { + bytes return_id = 1; +} + +message PutRequest { + bytes data = 1; +} + +message PutResponse { + bytes id = 1; +} + +message GetRequest { + bytes id = 1; +} + +message GetResponse { + bool valid = 1; + bytes data = 2; +} + +service TaskServer { + rpc GetObject(GetRequest) returns (GetResponse) {} + rpc PutObject(PutRequest) returns (PutResponse) {} + rpc Schedule(Task) returns (TaskTicket) {} +} + +service WorkServer { + rpc GetObject(GetRequest) returns (GetResponse) {} + rpc PutObject(PutRequest) returns (PutResponse) {} + rpc Execute(stream TaskTicket) returns (stream Task) {} +} diff --git a/python/BUILD b/python/BUILD new file mode 100644 index 0000000..0637bb2 --- /dev/null +++ b/python/BUILD @@ -0,0 +1,25 @@ +load("@py_deps//:requirements.bzl", "requirement") +load("//bazel:tools.bzl", "copy_to_workspace") + +filegroup( + name = "all_py_proto", + srcs = [ + "//proto:task_proto_grpc_py", + ], +) + +copy_to_workspace( + name = "proto_folder", + srcs = [":all_py_proto"], + dstdir = "python/proto", +) + + +py_binary( + name = "main", + srcs = ["main.py"], + data = [ ":proto_folder" ], + deps = [ + requirement("cloudpickle"), + ], +) diff --git a/python/dumb_raylet.py b/python/dumb_raylet.py new file mode 100644 index 0000000..81aa77e --- /dev/null +++ b/python/dumb_raylet.py @@ -0,0 +1,104 @@ +import cloudpickle +import ray +import logging +from concurrent import futures +import grpc +import proto.task_pb2 +import proto.task_pb2_grpc +import threading +import time + +id_printer = 1 + + +def generate_id(): + global id_printer + out = str(id_printer).encode() + id_printer += 1 + return out + + +class ObjectStore: + def __init__(self): + self.data = {} + self.lock = threading.Lock() + + def get(self, id): + with self.lock: + return self.data[id] + + def put(self, data): + with self.lock: + id = generate_id() + self.data[id] = data + return id + + +class TaskServicer(proto.task_pb2_grpc.TaskServerServicer): + def __init__(self, objects, executor): + self.objects = objects + self.executor = executor + + def GetObject(self, request, context=None): + data = self.objects.get(request.id) + if data is None: + return proto.task_pb2.GetResponse(valid=False) + return proto.task_pb2.GetResponse(valid=True, data=data) + + def PutObject(self, request, context=None): + id = self.objects.put(request.data) + return proto.task_pb2.PutResponse(id=id) + + def Schedule(self, task, context=None): + return_val = self.executor.execute(task, context) + return proto.task_pb2.TaskTicket(return_id=return_val) + + +class SimpleExecutor: + def __init__(self, object_store): + self.objects = object_store + + def execute(self, task, context): + #print("Executing task", task.name) + self.context = context + func_data = self.objects.get(task.payload_id) + if func_data is None: + raise Exception("WTF") + f = cloudpickle.loads(func_data) + args = [] + for a in task.args: + if a.local == proto.task_pb2.Arg.Locality.INTERNED: + data = a.data + else: + data = self.objects.get(a.reference_id) + realarg = cloudpickle.loads(data) + args.append(realarg) + out = f(*args) + id = self.objects.put(cloudpickle.dumps(out)) + self.context = None + return id + + def set_task_servicer(self, servicer): + ray._global_worker = ray.Worker(stub=servicer) + + +def serve(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + object_store = ObjectStore() + executor = SimpleExecutor(object_store) + task_servicer = TaskServicer(object_store, executor) + executor.set_task_servicer(task_servicer) + proto.task_pb2_grpc.add_TaskServerServicer_to_server( + task_servicer, server) + server.add_insecure_port('[::]:50051') + server.start() + try: + while True: + time.sleep(1000) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == '__main__': + logging.basicConfig() + serve() diff --git a/python/example_client.py b/python/example_client.py new file mode 100644 index 0000000..5d27a6c --- /dev/null +++ b/python/example_client.py @@ -0,0 +1,34 @@ +import cloudpickle +import logging +from concurrent import futures +import grpc +import ray + +ray.connect("localhost:50051") + + +@ray.remote +def fact(x): + out = 1 + for i in range(1, x+1): + out = i * out + return out + + +@ray.remote +def fib(x): + if x <= 1: + return 1 + return ray.get(fib.remote(x - 1)) + ray.get(fib.remote(x - 2)) + + +def run(): + # out = fact.remote(5) + # out2 = fact.remote(out) + # print("Got fact: ", ray.get(out2)) + fib_out = fib.remote(6) + # print("Fib out:", ray.get(fib_out)) + + +if __name__ == "__main__": + run() diff --git a/python/example_test.py b/python/example_test.py new file mode 100644 index 0000000..6b58def --- /dev/null +++ b/python/example_test.py @@ -0,0 +1,96 @@ +import ray +import numpy as np +import time +import dumb_raylet + +#ray.connect("localhost:50051") +object_store = dumb_raylet.ObjectStore() +executor = dumb_raylet.SimpleExecutor(object_store) +task_servicer = dumb_raylet.TaskServicer(object_store, executor) +executor.set_task_servicer(task_servicer) + + +def test_timing(): + @ray.remote + def empty_function(): + pass + + @ray.remote + def trivial_function(): + return 1 + + # Measure the time required to submit a remote task to the scheduler. + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + empty_function.remote() + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to submit an empty function call:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.00038. + + # Measure the time required to submit a remote task to the scheduler + # (where the remote task returns one value). + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + trivial_function.remote() + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to submit a trivial function call:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + + # Measure the time required to submit a remote task to the scheduler + # and get the result. + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + x = trivial_function.remote() + ray.get(x) + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to submit a trivial function call and get the " + "result:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.0013. + + # Measure the time required to do do a put. + elapsed_times = [] + for _ in range(1000): + start_time = time.time() + ray.put(1) + end_time = time.time() + elapsed_times.append(end_time - start_time) + elapsed_times = np.sort(elapsed_times) + average_elapsed_time = sum(elapsed_times) / 1000 + print("Time required to put an int:") + print(" Average: {}".format(average_elapsed_time)) + print(" 90th percentile: {}".format(elapsed_times[900])) + print(" 99th percentile: {}".format(elapsed_times[990])) + print(" worst: {}".format(elapsed_times[999])) + # average_elapsed_time should be about 0.00087. + + + +def run(): + test_timing() + + +if __name__ == "__main__": + run() diff --git a/python/main.py b/python/main.py new file mode 100644 index 0000000..ebfe863 --- /dev/null +++ b/python/main.py @@ -0,0 +1,11 @@ +import cloudpickle +import proto.task_pb2_grpc + + +def main(): + a = cloudpickle.dumps("hello world") + print(a) + + +if __name__ == "__main__": + main() diff --git a/python/proto/task_pb2.py b/python/proto/task_pb2.py new file mode 100755 index 0000000..116b6e6 --- /dev/null +++ b/python/proto/task_pb2.py @@ -0,0 +1,477 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: proto/task.proto + +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='proto/task.proto', + package='', + syntax='proto3', + serialized_options=None, + serialized_pb=b'\n\x10proto/task.proto\"\xa5\x01\n\x03\x41rg\x12#\n\x05local\x18\x01 \x01(\x0e\x32\r.Arg.LocalityR\x05local\x12!\n\x0creference_id\x18\x02 \x01(\x0cR\x0breferenceId\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\x0cR\x04\x64\x61ta\x12\x19\n\x04type\x18\x04 \x01(\x0e\x32\x05.TypeR\x04type\"\'\n\x08Locality\x12\x0c\n\x08INTERNED\x10\x00\x12\r\n\tREFERENCE\x10\x01\"S\n\x04Task\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n\npayload_id\x18\x02 \x01(\x0cR\tpayloadId\x12\x18\n\x04\x61rgs\x18\x03 \x03(\x0b\x32\x04.ArgR\x04\x61rgs\")\n\nTaskTicket\x12\x1b\n\treturn_id\x18\x01 \x01(\x0cR\x08returnId\" \n\nPutRequest\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\"\x1d\n\x0bPutResponse\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"\x1c\n\nGetRequest\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"7\n\x0bGetResponse\x12\x14\n\x05valid\x18\x01 \x01(\x08R\x05valid\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta*\x13\n\x04Type\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x32\x82\x01\n\nTaskServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12 \n\x08Schedule\x12\x05.Task\x1a\x0b.TaskTicket\"\x00\x32\x85\x01\n\nWorkServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12#\n\x07\x45xecute\x12\x0b.TaskTicket\x1a\x05.Task\"\x00(\x01\x30\x01\x62\x06proto3' +) + +_TYPE = _descriptor.EnumDescriptor( + name='Type', + full_name='Type', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='DEFAULT', index=0, number=0, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=468, + serialized_end=487, +) +_sym_db.RegisterEnumDescriptor(_TYPE) + +Type = enum_type_wrapper.EnumTypeWrapper(_TYPE) +DEFAULT = 0 + + +_ARG_LOCALITY = _descriptor.EnumDescriptor( + name='Locality', + full_name='Arg.Locality', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='INTERNED', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='REFERENCE', index=1, number=1, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=147, + serialized_end=186, +) +_sym_db.RegisterEnumDescriptor(_ARG_LOCALITY) + + +_ARG = _descriptor.Descriptor( + name='Arg', + full_name='Arg', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='local', full_name='Arg.local', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='local', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='reference_id', full_name='Arg.reference_id', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='referenceId', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='data', full_name='Arg.data', index=2, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='data', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='Arg.type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='type', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _ARG_LOCALITY, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=21, + serialized_end=186, +) + + +_TASK = _descriptor.Descriptor( + name='Task', + full_name='Task', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='Task.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='name', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='payload_id', full_name='Task.payload_id', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='payloadId', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='args', full_name='Task.args', index=2, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='args', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=188, + serialized_end=271, +) + + +_TASKTICKET = _descriptor.Descriptor( + name='TaskTicket', + full_name='TaskTicket', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='return_id', full_name='TaskTicket.return_id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='returnId', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=273, + serialized_end=314, +) + + +_PUTREQUEST = _descriptor.Descriptor( + name='PutRequest', + full_name='PutRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='data', full_name='PutRequest.data', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='data', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=316, + serialized_end=348, +) + + +_PUTRESPONSE = _descriptor.Descriptor( + name='PutResponse', + full_name='PutResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='PutResponse.id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='id', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=350, + serialized_end=379, +) + + +_GETREQUEST = _descriptor.Descriptor( + name='GetRequest', + full_name='GetRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='GetRequest.id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='id', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=381, + serialized_end=409, +) + + +_GETRESPONSE = _descriptor.Descriptor( + name='GetResponse', + full_name='GetResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='valid', full_name='GetResponse.valid', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='valid', file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='data', full_name='GetResponse.data', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, json_name='data', file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=411, + serialized_end=466, +) + +_ARG.fields_by_name['local'].enum_type = _ARG_LOCALITY +_ARG.fields_by_name['type'].enum_type = _TYPE +_ARG_LOCALITY.containing_type = _ARG +_TASK.fields_by_name['args'].message_type = _ARG +DESCRIPTOR.message_types_by_name['Arg'] = _ARG +DESCRIPTOR.message_types_by_name['Task'] = _TASK +DESCRIPTOR.message_types_by_name['TaskTicket'] = _TASKTICKET +DESCRIPTOR.message_types_by_name['PutRequest'] = _PUTREQUEST +DESCRIPTOR.message_types_by_name['PutResponse'] = _PUTRESPONSE +DESCRIPTOR.message_types_by_name['GetRequest'] = _GETREQUEST +DESCRIPTOR.message_types_by_name['GetResponse'] = _GETRESPONSE +DESCRIPTOR.enum_types_by_name['Type'] = _TYPE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Arg = _reflection.GeneratedProtocolMessageType('Arg', (_message.Message,), { + 'DESCRIPTOR' : _ARG, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:Arg) + }) +_sym_db.RegisterMessage(Arg) + +Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), { + 'DESCRIPTOR' : _TASK, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:Task) + }) +_sym_db.RegisterMessage(Task) + +TaskTicket = _reflection.GeneratedProtocolMessageType('TaskTicket', (_message.Message,), { + 'DESCRIPTOR' : _TASKTICKET, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:TaskTicket) + }) +_sym_db.RegisterMessage(TaskTicket) + +PutRequest = _reflection.GeneratedProtocolMessageType('PutRequest', (_message.Message,), { + 'DESCRIPTOR' : _PUTREQUEST, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:PutRequest) + }) +_sym_db.RegisterMessage(PutRequest) + +PutResponse = _reflection.GeneratedProtocolMessageType('PutResponse', (_message.Message,), { + 'DESCRIPTOR' : _PUTRESPONSE, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:PutResponse) + }) +_sym_db.RegisterMessage(PutResponse) + +GetRequest = _reflection.GeneratedProtocolMessageType('GetRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETREQUEST, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:GetRequest) + }) +_sym_db.RegisterMessage(GetRequest) + +GetResponse = _reflection.GeneratedProtocolMessageType('GetResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETRESPONSE, + '__module__' : 'proto.task_pb2' + # @@protoc_insertion_point(class_scope:GetResponse) + }) +_sym_db.RegisterMessage(GetResponse) + + + +_TASKSERVER = _descriptor.ServiceDescriptor( + name='TaskServer', + full_name='TaskServer', + file=DESCRIPTOR, + index=0, + serialized_options=None, + serialized_start=490, + serialized_end=620, + methods=[ + _descriptor.MethodDescriptor( + name='GetObject', + full_name='TaskServer.GetObject', + index=0, + containing_service=None, + input_type=_GETREQUEST, + output_type=_GETRESPONSE, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='PutObject', + full_name='TaskServer.PutObject', + index=1, + containing_service=None, + input_type=_PUTREQUEST, + output_type=_PUTRESPONSE, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='Schedule', + full_name='TaskServer.Schedule', + index=2, + containing_service=None, + input_type=_TASK, + output_type=_TASKTICKET, + serialized_options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_TASKSERVER) + +DESCRIPTOR.services_by_name['TaskServer'] = _TASKSERVER + + +_WORKSERVER = _descriptor.ServiceDescriptor( + name='WorkServer', + full_name='WorkServer', + file=DESCRIPTOR, + index=1, + serialized_options=None, + serialized_start=623, + serialized_end=756, + methods=[ + _descriptor.MethodDescriptor( + name='GetObject', + full_name='WorkServer.GetObject', + index=0, + containing_service=None, + input_type=_GETREQUEST, + output_type=_GETRESPONSE, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='PutObject', + full_name='WorkServer.PutObject', + index=1, + containing_service=None, + input_type=_PUTREQUEST, + output_type=_PUTRESPONSE, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='Execute', + full_name='WorkServer.Execute', + index=2, + containing_service=None, + input_type=_TASKTICKET, + output_type=_TASK, + serialized_options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_WORKSERVER) + +DESCRIPTOR.services_by_name['WorkServer'] = _WORKSERVER + +# @@protoc_insertion_point(module_scope) diff --git a/python/proto/task_pb2_grpc.py b/python/proto/task_pb2_grpc.py new file mode 100755 index 0000000..79a8526 --- /dev/null +++ b/python/proto/task_pb2_grpc.py @@ -0,0 +1,156 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +from proto import task_pb2 as proto_dot_task__pb2 + + +class TaskServerStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetObject = channel.unary_unary( + '/TaskServer/GetObject', + request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString, + response_deserializer=proto_dot_task__pb2.GetResponse.FromString, + ) + self.PutObject = channel.unary_unary( + '/TaskServer/PutObject', + request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString, + response_deserializer=proto_dot_task__pb2.PutResponse.FromString, + ) + self.Schedule = channel.unary_unary( + '/TaskServer/Schedule', + request_serializer=proto_dot_task__pb2.Task.SerializeToString, + response_deserializer=proto_dot_task__pb2.TaskTicket.FromString, + ) + + +class TaskServerServicer(object): + # missing associated documentation comment in .proto file + pass + + def GetObject(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PutObject(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Schedule(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_TaskServerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetObject': grpc.unary_unary_rpc_method_handler( + servicer.GetObject, + request_deserializer=proto_dot_task__pb2.GetRequest.FromString, + response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString, + ), + 'PutObject': grpc.unary_unary_rpc_method_handler( + servicer.PutObject, + request_deserializer=proto_dot_task__pb2.PutRequest.FromString, + response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString, + ), + 'Schedule': grpc.unary_unary_rpc_method_handler( + servicer.Schedule, + request_deserializer=proto_dot_task__pb2.Task.FromString, + response_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'TaskServer', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + +class WorkServerStub(object): + # missing associated documentation comment in .proto file + pass + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetObject = channel.unary_unary( + '/WorkServer/GetObject', + request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString, + response_deserializer=proto_dot_task__pb2.GetResponse.FromString, + ) + self.PutObject = channel.unary_unary( + '/WorkServer/PutObject', + request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString, + response_deserializer=proto_dot_task__pb2.PutResponse.FromString, + ) + self.Execute = channel.stream_stream( + '/WorkServer/Execute', + request_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString, + response_deserializer=proto_dot_task__pb2.Task.FromString, + ) + + +class WorkServerServicer(object): + # missing associated documentation comment in .proto file + pass + + def GetObject(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PutObject(self, request, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Execute(self, request_iterator, context): + # missing associated documentation comment in .proto file + pass + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_WorkServerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetObject': grpc.unary_unary_rpc_method_handler( + servicer.GetObject, + request_deserializer=proto_dot_task__pb2.GetRequest.FromString, + response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString, + ), + 'PutObject': grpc.unary_unary_rpc_method_handler( + servicer.PutObject, + request_deserializer=proto_dot_task__pb2.PutRequest.FromString, + response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString, + ), + 'Execute': grpc.stream_stream_rpc_method_handler( + servicer.Execute, + request_deserializer=proto_dot_task__pb2.TaskTicket.FromString, + response_serializer=proto_dot_task__pb2.Task.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'WorkServer', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) diff --git a/python/ray.py b/python/ray.py new file mode 100644 index 0000000..862efa9 --- /dev/null +++ b/python/ray.py @@ -0,0 +1,163 @@ +import cloudpickle +import grpc +import proto.task_pb2 +import proto.task_pb2_grpc +import uuid + + +class ObjectID: + def __init__(self, id): + self.id = id + + def __repr__(self): + return "ObjectID(%s)" % self.id.decode() + + +worker_registry = {} + + +def set_global_worker(worker): + global _global_worker + _global_worker = worker + + +def register_worker(worker): + id = uuid.uuid4() + worker_registry[id] = worker + return id + + +def get_worker_registry(id): + out = worker_registry.get(id) + if out is None: + return _global_worker + return out + + +class Worker: + def __init__(self, conn_str="", stub=None): + if stub is None: + self.channel = grpc.insecure_channel(conn_str) + self.server = proto.task_pb2_grpc.TaskServerStub(self.channel) + else: + self.server = stub + self.uuid = register_worker(self) + + def get(self, ids): + to_get = [] + single = False + if isinstance(ids, list): + to_get = [x.id for x in ids] + elif isinstance(ids, ObjectID): + to_get = [ids.id] + single = True + else: + raise Exception("Can't get something that's not a list of IDs or just an ID") + + out = [self._get(x) for x in to_get] + if single: + out = out[0] + return out + + def _get(self, id: bytes): + req = proto.task_pb2.GetRequest(id=id) + data = self.server.GetObject(req) + return cloudpickle.loads(data.data) + + def put(self, vals): + to_put = [] + single = False + if isinstance(vals, list): + to_put = vals + else: + single = True + to_put.append(vals) + + out = [self._put(x) for x in to_put] + if single: + out = out[0] + return out + + def _put(self, val): + data = cloudpickle.dumps(val) + #print("val: %s\ndata: %s"%(val, data)) + req = proto.task_pb2.PutRequest(data=data) + resp = self.server.PutObject(req) + return ObjectID(resp.id) + + def remote(self, func): + return RemoteFunc(self, func) + + def schedule(self, task): + return self.server.Schedule(task) + + def close(self): + self.channel.close() + + +class RemoteFunc: + def __init__(self, worker, f): + self._func = f + self._name = f.__name__ + self.id = None + self._worker_id = worker.uuid + + def __call__(self, *args, **kwargs): + raise Exception("Matching the old API") + + def remote(self, *args): + if self.id is None: + self._push_func() + t = proto.task_pb2.Task() + t.name = self._name + t.payload_id = self.id.id + for a in args: + arg = proto.task_pb2.Arg() + if isinstance(a, ObjectID): + arg.local = proto.task_pb2.Arg.Locality.REFERENCE + arg.reference_id = a.id + else: + arg.local = proto.task_pb2.Arg.Locality.INTERNED + arg.data = cloudpickle.dumps(a) + t.args.append(arg) + worker = get_worker_registry(self._worker_id) + ticket = worker.schedule(t) + return ObjectID(ticket.return_id) + + def _push_func(self): + worker = get_worker_registry(self._worker_id) + self.id = worker.put(self._func) + + def __repr__(self): + return "RemoteFunc(%s, %s)" % (self._name, self.id) + + +_global_worker = None + + +def connect(*args, **kwargs): + global _global_worker + if _global_worker is not None: + raise Exception("Can't connect a second global worker") + _global_worker = Worker(*args, **kwargs) + + +def get(*args, **kwargs): + global _global_worker + if _global_worker is None: + raise Exception("Need a connection before calling") + return _global_worker.get(*args, **kwargs) + + +def put(*args, **kwargs): + global _global_worker + if _global_worker is None: + raise Exception("Need a connection before calling") + return _global_worker.put(*args, **kwargs) + + +def remote(*args, **kwargs): + global _global_worker + if _global_worker is None: + raise Exception("Need a connection before calling") + return _global_worker.remote(*args, **kwargs) diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 0000000..3b749c3 --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,5 @@ +cloudpickle +futures +grpcio +google +protobuf