diff --git a/rpc/client/httpclient.go b/rpc/client/httpclient.go index d068ee95..7b09b5bd 100644 --- a/rpc/client/httpclient.go +++ b/rpc/client/httpclient.go @@ -226,7 +226,9 @@ func (w *WSEvents) Start() (bool, error) { st, err := w.EventSwitch.Start() // if we did start, then OnStart here... if st && err == nil { - ws := rpcclient.NewWSClient(w.remote, w.endpoint) + ws := rpcclient.NewWSClient(w.remote, w.endpoint, rpcclient.OnReconnect(func() { + w.redoSubscriptions() + })) _, err = ws.Start() if err == nil { w.ws = ws @@ -335,8 +337,6 @@ func (w *WSEvents) eventListener() { // before cleaning up the w.ws stuff w.done <- true return - case <-w.ws.ReconnectCh: - w.redoSubscriptions() } } } diff --git a/rpc/lib/client/ws_client.go b/rpc/lib/client/ws_client.go index 788cb860..2bdfa5c9 100644 --- a/rpc/lib/client/ws_client.go +++ b/rpc/lib/client/ws_client.go @@ -41,9 +41,11 @@ type WSClient struct { PingPongLatencyTimer metrics.Timer // user facing channels, closed only when the client is being stopped. - ResultsCh chan json.RawMessage - ErrorsCh chan error - ReconnectCh chan bool + ResultsCh chan json.RawMessage + ErrorsCh chan error + + // Callback, which will be called each time after successful reconnect. + onReconnect func() // internal channels send chan types.RPCRequest // user requests @@ -125,6 +127,14 @@ func PingPeriod(pingPeriod time.Duration) func(*WSClient) { } } +// OnReconnect sets the callback, which will be called every time after +// successful reconnect. +func OnReconnect(cb func()) func(*WSClient) { + return func(c *WSClient) { + c.onReconnect = cb + } +} + // String returns WS client full address. func (c *WSClient) String() string { return fmt.Sprintf("%s (%s)", c.Address, c.Endpoint) @@ -140,7 +150,6 @@ func (c *WSClient) OnStart() error { c.ResultsCh = make(chan json.RawMessage) c.ErrorsCh = make(chan error) - c.ReconnectCh = make(chan bool) c.send = make(chan types.RPCRequest) // 1 additional error may come from the read/write @@ -256,7 +265,9 @@ func (c *WSClient) reconnect() error { c.Logger.Error("failed to redial", "err", err) } else { c.Logger.Info("reconnected") - c.ReconnectCh <- true + if c.onReconnect != nil { + go c.onReconnect() + } return nil } diff --git a/rpc/lib/client/ws_client_test.go b/rpc/lib/client/ws_client_test.go index e90fc29d..f5aa027f 100644 --- a/rpc/lib/client/ws_client_test.go +++ b/rpc/lib/client/ws_client_test.go @@ -186,8 +186,6 @@ func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) { if err != nil { t.Fatalf("unexpected error: %v", err) } - case <-c.ReconnectCh: - t.Log("Reconnected") case <-c.Quit: return }