tinkerbell/object.go

167 lines
3 KiB
Go

package main
import (
"encoding/binary"
"errors"
"sync"
"time"
"github.com/google/uuid"
)
var ErrObjectNotFound = errors.New("object not found")
type ObjectStore interface {
AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration)
PutObject(object *Object) error
MakeID() ObjectID
}
type ObjectResult struct {
Object *Object
Error error
}
type Object struct {
ID ObjectID
Data []byte
}
func GetObject(s ObjectStore, id ObjectID) ([]byte, error) {
c := make(chan ObjectResult)
ids := []ObjectID{id}
go func() {
s.AwaitObjects(ids, c, nil)
}()
obj, ok := <-c
if !ok {
return nil, errors.New("Couldn't get object")
}
if obj.Error != nil {
return nil, obj.Error
}
return obj.Object.Data, nil
}
type MemObjectStore struct {
sync.RWMutex
db map[ObjectID][]byte
prefix uint64
printer uint64
subscribers []chan *Object
}
func (mem *MemObjectStore) AwaitObjects(ids []ObjectID, c chan ObjectResult, timeout *time.Duration) {
waitchan := make(chan *Object)
found := 0
waitingFor := make(map[ObjectID]bool)
// First, get everything we can
mem.Lock()
for _, id := range ids {
v, ok := mem.db[id]
if !ok {
waitingFor[id] = true
} else {
found++
c <- ObjectResult{&Object{ID: id, Data: v}, nil}
}
}
// We were lucky, and are done.
if found == len(ids) {
mem.Unlock()
close(c)
return
}
// Wait for the rest
mem.subscribe(waitchan)
mem.Unlock()
var timer <-chan time.Time
if timeout != nil {
timer = time.After(*timeout)
} else {
timer = make(<-chan time.Time)
}
for found != len(ids) {
giveUp := false
select {
case o := <-waitchan:
if waitingFor[o.ID] {
c <- ObjectResult{o, nil}
waitingFor[o.ID] = false
found += 1
}
case <-timer:
giveUp = true
}
if giveUp {
break
}
}
mem.Unsubscribe(waitchan)
close(c)
}
func (mem *MemObjectStore) subscribe(c chan *Object) {
mem.subscribers = append(mem.subscribers, c)
}
func (mem *MemObjectStore) Unsubscribe(c chan *Object) {
mem.Lock()
defer mem.Unlock()
mem.unsubscribe(c)
}
func (mem *MemObjectStore) unsubscribe(c chan *Object) {
for i, s := range mem.subscribers {
if s == c {
mem.subscribers = append(mem.subscribers[:i], mem.subscribers[i+1:]...)
break
}
}
}
func (mem *MemObjectStore) PutObject(object *Object) error {
mem.Lock()
defer mem.Unlock()
mem.db[object.ID] = object.Data
for _, s := range mem.subscribers {
s <- object
}
return nil
}
func (mem *MemObjectStore) MakeID() ObjectID {
id := mem.prefix + mem.printer
mem.prefix++
return ObjectID(id)
}
func NewMemObjectStore() *MemObjectStore {
prefixUuid, err := uuid.NewRandom()
if err != nil {
panic(err)
}
prefix := uint64(prefixUuid.ID()) << 32
return &MemObjectStore{
db: make(map[ObjectID][]byte),
prefix: prefix,
printer: 1,
}
}
type ObjectID uint64
func serializeObjectID(id ObjectID) []byte {
out := make([]byte, 8)
binary.BigEndian.PutUint64(out, uint64(id))
return out
}
func deserializeObjectID(id_bytes []byte) ObjectID {
id := binary.BigEndian.Uint64(id_bytes)
return ObjectID(id)
}