diff --git a/lite/proxy/proxy.go b/lite/proxy/proxy.go index d7ffb27d..4ac3cc0d 100644 --- a/lite/proxy/proxy.go +++ b/lite/proxy/proxy.go @@ -1,7 +1,9 @@ package proxy import ( + "context" "net/http" + "time" amino "github.com/tendermint/go-amino" "github.com/tendermint/tendermint/libs/log" @@ -34,7 +36,14 @@ func StartProxy(c rpcclient.Client, listenAddr string, logger log.Logger, maxOpe mux := http.NewServeMux() rpcserver.RegisterRPCFuncs(mux, r, cdc, logger) - wm := rpcserver.NewWebsocketManager(r, cdc, rpcserver.EventSubscriber(c)) + wm := rpcserver.NewWebsocketManager(r, cdc, rpcserver.OnDisconnect(func(remoteAddr string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := c.UnsubscribeAll(ctx, remoteAddr) + if err != nil { + logger.Error("Failed to unsubscribe from events", "err", err) + } + })) wm.SetLogger(logger) core.SetLogger(logger) mux.HandleFunc(wsEndpoint, wm.WebsocketHandler) @@ -51,13 +60,11 @@ func StartProxy(c rpcclient.Client, listenAddr string, logger log.Logger, maxOpe // // if we want security, the client must implement it as a secure client func RPCRoutes(c rpcclient.Client) map[string]*rpcserver.RPCFunc { - return map[string]*rpcserver.RPCFunc{ // Subscribe/unsubscribe are reserved for websocket events. - // We can just use the core tendermint impl, which uses the - // EventSwitch we registered in NewWebsocketManager above - "subscribe": rpcserver.NewWSRPCFunc(core.Subscribe, "query"), - "unsubscribe": rpcserver.NewWSRPCFunc(core.Unsubscribe, "query"), + "subscribe": rpcserver.NewWSRPCFunc(c.Subscribe, "query"), + "unsubscribe": rpcserver.NewWSRPCFunc(c.Unsubscribe, "query"), + "unsubscribe_all": rpcserver.NewWSRPCFunc(c.UnsubscribeAll, ""), // info API "status": rpcserver.NewRPCFunc(c.Status, ""), diff --git a/node/node.go b/node/node.go index 6ba84203..061517c0 100644 --- a/node/node.go +++ b/node/node.go @@ -678,8 +678,17 @@ func (n *Node) startRPC() ([]net.Listener, error) { for i, listenAddr := range listenAddrs { mux := http.NewServeMux() rpcLogger := n.Logger.With("module", "rpc-server") - wm := rpcserver.NewWebsocketManager(rpccore.Routes, coreCodec, rpcserver.EventSubscriber(n.eventBus)) - wm.SetLogger(rpcLogger.With("protocol", "websocket")) + wmLogger := rpcLogger.With("protocol", "websocket") + wm := rpcserver.NewWebsocketManager(rpccore.Routes, coreCodec, + rpcserver.OnDisconnect(func(remoteAddr string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := n.eventBus.UnsubscribeAll(ctx, remoteAddr) + if err != nil { + wmLogger.Error("Failed to unsubscribe addr from events", "addr", remoteAddr, "err", err) + } + })) + wm.SetLogger(wmLogger) mux.HandleFunc("/websocket", wm.WebsocketHandler) rpcserver.RegisterRPCFuncs(mux, rpccore.Routes, coreCodec, rpcLogger) diff --git a/rpc/core/events.go b/rpc/core/events.go index e2a7b3c4..5139cd38 100644 --- a/rpc/core/events.go +++ b/rpc/core/events.go @@ -10,7 +10,6 @@ import ( tmquery "github.com/tendermint/tendermint/libs/pubsub/query" ctypes "github.com/tendermint/tendermint/rpc/core/types" rpctypes "github.com/tendermint/tendermint/rpc/lib/types" - tmtypes "github.com/tendermint/tendermint/types" ) // Subscribe for events via WebSocket. @@ -94,9 +93,9 @@ import ( func Subscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultSubscribe, error) { addr := wsCtx.GetRemoteAddr() - if eventBusFor(wsCtx).NumClients() > MaxSubscriptionClients { + if eventBus.NumClients() > MaxSubscriptionClients { return nil, fmt.Errorf("max_subscription_clients %d reached", MaxSubscriptionClients) - } else if eventBusFor(wsCtx).NumClientSubscriptions(addr) > MaxSubscriptionsPerClient { + } else if eventBus.NumClientSubscriptions(addr) > MaxSubscriptionsPerClient { return nil, fmt.Errorf("max_subscriptions_per_client %d reached", MaxSubscriptionsPerClient) } @@ -109,7 +108,7 @@ func Subscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultSubscri ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout) defer cancel() - sub, err := eventBusFor(wsCtx).Subscribe(ctx, addr, q) + sub, err := eventBus.Subscribe(ctx, addr, q) if err != nil { return nil, err } @@ -179,7 +178,7 @@ func Unsubscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultUnsub if err != nil { return nil, errors.Wrap(err, "failed to parse query") } - err = eventBusFor(wsCtx).Unsubscribe(context.Background(), addr, q) + err = eventBus.Unsubscribe(context.Background(), addr, q) if err != nil { return nil, err } @@ -213,17 +212,9 @@ func Unsubscribe(wsCtx rpctypes.WSRPCContext, query string) (*ctypes.ResultUnsub func UnsubscribeAll(wsCtx rpctypes.WSRPCContext) (*ctypes.ResultUnsubscribe, error) { addr := wsCtx.GetRemoteAddr() logger.Info("Unsubscribe from all", "remote", addr) - err := eventBusFor(wsCtx).UnsubscribeAll(context.Background(), addr) + err := eventBus.UnsubscribeAll(context.Background(), addr) if err != nil { return nil, err } return &ctypes.ResultUnsubscribe{}, nil } - -func eventBusFor(wsCtx rpctypes.WSRPCContext) tmtypes.EventBusSubscriber { - es := wsCtx.GetEventSubscriber() - if es == nil { - es = eventBus - } - return es -} diff --git a/rpc/lib/server/handlers.go b/rpc/lib/server/handlers.go index 80eb4308..d3967727 100644 --- a/rpc/lib/server/handlers.go +++ b/rpc/lib/server/handlers.go @@ -2,7 +2,6 @@ package rpcserver import ( "bytes" - "context" "encoding/hex" "encoding/json" "fmt" @@ -434,8 +433,8 @@ type wsConnection struct { // Send pings to server with this period. Must be less than readWait, but greater than zero. pingPeriod time.Duration - // object that is used to subscribe / unsubscribe from events - eventSub types.EventSubscriber + // callback which is called upon disconnect + onDisconnect func(remoteAddr string) } // NewWSConnection wraps websocket.Conn. @@ -468,12 +467,11 @@ func NewWSConnection( return wsc } -// EventSubscriber sets object that is used to subscribe / unsubscribe from -// events - not Goroutine-safe. If none given, default node's eventBus will be -// used. -func EventSubscriber(eventSub types.EventSubscriber) func(*wsConnection) { +// OnDisconnect sets a callback which is used upon disconnect - not +// Goroutine-safe. Nop by default. +func OnDisconnect(onDisconnect func(remoteAddr string)) func(*wsConnection) { return func(wsc *wsConnection) { - wsc.eventSub = eventSub + wsc.onDisconnect = onDisconnect } } @@ -527,8 +525,8 @@ func (wsc *wsConnection) OnStop() { // Both read and write loops close the websocket connection when they exit their loops. // The writeChan is never closed, to allow WriteRPCResponse() to fail. - if wsc.eventSub != nil { - wsc.eventSub.UnsubscribeAll(context.TODO(), wsc.remoteAddr) + if wsc.onDisconnect != nil { + wsc.onDisconnect(wsc.remoteAddr) } } @@ -538,11 +536,6 @@ func (wsc *wsConnection) GetRemoteAddr() string { return wsc.remoteAddr } -// GetEventSubscriber implements WSRPCConnection by returning event subscriber. -func (wsc *wsConnection) GetEventSubscriber() types.EventSubscriber { - return wsc.eventSub -} - // WriteRPCResponse pushes a response to the writeChan, and blocks until it is accepted. // It implements WSRPCConnection. It is Goroutine-safe. func (wsc *wsConnection) WriteRPCResponse(resp types.RPCResponse) { diff --git a/rpc/lib/types/types.go b/rpc/lib/types/types.go index 13b30cda..ceb7be83 100644 --- a/rpc/lib/types/types.go +++ b/rpc/lib/types/types.go @@ -1,7 +1,6 @@ package rpctypes import ( - "context" "encoding/json" "fmt" "reflect" @@ -10,9 +9,6 @@ import ( "github.com/pkg/errors" amino "github.com/tendermint/go-amino" - - tmpubsub "github.com/tendermint/tendermint/libs/pubsub" - tmtypes "github.com/tendermint/tendermint/types" ) // a wrapper to emulate a sum type: jsonrpcid = string | int @@ -241,7 +237,6 @@ type WSRPCConnection interface { GetRemoteAddr() string WriteRPCResponse(resp RPCResponse) TryWriteRPCResponse(resp RPCResponse) bool - GetEventSubscriber() EventSubscriber Codec() *amino.Codec } @@ -251,16 +246,6 @@ type WSRPCContext struct { WSRPCConnection } -// EventSubscriber mirrors tendermint/tendermint/types.EventBusSubscriber -type EventSubscriber interface { - Subscribe(ctx context.Context, subscriber string, query tmpubsub.Query, outCapacity ...int) (tmtypes.Subscription, error) - Unsubscribe(ctx context.Context, subscriber string, query tmpubsub.Query) error - UnsubscribeAll(ctx context.Context, subscriber string) error - - NumClients() int - NumClientSubscriptions(clientID string) int -} - //---------------------------------------- // SOCKETS //