From 45f20bb5c3bdd29bb17de5035f90fa4e2abded37 Mon Sep 17 00:00:00 2001 From: sawka Date: Wed, 29 May 2024 23:17:23 -0700 Subject: [PATCH] wsh rpc client --- cmd/wsh/main-wsh.go | 2 - pkg/wshrpc/rpc_client.go | 264 +++++++++++++++++++++++++++++++++++++ pkg/wshrpc/rpc_test.go | 115 ++++++++++++++++ pkg/wshrpc/wshrpc.go | 58 ++++++++ pkg/wshutil/wshcommands.go | 12 ++ 5 files changed, 449 insertions(+), 2 deletions(-) create mode 100644 pkg/wshrpc/rpc_client.go create mode 100644 pkg/wshrpc/rpc_test.go create mode 100644 pkg/wshrpc/wshrpc.go diff --git a/cmd/wsh/main-wsh.go b/cmd/wsh/main-wsh.go index bcfe879d8..927af8c97 100644 --- a/cmd/wsh/main-wsh.go +++ b/cmd/wsh/main-wsh.go @@ -8,8 +8,6 @@ import ( "os" ) -const WaveOSC = "23198" - func main() { barr, err := os.ReadFile("/Users/mike/Downloads/2.png") if err != nil { diff --git a/pkg/wshrpc/rpc_client.go b/pkg/wshrpc/rpc_client.go new file mode 100644 index 000000000..f0388d1e7 --- /dev/null +++ b/pkg/wshrpc/rpc_client.go @@ -0,0 +1,264 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshprc + +import ( + "context" + "errors" + "fmt" + "log" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" +) + +// there is a single go-routine that reads from RecvCh +type RpcClient struct { + CVar *sync.Cond + NextSeqNum *atomic.Int64 + ReqPacketsInFlight map[int64]string // seqnum -> rpcId + AckList []int64 + RpcReqs map[string]*RpcInfo + SendCh chan *RpcPacket + RecvCh chan *RpcPacket +} + +type RpcInfo struct { + CloseSync *sync.Once + RpcId string + ReqPacketsInFlight map[int64]bool // seqnum -> bool + RespCh chan *RpcPacket +} + +func MakeRpcClient(sendCh chan *RpcPacket, recvCh chan *RpcPacket) *RpcClient { + if cap(sendCh) < MaxInFlightPackets { + panic(fmt.Errorf("sendCh buffer size must be at least MaxInFlightPackets(%d)", MaxInFlightPackets)) + } + rtn := &RpcClient{ + CVar: sync.NewCond(&sync.Mutex{}), + NextSeqNum: &atomic.Int64{}, + ReqPacketsInFlight: make(map[int64]string), + AckList: nil, + RpcReqs: make(map[string]*RpcInfo), + SendCh: sendCh, + RecvCh: recvCh, + } + go rtn.runRecvLoop() + return rtn +} + +func (c *RpcClient) runRecvLoop() { + defer func() { + if r := recover(); r != nil { + log.Printf("RpcClient.runRecvLoop() panic: %v", r) + debug.PrintStack() + } + }() + for pk := range c.RecvCh { + if pk.RpcType == RpcType_Resp { + c.handleResp(pk) + continue + } + log.Printf("RpcClient.runRecvLoop() bad packet type: %v", pk) + } + log.Printf("RpcClient.runRecvLoop() normal exit") +} + +func (c *RpcClient) getRpcInfo(rpcId string) *RpcInfo { + c.CVar.L.Lock() + defer c.CVar.L.Unlock() + return c.RpcReqs[rpcId] +} + +func (c *RpcClient) handleResp(pk *RpcPacket) { + c.handleAcks(pk.Acks) + if pk.RpcId == "" { + c.ackResp(pk.SeqNum) + log.Printf("RpcClient.handleResp() missing rpcId: %v", pk) + return + } + rpcInfo := c.getRpcInfo(pk.RpcId) + if rpcInfo == nil { + c.ackResp(pk.SeqNum) + log.Printf("RpcClient.handleResp() unknown rpcId: %v", pk) + return + } + select { + case rpcInfo.RespCh <- pk: + default: + log.Printf("RpcClient.handleResp() respCh full, dropping packet") + } + if pk.RespDone { + c.removeReqInfo(pk.RpcId, false) + } +} + +func (c *RpcClient) grabAcks() []int64 { + c.CVar.L.Lock() + defer c.CVar.L.Unlock() + acks := c.AckList + c.AckList = nil + return acks +} + +func (c *RpcClient) ackResp(seqNum int64) { + if seqNum == 0 { + return + } + c.CVar.L.Lock() + defer c.CVar.L.Unlock() + c.AckList = append(c.AckList, seqNum) +} + +func (c *RpcClient) waitForReq(ctx context.Context, req *RpcPacket) (*RpcInfo, error) { + c.CVar.L.Lock() + defer c.CVar.L.Unlock() + // issue with ctx timeout sync -- we need the cvar to be signaled fairly regularly so we can check ctx.Err() + for { + if ctx.Err() != nil { + return nil, ctx.Err() + } + if len(c.RpcReqs) >= MaxOpenRpcs { + c.CVar.Wait() + continue + } + if len(c.ReqPacketsInFlight) >= MaxOpenRpcs { + c.CVar.Wait() + continue + } + if rpcInfo, ok := c.RpcReqs[req.RpcId]; ok { + if len(rpcInfo.ReqPacketsInFlight) >= MaxUnackedPerRpc { + c.CVar.Wait() + continue + } + } + break + } + select { + case c.SendCh <- req: + default: + return nil, errors.New("SendCh Full") + } + c.ReqPacketsInFlight[req.SeqNum] = req.RpcId + rpcInfo := c.RpcReqs[req.RpcId] + if rpcInfo == nil { + rpcInfo = &RpcInfo{ + CloseSync: &sync.Once{}, + RpcId: req.RpcId, + ReqPacketsInFlight: make(map[int64]bool), + RespCh: make(chan *RpcPacket, MaxUnackedPerRpc), + } + rpcInfo.ReqPacketsInFlight[req.SeqNum] = true + c.RpcReqs[req.RpcId] = rpcInfo + } + return rpcInfo, nil +} + +func (c *RpcClient) handleAcks(acks []int64) { + if len(acks) == 0 { + return + } + c.CVar.L.Lock() + defer c.CVar.L.Unlock() + for _, ack := range acks { + rpcId, ok := c.ReqPacketsInFlight[ack] + if !ok { + continue + } + rpcInfo := c.RpcReqs[rpcId] + if rpcInfo != nil { + delete(rpcInfo.ReqPacketsInFlight, ack) + } + delete(c.ReqPacketsInFlight, ack) + } + c.CVar.Broadcast() +} + +func (c *RpcClient) removeReqInfo(rpcId string, clearSend bool) { + c.CVar.L.Lock() + defer c.CVar.L.Unlock() + rpcInfo := c.RpcReqs[rpcId] + delete(c.RpcReqs, rpcId) + if rpcInfo != nil { + if clearSend { + // unblock the recv loop if it happens to be waiting + // because the delete has already happens, it will not be able to send again on the channel + select { + case <-rpcInfo.RespCh: + default: + } + } + rpcInfo.CloseSync.Do(func() { + close(rpcInfo.RespCh) + }) + } +} + +func (c *RpcClient) SimpleReq(ctx context.Context, command string, data any) (any, error) { + rpcId := uuid.New().String() + seqNum := c.NextSeqNum.Add(1) + var timeoutInfo *TimeoutInfo + deadline, ok := ctx.Deadline() + if ok { + timeoutInfo = &TimeoutInfo{Deadline: deadline.UnixMilli()} + } + req := &RpcPacket{ + Command: command, + RpcId: rpcId, + RpcType: RpcType_Req, + SeqNum: seqNum, + ReqDone: true, + Acks: c.grabAcks(), + Timeout: timeoutInfo, + Data: data, + } + rpcInfo, err := c.waitForReq(ctx, req) + if err != nil { + return nil, err + } + defer c.removeReqInfo(rpcId, true) + var rtnPacket *RpcPacket + select { + case <-ctx.Done(): + return nil, ctx.Err() + case rtnPacket = <-rpcInfo.RespCh: + // fallthrough + } + if rtnPacket.Error != "" { + return nil, errors.New(rtnPacket.Error) + } + return rtnPacket.Data, nil +} + +func (c *RpcClient) StreamReq(ctx context.Context, command string, data any, respTimeout time.Duration) (chan *RpcPacket, error) { + rpcId := uuid.New().String() + seqNum := c.NextSeqNum.Add(1) + var timeoutInfo *TimeoutInfo = &TimeoutInfo{RespPacketTimeout: respTimeout.Milliseconds()} + deadline, ok := ctx.Deadline() + if ok { + timeoutInfo.Deadline = deadline.UnixMilli() + } + req := &RpcPacket{ + Command: command, + RpcId: rpcId, + RpcType: RpcType_Req, + SeqNum: seqNum, + ReqDone: true, + Acks: c.grabAcks(), + Timeout: timeoutInfo, + Data: data, + } + rpcInfo, err := c.waitForReq(ctx, req) + if err != nil { + return nil, err + } + return rpcInfo.RespCh, nil +} + +func (c *RpcClient) EndStreamReq(rpcId string) { + c.removeReqInfo(rpcId, true) +} diff --git a/pkg/wshrpc/rpc_test.go b/pkg/wshrpc/rpc_test.go new file mode 100644 index 000000000..4f8ed783b --- /dev/null +++ b/pkg/wshrpc/rpc_test.go @@ -0,0 +1,115 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshprc + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestSimple(t *testing.T) { + sendCh := make(chan *RpcPacket, MaxInFlightPackets) + recvCh := make(chan *RpcPacket, MaxInFlightPackets) + client := MakeRpcClient(sendCh, recvCh) + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + resp, err := client.SimpleReq(ctx, "test", "hello") + if err != nil { + t.Errorf("SimpleReq() failed: %v", err) + return + } + if resp != "world" { + t.Errorf("SimpleReq() failed: expected 'world', got '%s'", resp) + } + }() + go func() { + defer wg.Done() + req := <-sendCh + if req.Command != "test" { + t.Errorf("expected 'test', got '%s'", req.Command) + } + if req.Data != "hello" { + t.Errorf("expected 'hello', got '%s'", req.Data) + } + resp := &RpcPacket{ + Command: "test", + RpcId: req.RpcId, + RpcType: RpcType_Resp, + SeqNum: 1, + RespDone: true, + Acks: []int64{req.SeqNum}, + Data: "world", + } + recvCh <- resp + }() + wg.Wait() +} + +func makeRpcResp(req *RpcPacket, data any, seqNum int64, done bool) *RpcPacket { + return &RpcPacket{ + Command: req.Command, + RpcId: req.RpcId, + RpcType: RpcType_Resp, + SeqNum: seqNum, + RespDone: done, + Data: data, + } +} + +func TestStream(t *testing.T) { + sendCh := make(chan *RpcPacket, MaxInFlightPackets) + recvCh := make(chan *RpcPacket, MaxInFlightPackets) + client := MakeRpcClient(sendCh, recvCh) + ctx, cancelFn := context.WithTimeout(context.Background(), 2*time.Second) + defer cancelFn() + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + respCh, err := client.StreamReq(ctx, "test", "hello", 1000) + if err != nil { + t.Errorf("StreamReq() failed: %v", err) + return + } + var output []string + for resp := range respCh { + if resp.Error != "" { + t.Errorf("StreamReq() failed: %v", resp.Error) + return + } + output = append(output, resp.Data.(string)) + } + if len(output) != 3 { + t.Errorf("expected 3 responses, got %d (%v)", len(output), output) + return + } + if output[0] != "one" || output[1] != "two" || output[2] != "three" { + t.Errorf("expected 'one', 'two', 'three', got %v", output) + return + } + }() + go func() { + defer wg.Done() + req := <-sendCh + if req.Command != "test" { + t.Errorf("expected 'test', got '%s'", req.Command) + } + if req.Data != "hello" { + t.Errorf("expected 'hello', got '%s'", req.Data) + } + resp := makeRpcResp(req, "one", 1, false) + recvCh <- resp + resp = makeRpcResp(req, "two", 2, false) + recvCh <- resp + resp = makeRpcResp(req, "three", 3, true) + recvCh <- resp + }() + wg.Wait() +} diff --git a/pkg/wshrpc/wshrpc.go b/pkg/wshrpc/wshrpc.go new file mode 100644 index 000000000..fe9034fb8 --- /dev/null +++ b/pkg/wshrpc/wshrpc.go @@ -0,0 +1,58 @@ +// Copyright 2024, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package wshprc + +import ( + "context" +) + +const ( + MaxOpenRpcs = 10 + MaxUnackedPerRpc = 10 + MaxInFlightPackets = MaxOpenRpcs * MaxUnackedPerRpc +) + +const ( + RpcType_Req = "req" + RpcType_Resp = "resp" +) + +const ( + CommandType_Ack = ":ack" + CommandType_Ping = ":ping" + CommandType_Cancel = ":cancel" + CommandType_Timeout = ":timeout" +) + +var rpcClientContextKey = struct{}{} + +type TimeoutInfo struct { + Deadline int64 `json:"deadline,omitempty"` + ReqPacketTimeout int64 `json:"reqpackettimeout,omitempty"` // for streaming requests + RespPacketTimeout int64 `json:"resppackettimeout,omitempty"` // for streaming responses +} + +type RpcPacket struct { + Command string `json:"command"` + RpcId string `json:"rpcid"` + RpcType string `json:"rpctype"` + SeqNum int64 `json:"seqnum"` + ReqDone bool `json:"reqdone"` + RespDone bool `json:"resdone"` + Acks []int64 `json:"acks,omitempty"` // seqnums acked + Timeout *TimeoutInfo `json:"timeout,omitempty"` // for initial request only + Data any `json:"data"` // json data for command + Error string `json:"error,omitempty"` +} + +func GetRpcClient(ctx context.Context) *RpcClient { + if ctx == nil { + return nil + } + val := ctx.Value(rpcClientContextKey) + if val == nil { + return nil + } + return val.(*RpcClient) +} diff --git a/pkg/wshutil/wshcommands.go b/pkg/wshutil/wshcommands.go index 5051b09ae..53fc7b8f7 100644 --- a/pkg/wshutil/wshcommands.go +++ b/pkg/wshutil/wshcommands.go @@ -9,6 +9,7 @@ const ( CommandSetView = "setview" CommandSetMeta = "setmeta" CommandBlockFileAppend = "blockfile:append" + CommandStreamFile = "streamfile" ) var CommandToTypeMap = map[string]reflect.Type{ @@ -23,6 +24,8 @@ type Command interface { // for unmarshalling type baseCommand struct { Command string `json:"command"` + RpcID string `json:"rpcid"` + RpcType string `json:"rpctype"` } type SetViewCommand struct { @@ -52,3 +55,12 @@ type BlockFileAppendCommand struct { func (bfac *BlockFileAppendCommand) GetCommand() string { return CommandBlockFileAppend } + +type StreamFileCommand struct { + Command string `json:"command"` + FileName string `json:"filename"` +} + +func (c *StreamFileCommand) GetCommand() string { + return CommandStreamFile +}