diff --git a/go.mod b/go.mod index c554d2e..c6c3edd 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.15 require ( github.com/gogo/protobuf v1.3.1 + github.com/google/uuid v1.1.2 google.golang.org/grpc v1.34.0 ) diff --git a/go.sum b/go.sum index 4a4ac87..0dbbc7b 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= diff --git a/main.go b/main.go index be59112..86a1fcf 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,9 @@ func main() { s := grpc.NewServer( grpc.Creds(insecure.NewCredentials()), ) - ray_rpc.RegisterRayletDriverServer(s, NewRayletServicer()) + server := NewMemRaylet() + ray_rpc.RegisterRayletDriverServer(s, server) + ray_rpc.RegisterRayletWorkerConnectionServer(s, server) // Serve gRPC Server log.Info("Serving gRPC on https://", addr) diff --git a/object.go b/object.go index 5ef6a38..7aeb45b 100644 --- a/object.go +++ b/object.go @@ -1,18 +1,20 @@ package main import ( + "encoding/binary" "errors" "sync" "time" -) -type ObjectID int64 + "github.com/google/uuid" +) var ErrObjectNotFound = errors.New("object not found") type ObjectStore interface { AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration) - MakeObject(data []byte) (ObjectID, error) + PutObject(object *Object) error + MakeID() ObjectID } type ObjectResult struct { @@ -32,7 +34,7 @@ func GetObject(s ObjectStore, id ObjectID) ([]byte, error) { go func() { s.AwaitObjects(ids, c, nil) }() - obj := <-c + obj, ok := <-c if !ok { return nil, errors.New("Couldn't get object") } @@ -44,10 +46,12 @@ func GetObject(s ObjectStore, id ObjectID) ([]byte, error) { type MemObjectStore struct { sync.RWMutex - db map[ObjectID][]byte + db map[ObjectID][]byte + prefix uint64 + printer uint64 } -func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan Object, timeout *time.Duration) { +func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration) { if timeout != nil { panic("timeout not yet implemented") } @@ -58,22 +62,47 @@ func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan Object, timeout * if !ok { c <- ObjectResult{Error: ErrObjectNotFound} } else { - c <- ObjectResult{&Object{ID: id, Data: v}} + c <- ObjectResult{&Object{ID: id, Data: v}, nil} } } close(c) } -func (mem *MemObjectStore) MakeObject(data []byte) (ObjectID, error) { +func (mem *MemObjectStore) PutObject(object *Object) error { mem.Lock() defer mem.Unlock() - id := mem.makeID() - mem.db[id] = data - return id, nil + mem.db[object.ID] = object.Data + return nil +} + +func (mem *MemObjectStore) MakeID() ObjectID { + id := mem.prefix + mem.printer + mem.prefix++ + return ObjectID(id) } func NewMemObjectStore() *MemObjectStore { + prefixUuid, err := uuid.NewRandom() + if err != nil { + panic(err) + } + prefix := uint64(prefixUuid.ID()) << 32 return &MemObjectStore{ - db: make(map[ObjectID][]byte), + db: make(map[ObjectID][]byte), + prefix: prefix, + printer: 1, } } + +type ObjectID uint64 + +func serializeObjectID(id ObjectID) []byte { + out := make([]byte, 8) + binary.BigEndian.PutUint64(out, uint64(id)) + return out +} + +func deserializeObjectID(id_bytes []byte) ObjectID { + id := binary.BigEndian.Uint64(id_bytes) + return ObjectID(id) +} diff --git a/python/rpc/mk_proto.sh b/python/rpc/mk_proto.sh new file mode 100755 index 0000000..557869b --- /dev/null +++ b/python/rpc/mk_proto.sh @@ -0,0 +1 @@ +python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto diff --git a/python/rpc/ray_client_pb2.py b/python/rpc/ray_client_pb2.py new file mode 100644 index 0000000..3f4a1d0 --- /dev/null +++ b/python/rpc/ray_client_pb2.py @@ -0,0 +1,785 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: ray_client.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='ray_client.proto', + package='ray.rpc', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x10ray_client.proto\x12\x07ray.rpc\"\x95\x01\n\x03\x41rg\x12$\n\x05local\x18\x01 \x01(\x0e\x32\x15.ray.rpc.Arg.Locality\x12\x14\n\x0creference_id\x18\x02 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\x1b\n\x04type\x18\x04 \x01(\x0e\x32\r.ray.rpc.Type\"\'\n\x08Locality\x12\x0c\n\x08INTERNED\x10\x00\x12\r\n\tREFERENCE\x10\x01\"\xc6\x01\n\nClientTask\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".ray.rpc.ClientTask.RemoteExecType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x12\n\npayload_id\x18\x03 \x01(\x0c\x12\x1a\n\x04\x61rgs\x18\x04 \x03(\x0b\x32\x0c.ray.rpc.Arg\"H\n\x0eRemoteExecType\x12\x0c\n\x08\x46UNCTION\x10\x00\x12\t\n\x05\x41\x43TOR\x10\x01\x12\n\n\x06METHOD\x10\x02\x12\x11\n\rSTATIC_METHOD\x10\x03\"%\n\x10\x43lientTaskTicket\x12\x11\n\treturn_id\x18\x01 \x01(\x0c\"\x1a\n\nPutRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x19\n\x0bPutResponse\x12\n\n\x02id\x18\x01 \x01(\x0c\"\x18\n\nGetRequest\x12\n\n\x02id\x18\x01 \x01(\x0c\"*\n\x0bGetResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"H\n\x0bWaitRequest\x12\x13\n\x0bobject_refs\x18\x01 \x03(\x0c\x12\x13\n\x0bnum_returns\x18\x02 \x01(\x03\x12\x0f\n\x07timeout\x18\x03 \x01(\x01\"U\n\x0cWaitResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x18\n\x10ready_object_ids\x18\x02 \x03(\x0c\x12\x1c\n\x14remaining_object_ids\x18\x03 \x03(\x0c\"\xcc\x01\n\nWorkStatus\x12.\n\x06status\x18\x01 \x01(\x0e\x32\x1e.ray.rpc.WorkStatus.StatusCode\x12\x15\n\rcomplete_data\x18\x02 \x01(\x0c\x12\x32\n\x0f\x66inished_ticket\x18\x03 \x01(\x0b\x32\x19.ray.rpc.ClientTaskTicket\x12\x11\n\terror_msg\x18\x04 \x01(\t\"0\n\nStatusCode\x12\x0c\n\x08\x43OMPLETE\x10\x00\x12\t\n\x05\x45RROR\x10\x01\x12\t\n\x05READY\x10\x02\"T\n\x04Work\x12!\n\x04task\x18\x01 \x01(\x0b\x32\x13.ray.rpc.ClientTask\x12)\n\x06ticket\x18\x02 \x01(\x0b\x32\x19.ray.rpc.ClientTaskTicket*\x13\n\x04Type\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x32\xfd\x01\n\x0cRayletDriver\x12\x38\n\tGetObject\x12\x13.ray.rpc.GetRequest\x1a\x14.ray.rpc.GetResponse\"\x00\x12\x38\n\tPutObject\x12\x13.ray.rpc.PutRequest\x1a\x14.ray.rpc.PutResponse\"\x00\x12;\n\nWaitObject\x12\x14.ray.rpc.WaitRequest\x1a\x15.ray.rpc.WaitResponse\"\x00\x12<\n\x08Schedule\x12\x13.ray.rpc.ClientTask\x1a\x19.ray.rpc.ClientTaskTicket\"\x00\x32P\n\x16RayletWorkerConnection\x12\x36\n\nWorkstream\x12\x13.ray.rpc.WorkStatus\x1a\r.ray.rpc.Work\"\x00(\x01\x30\x01\x62\x06proto3' +) + +_TYPE = _descriptor.EnumDescriptor( + name='Type', + full_name='ray.rpc.Type', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='DEFAULT', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=1000, + serialized_end=1019, +) +_sym_db.RegisterEnumDescriptor(_TYPE) + +Type = enum_type_wrapper.EnumTypeWrapper(_TYPE) +DEFAULT = 0 + + +_ARG_LOCALITY = _descriptor.EnumDescriptor( + name='Locality', + full_name='ray.rpc.Arg.Locality', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='INTERNED', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='REFERENCE', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=140, + serialized_end=179, +) +_sym_db.RegisterEnumDescriptor(_ARG_LOCALITY) + +_CLIENTTASK_REMOTEEXECTYPE = _descriptor.EnumDescriptor( + name='RemoteExecType', + full_name='ray.rpc.ClientTask.RemoteExecType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='FUNCTION', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='ACTOR', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='METHOD', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='STATIC_METHOD', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=308, + serialized_end=380, +) +_sym_db.RegisterEnumDescriptor(_CLIENTTASK_REMOTEEXECTYPE) + +_WORKSTATUS_STATUSCODE = _descriptor.EnumDescriptor( + name='StatusCode', + full_name='ray.rpc.WorkStatus.StatusCode', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='COMPLETE', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='ERROR', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='READY', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=864, + serialized_end=912, +) +_sym_db.RegisterEnumDescriptor(_WORKSTATUS_STATUSCODE) + + +_ARG = _descriptor.Descriptor( + name='Arg', + full_name='ray.rpc.Arg', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='local', full_name='ray.rpc.Arg.local', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='reference_id', full_name='ray.rpc.Arg.reference_id', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='data', full_name='ray.rpc.Arg.data', index=2, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='type', full_name='ray.rpc.Arg.type', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _ARG_LOCALITY, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=30, + serialized_end=179, +) + + +_CLIENTTASK = _descriptor.Descriptor( + name='ClientTask', + full_name='ray.rpc.ClientTask', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='type', full_name='ray.rpc.ClientTask.type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='name', full_name='ray.rpc.ClientTask.name', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='payload_id', full_name='ray.rpc.ClientTask.payload_id', index=2, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='args', full_name='ray.rpc.ClientTask.args', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _CLIENTTASK_REMOTEEXECTYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=182, + serialized_end=380, +) + + +_CLIENTTASKTICKET = _descriptor.Descriptor( + name='ClientTaskTicket', + full_name='ray.rpc.ClientTaskTicket', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='return_id', full_name='ray.rpc.ClientTaskTicket.return_id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=382, + serialized_end=419, +) + + +_PUTREQUEST = _descriptor.Descriptor( + name='PutRequest', + full_name='ray.rpc.PutRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='data', full_name='ray.rpc.PutRequest.data', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=421, + serialized_end=447, +) + + +_PUTRESPONSE = _descriptor.Descriptor( + name='PutResponse', + full_name='ray.rpc.PutResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='ray.rpc.PutResponse.id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=449, + serialized_end=474, +) + + +_GETREQUEST = _descriptor.Descriptor( + name='GetRequest', + full_name='ray.rpc.GetRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='id', full_name='ray.rpc.GetRequest.id', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=476, + serialized_end=500, +) + + +_GETRESPONSE = _descriptor.Descriptor( + name='GetResponse', + full_name='ray.rpc.GetResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='valid', full_name='ray.rpc.GetResponse.valid', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='data', full_name='ray.rpc.GetResponse.data', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=502, + serialized_end=544, +) + + +_WAITREQUEST = _descriptor.Descriptor( + name='WaitRequest', + full_name='ray.rpc.WaitRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='object_refs', full_name='ray.rpc.WaitRequest.object_refs', index=0, + number=1, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_returns', full_name='ray.rpc.WaitRequest.num_returns', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='timeout', full_name='ray.rpc.WaitRequest.timeout', index=2, + number=3, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=546, + serialized_end=618, +) + + +_WAITRESPONSE = _descriptor.Descriptor( + name='WaitResponse', + full_name='ray.rpc.WaitResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='valid', full_name='ray.rpc.WaitResponse.valid', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='ready_object_ids', full_name='ray.rpc.WaitResponse.ready_object_ids', index=1, + number=2, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='remaining_object_ids', full_name='ray.rpc.WaitResponse.remaining_object_ids', index=2, + number=3, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=620, + serialized_end=705, +) + + +_WORKSTATUS = _descriptor.Descriptor( + name='WorkStatus', + full_name='ray.rpc.WorkStatus', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='ray.rpc.WorkStatus.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='complete_data', full_name='ray.rpc.WorkStatus.complete_data', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='finished_ticket', full_name='ray.rpc.WorkStatus.finished_ticket', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='error_msg', full_name='ray.rpc.WorkStatus.error_msg', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _WORKSTATUS_STATUSCODE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=708, + serialized_end=912, +) + + +_WORK = _descriptor.Descriptor( + name='Work', + full_name='ray.rpc.Work', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='task', full_name='ray.rpc.Work.task', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='ticket', full_name='ray.rpc.Work.ticket', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=914, + serialized_end=998, +) + +_ARG.fields_by_name['local'].enum_type = _ARG_LOCALITY +_ARG.fields_by_name['type'].enum_type = _TYPE +_ARG_LOCALITY.containing_type = _ARG +_CLIENTTASK.fields_by_name['type'].enum_type = _CLIENTTASK_REMOTEEXECTYPE +_CLIENTTASK.fields_by_name['args'].message_type = _ARG +_CLIENTTASK_REMOTEEXECTYPE.containing_type = _CLIENTTASK +_WORKSTATUS.fields_by_name['status'].enum_type = _WORKSTATUS_STATUSCODE +_WORKSTATUS.fields_by_name['finished_ticket'].message_type = _CLIENTTASKTICKET +_WORKSTATUS_STATUSCODE.containing_type = _WORKSTATUS +_WORK.fields_by_name['task'].message_type = _CLIENTTASK +_WORK.fields_by_name['ticket'].message_type = _CLIENTTASKTICKET +DESCRIPTOR.message_types_by_name['Arg'] = _ARG +DESCRIPTOR.message_types_by_name['ClientTask'] = _CLIENTTASK +DESCRIPTOR.message_types_by_name['ClientTaskTicket'] = _CLIENTTASKTICKET +DESCRIPTOR.message_types_by_name['PutRequest'] = _PUTREQUEST +DESCRIPTOR.message_types_by_name['PutResponse'] = _PUTRESPONSE +DESCRIPTOR.message_types_by_name['GetRequest'] = _GETREQUEST +DESCRIPTOR.message_types_by_name['GetResponse'] = _GETRESPONSE +DESCRIPTOR.message_types_by_name['WaitRequest'] = _WAITREQUEST +DESCRIPTOR.message_types_by_name['WaitResponse'] = _WAITRESPONSE +DESCRIPTOR.message_types_by_name['WorkStatus'] = _WORKSTATUS +DESCRIPTOR.message_types_by_name['Work'] = _WORK +DESCRIPTOR.enum_types_by_name['Type'] = _TYPE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Arg = _reflection.GeneratedProtocolMessageType('Arg', (_message.Message,), { + 'DESCRIPTOR' : _ARG, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.Arg) + }) +_sym_db.RegisterMessage(Arg) + +ClientTask = _reflection.GeneratedProtocolMessageType('ClientTask', (_message.Message,), { + 'DESCRIPTOR' : _CLIENTTASK, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ClientTask) + }) +_sym_db.RegisterMessage(ClientTask) + +ClientTaskTicket = _reflection.GeneratedProtocolMessageType('ClientTaskTicket', (_message.Message,), { + 'DESCRIPTOR' : _CLIENTTASKTICKET, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.ClientTaskTicket) + }) +_sym_db.RegisterMessage(ClientTaskTicket) + +PutRequest = _reflection.GeneratedProtocolMessageType('PutRequest', (_message.Message,), { + 'DESCRIPTOR' : _PUTREQUEST, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.PutRequest) + }) +_sym_db.RegisterMessage(PutRequest) + +PutResponse = _reflection.GeneratedProtocolMessageType('PutResponse', (_message.Message,), { + 'DESCRIPTOR' : _PUTRESPONSE, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.PutResponse) + }) +_sym_db.RegisterMessage(PutResponse) + +GetRequest = _reflection.GeneratedProtocolMessageType('GetRequest', (_message.Message,), { + 'DESCRIPTOR' : _GETREQUEST, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetRequest) + }) +_sym_db.RegisterMessage(GetRequest) + +GetResponse = _reflection.GeneratedProtocolMessageType('GetResponse', (_message.Message,), { + 'DESCRIPTOR' : _GETRESPONSE, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.GetResponse) + }) +_sym_db.RegisterMessage(GetResponse) + +WaitRequest = _reflection.GeneratedProtocolMessageType('WaitRequest', (_message.Message,), { + 'DESCRIPTOR' : _WAITREQUEST, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.WaitRequest) + }) +_sym_db.RegisterMessage(WaitRequest) + +WaitResponse = _reflection.GeneratedProtocolMessageType('WaitResponse', (_message.Message,), { + 'DESCRIPTOR' : _WAITRESPONSE, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.WaitResponse) + }) +_sym_db.RegisterMessage(WaitResponse) + +WorkStatus = _reflection.GeneratedProtocolMessageType('WorkStatus', (_message.Message,), { + 'DESCRIPTOR' : _WORKSTATUS, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.WorkStatus) + }) +_sym_db.RegisterMessage(WorkStatus) + +Work = _reflection.GeneratedProtocolMessageType('Work', (_message.Message,), { + 'DESCRIPTOR' : _WORK, + '__module__' : 'ray_client_pb2' + # @@protoc_insertion_point(class_scope:ray.rpc.Work) + }) +_sym_db.RegisterMessage(Work) + + + +_RAYLETDRIVER = _descriptor.ServiceDescriptor( + name='RayletDriver', + full_name='ray.rpc.RayletDriver', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=1022, + serialized_end=1275, + methods=[ + _descriptor.MethodDescriptor( + name='GetObject', + full_name='ray.rpc.RayletDriver.GetObject', + index=0, + containing_service=None, + input_type=_GETREQUEST, + output_type=_GETRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='PutObject', + full_name='ray.rpc.RayletDriver.PutObject', + index=1, + containing_service=None, + input_type=_PUTREQUEST, + output_type=_PUTRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='WaitObject', + full_name='ray.rpc.RayletDriver.WaitObject', + index=2, + containing_service=None, + input_type=_WAITREQUEST, + output_type=_WAITRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='Schedule', + full_name='ray.rpc.RayletDriver.Schedule', + index=3, + containing_service=None, + input_type=_CLIENTTASK, + output_type=_CLIENTTASKTICKET, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) +_sym_db.RegisterServiceDescriptor(_RAYLETDRIVER) + +DESCRIPTOR.services_by_name['RayletDriver'] = _RAYLETDRIVER + + +_RAYLETWORKERCONNECTION = _descriptor.ServiceDescriptor( + name='RayletWorkerConnection', + full_name='ray.rpc.RayletWorkerConnection', + file=DESCRIPTOR, + index=1, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=1277, + serialized_end=1357, + methods=[ + _descriptor.MethodDescriptor( + name='Workstream', + full_name='ray.rpc.RayletWorkerConnection.Workstream', + index=0, + containing_service=None, + input_type=_WORKSTATUS, + output_type=_WORK, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) +_sym_db.RegisterServiceDescriptor(_RAYLETWORKERCONNECTION) + +DESCRIPTOR.services_by_name['RayletWorkerConnection'] = _RAYLETWORKERCONNECTION + +# @@protoc_insertion_point(module_scope) diff --git a/python/rpc/ray_client_pb2_grpc.py b/python/rpc/ray_client_pb2_grpc.py new file mode 100644 index 0000000..6b3a4b4 --- /dev/null +++ b/python/rpc/ray_client_pb2_grpc.py @@ -0,0 +1,226 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import ray_client_pb2 as ray__client__pb2 + + +class RayletDriverStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GetObject = channel.unary_unary( + '/ray.rpc.RayletDriver/GetObject', + request_serializer=ray__client__pb2.GetRequest.SerializeToString, + response_deserializer=ray__client__pb2.GetResponse.FromString, + ) + self.PutObject = channel.unary_unary( + '/ray.rpc.RayletDriver/PutObject', + request_serializer=ray__client__pb2.PutRequest.SerializeToString, + response_deserializer=ray__client__pb2.PutResponse.FromString, + ) + self.WaitObject = channel.unary_unary( + '/ray.rpc.RayletDriver/WaitObject', + request_serializer=ray__client__pb2.WaitRequest.SerializeToString, + response_deserializer=ray__client__pb2.WaitResponse.FromString, + ) + self.Schedule = channel.unary_unary( + '/ray.rpc.RayletDriver/Schedule', + request_serializer=ray__client__pb2.ClientTask.SerializeToString, + response_deserializer=ray__client__pb2.ClientTaskTicket.FromString, + ) + + +class RayletDriverServicer(object): + """Missing associated documentation comment in .proto file.""" + + def GetObject(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PutObject(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def WaitObject(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Schedule(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RayletDriverServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GetObject': grpc.unary_unary_rpc_method_handler( + servicer.GetObject, + request_deserializer=ray__client__pb2.GetRequest.FromString, + response_serializer=ray__client__pb2.GetResponse.SerializeToString, + ), + 'PutObject': grpc.unary_unary_rpc_method_handler( + servicer.PutObject, + request_deserializer=ray__client__pb2.PutRequest.FromString, + response_serializer=ray__client__pb2.PutResponse.SerializeToString, + ), + 'WaitObject': grpc.unary_unary_rpc_method_handler( + servicer.WaitObject, + request_deserializer=ray__client__pb2.WaitRequest.FromString, + response_serializer=ray__client__pb2.WaitResponse.SerializeToString, + ), + 'Schedule': grpc.unary_unary_rpc_method_handler( + servicer.Schedule, + request_deserializer=ray__client__pb2.ClientTask.FromString, + response_serializer=ray__client__pb2.ClientTaskTicket.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'ray.rpc.RayletDriver', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class RayletDriver(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def GetObject(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/GetObject', + ray__client__pb2.GetRequest.SerializeToString, + ray__client__pb2.GetResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PutObject(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/PutObject', + ray__client__pb2.PutRequest.SerializeToString, + ray__client__pb2.PutResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def WaitObject(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/WaitObject', + ray__client__pb2.WaitRequest.SerializeToString, + ray__client__pb2.WaitResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Schedule(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/Schedule', + ray__client__pb2.ClientTask.SerializeToString, + ray__client__pb2.ClientTaskTicket.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + +class RayletWorkerConnectionStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Workstream = channel.stream_stream( + '/ray.rpc.RayletWorkerConnection/Workstream', + request_serializer=ray__client__pb2.WorkStatus.SerializeToString, + response_deserializer=ray__client__pb2.Work.FromString, + ) + + +class RayletWorkerConnectionServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Workstream(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RayletWorkerConnectionServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Workstream': grpc.stream_stream_rpc_method_handler( + servicer.Workstream, + request_deserializer=ray__client__pb2.WorkStatus.FromString, + response_serializer=ray__client__pb2.Work.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'ray.rpc.RayletWorkerConnection', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class RayletWorkerConnection(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Workstream(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream(request_iterator, target, '/ray.rpc.RayletWorkerConnection/Workstream', + ray__client__pb2.WorkStatus.SerializeToString, + ray__client__pb2.Work.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/python/worker.py b/python/worker.py new file mode 100644 index 0000000..261a2cb --- /dev/null +++ b/python/worker.py @@ -0,0 +1,67 @@ +from rpc import ray_client_pb2 +from rpc import ray_client_pb2_grpc +import queue + + +class Worker: + def __init__(self, conn_str): + self.channel = grpc.insecure_channel(conn_str) + self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel) + self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + + def begin(self): + send_queue = queue.SimpleQueue() + work_stream = self.worker_stub.Workstream(iter(send_queue.get, None)) + start = ray_client_pb2.WorkStatus( + status = ray_client_pb2.WorkStatus.StatusCode.READY + ) + send_queue.put(start) + for work in work_stream: + task = work.task + if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION: + send_queue.put(ray_client_pb2.WorkStatus( + status = ray_client_pb2.WorkStatus.StatusCode.ERROR, + error_msg = "unimplemented", + )) + continue + args = self.decode_args(task) + func_data = self.get(task.payload_id) + func = cloudpickle.loads(func_data) + res = func(*args) + out_data = cloudpickle.dumps(res) + send_queue.put(ray_client_pb2.WorkStatus( + status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE, + complete_data = out_data, + finished_ticket = work.ticket, + )) + + def get(self, id_bytes): + data = self.server.GetObject(ray_client_pb2.GetRequest( + id = id_bytes, + )) + return data.data + + def decode_args(self, task): + out = [] + for arg in arg_list: + t = self.convert_from_arg(arg) + out.append(t) + return out + + def convert_from_arg(self, pb): + if pb.local == ray_client_pb2.Arg.Locality.REFERENCE: + return self.get(pb.reference_id) + elif pb.local == ray_client_pb2.Arg.Locality.INTERNED: + return cloudpickle.loads(pb.data) + raise Exception("convert_from_arg: Uncovered locality enum") + + +def main(): + worker = Worker(os.args[1]) + worker.begin() + print("Shutting down...") + worker.channel.close() + + +if __name__ == "__main__": + main() diff --git a/ray_rpc/ray_client.pb.go b/ray_rpc/ray_client.pb.go index c432ac5..5d37935 100644 --- a/ray_rpc/ray_client.pb.go +++ b/ray_rpc/ray_client.pb.go @@ -97,6 +97,30 @@ func (ClientTask_RemoteExecType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_b171423e2687a533, []int{1, 0} } +type WorkStatus_StatusCode int32 + +const ( + COMPLETE WorkStatus_StatusCode = 0 + ERROR WorkStatus_StatusCode = 1 + READY WorkStatus_StatusCode = 2 +) + +var WorkStatus_StatusCode_name = map[int32]string{ + 0: "COMPLETE", + 1: "ERROR", + 2: "READY", +} + +var WorkStatus_StatusCode_value = map[string]int32{ + "COMPLETE": 0, + "ERROR": 1, + "READY": 2, +} + +func (WorkStatus_StatusCode) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_b171423e2687a533, []int{9, 0} +} + type Arg struct { Local Arg_Locality `protobuf:"varint,1,opt,name=local,proto3,enum=ray.rpc.Arg_Locality" json:"local,omitempty"` ReferenceId []byte `protobuf:"bytes,2,opt,name=reference_id,json=referenceId,proto3" json:"reference_id,omitempty"` @@ -572,10 +596,129 @@ func (m *WaitResponse) GetRemainingObjectIds() [][]byte { return nil } +type WorkStatus struct { + Status WorkStatus_StatusCode `protobuf:"varint,1,opt,name=status,proto3,enum=ray.rpc.WorkStatus_StatusCode" json:"status,omitempty"` + CompleteData []byte `protobuf:"bytes,2,opt,name=complete_data,json=completeData,proto3" json:"complete_data,omitempty"` + FinishedTicket *ClientTaskTicket `protobuf:"bytes,3,opt,name=finished_ticket,json=finishedTicket,proto3" json:"finished_ticket,omitempty"` + ErrorMsg string `protobuf:"bytes,4,opt,name=error_msg,json=errorMsg,proto3" json:"error_msg,omitempty"` +} + +func (m *WorkStatus) Reset() { *m = WorkStatus{} } +func (*WorkStatus) ProtoMessage() {} +func (*WorkStatus) Descriptor() ([]byte, []int) { + return fileDescriptor_b171423e2687a533, []int{9} +} +func (m *WorkStatus) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *WorkStatus) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_WorkStatus.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *WorkStatus) XXX_Merge(src proto.Message) { + xxx_messageInfo_WorkStatus.Merge(m, src) +} +func (m *WorkStatus) XXX_Size() int { + return m.Size() +} +func (m *WorkStatus) XXX_DiscardUnknown() { + xxx_messageInfo_WorkStatus.DiscardUnknown(m) +} + +var xxx_messageInfo_WorkStatus proto.InternalMessageInfo + +func (m *WorkStatus) GetStatus() WorkStatus_StatusCode { + if m != nil { + return m.Status + } + return COMPLETE +} + +func (m *WorkStatus) GetCompleteData() []byte { + if m != nil { + return m.CompleteData + } + return nil +} + +func (m *WorkStatus) GetFinishedTicket() *ClientTaskTicket { + if m != nil { + return m.FinishedTicket + } + return nil +} + +func (m *WorkStatus) GetErrorMsg() string { + if m != nil { + return m.ErrorMsg + } + return "" +} + +type Work struct { + Task *ClientTask `protobuf:"bytes,1,opt,name=task,proto3" json:"task,omitempty"` + Ticket *ClientTaskTicket `protobuf:"bytes,2,opt,name=ticket,proto3" json:"ticket,omitempty"` +} + +func (m *Work) Reset() { *m = Work{} } +func (*Work) ProtoMessage() {} +func (*Work) Descriptor() ([]byte, []int) { + return fileDescriptor_b171423e2687a533, []int{10} +} +func (m *Work) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Work) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Work.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Work) XXX_Merge(src proto.Message) { + xxx_messageInfo_Work.Merge(m, src) +} +func (m *Work) XXX_Size() int { + return m.Size() +} +func (m *Work) XXX_DiscardUnknown() { + xxx_messageInfo_Work.DiscardUnknown(m) +} + +var xxx_messageInfo_Work proto.InternalMessageInfo + +func (m *Work) GetTask() *ClientTask { + if m != nil { + return m.Task + } + return nil +} + +func (m *Work) GetTicket() *ClientTaskTicket { + if m != nil { + return m.Ticket + } + return nil +} + func init() { proto.RegisterEnum("ray.rpc.Type", Type_name, Type_value) proto.RegisterEnum("ray.rpc.Arg_Locality", Arg_Locality_name, Arg_Locality_value) proto.RegisterEnum("ray.rpc.ClientTask_RemoteExecType", ClientTask_RemoteExecType_name, ClientTask_RemoteExecType_value) + proto.RegisterEnum("ray.rpc.WorkStatus_StatusCode", WorkStatus_StatusCode_name, WorkStatus_StatusCode_value) proto.RegisterType((*Arg)(nil), "ray.rpc.Arg") proto.RegisterType((*ClientTask)(nil), "ray.rpc.ClientTask") proto.RegisterType((*ClientTaskTicket)(nil), "ray.rpc.ClientTaskTicket") @@ -585,56 +728,70 @@ func init() { proto.RegisterType((*GetResponse)(nil), "ray.rpc.GetResponse") proto.RegisterType((*WaitRequest)(nil), "ray.rpc.WaitRequest") proto.RegisterType((*WaitResponse)(nil), "ray.rpc.WaitResponse") + proto.RegisterType((*WorkStatus)(nil), "ray.rpc.WorkStatus") + proto.RegisterType((*Work)(nil), "ray.rpc.Work") } func init() { proto.RegisterFile("ray_client.proto", fileDescriptor_b171423e2687a533) } var fileDescriptor_b171423e2687a533 = []byte{ - // 694 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0xc1, 0x4e, 0xdb, 0x4a, - 0x14, 0xf5, 0x24, 0x01, 0x92, 0x9b, 0x10, 0xf9, 0x0d, 0x20, 0xe5, 0xf1, 0x1e, 0x7e, 0xc1, 0x9b, - 0x17, 0xb5, 0x52, 0x5a, 0xd1, 0xaa, 0xad, 0xd4, 0x6e, 0xd2, 0xc4, 0x40, 0x24, 0x9a, 0xa0, 0xc1, - 0xa8, 0xcb, 0x68, 0xb0, 0x87, 0xd4, 0x25, 0xb1, 0xd3, 0xf1, 0x18, 0xd5, 0x3b, 0x3e, 0xa1, 0x9f, - 0xd1, 0x1f, 0xe8, 0x3f, 0x74, 0xc9, 0x92, 0x65, 0x49, 0x36, 0x5d, 0xf2, 0x03, 0x95, 0x2a, 0x8f, - 0x8d, 0x6d, 0x68, 0xd5, 0x9d, 0xe7, 0xdc, 0x73, 0xee, 0xb9, 0x67, 0xee, 0xc8, 0xa0, 0x72, 0x1a, - 0x8e, 0xac, 0x89, 0xc3, 0x5c, 0xd1, 0x9e, 0x71, 0x4f, 0x78, 0x78, 0x85, 0xd3, 0xb0, 0xcd, 0x67, - 0x96, 0xfe, 0x05, 0x41, 0xb1, 0xc3, 0xc7, 0xf8, 0x21, 0x2c, 0x4d, 0x3c, 0x8b, 0x4e, 0x1a, 0xa8, - 0x89, 0x5a, 0xf5, 0x9d, 0x8d, 0x76, 0x42, 0x68, 0x77, 0xf8, 0xb8, 0x7d, 0x10, 0x55, 0x1c, 0x11, - 0x92, 0x98, 0x83, 0xb7, 0xa1, 0xc6, 0xd9, 0x29, 0xe3, 0xcc, 0xb5, 0xd8, 0xc8, 0xb1, 0x1b, 0x85, - 0x26, 0x6a, 0xd5, 0x48, 0x35, 0xc5, 0xfa, 0x36, 0xc6, 0x50, 0xb2, 0xa9, 0xa0, 0x8d, 0xa2, 0x2c, - 0xc9, 0x6f, 0xbc, 0x0d, 0x25, 0x11, 0xce, 0x58, 0xa3, 0x24, 0x2d, 0x56, 0x53, 0x0b, 0x33, 0x9c, - 0x31, 0x22, 0x4b, 0xfa, 0xff, 0x50, 0xbe, 0x35, 0xc3, 0x35, 0x28, 0xf7, 0x07, 0xa6, 0x41, 0x06, - 0x46, 0x4f, 0x55, 0xf0, 0x2a, 0x54, 0x88, 0xb1, 0x6b, 0x10, 0x63, 0xd0, 0x35, 0x54, 0xa4, 0x2f, - 0x10, 0x40, 0x57, 0x26, 0x32, 0xa9, 0x7f, 0x86, 0x9f, 0x25, 0xad, 0xe3, 0xe9, 0xf5, 0xb4, 0x75, - 0x46, 0x69, 0x13, 0x36, 0xf5, 0x04, 0x33, 0x3e, 0x32, 0x2b, 0xf3, 0x8b, 0xc6, 0x74, 0xe9, 0x94, - 0xc9, 0x04, 0x15, 0x22, 0xbf, 0xf1, 0x16, 0xc0, 0x8c, 0x86, 0x13, 0x8f, 0xda, 0x51, 0xb6, 0x38, - 0x40, 0x25, 0x41, 0xfa, 0x36, 0x6e, 0x42, 0x89, 0xf2, 0xb1, 0xdf, 0x28, 0x35, 0x8b, 0xad, 0xea, - 0x4e, 0x2d, 0x7f, 0x51, 0x44, 0x56, 0xf4, 0x7d, 0xa8, 0xdf, 0x35, 0x8b, 0xa2, 0xec, 0x1e, 0x0f, - 0xba, 0x66, 0x7f, 0x38, 0x50, 0x15, 0x5c, 0x81, 0xa5, 0x4e, 0xd7, 0x1c, 0x12, 0x15, 0x61, 0x80, - 0xe5, 0x37, 0x86, 0xb9, 0x3f, 0xec, 0xa9, 0x05, 0xfc, 0x17, 0xac, 0x1e, 0x99, 0x1d, 0xb3, 0xdf, - 0x1d, 0x25, 0x50, 0x51, 0x7f, 0x04, 0x6a, 0x96, 0xc0, 0x74, 0xac, 0x33, 0x26, 0xf0, 0x3f, 0x50, - 0xe1, 0x4c, 0x04, 0xdc, 0x8d, 0xa6, 0x43, 0x72, 0xba, 0x72, 0x0c, 0xf4, 0x6d, 0xbd, 0x09, 0x70, - 0x18, 0x08, 0xc2, 0x3e, 0x04, 0xcc, 0x17, 0xe9, 0x12, 0x50, 0xb6, 0x04, 0x7d, 0x0b, 0xaa, 0x92, - 0xe1, 0xcf, 0x3c, 0xd7, 0x67, 0xb8, 0x0e, 0x85, 0xb4, 0x4d, 0xc1, 0xb1, 0xf5, 0x7f, 0x01, 0xf6, - 0x58, 0xda, 0xe0, 0x7e, 0xf5, 0x39, 0x54, 0x65, 0x35, 0x11, 0xaf, 0xc3, 0xd2, 0x39, 0x9d, 0x24, - 0x8c, 0x32, 0x89, 0x0f, 0xa9, 0x6b, 0x21, 0xe7, 0xea, 0x40, 0xf5, 0x2d, 0x75, 0xd2, 0xbe, 0xff, - 0x41, 0xd5, 0x3b, 0x79, 0xcf, 0x2c, 0x31, 0xe2, 0xec, 0xd4, 0x6f, 0xa0, 0x66, 0xb1, 0x55, 0x23, - 0x10, 0x43, 0x84, 0x9d, 0xfa, 0x11, 0xc1, 0x0d, 0xa6, 0xa3, 0x38, 0x97, 0x2f, 0x5b, 0x15, 0x09, - 0xb8, 0xc1, 0x94, 0xc4, 0x08, 0x6e, 0xc0, 0x8a, 0x70, 0xa6, 0xcc, 0x0b, 0x84, 0xdc, 0x10, 0x22, - 0xb7, 0x47, 0xfd, 0x02, 0x41, 0x2d, 0xf6, 0xfa, 0xe3, 0x94, 0x2d, 0x50, 0x39, 0xa3, 0x76, 0x38, - 0x4a, 0x06, 0x71, 0xec, 0xc8, 0x26, 0x9a, 0xa3, 0x2e, 0xf1, 0xa1, 0x84, 0xfb, 0xb6, 0x8f, 0x1f, - 0xc3, 0x3a, 0x67, 0x53, 0xea, 0xb8, 0x8e, 0x3b, 0xce, 0xb3, 0x8b, 0x92, 0x8d, 0xd3, 0x5a, 0xaa, - 0x78, 0xb0, 0x06, 0x25, 0xb9, 0xf6, 0x2a, 0xac, 0xf4, 0x8c, 0xdd, 0xce, 0xf1, 0x81, 0xa9, 0x2a, - 0x3b, 0x3f, 0x10, 0xd4, 0x08, 0x0d, 0x27, 0x4c, 0xf4, 0xb8, 0x73, 0xce, 0x38, 0x7e, 0x01, 0x95, - 0x3d, 0x26, 0x62, 0x15, 0x5e, 0x4b, 0xdf, 0x51, 0x76, 0xfd, 0x9b, 0xeb, 0x77, 0xc1, 0x38, 0x8f, - 0xae, 0x44, 0xca, 0xc3, 0xe0, 0x57, 0x65, 0xb6, 0xf9, 0x9c, 0x32, 0xb7, 0x6c, 0x5d, 0xc1, 0x2f, - 0x01, 0xa2, 0xbb, 0x49, 0xa4, 0x19, 0x2b, 0xb7, 0x9c, 0xcd, 0x8d, 0x7b, 0x68, 0x2a, 0x7e, 0x05, - 0xe5, 0x23, 0xeb, 0x1d, 0xb3, 0x83, 0x09, 0xcb, 0xb9, 0x66, 0x0f, 0x74, 0xf3, 0xef, 0xdf, 0x80, - 0xf1, 0xab, 0xd5, 0x95, 0xd7, 0x4f, 0x2f, 0xaf, 0x35, 0xe5, 0xea, 0x5a, 0x53, 0x6e, 0xae, 0x35, - 0x74, 0x31, 0xd7, 0xd0, 0xe7, 0xb9, 0x86, 0xbe, 0xce, 0x35, 0x74, 0x39, 0xd7, 0xd0, 0xb7, 0xb9, - 0x86, 0xbe, 0xcf, 0x35, 0xe5, 0x66, 0xae, 0xa1, 0x4f, 0x0b, 0x4d, 0xb9, 0x5c, 0x68, 0xca, 0xd5, - 0x42, 0x53, 0x4e, 0x96, 0xe5, 0xff, 0xea, 0xc9, 0xcf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xd7, 0xa9, - 0x9f, 0x54, 0xc3, 0x04, 0x00, 0x00, + // 887 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xdb, 0x46, + 0x10, 0xe6, 0x4a, 0xb2, 0x2d, 0x8d, 0x7e, 0xca, 0x6e, 0x9c, 0x42, 0x75, 0x1b, 0x56, 0x61, 0x0f, + 0x11, 0x5a, 0x40, 0x75, 0xd5, 0xc2, 0x2d, 0xd0, 0x5e, 0x14, 0x89, 0x4e, 0x04, 0xd8, 0x92, 0xb1, + 0x66, 0x10, 0xf4, 0x24, 0xac, 0xc9, 0xb1, 0xc2, 0x5a, 0x24, 0xd5, 0xe5, 0x32, 0xa8, 0x6e, 0x79, + 0x84, 0x3e, 0x46, 0x5f, 0xa0, 0xef, 0xd0, 0xa3, 0x8f, 0x39, 0xd6, 0xf2, 0xa5, 0xc7, 0x3c, 0x40, + 0x0b, 0x14, 0x5c, 0xd2, 0xa4, 0x9c, 0xba, 0xcd, 0x49, 0xcb, 0x6f, 0xbe, 0x99, 0xf9, 0x66, 0xf7, + 0x1b, 0x08, 0x74, 0xc1, 0x57, 0x33, 0x67, 0xe1, 0x61, 0x20, 0x7b, 0x4b, 0x11, 0xca, 0x90, 0xee, + 0x08, 0xbe, 0xea, 0x89, 0xa5, 0x63, 0xfe, 0x46, 0xa0, 0x3c, 0x10, 0x73, 0xfa, 0x39, 0x6c, 0x2d, + 0x42, 0x87, 0x2f, 0xda, 0xa4, 0x43, 0xba, 0xad, 0xfe, 0xfd, 0x5e, 0x46, 0xe8, 0x0d, 0xc4, 0xbc, + 0x77, 0x94, 0x44, 0x3c, 0xb9, 0x62, 0x29, 0x87, 0x3e, 0x84, 0x86, 0xc0, 0x73, 0x14, 0x18, 0x38, + 0x38, 0xf3, 0xdc, 0x76, 0xa9, 0x43, 0xba, 0x0d, 0x56, 0xcf, 0xb1, 0xb1, 0x4b, 0x29, 0x54, 0x5c, + 0x2e, 0x79, 0xbb, 0xac, 0x42, 0xea, 0x4c, 0x1f, 0x42, 0x45, 0xae, 0x96, 0xd8, 0xae, 0xa8, 0x16, + 0xcd, 0xbc, 0x85, 0xbd, 0x5a, 0x22, 0x53, 0x21, 0xf3, 0x11, 0x54, 0x6f, 0x9a, 0xd1, 0x06, 0x54, + 0xc7, 0x13, 0xdb, 0x62, 0x13, 0x6b, 0xa4, 0x6b, 0xb4, 0x09, 0x35, 0x66, 0x1d, 0x5a, 0xcc, 0x9a, + 0x0c, 0x2d, 0x9d, 0x98, 0xd7, 0x04, 0x60, 0xa8, 0x26, 0xb2, 0x79, 0x74, 0x41, 0x0f, 0xb2, 0xd2, + 0xa9, 0x7a, 0x33, 0x2f, 0x5d, 0x50, 0x7a, 0x0c, 0xfd, 0x50, 0xa2, 0xf5, 0x33, 0x3a, 0x45, 0xbf, + 0x44, 0x66, 0xc0, 0x7d, 0x54, 0x13, 0xd4, 0x98, 0x3a, 0xd3, 0x07, 0x00, 0x4b, 0xbe, 0x5a, 0x84, + 0xdc, 0x4d, 0x66, 0x4b, 0x07, 0xa8, 0x65, 0xc8, 0xd8, 0xa5, 0x1d, 0xa8, 0x70, 0x31, 0x8f, 0xda, + 0x95, 0x4e, 0xb9, 0x5b, 0xef, 0x37, 0x36, 0x2f, 0x8a, 0xa9, 0x88, 0xf9, 0x14, 0x5a, 0xb7, 0x9b, + 0x25, 0xa3, 0x1c, 0x3e, 0x9b, 0x0c, 0xed, 0xf1, 0x74, 0xa2, 0x6b, 0xb4, 0x06, 0x5b, 0x83, 0xa1, + 0x3d, 0x65, 0x3a, 0xa1, 0x00, 0xdb, 0xc7, 0x96, 0xfd, 0x74, 0x3a, 0xd2, 0x4b, 0xf4, 0x7d, 0x68, + 0x9e, 0xda, 0x03, 0x7b, 0x3c, 0x9c, 0x65, 0x50, 0xd9, 0xfc, 0x02, 0xf4, 0x62, 0x02, 0xdb, 0x73, + 0x2e, 0x50, 0xd2, 0x8f, 0xa0, 0x26, 0x50, 0xc6, 0x22, 0x48, 0xd4, 0x11, 0xa5, 0xae, 0x9a, 0x02, + 0x63, 0xd7, 0xec, 0x00, 0x9c, 0xc4, 0x92, 0xe1, 0x4f, 0x31, 0x46, 0x32, 0x7f, 0x04, 0x52, 0x3c, + 0x82, 0xf9, 0x00, 0xea, 0x8a, 0x11, 0x2d, 0xc3, 0x20, 0x42, 0xda, 0x82, 0x52, 0x5e, 0xa6, 0xe4, + 0xb9, 0xe6, 0xc7, 0x00, 0x4f, 0x30, 0x2f, 0xf0, 0x76, 0xf4, 0x1b, 0xa8, 0xab, 0x68, 0x96, 0xbc, + 0x0b, 0x5b, 0x2f, 0xf9, 0x22, 0x63, 0x54, 0x59, 0xfa, 0x91, 0x77, 0x2d, 0x6d, 0x74, 0xf5, 0xa0, + 0xfe, 0x9c, 0x7b, 0x79, 0xdd, 0x4f, 0xa0, 0x1e, 0x9e, 0xfd, 0x88, 0x8e, 0x9c, 0x09, 0x3c, 0x8f, + 0xda, 0xa4, 0x53, 0xee, 0x36, 0x18, 0xa4, 0x10, 0xc3, 0xf3, 0x28, 0x21, 0x04, 0xb1, 0x3f, 0x4b, + 0xe7, 0x8a, 0x54, 0xa9, 0x32, 0x83, 0x20, 0xf6, 0x59, 0x8a, 0xd0, 0x36, 0xec, 0x48, 0xcf, 0xc7, + 0x30, 0x96, 0xea, 0x85, 0x08, 0xbb, 0xf9, 0x34, 0x5f, 0x11, 0x68, 0xa4, 0xbd, 0xfe, 0x57, 0x65, + 0x17, 0x74, 0x81, 0xdc, 0x5d, 0xcd, 0x32, 0x21, 0x9e, 0x9b, 0xb4, 0x49, 0x74, 0xb4, 0x14, 0x3e, + 0x55, 0xf0, 0xd8, 0x8d, 0xe8, 0x3e, 0xec, 0x0a, 0xf4, 0xb9, 0x17, 0x78, 0xc1, 0x7c, 0x93, 0x5d, + 0x56, 0x6c, 0x9a, 0xc7, 0xf2, 0x0c, 0xf3, 0x2f, 0x02, 0xf0, 0x3c, 0x14, 0x17, 0xa7, 0x92, 0xcb, + 0x38, 0xa2, 0x07, 0xb0, 0x1d, 0xa9, 0x53, 0x66, 0x4f, 0x23, 0xf7, 0x4c, 0x41, 0xea, 0xa5, 0x3f, + 0xc3, 0xd0, 0x45, 0x96, 0xb1, 0xe9, 0xa7, 0xd0, 0x74, 0x42, 0x7f, 0xb9, 0x40, 0x89, 0xb3, 0x8d, + 0x1b, 0x6d, 0xdc, 0x80, 0xa3, 0x64, 0xa9, 0x1e, 0xc3, 0x7b, 0xe7, 0x5e, 0xe0, 0x45, 0x2f, 0xd0, + 0x9d, 0x49, 0xe5, 0x10, 0x75, 0x21, 0xf5, 0xfe, 0x87, 0x77, 0x2c, 0x41, 0x6a, 0x21, 0xd6, 0xba, + 0xc9, 0x28, 0x2c, 0x85, 0x42, 0x84, 0x62, 0xe6, 0x47, 0x73, 0xb5, 0x9d, 0x35, 0x56, 0x55, 0xc0, + 0x71, 0x34, 0x37, 0xf7, 0x01, 0x0a, 0x6d, 0x89, 0x93, 0x87, 0xd3, 0xe3, 0x93, 0x23, 0xcb, 0xb6, + 0x52, 0x27, 0x5b, 0x8c, 0x29, 0x27, 0xd7, 0x60, 0x8b, 0x59, 0x83, 0xd1, 0x0f, 0x7a, 0xc9, 0x3c, + 0x83, 0x4a, 0x32, 0x18, 0x7d, 0x04, 0x15, 0xc9, 0xa3, 0x0b, 0x35, 0x75, 0xbd, 0x7f, 0xef, 0x0e, + 0x3d, 0x4c, 0x11, 0xe8, 0x97, 0xb0, 0x9d, 0x49, 0x2f, 0xbd, 0x4b, 0x7a, 0x46, 0xfc, 0xec, 0x1e, + 0x54, 0xd4, 0x66, 0xd5, 0x61, 0x67, 0x64, 0x1d, 0x0e, 0x9e, 0x1d, 0xd9, 0xba, 0xd6, 0xff, 0x9b, + 0x40, 0x83, 0xf1, 0xd5, 0x02, 0xe5, 0x48, 0x78, 0x2f, 0x51, 0xd0, 0x6f, 0xa1, 0xf6, 0x04, 0x65, + 0xfa, 0x30, 0xb4, 0x10, 0x50, 0x38, 0x7c, 0x6f, 0xf7, 0x36, 0x98, 0x5a, 0xc6, 0xd4, 0x92, 0xcc, + 0x93, 0xf8, 0xdf, 0x99, 0xc5, 0x72, 0x6d, 0x64, 0x6e, 0xec, 0x93, 0xa9, 0xd1, 0xef, 0x00, 0x12, + 0xfb, 0x65, 0xa9, 0x05, 0x6b, 0xc3, 0xff, 0x7b, 0xf7, 0xdf, 0x42, 0xf3, 0xe4, 0xef, 0xa1, 0x7a, + 0xea, 0xbc, 0x40, 0x37, 0x5e, 0x20, 0xbd, 0xeb, 0xc2, 0xf6, 0xfe, 0xfb, 0x6a, 0x4c, 0xad, 0x7f, + 0x02, 0x1f, 0xa4, 0xe3, 0x27, 0xd7, 0x8f, 0x62, 0x18, 0x06, 0x01, 0x3a, 0xd2, 0x0b, 0x03, 0x7a, + 0x90, 0x1a, 0x32, 0x92, 0x02, 0xb9, 0xbf, 0x51, 0xb9, 0x30, 0xe0, 0x5e, 0xf3, 0x16, 0x68, 0x6a, + 0x5d, 0xb2, 0x4f, 0x1e, 0x7f, 0x7d, 0x79, 0x65, 0x68, 0xaf, 0xaf, 0x0c, 0xed, 0xcd, 0x95, 0x41, + 0x5e, 0xad, 0x0d, 0xf2, 0xeb, 0xda, 0x20, 0xbf, 0xaf, 0x0d, 0x72, 0xb9, 0x36, 0xc8, 0x1f, 0x6b, + 0x83, 0xfc, 0xb9, 0x36, 0xb4, 0x37, 0x6b, 0x83, 0xfc, 0x72, 0x6d, 0x68, 0x97, 0xd7, 0x86, 0xf6, + 0xfa, 0xda, 0xd0, 0xce, 0xb6, 0xd5, 0x9f, 0xcc, 0x57, 0xff, 0x04, 0x00, 0x00, 0xff, 0xff, 0x3d, + 0xce, 0xb0, 0x31, 0x78, 0x06, 0x00, 0x00, } func (x Type) String() string { @@ -658,6 +815,13 @@ func (x ClientTask_RemoteExecType) String() string { } return strconv.Itoa(int(x)) } +func (x WorkStatus_StatusCode) String() string { + s, ok := WorkStatus_StatusCode_name[int32(x)] + if ok { + return s + } + return strconv.Itoa(int(x)) +} func (this *Arg) Equal(that interface{}) bool { if that == nil { return this == nil @@ -927,6 +1091,66 @@ func (this *WaitResponse) Equal(that interface{}) bool { } return true } +func (this *WorkStatus) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*WorkStatus) + if !ok { + that2, ok := that.(WorkStatus) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.Status != that1.Status { + return false + } + if !bytes.Equal(this.CompleteData, that1.CompleteData) { + return false + } + if !this.FinishedTicket.Equal(that1.FinishedTicket) { + return false + } + if this.ErrorMsg != that1.ErrorMsg { + return false + } + return true +} +func (this *Work) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*Work) + if !ok { + that2, ok := that.(Work) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if !this.Task.Equal(that1.Task) { + return false + } + if !this.Ticket.Equal(that1.Ticket) { + return false + } + return true +} func (this *Arg) GoString() string { if this == nil { return "nil" @@ -1030,6 +1254,36 @@ func (this *WaitResponse) GoString() string { s = append(s, "}") return strings.Join(s, "") } +func (this *WorkStatus) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 8) + s = append(s, "&ray_rpc.WorkStatus{") + s = append(s, "Status: "+fmt.Sprintf("%#v", this.Status)+",\n") + s = append(s, "CompleteData: "+fmt.Sprintf("%#v", this.CompleteData)+",\n") + if this.FinishedTicket != nil { + s = append(s, "FinishedTicket: "+fmt.Sprintf("%#v", this.FinishedTicket)+",\n") + } + s = append(s, "ErrorMsg: "+fmt.Sprintf("%#v", this.ErrorMsg)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func (this *Work) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&ray_rpc.Work{") + if this.Task != nil { + s = append(s, "Task: "+fmt.Sprintf("%#v", this.Task)+",\n") + } + if this.Ticket != nil { + s = append(s, "Ticket: "+fmt.Sprintf("%#v", this.Ticket)+",\n") + } + s = append(s, "}") + return strings.Join(s, "") +} func valueToGoStringRayClient(v interface{}, typ string) string { rv := reflect.ValueOf(v) if rv.IsNil() { @@ -1227,6 +1481,110 @@ var _RayletDriver_serviceDesc = grpc.ServiceDesc{ Metadata: "ray_client.proto", } +// RayletWorkerConnectionClient is the client API for RayletWorkerConnection service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type RayletWorkerConnectionClient interface { + Workstream(ctx context.Context, opts ...grpc.CallOption) (RayletWorkerConnection_WorkstreamClient, error) +} + +type rayletWorkerConnectionClient struct { + cc *grpc.ClientConn +} + +func NewRayletWorkerConnectionClient(cc *grpc.ClientConn) RayletWorkerConnectionClient { + return &rayletWorkerConnectionClient{cc} +} + +func (c *rayletWorkerConnectionClient) Workstream(ctx context.Context, opts ...grpc.CallOption) (RayletWorkerConnection_WorkstreamClient, error) { + stream, err := c.cc.NewStream(ctx, &_RayletWorkerConnection_serviceDesc.Streams[0], "/ray.rpc.RayletWorkerConnection/Workstream", opts...) + if err != nil { + return nil, err + } + x := &rayletWorkerConnectionWorkstreamClient{stream} + return x, nil +} + +type RayletWorkerConnection_WorkstreamClient interface { + Send(*WorkStatus) error + Recv() (*Work, error) + grpc.ClientStream +} + +type rayletWorkerConnectionWorkstreamClient struct { + grpc.ClientStream +} + +func (x *rayletWorkerConnectionWorkstreamClient) Send(m *WorkStatus) error { + return x.ClientStream.SendMsg(m) +} + +func (x *rayletWorkerConnectionWorkstreamClient) Recv() (*Work, error) { + m := new(Work) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// RayletWorkerConnectionServer is the server API for RayletWorkerConnection service. +type RayletWorkerConnectionServer interface { + Workstream(RayletWorkerConnection_WorkstreamServer) error +} + +// UnimplementedRayletWorkerConnectionServer can be embedded to have forward compatible implementations. +type UnimplementedRayletWorkerConnectionServer struct { +} + +func (*UnimplementedRayletWorkerConnectionServer) Workstream(srv RayletWorkerConnection_WorkstreamServer) error { + return status.Errorf(codes.Unimplemented, "method Workstream not implemented") +} + +func RegisterRayletWorkerConnectionServer(s *grpc.Server, srv RayletWorkerConnectionServer) { + s.RegisterService(&_RayletWorkerConnection_serviceDesc, srv) +} + +func _RayletWorkerConnection_Workstream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(RayletWorkerConnectionServer).Workstream(&rayletWorkerConnectionWorkstreamServer{stream}) +} + +type RayletWorkerConnection_WorkstreamServer interface { + Send(*Work) error + Recv() (*WorkStatus, error) + grpc.ServerStream +} + +type rayletWorkerConnectionWorkstreamServer struct { + grpc.ServerStream +} + +func (x *rayletWorkerConnectionWorkstreamServer) Send(m *Work) error { + return x.ServerStream.SendMsg(m) +} + +func (x *rayletWorkerConnectionWorkstreamServer) Recv() (*WorkStatus, error) { + m := new(WorkStatus) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _RayletWorkerConnection_serviceDesc = grpc.ServiceDesc{ + ServiceName: "ray.rpc.RayletWorkerConnection", + HandlerType: (*RayletWorkerConnectionServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Workstream", + Handler: _RayletWorkerConnection_Workstream_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "ray_client.proto", +} + func (m *Arg) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -1584,6 +1942,107 @@ func (m *WaitResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *WorkStatus) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *WorkStatus) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *WorkStatus) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.ErrorMsg) > 0 { + i -= len(m.ErrorMsg) + copy(dAtA[i:], m.ErrorMsg) + i = encodeVarintRayClient(dAtA, i, uint64(len(m.ErrorMsg))) + i-- + dAtA[i] = 0x22 + } + if m.FinishedTicket != nil { + { + size, err := m.FinishedTicket.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintRayClient(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x1a + } + if len(m.CompleteData) > 0 { + i -= len(m.CompleteData) + copy(dAtA[i:], m.CompleteData) + i = encodeVarintRayClient(dAtA, i, uint64(len(m.CompleteData))) + i-- + dAtA[i] = 0x12 + } + if m.Status != 0 { + i = encodeVarintRayClient(dAtA, i, uint64(m.Status)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *Work) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Work) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Work) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Ticket != nil { + { + size, err := m.Ticket.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintRayClient(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + if m.Task != nil { + { + size, err := m.Task.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintRayClient(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + func encodeVarintRayClient(dAtA []byte, offset int, v uint64) int { offset -= sovRayClient(v) base := offset @@ -1757,6 +2216,47 @@ func (m *WaitResponse) Size() (n int) { return n } +func (m *WorkStatus) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Status != 0 { + n += 1 + sovRayClient(uint64(m.Status)) + } + l = len(m.CompleteData) + if l > 0 { + n += 1 + l + sovRayClient(uint64(l)) + } + if m.FinishedTicket != nil { + l = m.FinishedTicket.Size() + n += 1 + l + sovRayClient(uint64(l)) + } + l = len(m.ErrorMsg) + if l > 0 { + n += 1 + l + sovRayClient(uint64(l)) + } + return n +} + +func (m *Work) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Task != nil { + l = m.Task.Size() + n += 1 + l + sovRayClient(uint64(l)) + } + if m.Ticket != nil { + l = m.Ticket.Size() + n += 1 + l + sovRayClient(uint64(l)) + } + return n +} + func sovRayClient(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -1869,6 +2369,30 @@ func (this *WaitResponse) String() string { }, "") return s } +func (this *WorkStatus) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&WorkStatus{`, + `Status:` + fmt.Sprintf("%v", this.Status) + `,`, + `CompleteData:` + fmt.Sprintf("%v", this.CompleteData) + `,`, + `FinishedTicket:` + strings.Replace(this.FinishedTicket.String(), "ClientTaskTicket", "ClientTaskTicket", 1) + `,`, + `ErrorMsg:` + fmt.Sprintf("%v", this.ErrorMsg) + `,`, + `}`, + }, "") + return s +} +func (this *Work) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&Work{`, + `Task:` + strings.Replace(this.Task.String(), "ClientTask", "ClientTask", 1) + `,`, + `Ticket:` + strings.Replace(this.Ticket.String(), "ClientTaskTicket", "ClientTaskTicket", 1) + `,`, + `}`, + }, "") + return s +} func valueToStringRayClient(v interface{}) string { rv := reflect.ValueOf(v) if rv.IsNil() { @@ -2915,6 +3439,305 @@ func (m *WaitResponse) Unmarshal(dAtA []byte) error { } return nil } +func (m *WorkStatus) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: WorkStatus: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: WorkStatus: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Status", wireType) + } + m.Status = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Status |= WorkStatus_StatusCode(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CompleteData", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRayClient + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRayClient + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.CompleteData = append(m.CompleteData[:0], dAtA[iNdEx:postIndex]...) + if m.CompleteData == nil { + m.CompleteData = []byte{} + } + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field FinishedTicket", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRayClient + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRayClient + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.FinishedTicket == nil { + m.FinishedTicket = &ClientTaskTicket{} + } + if err := m.FinishedTicket.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ErrorMsg", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthRayClient + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthRayClient + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ErrorMsg = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipRayClient(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthRayClient + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthRayClient + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Work) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Work: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Work: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Task", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRayClient + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRayClient + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Task == nil { + m.Task = &ClientTask{} + } + if err := m.Task.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Ticket", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRayClient + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRayClient + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRayClient + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Ticket == nil { + m.Ticket = &ClientTaskTicket{} + } + if err := m.Ticket.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipRayClient(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthRayClient + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthRayClient + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipRayClient(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/ray_rpc/ray_client.proto b/ray_rpc/ray_client.proto index fd8fe53..384ffe1 100644 --- a/ray_rpc/ray_client.proto +++ b/ray_rpc/ray_client.proto @@ -84,3 +84,24 @@ service RayletDriver { rpc Schedule(ClientTask) returns (ClientTaskTicket) { } } + +service RayletWorkerConnection { + rpc Workstream(stream WorkStatus) returns (stream Work) {} +} + +message WorkStatus { + enum StatusCode { + COMPLETE = 0; + ERROR = 1; + READY = 2; + } + StatusCode status = 1; + bytes complete_data = 2; + ClientTaskTicket finished_ticket = 3; + string error_msg = 4; +} + +message Work { + ClientTask task = 1; + ClientTaskTicket ticket = 2; +} diff --git a/raylet_grpc.go b/raylet_grpc.go new file mode 100644 index 0000000..e6bd119 --- /dev/null +++ b/raylet_grpc.go @@ -0,0 +1,65 @@ +package main + +import ( + "context" + + "github.com/barakmich/go_raylet/ray_rpc" +) + +type Raylet struct { + Objects ObjectStore + Workers WorkerPool +} + +func (r *Raylet) GetObject(_ context.Context, req *ray_rpc.GetRequest) (*ray_rpc.GetResponse, error) { + data, err := GetObject(r.Objects, deserializeObjectID(req.Id)) + if err != nil { + return nil, err + } + return &ray_rpc.GetResponse{ + Valid: true, + Data: data, + }, nil +} + +func (r *Raylet) PutObject(_ context.Context, req *ray_rpc.PutRequest) (*ray_rpc.PutResponse, error) { + id := r.Objects.MakeID() + err := r.Objects.PutObject(&Object{id, req.Data}) + if err != nil { + return nil, err + } + return &ray_rpc.PutResponse{ + Id: serializeObjectID(id), + }, nil +} + +func (r *Raylet) WaitObject(_ context.Context, _ *ray_rpc.WaitRequest) (*ray_rpc.WaitResponse, error) { + panic("not implemented") // TODO: Implement +} + +func (r *Raylet) Schedule(_ context.Context, task *ray_rpc.ClientTask) (*ray_rpc.ClientTaskTicket, error) { + id := r.Objects.MakeID() + ticket := &ray_rpc.ClientTaskTicket{serializeObjectID(id)} + work := &ray_rpc.Work{} + work.Task = task + work.Ticket = ticket + err := r.Workers.Schedule(work) + if err != nil { + return nil, err + } + return ticket, nil +} + +func (r *Raylet) Workstream(conn WorkstreamConnection) error { + return r.Workers.Workstream(conn) +} + +//func (r *Raylet) Workstream() + +func NewMemRaylet() *Raylet { + store := NewMemObjectStore() + return &Raylet{ + Objects: store, + Workers: NewRoundRobinWorkerPool(store), + } +} diff --git a/servicer.go b/servicer.go deleted file mode 100644 index bab6114..0000000 --- a/servicer.go +++ /dev/null @@ -1,33 +0,0 @@ -package main - -import ( - "context" - - "github.com/barakmich/go_raylet/ray_rpc" -) - -type Raylet struct { - Objects ObjectStore -} - -func (r *Raylet) GetObject(_ context.Context, req *ray_rpc.GetRequest) (*ray_rpc.GetResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (r *Raylet) PutObject(_ context.Context, _ *ray_rpc.PutRequest) (*ray_rpc.PutResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (r *Raylet) WaitObject(_ context.Context, _ *ray_rpc.WaitRequest) (*ray_rpc.WaitResponse, error) { - panic("not implemented") // TODO: Implement -} - -func (r *Raylet) Schedule(_ context.Context, _ *ray_rpc.ClientTask) (*ray_rpc.ClientTaskTicket, error) { - panic("not implemented") // TODO: Implement -} - -func NewMemRaylet() *Raylet { - return &Raylet{ - Objects: NewMemObjectStore(), - } -} diff --git a/worker_pool.go b/worker_pool.go new file mode 100644 index 0000000..653a4e6 --- /dev/null +++ b/worker_pool.go @@ -0,0 +1,134 @@ +package main + +import ( + "errors" + "fmt" + "sync" + + "github.com/barakmich/go_raylet/ray_rpc" +) + +type WorkstreamConnection = ray_rpc.RayletWorkerConnection_WorkstreamServer + +type WorkerPool interface { + ray_rpc.RayletWorkerConnectionServer + Schedule(*ray_rpc.Work) error + Close() error + Finish(*ray_rpc.WorkStatus) error + Deregister(interface{}) error +} + +type SimpleRRWorkerPool struct { + sync.Mutex + workers []*SimpleWorker + store ObjectStore + offset int +} + +type SimpleWorker struct { + workChan chan *ray_rpc.Work + clientConn WorkstreamConnection + pool WorkerPool +} + +func NewRoundRobinWorkerPool(obj ObjectStore) *SimpleRRWorkerPool { + return &SimpleRRWorkerPool{ + store: obj, + } +} + +func (wp *SimpleRRWorkerPool) Workstream(conn WorkstreamConnection) error { + wp.Lock() + defer wp.Unlock() + worker := &SimpleWorker{ + workChan: make(chan *ray_rpc.Work, 10), + clientConn: conn, + pool: wp, + } + go worker.Main() + wp.workers = append(wp.workers, worker) + return nil +} + +func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error { + wp.Lock() + defer wp.Unlock() + if len(wp.workers) == 0 { + return errors.New("No workers available, try again later") + } + wp.workers[wp.offset].workChan <- work + wp.offset++ + if wp.offset == len(wp.workers) { + wp.offset = 0 + } + return nil +} + +func (wp *SimpleRRWorkerPool) Finish(status *ray_rpc.WorkStatus) error { + if status.Status != ray_rpc.COMPLETE { + panic("todo: Only call Finish on successfully completed work") + } + id := deserializeObjectID(status.FinishedTicket.ReturnId) + return wp.store.PutObject(&Object{id, status.CompleteData}) +} + +func (wp *SimpleRRWorkerPool) Close() error { + wp.Lock() + defer wp.Unlock() + for _, w := range wp.workers { + close(w.workChan) + } + return nil +} + +func (wp *SimpleRRWorkerPool) Deregister(ptr interface{}) error { + wp.Lock() + defer wp.Unlock() + worker := ptr.(*SimpleWorker) + found := false + for i, w := range wp.workers { + if w == worker { + wp.workers = append(wp.workers[:i], wp.workers[i+1:]...) + if wp.offset == len(wp.workers) { + wp.offset = 0 + } + found = true + } + } + if !found { + panic("Trying to deregister a worker that was never created") + } + return nil +} + +func (w *SimpleWorker) Main() { + sentinel, err := w.clientConn.Recv() + if err != nil { + fmt.Println(err) + w.pool.Deregister(w) + return + } + if sentinel.Status != ray_rpc.READY { + fmt.Println("Sent wrong sentinel? Closing...") + w.pool.Deregister(w) + return + } + for { + work, ok := <-w.workChan + if !ok { + break + } + err = w.clientConn.Send(work) + if err != nil { + fmt.Println("Error sending: %s", err) + break + } + result, err := w.clientConn.Recv() + err = w.pool.Finish(result) + if err != nil { + fmt.Println("Error finishing: %s", err) + break + } + } + w.pool.Deregister(w) +}