Skip to content

bson encoding for wshrpc #1895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions emain/emain-wavesrv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
getXdgCurrentDesktop,
WaveConfigHomeVarName,
WaveDataHomeVarName,
WshEncTypeVarName,
} from "./platform";
import { updater } from "./updater";

Expand Down Expand Up @@ -62,6 +63,7 @@ export function runWaveSrv(handleWSEvent: (evtMsg: WSEventType) => void): Promis
envCopy[WaveAuthKeyEnv] = AuthKey;
envCopy[WaveDataHomeVarName] = getWaveDataDir();
envCopy[WaveConfigHomeVarName] = getWaveConfigDir();
envCopy[WshEncTypeVarName] = "bson";
const waveSrvCmd = getWaveSrvPath();
console.log("trying to run local server", waveSrvCmd);
const proc = child_process.spawn(getWaveSrvPath(), {
Expand Down
2 changes: 2 additions & 0 deletions emain/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ keyutil.setKeyUtilPlatform(unamePlatform);
const WaveConfigHomeVarName = "WAVETERM_CONFIG_HOME";
const WaveDataHomeVarName = "WAVETERM_DATA_HOME";
const WaveHomeVarName = "WAVETERM_HOME";
const WshEncTypeVarName = "WSH_ENC_TYPE";

export function checkIfRunningUnderARM64Translation(fullConfig: FullConfigType) {
if (!fullConfig.settings["app:dismissarchitecturewarning"] && app.runningUnderARM64Translation) {
Expand Down Expand Up @@ -270,4 +271,5 @@ export {
unamePlatform,
WaveConfigHomeVarName,
WaveDataHomeVarName,
WshEncTypeVarName,
};
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ require (
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.mongodb.org/mongo-driver/v2 v2.0.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.32.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34 h1:I8VZVTZE
github.com/wavetermdev/ssh_config v0.0.0-20241219203747-6409e4292f34/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.mongodb.org/mongo-driver/v2 v2.0.0 h1:Jfd7XpdZa9yk3eY774bO7SWVb30noLSirL9nKTpavhI=
go.mongodb.org/mongo-driver/v2 v2.0.0/go.mod h1:nSjmNq4JUstE8IRZKTktLgMHM4F1fccL6HGX1yh+8RA=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
Expand Down
2 changes: 2 additions & 0 deletions pkg/blockcontroller/blockcontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshencode"
"github.com/wavetermdev/waveterm/pkg/wshutil"
"github.com/wavetermdev/waveterm/pkg/wslconn"
"github.com/wavetermdev/waveterm/pkg/wstore"
Expand Down Expand Up @@ -391,6 +392,7 @@ func (bc *BlockController) makeSwapToken(ctx context.Context, logCtx context.Con
token.Env["WAVETERM_BLOCKID"] = bc.BlockId
token.Env["WAVETERM_VERSION"] = wavebase.WaveVersion
token.Env["WAVETERM"] = "1"
token.Env[wshencode.EncTypeEnvVar] = wshencode.GetEncTypeFromEnv()
tabId, err := wstore.DBFindTabForBlockId(ctx, bc.BlockId)
if err != nil {
log.Printf("error finding tab for block: %v\n", err)
Expand Down
52 changes: 52 additions & 0 deletions pkg/wshrpc/wshencode/wshencode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package wshencode

import (
"encoding/json"
"fmt"
"os"

"go.mongodb.org/mongo-driver/v2/bson"
)

const (
EncTypeJson = "json"
EncTypeBson = "bson"
EncTypeEnvVar = "WSH_ENC_TYPE"
UnsupportedEncTypeErr = "unsupported encoding type: %s"
)

type EncoderDecoder struct {
EncType string
}

func MakeEncoderDecoder() *EncoderDecoder {
return &EncoderDecoder{
EncType: GetEncTypeFromEnv(),
}
}

func (e EncoderDecoder) Marshal(v interface{}) ([]byte, error) {
if e.EncType == EncTypeJson {
return json.Marshal(v)
} else if e.EncType == EncTypeBson {
return bson.MarshalExtJSON(v, true, false)
}
return nil, fmt.Errorf(UnsupportedEncTypeErr, e.EncType)
}

func (e EncoderDecoder) Unmarshal(data []byte, v interface{}) error {
if e.EncType == EncTypeJson {
return json.Unmarshal(data, v)
} else if e.EncType == EncTypeBson {
return bson.UnmarshalExtJSON(data, true, v)
}
return fmt.Errorf(UnsupportedEncTypeErr, e.EncType)
}

func GetEncTypeFromEnv() string {
encType := EncTypeJson
if envEncType := os.Getenv(EncTypeEnvVar); envEncType != "" {
encType = envEncType
}
return encType
}
20 changes: 11 additions & 9 deletions pkg/wshutil/wshrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package wshutil

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
Expand All @@ -17,6 +16,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshencode"
)

const (
Expand Down Expand Up @@ -55,6 +55,7 @@ type WshRouter struct {
RpcMap map[string]*routeInfo // rpcid => routeinfo
SimpleRequestMap map[string]chan *RpcMessage // simple reqid => response channel
InputCh chan msgAndRoute
EncDec *wshencode.EncoderDecoder // abstract encoder/decoder
}

func MakeConnectionRouteId(connId string) string {
Expand Down Expand Up @@ -87,6 +88,7 @@ func NewWshRouter() *WshRouter {
RpcMap: make(map[string]*routeInfo),
SimpleRequestMap: make(map[string]chan *RpcMessage),
InputCh: make(chan msgAndRoute, DefaultInputChSize),
EncDec: wshencode.MakeEncoderDecoder(),
}
go rtn.runServer()
return rtn
Expand All @@ -112,7 +114,7 @@ func (router *WshRouter) SendEvent(routeId string, event wps.WaveEvent) {
Route: routeId,
Data: event,
}
msgBytes, err := json.Marshal(msg)
msgBytes, err := router.EncDec.Marshal(msg)
if err != nil {
// nothing to do
return
Expand All @@ -129,7 +131,7 @@ func (router *WshRouter) handleNoRoute(msg RpcMessage) {
}
// no response needed, but send message back to source
respMsg := RpcMessage{Command: wshrpc.Command_Message, Route: msg.Source, Data: wshrpc.CommandMessageData{Message: nrErr.Error()}}
respBytes, _ := json.Marshal(respMsg)
respBytes, _ := router.EncDec.Marshal(respMsg)
router.InputCh <- msgAndRoute{msgBytes: respBytes, fromRouteId: SysRoute}
return
}
Expand All @@ -138,7 +140,7 @@ func (router *WshRouter) handleNoRoute(msg RpcMessage) {
ResId: msg.ReqId,
Error: nrErr.Error(),
}
respBytes, _ := json.Marshal(response)
respBytes, _ := router.EncDec.Marshal(response)
router.sendRoutedMessage(respBytes, msg.Source)
}

Expand Down Expand Up @@ -220,7 +222,7 @@ func (router *WshRouter) runServer() {
for input := range router.InputCh {
msgBytes := input.msgBytes
var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
err := router.EncDec.Unmarshal(msgBytes, &msg)
if err != nil {
fmt.Println("error unmarshalling message: ", err)
continue
Expand Down Expand Up @@ -315,7 +317,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, sh
// announce
if shouldAnnounce && !alreadyExists && router.GetUpstreamClient() != nil {
announceMsg := RpcMessage{Command: wshrpc.Command_RouteAnnounce, Source: routeId}
announceBytes, _ := json.Marshal(announceMsg)
announceBytes, _ := router.EncDec.Marshal(announceMsg)
router.GetUpstreamClient().SendRpcMessage(announceBytes)
}
for {
Expand All @@ -324,7 +326,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, sh
break
}
var rpcMsg RpcMessage
err := json.Unmarshal(msgBytes, &rpcMsg)
err := router.EncDec.Unmarshal(msgBytes, &rpcMsg)
if err != nil {
continue
}
Expand All @@ -335,7 +337,7 @@ func (router *WshRouter) RegisterRoute(routeId string, rpc AbstractRpcClient, sh
if rpcMsg.Route == "" {
rpcMsg.Route = DefaultRoute
}
msgBytes, err = json.Marshal(rpcMsg)
msgBytes, err = router.EncDec.Marshal(rpcMsg)
if err != nil {
continue
}
Expand Down Expand Up @@ -418,7 +420,7 @@ func (router *WshRouter) RunSimpleRawCommand(ctx context.Context, msg RpcMessage
if msg.Command == "" {
return nil, errors.New("no command")
}
msgBytes, err := json.Marshal(msg)
msgBytes, err := router.EncDec.Marshal(msg)
if err != nil {
return nil, err
}
Expand Down
16 changes: 9 additions & 7 deletions pkg/wshutil/wshrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package wshutil

import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
Expand All @@ -20,6 +19,7 @@ import (
"github.com/wavetermdev/waveterm/pkg/util/utilfn"
"github.com/wavetermdev/waveterm/pkg/wps"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshencode"
)

const DefaultTimeoutMs = 5000
Expand Down Expand Up @@ -57,6 +57,7 @@ type WshRpc struct {
Debug bool
DebugName string
ServerDone bool
EncDec *wshencode.EncoderDecoder
}

type wshRpcContextKey struct{}
Expand Down Expand Up @@ -219,6 +220,7 @@ func MakeWshRpc(inputCh chan []byte, outputCh chan []byte, rpcCtx wshrpc.RpcCont
EventListener: MakeEventListener(),
ServerImpl: serverImpl,
ResponseHandlerMap: make(map[string]*RpcResponseHandler),
EncDec: wshencode.MakeEncoderDecoder(),
}
rtn.RpcContext.Store(&rpcCtx)
go rtn.runServer()
Expand Down Expand Up @@ -358,7 +360,7 @@ outer:
}

var msg RpcMessage
err := json.Unmarshal(msgBytes, &msg)
err := w.EncDec.Unmarshal(msgBytes, &msg)
if err != nil {
log.Printf("wshrpc received bad message: %v\n", err)
continue
Expand Down Expand Up @@ -503,7 +505,7 @@ func (handler *RpcRequestHandler) SendCancel() {
ReqId: handler.reqId,
AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
barr, _ := handler.w.EncDec.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
handler.finalize()
}
Expand Down Expand Up @@ -600,7 +602,7 @@ func (handler *RpcResponseHandler) SendMessage(msg string) {
},
AuthToken: handler.w.GetAuthToken(),
}
msgBytes, _ := json.Marshal(rpcMsg) // will never fail
msgBytes, _ := handler.w.EncDec.Marshal(rpcMsg) // will never fail
handler.w.OutputCh <- msgBytes
}

Expand All @@ -623,7 +625,7 @@ func (handler *RpcResponseHandler) SendResponse(data any, done bool) error {
Cont: !done,
AuthToken: handler.w.GetAuthToken(),
}
barr, err := json.Marshal(msg)
barr, err := handler.w.EncDec.Marshal(msg)
if err != nil {
return err
}
Expand All @@ -644,7 +646,7 @@ func (handler *RpcResponseHandler) SendResponseError(err error) {
Error: err.Error(),
AuthToken: handler.w.GetAuthToken(),
}
barr, _ := json.Marshal(msg) // will never fail
barr, _ := handler.w.EncDec.Marshal(msg) // will never fail
handler.w.OutputCh <- barr
}

Expand Down Expand Up @@ -710,7 +712,7 @@ func (w *WshRpc) SendComplexRequest(command string, data any, opts *wshrpc.RpcOp
Route: opts.Route,
AuthToken: w.GetAuthToken(),
}
barr, err := json.Marshal(req)
barr, err := w.EncDec.Marshal(req)
if err != nil {
return nil, err
}
Expand Down
17 changes: 0 additions & 17 deletions pkg/wshutil/wshutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,23 +221,6 @@ func SetupTerminalRpcClient(serverImpl ServerImpl, debugStr string) (*WshRpc, io
return rpcClient, ptyBuf
}

func SetupPacketRpcClient(input io.Reader, output io.Writer, serverImpl ServerImpl, debugStr string) (*WshRpc, chan []byte) {
messageCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
rawCh := make(chan []byte, DefaultOutputChSize)
rpcClient := MakeWshRpc(messageCh, outputCh, wshrpc.RpcContext{}, serverImpl, debugStr)
go packetparser.Parse(input, messageCh, rawCh)
go func() {
defer func() {
panichandler.PanicHandler("SetupPacketRpcClient:outputloop", recover())
}()
for msg := range outputCh {
packetparser.WritePacket(output, msg)
}
}()
return rpcClient, rawCh
}

func SetupConnRpcClient(conn net.Conn, serverImpl ServerImpl, debugStr string) (*WshRpc, chan error, error) {
inputCh := make(chan []byte, DefaultInputChSize)
outputCh := make(chan []byte, DefaultOutputChSize)
Expand Down