implement through workers
This commit is contained in:
parent
39385bc8b2
commit
46524832de
13 changed files with 2213 additions and 91 deletions
1
go.mod
1
go.mod
|
|
@ -4,5 +4,6 @@ go 1.15
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gogo/protobuf v1.3.1
|
github.com/gogo/protobuf v1.3.1
|
||||||
|
github.com/google/uuid v1.1.2
|
||||||
google.golang.org/grpc v1.34.0
|
google.golang.org/grpc v1.34.0
|
||||||
)
|
)
|
||||||
|
|
|
||||||
1
go.sum
1
go.sum
|
|
@ -27,6 +27,7 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
|
||||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
|
|
||||||
4
main.go
4
main.go
|
|
@ -35,7 +35,9 @@ func main() {
|
||||||
s := grpc.NewServer(
|
s := grpc.NewServer(
|
||||||
grpc.Creds(insecure.NewCredentials()),
|
grpc.Creds(insecure.NewCredentials()),
|
||||||
)
|
)
|
||||||
ray_rpc.RegisterRayletDriverServer(s, NewRayletServicer())
|
server := NewMemRaylet()
|
||||||
|
ray_rpc.RegisterRayletDriverServer(s, server)
|
||||||
|
ray_rpc.RegisterRayletWorkerConnectionServer(s, server)
|
||||||
|
|
||||||
// Serve gRPC Server
|
// Serve gRPC Server
|
||||||
log.Info("Serving gRPC on https://", addr)
|
log.Info("Serving gRPC on https://", addr)
|
||||||
|
|
|
||||||
53
object.go
53
object.go
|
|
@ -1,18 +1,20 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
|
||||||
|
|
||||||
type ObjectID int64
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
var ErrObjectNotFound = errors.New("object not found")
|
var ErrObjectNotFound = errors.New("object not found")
|
||||||
|
|
||||||
type ObjectStore interface {
|
type ObjectStore interface {
|
||||||
AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration)
|
AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration)
|
||||||
MakeObject(data []byte) (ObjectID, error)
|
PutObject(object *Object) error
|
||||||
|
MakeID() ObjectID
|
||||||
}
|
}
|
||||||
|
|
||||||
type ObjectResult struct {
|
type ObjectResult struct {
|
||||||
|
|
@ -32,7 +34,7 @@ func GetObject(s ObjectStore, id ObjectID) ([]byte, error) {
|
||||||
go func() {
|
go func() {
|
||||||
s.AwaitObjects(ids, c, nil)
|
s.AwaitObjects(ids, c, nil)
|
||||||
}()
|
}()
|
||||||
obj := <-c
|
obj, ok := <-c
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Couldn't get object")
|
return nil, errors.New("Couldn't get object")
|
||||||
}
|
}
|
||||||
|
|
@ -44,10 +46,12 @@ func GetObject(s ObjectStore, id ObjectID) ([]byte, error) {
|
||||||
|
|
||||||
type MemObjectStore struct {
|
type MemObjectStore struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
db map[ObjectID][]byte
|
db map[ObjectID][]byte
|
||||||
|
prefix uint64
|
||||||
|
printer uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan Object, timeout *time.Duration) {
|
func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration) {
|
||||||
if timeout != nil {
|
if timeout != nil {
|
||||||
panic("timeout not yet implemented")
|
panic("timeout not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
@ -58,22 +62,47 @@ func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan Object, timeout *
|
||||||
if !ok {
|
if !ok {
|
||||||
c <- ObjectResult{Error: ErrObjectNotFound}
|
c <- ObjectResult{Error: ErrObjectNotFound}
|
||||||
} else {
|
} else {
|
||||||
c <- ObjectResult{&Object{ID: id, Data: v}}
|
c <- ObjectResult{&Object{ID: id, Data: v}, nil}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(c)
|
close(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mem *MemObjectStore) MakeObject(data []byte) (ObjectID, error) {
|
func (mem *MemObjectStore) PutObject(object *Object) error {
|
||||||
mem.Lock()
|
mem.Lock()
|
||||||
defer mem.Unlock()
|
defer mem.Unlock()
|
||||||
id := mem.makeID()
|
mem.db[object.ID] = object.Data
|
||||||
mem.db[id] = data
|
return nil
|
||||||
return id, nil
|
}
|
||||||
|
|
||||||
|
func (mem *MemObjectStore) MakeID() ObjectID {
|
||||||
|
id := mem.prefix + mem.printer
|
||||||
|
mem.prefix++
|
||||||
|
return ObjectID(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMemObjectStore() *MemObjectStore {
|
func NewMemObjectStore() *MemObjectStore {
|
||||||
|
prefixUuid, err := uuid.NewRandom()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
prefix := uint64(prefixUuid.ID()) << 32
|
||||||
return &MemObjectStore{
|
return &MemObjectStore{
|
||||||
db: make(map[ObjectID][]byte),
|
db: make(map[ObjectID][]byte),
|
||||||
|
prefix: prefix,
|
||||||
|
printer: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ObjectID uint64
|
||||||
|
|
||||||
|
func serializeObjectID(id ObjectID) []byte {
|
||||||
|
out := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(out, uint64(id))
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func deserializeObjectID(id_bytes []byte) ObjectID {
|
||||||
|
id := binary.BigEndian.Uint64(id_bytes)
|
||||||
|
return ObjectID(id)
|
||||||
|
}
|
||||||
|
|
|
||||||
1
python/rpc/mk_proto.sh
Executable file
1
python/rpc/mk_proto.sh
Executable file
|
|
@ -0,0 +1 @@
|
||||||
|
python -m grpc_tools.protoc -I../../ray_rpc --python_out=. --grpc_python_out=. ../../ray_rpc/ray_client.proto
|
||||||
785
python/rpc/ray_client_pb2.py
Normal file
785
python/rpc/ray_client_pb2.py
Normal file
File diff suppressed because it is too large
Load diff
226
python/rpc/ray_client_pb2_grpc.py
Normal file
226
python/rpc/ray_client_pb2_grpc.py
Normal file
|
|
@ -0,0 +1,226 @@
|
||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
import ray_client_pb2 as ray__client__pb2
|
||||||
|
|
||||||
|
|
||||||
|
class RayletDriverStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.GetObject = channel.unary_unary(
|
||||||
|
'/ray.rpc.RayletDriver/GetObject',
|
||||||
|
request_serializer=ray__client__pb2.GetRequest.SerializeToString,
|
||||||
|
response_deserializer=ray__client__pb2.GetResponse.FromString,
|
||||||
|
)
|
||||||
|
self.PutObject = channel.unary_unary(
|
||||||
|
'/ray.rpc.RayletDriver/PutObject',
|
||||||
|
request_serializer=ray__client__pb2.PutRequest.SerializeToString,
|
||||||
|
response_deserializer=ray__client__pb2.PutResponse.FromString,
|
||||||
|
)
|
||||||
|
self.WaitObject = channel.unary_unary(
|
||||||
|
'/ray.rpc.RayletDriver/WaitObject',
|
||||||
|
request_serializer=ray__client__pb2.WaitRequest.SerializeToString,
|
||||||
|
response_deserializer=ray__client__pb2.WaitResponse.FromString,
|
||||||
|
)
|
||||||
|
self.Schedule = channel.unary_unary(
|
||||||
|
'/ray.rpc.RayletDriver/Schedule',
|
||||||
|
request_serializer=ray__client__pb2.ClientTask.SerializeToString,
|
||||||
|
response_deserializer=ray__client__pb2.ClientTaskTicket.FromString,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RayletDriverServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def GetObject(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def PutObject(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def WaitObject(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Schedule(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_RayletDriverServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'GetObject': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.GetObject,
|
||||||
|
request_deserializer=ray__client__pb2.GetRequest.FromString,
|
||||||
|
response_serializer=ray__client__pb2.GetResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'PutObject': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.PutObject,
|
||||||
|
request_deserializer=ray__client__pb2.PutRequest.FromString,
|
||||||
|
response_serializer=ray__client__pb2.PutResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'WaitObject': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.WaitObject,
|
||||||
|
request_deserializer=ray__client__pb2.WaitRequest.FromString,
|
||||||
|
response_serializer=ray__client__pb2.WaitResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'Schedule': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Schedule,
|
||||||
|
request_deserializer=ray__client__pb2.ClientTask.FromString,
|
||||||
|
response_serializer=ray__client__pb2.ClientTaskTicket.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'ray.rpc.RayletDriver', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class RayletDriver(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def GetObject(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/GetObject',
|
||||||
|
ray__client__pb2.GetRequest.SerializeToString,
|
||||||
|
ray__client__pb2.GetResponse.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def PutObject(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/PutObject',
|
||||||
|
ray__client__pb2.PutRequest.SerializeToString,
|
||||||
|
ray__client__pb2.PutResponse.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def WaitObject(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/WaitObject',
|
||||||
|
ray__client__pb2.WaitRequest.SerializeToString,
|
||||||
|
ray__client__pb2.WaitResponse.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Schedule(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/ray.rpc.RayletDriver/Schedule',
|
||||||
|
ray__client__pb2.ClientTask.SerializeToString,
|
||||||
|
ray__client__pb2.ClientTaskTicket.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class RayletWorkerConnectionStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.Workstream = channel.stream_stream(
|
||||||
|
'/ray.rpc.RayletWorkerConnection/Workstream',
|
||||||
|
request_serializer=ray__client__pb2.WorkStatus.SerializeToString,
|
||||||
|
response_deserializer=ray__client__pb2.Work.FromString,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RayletWorkerConnectionServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def Workstream(self, request_iterator, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_RayletWorkerConnectionServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'Workstream': grpc.stream_stream_rpc_method_handler(
|
||||||
|
servicer.Workstream,
|
||||||
|
request_deserializer=ray__client__pb2.WorkStatus.FromString,
|
||||||
|
response_serializer=ray__client__pb2.Work.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'ray.rpc.RayletWorkerConnection', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class RayletWorkerConnection(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Workstream(request_iterator,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.stream_stream(request_iterator, target, '/ray.rpc.RayletWorkerConnection/Workstream',
|
||||||
|
ray__client__pb2.WorkStatus.SerializeToString,
|
||||||
|
ray__client__pb2.Work.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
67
python/worker.py
Normal file
67
python/worker.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
from rpc import ray_client_pb2
|
||||||
|
from rpc import ray_client_pb2_grpc
|
||||||
|
import queue
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(self, conn_str):
|
||||||
|
self.channel = grpc.insecure_channel(conn_str)
|
||||||
|
self.worker_stub = ray_client_pb2_grpc.RayletWorkerConnectionStub(self.channel)
|
||||||
|
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||||
|
|
||||||
|
def begin(self):
|
||||||
|
send_queue = queue.SimpleQueue()
|
||||||
|
work_stream = self.worker_stub.Workstream(iter(send_queue.get, None))
|
||||||
|
start = ray_client_pb2.WorkStatus(
|
||||||
|
status = ray_client_pb2.WorkStatus.StatusCode.READY
|
||||||
|
)
|
||||||
|
send_queue.put(start)
|
||||||
|
for work in work_stream:
|
||||||
|
task = work.task
|
||||||
|
if task.type != ray_client_pb2.ClientTask.RemoteExecType.FUNCTION:
|
||||||
|
send_queue.put(ray_client_pb2.WorkStatus(
|
||||||
|
status = ray_client_pb2.WorkStatus.StatusCode.ERROR,
|
||||||
|
error_msg = "unimplemented",
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
args = self.decode_args(task)
|
||||||
|
func_data = self.get(task.payload_id)
|
||||||
|
func = cloudpickle.loads(func_data)
|
||||||
|
res = func(*args)
|
||||||
|
out_data = cloudpickle.dumps(res)
|
||||||
|
send_queue.put(ray_client_pb2.WorkStatus(
|
||||||
|
status = ray_client_pb2.WorkStatus.StatusCode.COMPLETE,
|
||||||
|
complete_data = out_data,
|
||||||
|
finished_ticket = work.ticket,
|
||||||
|
))
|
||||||
|
|
||||||
|
def get(self, id_bytes):
|
||||||
|
data = self.server.GetObject(ray_client_pb2.GetRequest(
|
||||||
|
id = id_bytes,
|
||||||
|
))
|
||||||
|
return data.data
|
||||||
|
|
||||||
|
def decode_args(self, task):
|
||||||
|
out = []
|
||||||
|
for arg in arg_list:
|
||||||
|
t = self.convert_from_arg(arg)
|
||||||
|
out.append(t)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def convert_from_arg(self, pb):
|
||||||
|
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||||
|
return self.get(pb.reference_id)
|
||||||
|
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||||
|
return cloudpickle.loads(pb.data)
|
||||||
|
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
worker = Worker(os.args[1])
|
||||||
|
worker.begin()
|
||||||
|
print("Shutting down...")
|
||||||
|
worker.channel.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -84,3 +84,24 @@ service RayletDriver {
|
||||||
rpc Schedule(ClientTask) returns (ClientTaskTicket) {
|
rpc Schedule(ClientTask) returns (ClientTaskTicket) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
service RayletWorkerConnection {
|
||||||
|
rpc Workstream(stream WorkStatus) returns (stream Work) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
message WorkStatus {
|
||||||
|
enum StatusCode {
|
||||||
|
COMPLETE = 0;
|
||||||
|
ERROR = 1;
|
||||||
|
READY = 2;
|
||||||
|
}
|
||||||
|
StatusCode status = 1;
|
||||||
|
bytes complete_data = 2;
|
||||||
|
ClientTaskTicket finished_ticket = 3;
|
||||||
|
string error_msg = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Work {
|
||||||
|
ClientTask task = 1;
|
||||||
|
ClientTaskTicket ticket = 2;
|
||||||
|
}
|
||||||
|
|
|
||||||
65
raylet_grpc.go
Normal file
65
raylet_grpc.go
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/barakmich/go_raylet/ray_rpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Raylet struct {
|
||||||
|
Objects ObjectStore
|
||||||
|
Workers WorkerPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Raylet) GetObject(_ context.Context, req *ray_rpc.GetRequest) (*ray_rpc.GetResponse, error) {
|
||||||
|
data, err := GetObject(r.Objects, deserializeObjectID(req.Id))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ray_rpc.GetResponse{
|
||||||
|
Valid: true,
|
||||||
|
Data: data,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Raylet) PutObject(_ context.Context, req *ray_rpc.PutRequest) (*ray_rpc.PutResponse, error) {
|
||||||
|
id := r.Objects.MakeID()
|
||||||
|
err := r.Objects.PutObject(&Object{id, req.Data})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ray_rpc.PutResponse{
|
||||||
|
Id: serializeObjectID(id),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Raylet) WaitObject(_ context.Context, _ *ray_rpc.WaitRequest) (*ray_rpc.WaitResponse, error) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Raylet) Schedule(_ context.Context, task *ray_rpc.ClientTask) (*ray_rpc.ClientTaskTicket, error) {
|
||||||
|
id := r.Objects.MakeID()
|
||||||
|
ticket := &ray_rpc.ClientTaskTicket{serializeObjectID(id)}
|
||||||
|
work := &ray_rpc.Work{}
|
||||||
|
work.Task = task
|
||||||
|
work.Ticket = ticket
|
||||||
|
err := r.Workers.Schedule(work)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ticket, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Raylet) Workstream(conn WorkstreamConnection) error {
|
||||||
|
return r.Workers.Workstream(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
//func (r *Raylet) Workstream()
|
||||||
|
|
||||||
|
func NewMemRaylet() *Raylet {
|
||||||
|
store := NewMemObjectStore()
|
||||||
|
return &Raylet{
|
||||||
|
Objects: store,
|
||||||
|
Workers: NewRoundRobinWorkerPool(store),
|
||||||
|
}
|
||||||
|
}
|
||||||
33
servicer.go
33
servicer.go
|
|
@ -1,33 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/barakmich/go_raylet/ray_rpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Raylet struct {
|
|
||||||
Objects ObjectStore
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Raylet) GetObject(_ context.Context, req *ray_rpc.GetRequest) (*ray_rpc.GetResponse, error) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Raylet) PutObject(_ context.Context, _ *ray_rpc.PutRequest) (*ray_rpc.PutResponse, error) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Raylet) WaitObject(_ context.Context, _ *ray_rpc.WaitRequest) (*ray_rpc.WaitResponse, error) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Raylet) Schedule(_ context.Context, _ *ray_rpc.ClientTask) (*ray_rpc.ClientTaskTicket, error) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMemRaylet() *Raylet {
|
|
||||||
return &Raylet{
|
|
||||||
Objects: NewMemObjectStore(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
134
worker_pool.go
Normal file
134
worker_pool.go
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/barakmich/go_raylet/ray_rpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WorkstreamConnection = ray_rpc.RayletWorkerConnection_WorkstreamServer
|
||||||
|
|
||||||
|
type WorkerPool interface {
|
||||||
|
ray_rpc.RayletWorkerConnectionServer
|
||||||
|
Schedule(*ray_rpc.Work) error
|
||||||
|
Close() error
|
||||||
|
Finish(*ray_rpc.WorkStatus) error
|
||||||
|
Deregister(interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type SimpleRRWorkerPool struct {
|
||||||
|
sync.Mutex
|
||||||
|
workers []*SimpleWorker
|
||||||
|
store ObjectStore
|
||||||
|
offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
type SimpleWorker struct {
|
||||||
|
workChan chan *ray_rpc.Work
|
||||||
|
clientConn WorkstreamConnection
|
||||||
|
pool WorkerPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRoundRobinWorkerPool(obj ObjectStore) *SimpleRRWorkerPool {
|
||||||
|
return &SimpleRRWorkerPool{
|
||||||
|
store: obj,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wp *SimpleRRWorkerPool) Workstream(conn WorkstreamConnection) error {
|
||||||
|
wp.Lock()
|
||||||
|
defer wp.Unlock()
|
||||||
|
worker := &SimpleWorker{
|
||||||
|
workChan: make(chan *ray_rpc.Work, 10),
|
||||||
|
clientConn: conn,
|
||||||
|
pool: wp,
|
||||||
|
}
|
||||||
|
go worker.Main()
|
||||||
|
wp.workers = append(wp.workers, worker)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wp *SimpleRRWorkerPool) Schedule(work *ray_rpc.Work) error {
|
||||||
|
wp.Lock()
|
||||||
|
defer wp.Unlock()
|
||||||
|
if len(wp.workers) == 0 {
|
||||||
|
return errors.New("No workers available, try again later")
|
||||||
|
}
|
||||||
|
wp.workers[wp.offset].workChan <- work
|
||||||
|
wp.offset++
|
||||||
|
if wp.offset == len(wp.workers) {
|
||||||
|
wp.offset = 0
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wp *SimpleRRWorkerPool) Finish(status *ray_rpc.WorkStatus) error {
|
||||||
|
if status.Status != ray_rpc.COMPLETE {
|
||||||
|
panic("todo: Only call Finish on successfully completed work")
|
||||||
|
}
|
||||||
|
id := deserializeObjectID(status.FinishedTicket.ReturnId)
|
||||||
|
return wp.store.PutObject(&Object{id, status.CompleteData})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wp *SimpleRRWorkerPool) Close() error {
|
||||||
|
wp.Lock()
|
||||||
|
defer wp.Unlock()
|
||||||
|
for _, w := range wp.workers {
|
||||||
|
close(w.workChan)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wp *SimpleRRWorkerPool) Deregister(ptr interface{}) error {
|
||||||
|
wp.Lock()
|
||||||
|
defer wp.Unlock()
|
||||||
|
worker := ptr.(*SimpleWorker)
|
||||||
|
found := false
|
||||||
|
for i, w := range wp.workers {
|
||||||
|
if w == worker {
|
||||||
|
wp.workers = append(wp.workers[:i], wp.workers[i+1:]...)
|
||||||
|
if wp.offset == len(wp.workers) {
|
||||||
|
wp.offset = 0
|
||||||
|
}
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
panic("Trying to deregister a worker that was never created")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *SimpleWorker) Main() {
|
||||||
|
sentinel, err := w.clientConn.Recv()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
w.pool.Deregister(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if sentinel.Status != ray_rpc.READY {
|
||||||
|
fmt.Println("Sent wrong sentinel? Closing...")
|
||||||
|
w.pool.Deregister(w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
work, ok := <-w.workChan
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = w.clientConn.Send(work)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error sending: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result, err := w.clientConn.Recv()
|
||||||
|
err = w.pool.Finish(result)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error finishing: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.pool.Deregister(w)
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue