small python ray
This commit is contained in:
commit
e02e80708a
15 changed files with 1261 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
/bazel-*
|
||||
*.pyc
|
||||
78
WORKSPACE
Normal file
78
WORKSPACE
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
http_archive(
|
||||
name = "rules_python",
|
||||
url = "https://github.com/bazelbuild/rules_python/releases/download/0.0.2/rules_python-0.0.2.tar.gz",
|
||||
strip_prefix = "rules_python-0.0.2",
|
||||
sha256 = "b5668cde8bb6e3515057ef465a35ad712214962f0b3a314e551204266c7be90c",
|
||||
)
|
||||
|
||||
load("@rules_python//python:repositories.bzl", "py_repositories")
|
||||
|
||||
py_repositories()
|
||||
|
||||
load("@rules_python//python:pip.bzl", "pip_repositories")
|
||||
|
||||
pip_repositories()
|
||||
|
||||
http_archive(
|
||||
name = "rules_proto_grpc",
|
||||
urls = ["https://github.com/rules-proto-grpc/rules_proto_grpc/archive/1.0.2.tar.gz"],
|
||||
sha256 = "5f0f2fc0199810c65a2de148a52ba0aff14d631d4e8202f41aff6a9d590a471b",
|
||||
strip_prefix = "rules_proto_grpc-1.0.2",
|
||||
)
|
||||
|
||||
load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains", "rules_proto_grpc_repos")
|
||||
rules_proto_grpc_toolchains()
|
||||
rules_proto_grpc_repos()
|
||||
|
||||
|
||||
load("@rules_proto_grpc//python:repositories.bzl", rules_proto_grpc_python_repos="python_repos")
|
||||
|
||||
rules_proto_grpc_python_repos()
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
grpc_deps()
|
||||
|
||||
rules_python_external_version = "0.1.5"
|
||||
|
||||
http_archive(
|
||||
name = "rules_python_external",
|
||||
url = "https://github.com/dillon-giacoppo/rules_python_external/archive/v{version}.zip".format(version = rules_python_external_version),
|
||||
sha256 = "bc655e6d402915944e014c3b2cad23d0a97b83a66cc22f20db09c9f8da2e2789",
|
||||
strip_prefix = "rules_python_external-{version}".format(version = rules_python_external_version),
|
||||
)
|
||||
|
||||
load("@rules_python_external//:repositories.bzl", "rules_python_external_dependencies")
|
||||
rules_python_external_dependencies()
|
||||
|
||||
|
||||
|
||||
load("@rules_python//python:pip.bzl", "pip_import")
|
||||
pip_import(
|
||||
name = "rules_proto_grpc_py3_deps",
|
||||
#python_interpreter = "python3",
|
||||
requirements = "@rules_proto_grpc//python:requirements.txt",
|
||||
)
|
||||
|
||||
load("@rules_proto_grpc_py3_deps//:requirements.bzl", pip3_install="pip_install")
|
||||
pip3_install()
|
||||
|
||||
load("@rules_python_external//:defs.bzl", "pip_install")
|
||||
pip_install(
|
||||
name = "py_deps",
|
||||
requirements = "//python:requirements.txt",
|
||||
# (Optional) You can provide a python interpreter (by path):
|
||||
#python_interpreter = "/usr/bin/python3.8",
|
||||
# (Optional) Alternatively you can provide an in-build python interpreter, that is available as a Bazel target.
|
||||
# This overrides `python_interpreter`.
|
||||
# Note: You need to set up the interpreter target beforehand (not shown here). Please see the `example` folder for further details.
|
||||
#python_interpreter_target = "@python_interpreter//:python_bin",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
0
bazel/BUILD
Normal file
0
bazel/BUILD
Normal file
36
bazel/tools.bzl
Normal file
36
bazel/tools.bzl
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
def copy_to_workspace(name, srcs, dstdir = ""):
|
||||
if dstdir.startswith("/") or dstdir.startswith("\\"):
|
||||
fail("Subdirectory must be a relative path: " + dstdir)
|
||||
src_locations = " ".join(["$(locations %s)" % (src,) for src in srcs])
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
outs = [name + ".out"],
|
||||
# Keep this Bash script equivalent to the batch script below (or take out the batch script)
|
||||
cmd = r"""
|
||||
mkdir -p -- {dstdir}
|
||||
for f in {locations}; do
|
||||
rm -f -- {dstdir}$${{f##*/}}
|
||||
cp -f -- "$$f" {dstdir}
|
||||
done
|
||||
date > $@
|
||||
""".format(
|
||||
locations = src_locations,
|
||||
dstdir = "." + ("/" + dstdir.replace("\\", "/")).rstrip("/") + "/",
|
||||
),
|
||||
# Keep this batch script equivalent to the Bash script above (or take out the batch script)
|
||||
cmd_bat = """
|
||||
(
|
||||
if not exist {dstdir} mkdir {dstdir}
|
||||
) && (
|
||||
for %f in ({locations}) do @(
|
||||
(if exist {dstdir}%~nxf del /f /q {dstdir}%~nxf) &&
|
||||
copy /B /Y %f {dstdir} >NUL
|
||||
)
|
||||
) && >$@ echo %TIME%
|
||||
""".replace("\r", "").replace("\n", " ").format(
|
||||
locations = src_locations,
|
||||
dstdir = "." + ("\\" + dstdir.replace("/", "\\")).rstrip("\\") + "\\",
|
||||
),
|
||||
local = 1,
|
||||
)
|
||||
18
proto/BUILD
Normal file
18
proto/BUILD
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("@rules_proto_grpc//python:defs.bzl", "python_grpc_compile")
|
||||
|
||||
proto_library(
|
||||
name = "task_proto",
|
||||
srcs = [
|
||||
"task.proto",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_protobuf//:any_proto",
|
||||
],
|
||||
)
|
||||
|
||||
python_grpc_compile(
|
||||
name = "task_proto_grpc_py",
|
||||
deps = [":task_proto"],
|
||||
)
|
||||
56
proto/task.proto
Normal file
56
proto/task.proto
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
syntax = "proto3";
|
||||
|
||||
enum Type {
|
||||
DEFAULT = 0;
|
||||
}
|
||||
|
||||
message Arg {
|
||||
enum Locality {
|
||||
INTERNED = 0;
|
||||
REFERENCE = 1;
|
||||
}
|
||||
Locality local = 1;
|
||||
bytes reference_id = 2;
|
||||
bytes data = 3;
|
||||
Type type = 4;
|
||||
}
|
||||
|
||||
message Task {
|
||||
// Optionally Provided Task Name
|
||||
string name = 1;
|
||||
bytes payload_id = 2;
|
||||
repeated Arg args = 3;
|
||||
}
|
||||
|
||||
message TaskTicket {
|
||||
bytes return_id = 1;
|
||||
}
|
||||
|
||||
message PutRequest {
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
message PutResponse {
|
||||
bytes id = 1;
|
||||
}
|
||||
|
||||
message GetRequest {
|
||||
bytes id = 1;
|
||||
}
|
||||
|
||||
message GetResponse {
|
||||
bool valid = 1;
|
||||
bytes data = 2;
|
||||
}
|
||||
|
||||
service TaskServer {
|
||||
rpc GetObject(GetRequest) returns (GetResponse) {}
|
||||
rpc PutObject(PutRequest) returns (PutResponse) {}
|
||||
rpc Schedule(Task) returns (TaskTicket) {}
|
||||
}
|
||||
|
||||
service WorkServer {
|
||||
rpc GetObject(GetRequest) returns (GetResponse) {}
|
||||
rpc PutObject(PutRequest) returns (PutResponse) {}
|
||||
rpc Execute(stream TaskTicket) returns (stream Task) {}
|
||||
}
|
||||
25
python/BUILD
Normal file
25
python/BUILD
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
load("@py_deps//:requirements.bzl", "requirement")
|
||||
load("//bazel:tools.bzl", "copy_to_workspace")
|
||||
|
||||
filegroup(
|
||||
name = "all_py_proto",
|
||||
srcs = [
|
||||
"//proto:task_proto_grpc_py",
|
||||
],
|
||||
)
|
||||
|
||||
copy_to_workspace(
|
||||
name = "proto_folder",
|
||||
srcs = [":all_py_proto"],
|
||||
dstdir = "python/proto",
|
||||
)
|
||||
|
||||
|
||||
py_binary(
|
||||
name = "main",
|
||||
srcs = ["main.py"],
|
||||
data = [ ":proto_folder" ],
|
||||
deps = [
|
||||
requirement("cloudpickle"),
|
||||
],
|
||||
)
|
||||
104
python/dumb_raylet.py
Normal file
104
python/dumb_raylet.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import cloudpickle
|
||||
import ray
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import proto.task_pb2
|
||||
import proto.task_pb2_grpc
|
||||
import threading
|
||||
import time
|
||||
|
||||
id_printer = 1
|
||||
|
||||
|
||||
def generate_id():
|
||||
global id_printer
|
||||
out = str(id_printer).encode()
|
||||
id_printer += 1
|
||||
return out
|
||||
|
||||
|
||||
class ObjectStore:
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get(self, id):
|
||||
with self.lock:
|
||||
return self.data[id]
|
||||
|
||||
def put(self, data):
|
||||
with self.lock:
|
||||
id = generate_id()
|
||||
self.data[id] = data
|
||||
return id
|
||||
|
||||
|
||||
class TaskServicer(proto.task_pb2_grpc.TaskServerServicer):
|
||||
def __init__(self, objects, executor):
|
||||
self.objects = objects
|
||||
self.executor = executor
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
data = self.objects.get(request.id)
|
||||
if data is None:
|
||||
return proto.task_pb2.GetResponse(valid=False)
|
||||
return proto.task_pb2.GetResponse(valid=True, data=data)
|
||||
|
||||
def PutObject(self, request, context=None):
|
||||
id = self.objects.put(request.data)
|
||||
return proto.task_pb2.PutResponse(id=id)
|
||||
|
||||
def Schedule(self, task, context=None):
|
||||
return_val = self.executor.execute(task, context)
|
||||
return proto.task_pb2.TaskTicket(return_id=return_val)
|
||||
|
||||
|
||||
class SimpleExecutor:
|
||||
def __init__(self, object_store):
|
||||
self.objects = object_store
|
||||
|
||||
def execute(self, task, context):
|
||||
#print("Executing task", task.name)
|
||||
self.context = context
|
||||
func_data = self.objects.get(task.payload_id)
|
||||
if func_data is None:
|
||||
raise Exception("WTF")
|
||||
f = cloudpickle.loads(func_data)
|
||||
args = []
|
||||
for a in task.args:
|
||||
if a.local == proto.task_pb2.Arg.Locality.INTERNED:
|
||||
data = a.data
|
||||
else:
|
||||
data = self.objects.get(a.reference_id)
|
||||
realarg = cloudpickle.loads(data)
|
||||
args.append(realarg)
|
||||
out = f(*args)
|
||||
id = self.objects.put(cloudpickle.dumps(out))
|
||||
self.context = None
|
||||
return id
|
||||
|
||||
def set_task_servicer(self, servicer):
|
||||
ray._global_worker = ray.Worker(stub=servicer)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
object_store = ObjectStore()
|
||||
executor = SimpleExecutor(object_store)
|
||||
task_servicer = TaskServicer(object_store, executor)
|
||||
executor.set_task_servicer(task_servicer)
|
||||
proto.task_pb2_grpc.add_TaskServerServicer_to_server(
|
||||
task_servicer, server)
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig()
|
||||
serve()
|
||||
34
python/example_client.py
Normal file
34
python/example_client.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import cloudpickle
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import ray
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
|
||||
@ray.remote
|
||||
def fact(x):
|
||||
out = 1
|
||||
for i in range(1, x+1):
|
||||
out = i * out
|
||||
return out
|
||||
|
||||
|
||||
@ray.remote
|
||||
def fib(x):
|
||||
if x <= 1:
|
||||
return 1
|
||||
return ray.get(fib.remote(x - 1)) + ray.get(fib.remote(x - 2))
|
||||
|
||||
|
||||
def run():
|
||||
# out = fact.remote(5)
|
||||
# out2 = fact.remote(out)
|
||||
# print("Got fact: ", ray.get(out2))
|
||||
fib_out = fib.remote(6)
|
||||
# print("Fib out:", ray.get(fib_out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
96
python/example_test.py
Normal file
96
python/example_test.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import ray
|
||||
import numpy as np
|
||||
import time
|
||||
import dumb_raylet
|
||||
|
||||
#ray.connect("localhost:50051")
|
||||
object_store = dumb_raylet.ObjectStore()
|
||||
executor = dumb_raylet.SimpleExecutor(object_store)
|
||||
task_servicer = dumb_raylet.TaskServicer(object_store, executor)
|
||||
executor.set_task_servicer(task_servicer)
|
||||
|
||||
|
||||
def test_timing():
|
||||
@ray.remote
|
||||
def empty_function():
|
||||
pass
|
||||
|
||||
@ray.remote
|
||||
def trivial_function():
|
||||
return 1
|
||||
|
||||
# Measure the time required to submit a remote task to the scheduler.
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
empty_function.remote()
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to submit an empty function call:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
# average_elapsed_time should be about 0.00038.
|
||||
|
||||
# Measure the time required to submit a remote task to the scheduler
|
||||
# (where the remote task returns one value).
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
trivial_function.remote()
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to submit a trivial function call:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
|
||||
# Measure the time required to submit a remote task to the scheduler
|
||||
# and get the result.
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
x = trivial_function.remote()
|
||||
ray.get(x)
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to submit a trivial function call and get the "
|
||||
"result:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
# average_elapsed_time should be about 0.0013.
|
||||
|
||||
# Measure the time required to do do a put.
|
||||
elapsed_times = []
|
||||
for _ in range(1000):
|
||||
start_time = time.time()
|
||||
ray.put(1)
|
||||
end_time = time.time()
|
||||
elapsed_times.append(end_time - start_time)
|
||||
elapsed_times = np.sort(elapsed_times)
|
||||
average_elapsed_time = sum(elapsed_times) / 1000
|
||||
print("Time required to put an int:")
|
||||
print(" Average: {}".format(average_elapsed_time))
|
||||
print(" 90th percentile: {}".format(elapsed_times[900]))
|
||||
print(" 99th percentile: {}".format(elapsed_times[990]))
|
||||
print(" worst: {}".format(elapsed_times[999]))
|
||||
# average_elapsed_time should be about 0.00087.
|
||||
|
||||
|
||||
|
||||
def run():
|
||||
test_timing()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
11
python/main.py
Normal file
11
python/main.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import cloudpickle
|
||||
import proto.task_pb2_grpc
|
||||
|
||||
|
||||
def main():
|
||||
a = cloudpickle.dumps("hello world")
|
||||
print(a)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
477
python/proto/task_pb2.py
Executable file
477
python/proto/task_pb2.py
Executable file
|
|
@ -0,0 +1,477 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: proto/task.proto
|
||||
|
||||
from google.protobuf.internal import enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from google.protobuf import reflection as _reflection
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||
name='proto/task.proto',
|
||||
package='',
|
||||
syntax='proto3',
|
||||
serialized_options=None,
|
||||
serialized_pb=b'\n\x10proto/task.proto\"\xa5\x01\n\x03\x41rg\x12#\n\x05local\x18\x01 \x01(\x0e\x32\r.Arg.LocalityR\x05local\x12!\n\x0creference_id\x18\x02 \x01(\x0cR\x0breferenceId\x12\x12\n\x04\x64\x61ta\x18\x03 \x01(\x0cR\x04\x64\x61ta\x12\x19\n\x04type\x18\x04 \x01(\x0e\x32\x05.TypeR\x04type\"\'\n\x08Locality\x12\x0c\n\x08INTERNED\x10\x00\x12\r\n\tREFERENCE\x10\x01\"S\n\x04Task\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n\npayload_id\x18\x02 \x01(\x0cR\tpayloadId\x12\x18\n\x04\x61rgs\x18\x03 \x03(\x0b\x32\x04.ArgR\x04\x61rgs\")\n\nTaskTicket\x12\x1b\n\treturn_id\x18\x01 \x01(\x0cR\x08returnId\" \n\nPutRequest\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\"\x1d\n\x0bPutResponse\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"\x1c\n\nGetRequest\x12\x0e\n\x02id\x18\x01 \x01(\x0cR\x02id\"7\n\x0bGetResponse\x12\x14\n\x05valid\x18\x01 \x01(\x08R\x05valid\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta*\x13\n\x04Type\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x32\x82\x01\n\nTaskServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12 \n\x08Schedule\x12\x05.Task\x1a\x0b.TaskTicket\"\x00\x32\x85\x01\n\nWorkServer\x12(\n\tGetObject\x12\x0b.GetRequest\x1a\x0c.GetResponse\"\x00\x12(\n\tPutObject\x12\x0b.PutRequest\x1a\x0c.PutResponse\"\x00\x12#\n\x07\x45xecute\x12\x0b.TaskTicket\x1a\x05.Task\"\x00(\x01\x30\x01\x62\x06proto3'
|
||||
)
|
||||
|
||||
_TYPE = _descriptor.EnumDescriptor(
|
||||
name='Type',
|
||||
full_name='Type',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
values=[
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='DEFAULT', index=0, number=0,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
],
|
||||
containing_type=None,
|
||||
serialized_options=None,
|
||||
serialized_start=468,
|
||||
serialized_end=487,
|
||||
)
|
||||
_sym_db.RegisterEnumDescriptor(_TYPE)
|
||||
|
||||
Type = enum_type_wrapper.EnumTypeWrapper(_TYPE)
|
||||
DEFAULT = 0
|
||||
|
||||
|
||||
_ARG_LOCALITY = _descriptor.EnumDescriptor(
|
||||
name='Locality',
|
||||
full_name='Arg.Locality',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
values=[
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='INTERNED', index=0, number=0,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
_descriptor.EnumValueDescriptor(
|
||||
name='REFERENCE', index=1, number=1,
|
||||
serialized_options=None,
|
||||
type=None),
|
||||
],
|
||||
containing_type=None,
|
||||
serialized_options=None,
|
||||
serialized_start=147,
|
||||
serialized_end=186,
|
||||
)
|
||||
_sym_db.RegisterEnumDescriptor(_ARG_LOCALITY)
|
||||
|
||||
|
||||
_ARG = _descriptor.Descriptor(
|
||||
name='Arg',
|
||||
full_name='Arg',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='local', full_name='Arg.local', index=0,
|
||||
number=1, type=14, cpp_type=8, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='local', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='reference_id', full_name='Arg.reference_id', index=1,
|
||||
number=2, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='referenceId', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='Arg.data', index=2,
|
||||
number=3, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='data', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='type', full_name='Arg.type', index=3,
|
||||
number=4, type=14, cpp_type=8, label=1,
|
||||
has_default_value=False, default_value=0,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='type', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
_ARG_LOCALITY,
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=21,
|
||||
serialized_end=186,
|
||||
)
|
||||
|
||||
|
||||
_TASK = _descriptor.Descriptor(
|
||||
name='Task',
|
||||
full_name='Task',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='name', full_name='Task.name', index=0,
|
||||
number=1, type=9, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='name', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='payload_id', full_name='Task.payload_id', index=1,
|
||||
number=2, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='payloadId', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='args', full_name='Task.args', index=2,
|
||||
number=3, type=11, cpp_type=10, label=3,
|
||||
has_default_value=False, default_value=[],
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='args', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=188,
|
||||
serialized_end=271,
|
||||
)
|
||||
|
||||
|
||||
_TASKTICKET = _descriptor.Descriptor(
|
||||
name='TaskTicket',
|
||||
full_name='TaskTicket',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='return_id', full_name='TaskTicket.return_id', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='returnId', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=273,
|
||||
serialized_end=314,
|
||||
)
|
||||
|
||||
|
||||
_PUTREQUEST = _descriptor.Descriptor(
|
||||
name='PutRequest',
|
||||
full_name='PutRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='PutRequest.data', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='data', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=316,
|
||||
serialized_end=348,
|
||||
)
|
||||
|
||||
|
||||
_PUTRESPONSE = _descriptor.Descriptor(
|
||||
name='PutResponse',
|
||||
full_name='PutResponse',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='id', full_name='PutResponse.id', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='id', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=350,
|
||||
serialized_end=379,
|
||||
)
|
||||
|
||||
|
||||
_GETREQUEST = _descriptor.Descriptor(
|
||||
name='GetRequest',
|
||||
full_name='GetRequest',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='id', full_name='GetRequest.id', index=0,
|
||||
number=1, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='id', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=381,
|
||||
serialized_end=409,
|
||||
)
|
||||
|
||||
|
||||
_GETRESPONSE = _descriptor.Descriptor(
|
||||
name='GetResponse',
|
||||
full_name='GetResponse',
|
||||
filename=None,
|
||||
file=DESCRIPTOR,
|
||||
containing_type=None,
|
||||
fields=[
|
||||
_descriptor.FieldDescriptor(
|
||||
name='valid', full_name='GetResponse.valid', index=0,
|
||||
number=1, type=8, cpp_type=7, label=1,
|
||||
has_default_value=False, default_value=False,
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='valid', file=DESCRIPTOR),
|
||||
_descriptor.FieldDescriptor(
|
||||
name='data', full_name='GetResponse.data', index=1,
|
||||
number=2, type=12, cpp_type=9, label=1,
|
||||
has_default_value=False, default_value=b"",
|
||||
message_type=None, enum_type=None, containing_type=None,
|
||||
is_extension=False, extension_scope=None,
|
||||
serialized_options=None, json_name='data', file=DESCRIPTOR),
|
||||
],
|
||||
extensions=[
|
||||
],
|
||||
nested_types=[],
|
||||
enum_types=[
|
||||
],
|
||||
serialized_options=None,
|
||||
is_extendable=False,
|
||||
syntax='proto3',
|
||||
extension_ranges=[],
|
||||
oneofs=[
|
||||
],
|
||||
serialized_start=411,
|
||||
serialized_end=466,
|
||||
)
|
||||
|
||||
_ARG.fields_by_name['local'].enum_type = _ARG_LOCALITY
|
||||
_ARG.fields_by_name['type'].enum_type = _TYPE
|
||||
_ARG_LOCALITY.containing_type = _ARG
|
||||
_TASK.fields_by_name['args'].message_type = _ARG
|
||||
DESCRIPTOR.message_types_by_name['Arg'] = _ARG
|
||||
DESCRIPTOR.message_types_by_name['Task'] = _TASK
|
||||
DESCRIPTOR.message_types_by_name['TaskTicket'] = _TASKTICKET
|
||||
DESCRIPTOR.message_types_by_name['PutRequest'] = _PUTREQUEST
|
||||
DESCRIPTOR.message_types_by_name['PutResponse'] = _PUTRESPONSE
|
||||
DESCRIPTOR.message_types_by_name['GetRequest'] = _GETREQUEST
|
||||
DESCRIPTOR.message_types_by_name['GetResponse'] = _GETRESPONSE
|
||||
DESCRIPTOR.enum_types_by_name['Type'] = _TYPE
|
||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||
|
||||
Arg = _reflection.GeneratedProtocolMessageType('Arg', (_message.Message,), {
|
||||
'DESCRIPTOR' : _ARG,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:Arg)
|
||||
})
|
||||
_sym_db.RegisterMessage(Arg)
|
||||
|
||||
Task = _reflection.GeneratedProtocolMessageType('Task', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TASK,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:Task)
|
||||
})
|
||||
_sym_db.RegisterMessage(Task)
|
||||
|
||||
TaskTicket = _reflection.GeneratedProtocolMessageType('TaskTicket', (_message.Message,), {
|
||||
'DESCRIPTOR' : _TASKTICKET,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:TaskTicket)
|
||||
})
|
||||
_sym_db.RegisterMessage(TaskTicket)
|
||||
|
||||
PutRequest = _reflection.GeneratedProtocolMessageType('PutRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _PUTREQUEST,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:PutRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(PutRequest)
|
||||
|
||||
PutResponse = _reflection.GeneratedProtocolMessageType('PutResponse', (_message.Message,), {
|
||||
'DESCRIPTOR' : _PUTRESPONSE,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:PutResponse)
|
||||
})
|
||||
_sym_db.RegisterMessage(PutResponse)
|
||||
|
||||
GetRequest = _reflection.GeneratedProtocolMessageType('GetRequest', (_message.Message,), {
|
||||
'DESCRIPTOR' : _GETREQUEST,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:GetRequest)
|
||||
})
|
||||
_sym_db.RegisterMessage(GetRequest)
|
||||
|
||||
GetResponse = _reflection.GeneratedProtocolMessageType('GetResponse', (_message.Message,), {
|
||||
'DESCRIPTOR' : _GETRESPONSE,
|
||||
'__module__' : 'proto.task_pb2'
|
||||
# @@protoc_insertion_point(class_scope:GetResponse)
|
||||
})
|
||||
_sym_db.RegisterMessage(GetResponse)
|
||||
|
||||
|
||||
|
||||
_TASKSERVER = _descriptor.ServiceDescriptor(
|
||||
name='TaskServer',
|
||||
full_name='TaskServer',
|
||||
file=DESCRIPTOR,
|
||||
index=0,
|
||||
serialized_options=None,
|
||||
serialized_start=490,
|
||||
serialized_end=620,
|
||||
methods=[
|
||||
_descriptor.MethodDescriptor(
|
||||
name='GetObject',
|
||||
full_name='TaskServer.GetObject',
|
||||
index=0,
|
||||
containing_service=None,
|
||||
input_type=_GETREQUEST,
|
||||
output_type=_GETRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='PutObject',
|
||||
full_name='TaskServer.PutObject',
|
||||
index=1,
|
||||
containing_service=None,
|
||||
input_type=_PUTREQUEST,
|
||||
output_type=_PUTRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Schedule',
|
||||
full_name='TaskServer.Schedule',
|
||||
index=2,
|
||||
containing_service=None,
|
||||
input_type=_TASK,
|
||||
output_type=_TASKTICKET,
|
||||
serialized_options=None,
|
||||
),
|
||||
])
|
||||
_sym_db.RegisterServiceDescriptor(_TASKSERVER)
|
||||
|
||||
DESCRIPTOR.services_by_name['TaskServer'] = _TASKSERVER
|
||||
|
||||
|
||||
_WORKSERVER = _descriptor.ServiceDescriptor(
|
||||
name='WorkServer',
|
||||
full_name='WorkServer',
|
||||
file=DESCRIPTOR,
|
||||
index=1,
|
||||
serialized_options=None,
|
||||
serialized_start=623,
|
||||
serialized_end=756,
|
||||
methods=[
|
||||
_descriptor.MethodDescriptor(
|
||||
name='GetObject',
|
||||
full_name='WorkServer.GetObject',
|
||||
index=0,
|
||||
containing_service=None,
|
||||
input_type=_GETREQUEST,
|
||||
output_type=_GETRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='PutObject',
|
||||
full_name='WorkServer.PutObject',
|
||||
index=1,
|
||||
containing_service=None,
|
||||
input_type=_PUTREQUEST,
|
||||
output_type=_PUTRESPONSE,
|
||||
serialized_options=None,
|
||||
),
|
||||
_descriptor.MethodDescriptor(
|
||||
name='Execute',
|
||||
full_name='WorkServer.Execute',
|
||||
index=2,
|
||||
containing_service=None,
|
||||
input_type=_TASKTICKET,
|
||||
output_type=_TASK,
|
||||
serialized_options=None,
|
||||
),
|
||||
])
|
||||
_sym_db.RegisterServiceDescriptor(_WORKSERVER)
|
||||
|
||||
DESCRIPTOR.services_by_name['WorkServer'] = _WORKSERVER
|
||||
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
156
python/proto/task_pb2_grpc.py
Executable file
156
python/proto/task_pb2_grpc.py
Executable file
|
|
@ -0,0 +1,156 @@
|
|||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
import grpc
|
||||
|
||||
from proto import task_pb2 as proto_dot_task__pb2
|
||||
|
||||
|
||||
class TaskServerStub(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.GetObject = channel.unary_unary(
|
||||
'/TaskServer/GetObject',
|
||||
request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.GetResponse.FromString,
|
||||
)
|
||||
self.PutObject = channel.unary_unary(
|
||||
'/TaskServer/PutObject',
|
||||
request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.PutResponse.FromString,
|
||||
)
|
||||
self.Schedule = channel.unary_unary(
|
||||
'/TaskServer/Schedule',
|
||||
request_serializer=proto_dot_task__pb2.Task.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.TaskTicket.FromString,
|
||||
)
|
||||
|
||||
|
||||
class TaskServerServicer(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def GetObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PutObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Schedule(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_TaskServerServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'GetObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetObject,
|
||||
request_deserializer=proto_dot_task__pb2.GetRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString,
|
||||
),
|
||||
'PutObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.PutObject,
|
||||
request_deserializer=proto_dot_task__pb2.PutRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString,
|
||||
),
|
||||
'Schedule': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Schedule,
|
||||
request_deserializer=proto_dot_task__pb2.Task.FromString,
|
||||
response_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'TaskServer', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
class WorkServerStub(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.GetObject = channel.unary_unary(
|
||||
'/WorkServer/GetObject',
|
||||
request_serializer=proto_dot_task__pb2.GetRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.GetResponse.FromString,
|
||||
)
|
||||
self.PutObject = channel.unary_unary(
|
||||
'/WorkServer/PutObject',
|
||||
request_serializer=proto_dot_task__pb2.PutRequest.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.PutResponse.FromString,
|
||||
)
|
||||
self.Execute = channel.stream_stream(
|
||||
'/WorkServer/Execute',
|
||||
request_serializer=proto_dot_task__pb2.TaskTicket.SerializeToString,
|
||||
response_deserializer=proto_dot_task__pb2.Task.FromString,
|
||||
)
|
||||
|
||||
|
||||
class WorkServerServicer(object):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
|
||||
def GetObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def PutObject(self, request, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Execute(self, request_iterator, context):
|
||||
# missing associated documentation comment in .proto file
|
||||
pass
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_WorkServerServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'GetObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetObject,
|
||||
request_deserializer=proto_dot_task__pb2.GetRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.GetResponse.SerializeToString,
|
||||
),
|
||||
'PutObject': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.PutObject,
|
||||
request_deserializer=proto_dot_task__pb2.PutRequest.FromString,
|
||||
response_serializer=proto_dot_task__pb2.PutResponse.SerializeToString,
|
||||
),
|
||||
'Execute': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.Execute,
|
||||
request_deserializer=proto_dot_task__pb2.TaskTicket.FromString,
|
||||
response_serializer=proto_dot_task__pb2.Task.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'WorkServer', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
163
python/ray.py
Normal file
163
python/ray.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import cloudpickle
|
||||
import grpc
|
||||
import proto.task_pb2
|
||||
import proto.task_pb2_grpc
|
||||
import uuid
|
||||
|
||||
|
||||
class ObjectID:
|
||||
def __init__(self, id):
|
||||
self.id = id
|
||||
|
||||
def __repr__(self):
|
||||
return "ObjectID(%s)" % self.id.decode()
|
||||
|
||||
|
||||
worker_registry = {}
|
||||
|
||||
|
||||
def set_global_worker(worker):
|
||||
global _global_worker
|
||||
_global_worker = worker
|
||||
|
||||
|
||||
def register_worker(worker):
|
||||
id = uuid.uuid4()
|
||||
worker_registry[id] = worker
|
||||
return id
|
||||
|
||||
|
||||
def get_worker_registry(id):
|
||||
out = worker_registry.get(id)
|
||||
if out is None:
|
||||
return _global_worker
|
||||
return out
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(self, conn_str="", stub=None):
|
||||
if stub is None:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = proto.task_pb2_grpc.TaskServerStub(self.channel)
|
||||
else:
|
||||
self.server = stub
|
||||
self.uuid = register_worker(self)
|
||||
|
||||
def get(self, ids):
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(ids, list):
|
||||
to_get = [x.id for x in ids]
|
||||
elif isinstance(ids, ObjectID):
|
||||
to_get = [ids.id]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a list of IDs or just an ID")
|
||||
|
||||
out = [self._get(x) for x in to_get]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, id: bytes):
|
||||
req = proto.task_pb2.GetRequest(id=id)
|
||||
data = self.server.GetObject(req)
|
||||
return cloudpickle.loads(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
#print("val: %s\ndata: %s"%(val, data))
|
||||
req = proto.task_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req)
|
||||
return ObjectID(resp.id)
|
||||
|
||||
def remote(self, func):
|
||||
return RemoteFunc(self, func)
|
||||
|
||||
def schedule(self, task):
|
||||
return self.server.Schedule(task)
|
||||
|
||||
def close(self):
|
||||
self.channel.close()
|
||||
|
||||
|
||||
class RemoteFunc:
|
||||
def __init__(self, worker, f):
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
self._worker_id = worker.uuid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Matching the old API")
|
||||
|
||||
def remote(self, *args):
|
||||
if self.id is None:
|
||||
self._push_func()
|
||||
t = proto.task_pb2.Task()
|
||||
t.name = self._name
|
||||
t.payload_id = self.id.id
|
||||
for a in args:
|
||||
arg = proto.task_pb2.Arg()
|
||||
if isinstance(a, ObjectID):
|
||||
arg.local = proto.task_pb2.Arg.Locality.REFERENCE
|
||||
arg.reference_id = a.id
|
||||
else:
|
||||
arg.local = proto.task_pb2.Arg.Locality.INTERNED
|
||||
arg.data = cloudpickle.dumps(a)
|
||||
t.args.append(arg)
|
||||
worker = get_worker_registry(self._worker_id)
|
||||
ticket = worker.schedule(t)
|
||||
return ObjectID(ticket.return_id)
|
||||
|
||||
def _push_func(self):
|
||||
worker = get_worker_registry(self._worker_id)
|
||||
self.id = worker.put(self._func)
|
||||
|
||||
def __repr__(self):
|
||||
return "RemoteFunc(%s, %s)" % (self._name, self.id)
|
||||
|
||||
|
||||
_global_worker = None
|
||||
|
||||
|
||||
def connect(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is not None:
|
||||
raise Exception("Can't connect a second global worker")
|
||||
_global_worker = Worker(*args, **kwargs)
|
||||
|
||||
|
||||
def get(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is None:
|
||||
raise Exception("Need a connection before calling")
|
||||
return _global_worker.get(*args, **kwargs)
|
||||
|
||||
|
||||
def put(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is None:
|
||||
raise Exception("Need a connection before calling")
|
||||
return _global_worker.put(*args, **kwargs)
|
||||
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
global _global_worker
|
||||
if _global_worker is None:
|
||||
raise Exception("Need a connection before calling")
|
||||
return _global_worker.remote(*args, **kwargs)
|
||||
5
python/requirements.txt
Normal file
5
python/requirements.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
cloudpickle
|
||||
futures
|
||||
grpcio
|
||||
google
|
||||
protobuf
|
||||
Loading…
Add table
Add a link
Reference in a new issue