diff --git a/cmd/priv_val_server/main.go b/cmd/priv_val_server/main.go index 57c3355f..0d18f8ed 100644 --- a/cmd/priv_val_server/main.go +++ b/cmd/priv_val_server/main.go @@ -12,36 +12,41 @@ import ( func main() { var ( + addr = flag.String("addr", ":46659", "Address of client to connect to") chainID = flag.String("chain-id", "mychain", "chain id") - listenAddr = flag.String("laddr", ":46659", "Validator listen address (0.0.0.0:0 means any interface, any port") - maxConn = flag.Int("clients", 3, "maximum of concurrent connections") privValPath = flag.String("priv", "", "priv val file path") - logger = log.NewTMLogger(log.NewSyncWriter(os.Stdout)).With("module", "priv_val") + logger = log.NewTMLogger( + log.NewSyncWriter(os.Stdout), + ).With("module", "priv_val") ) flag.Parse() logger.Info( "Starting private validator", + "addr", *addr, "chainID", *chainID, - "listenAddr", *listenAddr, - "maxConn", *maxConn, "privPath", *privValPath, ) privVal := priv_val.LoadPrivValidatorJSON(*privValPath) - pvss := priv_val.NewPrivValidatorSocketServer( + rs := priv_val.NewRemoteSigner( logger, *chainID, - *listenAddr, - *maxConn, + *addr, privVal, nil, ) - pvss.Start() + err := rs.Start() + if err != nil { + panic(err) + } cmn.TrapSignal(func() { - pvss.Stop() + err := rs.Stop() + if err != nil { + panic(err) + } }) } diff --git a/types/priv_validator/socket.go b/types/priv_validator/socket.go index 0d40a1eb..05bc7771 100644 --- a/types/priv_validator/socket.go +++ b/types/priv_validator/socket.go @@ -19,12 +19,16 @@ import ( const ( defaultConnDeadlineSeconds = 3 - defaultDialRetryMax = 10 + defaultConnWaitSeconds = 60 + defaultDialRetries = 10 + defaultSignersMax = 1 ) // Socket errors. var ( - ErrDialRetryMax = errors.New("Error max client retries") + ErrDialRetryMax = errors.New("Error max client retries") + ErrConnWaitTimeout = errors.New("Error waiting for external connection") + ErrConnTimeout = errors.New("Error connection timed out") ) var ( @@ -34,10 +38,16 @@ var ( // SocketClientOption sets an optional parameter on the SocketClient. type SocketClientOption func(*SocketClient) -// SocketClientTimeout sets the timeout for connecting to the external socket -// address. -func SocketClientTimeout(timeout time.Duration) SocketClientOption { - return func(sc *SocketClient) { sc.connectTimeout = timeout } +// SocketClientConnDeadline sets the read and write deadline for connections +// from external signing processes. +func SocketClientConnDeadline(deadline time.Duration) SocketClientOption { + return func(sc *SocketClient) { sc.connDeadline = deadline } +} + +// SocketClientConnWait sets the timeout duration before connection of external +// signing processes are considered to be unsuccessful. +func SocketClientConnWait(timeout time.Duration) SocketClientOption { + return func(sc *SocketClient) { sc.connWaitTimeout = timeout } } // SocketClient implements PrivValidator, it uses a socket to request signatures @@ -45,11 +55,13 @@ func SocketClientTimeout(timeout time.Duration) SocketClientOption { type SocketClient struct { cmn.BaseService - conn net.Conn - privKey *crypto.PrivKeyEd25519 + addr string + connDeadline time.Duration + connWaitTimeout time.Duration + privKey *crypto.PrivKeyEd25519 - addr string - connectTimeout time.Duration + conn net.Conn + listener net.Listener } // Check that SocketClient implements PrivValidator2. @@ -62,24 +74,37 @@ func NewSocketClient( privKey *crypto.PrivKeyEd25519, ) *SocketClient { sc := &SocketClient{ - addr: socketAddr, - connectTimeout: time.Second * defaultConnDeadlineSeconds, - privKey: privKey, + addr: socketAddr, + connDeadline: time.Second * defaultConnDeadlineSeconds, + connWaitTimeout: time.Second * defaultConnWaitSeconds, + privKey: privKey, } - sc.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketClient", sc) + sc.BaseService = *cmn.NewBaseService(logger, "SocketClient", sc) return sc } // OnStart implements cmn.Service. func (sc *SocketClient) OnStart() error { - if err := sc.BaseService.OnStart(); err != nil { - return err + if sc.listener == nil { + if err := sc.listen(); err != nil { + sc.Logger.Error( + "OnStart", + "err", errors.Wrap(err, "failed to listen"), + ) + + return err + } } - conn, err := sc.connect() + conn, err := sc.waitConnection() if err != nil { + sc.Logger.Error( + "OnStart", + "err", errors.Wrap(err, "failed to accept connection"), + ) + return err } @@ -93,7 +118,21 @@ func (sc *SocketClient) OnStop() { sc.BaseService.OnStop() if sc.conn != nil { - sc.conn.Close() + if err := sc.conn.Close(); err != nil { + sc.Logger.Error( + "OnStop", + "err", errors.Wrap(err, "failed to close connection"), + ) + } + } + + if sc.listener != nil { + if err := sc.listener.Close(); err != nil { + sc.Logger.Error( + "OnStop", + "err", errors.Wrap(err, "failed to close listener"), + ) + } } } @@ -162,7 +201,10 @@ func (sc *SocketClient) SignVote(chainID string, vote *types.Vote) error { } // SignProposal implements PrivValidator2. -func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) error { +func (sc *SocketClient) SignProposal( + chainID string, + proposal *types.Proposal, +) error { err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal}) if err != nil { return err @@ -179,7 +221,10 @@ func (sc *SocketClient) SignProposal(chainID string, proposal *types.Proposal) e } // SignHeartbeat implements PrivValidator2. -func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat) error { +func (sc *SocketClient) SignHeartbeat( + chainID string, + heartbeat *types.Heartbeat, +) error { err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat}) if err != nil { return err @@ -195,22 +240,164 @@ func (sc *SocketClient) SignHeartbeat(chainID string, heartbeat *types.Heartbeat return nil } -func (sc *SocketClient) connect() (net.Conn, error) { - retries := defaultDialRetryMax +func (sc *SocketClient) acceptConnection() (net.Conn, error) { + conn, err := sc.listener.Accept() + if err != nil { + if !sc.IsRunning() { + return nil, nil // Ignore error from listener closing. + } + return nil, err + + } + + if err := conn.SetDeadline(time.Now().Add(sc.connDeadline)); err != nil { + return nil, err + } + + if sc.privKey != nil { + conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap()) + if err != nil { + return nil, err + } + } + + return conn, nil +} + +func (sc *SocketClient) listen() error { + ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr)) + if err != nil { + return err + } + + sc.listener = netutil.LimitListener(ln, defaultSignersMax) + + return nil +} + +// waitConnection uses the configured wait timeout to error if no external +// process connects in the time period. +func (sc *SocketClient) waitConnection() (net.Conn, error) { + var ( + connc = make(chan net.Conn, 1) + errc = make(chan error, 1) + ) + + go func(connc chan<- net.Conn, errc chan<- error) { + conn, err := sc.acceptConnection() + if err != nil { + errc <- err + return + } + + connc <- conn + }(connc, errc) + + select { + case conn := <-connc: + return conn, nil + case err := <-errc: + return nil, err + case <-time.After(sc.connWaitTimeout): + return nil, ErrConnWaitTimeout + } +} + +//--------------------------------------------------------- + +// RemoteSignerOption sets an optional parameter on the RemoteSigner. +type RemoteSignerOption func(*RemoteSigner) + +// RemoteSignerConnDeadline sets the read and write deadline for connections +// from external signing processes. +func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption { + return func(ss *RemoteSigner) { ss.connDeadline = deadline } +} + +// RemoteSignerConnRetries sets the amount of attempted retries to connect. +func RemoteSignerConnRetries(retries int) RemoteSignerOption { + return func(ss *RemoteSigner) { ss.connRetries = retries } +} + +// RemoteSigner implements PrivValidator. +// It responds to requests over a socket +type RemoteSigner struct { + cmn.BaseService + + addr string + chainID string + connDeadline time.Duration + connRetries int + privKey *crypto.PrivKeyEd25519 + privVal PrivValidator + + conn net.Conn +} + +// NewRemoteSigner returns an instance of +// RemoteSigner. +func NewRemoteSigner( + logger log.Logger, + chainID, socketAddr string, + privVal PrivValidator, + privKey *crypto.PrivKeyEd25519, +) *RemoteSigner { + rs := &RemoteSigner{ + addr: socketAddr, + chainID: chainID, + connDeadline: time.Second * defaultConnDeadlineSeconds, + connRetries: defaultDialRetries, + privKey: privKey, + privVal: privVal, + } + + rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs) + + return rs +} + +// OnStart implements cmn.Service. +func (rs *RemoteSigner) OnStart() error { + conn, err := rs.connect() + if err != nil { + rs.Logger.Error("OnStart", "err", errors.Wrap(err, "connect")) + + return err + } + + go rs.handleConnection(conn) + + return nil +} + +// OnStop implements cmn.Service. +func (rs *RemoteSigner) OnStop() { + if rs.conn == nil { + return + } + + if err := rs.conn.Close(); err != nil { + rs.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed")) + } +} + +func (rs *RemoteSigner) connect() (net.Conn, error) { + retries := defaultDialRetries RETRY_LOOP: for retries > 0 { - if retries != defaultDialRetryMax { - time.Sleep(sc.connectTimeout) + // Don't sleep if it is the first retry. + if retries != defaultDialRetries { + time.Sleep(rs.connDeadline) } retries-- - conn, err := cmn.Connect(sc.addr) + conn, err := cmn.Connect(rs.addr) if err != nil { - sc.Logger.Error( - "sc connect", - "addr", sc.addr, + rs.Logger.Error( + "connect", + "addr", rs.addr, "err", errors.Wrap(err, "connection failed"), ) @@ -218,17 +405,17 @@ RETRY_LOOP: } if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil { - sc.Logger.Error( - "sc connect", + rs.Logger.Error( + "connect", "err", errors.Wrap(err, "setting connection timeout failed"), ) continue } - if sc.privKey != nil { - conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey.Wrap()) + if rs.privKey != nil { + conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey.Wrap()) if err != nil { - sc.Logger.Error( + rs.Logger.Error( "sc connect", "err", errors.Wrap(err, "encrypting connection failed"), ) @@ -243,118 +430,16 @@ RETRY_LOOP: return nil, ErrDialRetryMax } -//--------------------------------------------------------- - -// PrivValidatorSocketServer implements PrivValidator. -// It responds to requests over a socket -type PrivValidatorSocketServer struct { - cmn.BaseService - - proto, addr string - listener net.Listener - maxConnections int - privKey *crypto.PrivKeyEd25519 - - privVal PrivValidator - chainID string -} - -// NewPrivValidatorSocketServer returns an instance of -// PrivValidatorSocketServer. -func NewPrivValidatorSocketServer( - logger log.Logger, - chainID, socketAddr string, - maxConnections int, - privVal PrivValidator, - privKey *crypto.PrivKeyEd25519, -) *PrivValidatorSocketServer { - proto, addr := cmn.ProtocolAndAddress(socketAddr) - pvss := &PrivValidatorSocketServer{ - proto: proto, - addr: addr, - maxConnections: maxConnections, - privKey: privKey, - privVal: privVal, - chainID: chainID, - } - pvss.BaseService = *cmn.NewBaseService(logger, "privValidatorSocketServer", pvss) - return pvss -} - -// OnStart implements cmn.Service. -func (pvss *PrivValidatorSocketServer) OnStart() error { - ln, err := net.Listen(pvss.proto, pvss.addr) - if err != nil { - return err - } - - pvss.listener = netutil.LimitListener(ln, pvss.maxConnections) - - go pvss.acceptConnections() - - return nil -} - -// OnStop implements cmn.Service. -func (pvss *PrivValidatorSocketServer) OnStop() { - if pvss.listener == nil { - return - } - - if err := pvss.listener.Close(); err != nil { - pvss.Logger.Error("OnStop", "err", errors.Wrap(err, "closing listener failed")) - } -} - -func (pvss *PrivValidatorSocketServer) acceptConnections() { +func (rs *RemoteSigner) handleConnection(conn net.Conn) { for { - conn, err := pvss.listener.Accept() - if err != nil { - if !pvss.IsRunning() { - return // Ignore error from listener closing. - } - pvss.Logger.Error( - "acceptConnections", - "err", errors.Wrap(err, "failed to accept connection"), - ) - continue - } - - if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil { - pvss.Logger.Error( - "acceptConnetions", - "err", errors.Wrap(err, "setting connection timeout failed"), - ) - continue - } - - if pvss.privKey != nil { - conn, err = p2pconn.MakeSecretConnection(conn, pvss.privKey.Wrap()) - if err != nil { - pvss.Logger.Error( - "acceptConnections", - "err", errors.Wrap(err, "secret connection failed"), - ) - continue - } - } - - go pvss.handleConnection(conn) - } -} - -func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) { - defer conn.Close() - - for { - if !pvss.IsRunning() { + if !rs.IsRunning() { return // Ignore error from listener closing. } req, err := readMsg(conn) if err != nil { if err != io.EOF { - pvss.Logger.Error("handleConnection", "err", err) + rs.Logger.Error("handleConnection", "err", err) } return } @@ -365,29 +450,29 @@ func (pvss *PrivValidatorSocketServer) handleConnection(conn net.Conn) { case *PubKeyMsg: var p crypto.PubKey - p, err = pvss.privVal.PubKey() + p, err = rs.privVal.PubKey() res = &PubKeyMsg{p} case *SignVoteMsg: - err = pvss.privVal.SignVote(pvss.chainID, r.Vote) + err = rs.privVal.SignVote(rs.chainID, r.Vote) res = &SignVoteMsg{r.Vote} case *SignProposalMsg: - err = pvss.privVal.SignProposal(pvss.chainID, r.Proposal) + err = rs.privVal.SignProposal(rs.chainID, r.Proposal) res = &SignProposalMsg{r.Proposal} case *SignHeartbeatMsg: - err = pvss.privVal.SignHeartbeat(pvss.chainID, r.Heartbeat) + err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat) res = &SignHeartbeatMsg{r.Heartbeat} default: err = fmt.Errorf("unknown msg: %v", r) } if err != nil { - pvss.Logger.Error("handleConnection", "err", err) + rs.Logger.Error("handleConnection", "err", err) return } err = writeMsg(conn, res) if err != nil { - pvss.Logger.Error("handleConnection", "err", err) + rs.Logger.Error("handleConnection", "err", err) return } } @@ -442,6 +527,10 @@ func readMsg(r io.Reader) (PrivValidatorSocketMsg, error) { read := wire.ReadBinary(struct{ PrivValidatorSocketMsg }{}, r, 0, &n, &err) if err != nil { + if opErr, ok := err.(*net.OpError); ok { + return nil, errors.Wrapf(ErrConnTimeout, opErr.Addr.String()) + } + return nil, err } @@ -461,6 +550,9 @@ func writeMsg(w io.Writer, msg interface{}) error { // TODO(xla): This extra wrap should be gone with the sdk-2 update. wire.WriteBinary(struct{ PrivValidatorSocketMsg }{msg}, w, &n, &err) + if opErr, ok := err.(*net.OpError); ok { + return errors.Wrapf(ErrConnTimeout, opErr.Addr.String()) + } return err } diff --git a/types/priv_validator/socket_test.go b/types/priv_validator/socket_test.go index d3d81580..36f09f40 100644 --- a/types/priv_validator/socket_test.go +++ b/types/priv_validator/socket_test.go @@ -4,10 +4,12 @@ import ( "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" crypto "github.com/tendermint/go-crypto" + cmn "github.com/tendermint/tmlibs/common" "github.com/tendermint/tmlibs/log" "github.com/tendermint/tendermint/types" @@ -16,13 +18,13 @@ import ( func TestSocketClientAddress(t *testing.T) { var ( assert, require = assert.New(t), require.New(t) - chainID = "test-chain-secret" - sc, pvss = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ) defer sc.Stop() - defer pvss.Stop() + defer rs.Stop() - serverAddr, err := pvss.privVal.Address() + serverAddr, err := rs.privVal.Address() require.NoError(err) clientAddr, err := sc.Address() @@ -38,16 +40,16 @@ func TestSocketClientAddress(t *testing.T) { func TestSocketClientPubKey(t *testing.T) { var ( assert, require = assert.New(t), require.New(t) - chainID = "test-chain-secret" - sc, pvss = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ) defer sc.Stop() - defer pvss.Stop() + defer rs.Stop() clientKey, err := sc.PubKey() require.NoError(err) - privKey, err := pvss.privVal.PubKey() + privKey, err := rs.privVal.PubKey() require.NoError(err) assert.Equal(privKey, clientKey) @@ -59,17 +61,17 @@ func TestSocketClientPubKey(t *testing.T) { func TestSocketClientProposal(t *testing.T) { var ( assert, require = assert.New(t), require.New(t) - chainID = "test-chain-secret" - sc, pvss = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ts = time.Now() privProposal = &types.Proposal{Timestamp: ts} clientProposal = &types.Proposal{Timestamp: ts} ) defer sc.Stop() - defer pvss.Stop() + defer rs.Stop() - require.NoError(pvss.privVal.SignProposal(chainID, privProposal)) + require.NoError(rs.privVal.SignProposal(chainID, privProposal)) require.NoError(sc.SignProposal(chainID, clientProposal)) assert.Equal(privProposal.Signature, clientProposal.Signature) } @@ -77,8 +79,8 @@ func TestSocketClientProposal(t *testing.T) { func TestSocketClientVote(t *testing.T) { var ( assert, require = assert.New(t), require.New(t) - chainID = "test-chain-secret" - sc, pvss = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) ts = time.Now() vType = types.VoteTypePrecommit @@ -86,9 +88,9 @@ func TestSocketClientVote(t *testing.T) { have = &types.Vote{Timestamp: ts, Type: vType} ) defer sc.Stop() - defer pvss.Stop() + defer rs.Stop() - require.NoError(pvss.privVal.SignVote(chainID, want)) + require.NoError(rs.privVal.SignVote(chainID, want)) require.NoError(sc.SignVote(chainID, have)) assert.Equal(want.Signature, have.Signature) } @@ -96,69 +98,129 @@ func TestSocketClientVote(t *testing.T) { func TestSocketClientHeartbeat(t *testing.T) { var ( assert, require = assert.New(t), require.New(t) - chainID = "test-chain-secret" - sc, pvss = testSetupSocketPair(t, chainID) + chainID = cmn.RandStr(12) + sc, rs = testSetupSocketPair(t, chainID) want = &types.Heartbeat{} have = &types.Heartbeat{} ) defer sc.Stop() - defer pvss.Stop() + defer rs.Stop() - require.NoError(pvss.privVal.SignHeartbeat(chainID, want)) + require.NoError(rs.privVal.SignHeartbeat(chainID, want)) require.NoError(sc.SignHeartbeat(chainID, have)) assert.Equal(want.Signature, have.Signature) } -func TestSocketClientConnectRetryMax(t *testing.T) { +func TestSocketClientDeadline(t *testing.T) { var ( - assert, _ = assert.New(t), require.New(t) - logger = log.TestingLogger() - clientPrivKey = crypto.GenPrivKeyEd25519() - sc = NewSocketClient( - logger, + assert, require = assert.New(t), require.New(t) + readyc = make(chan struct{}) + sc = NewSocketClient( + log.TestingLogger(), "127.0.0.1:0", - &clientPrivKey, + nil, ) ) defer sc.Stop() - SocketClientTimeout(time.Millisecond)(sc) + SocketClientConnDeadline(time.Millisecond)(sc) - assert.EqualError(sc.Start(), ErrDialRetryMax.Error()) + require.NoError(sc.listen()) + + go func(sc *SocketClient) { + require.NoError(sc.Start()) + assert.True(sc.IsRunning()) + + readyc <- struct{}{} + }(sc) + + _, err := cmn.Connect(sc.listener.Addr().String()) + require.NoError(err) + + <-readyc + + _, err = sc.PubKey() + assert.Equal(errors.Cause(err), ErrConnTimeout) } -func testSetupSocketPair(t *testing.T, chainID string) (*SocketClient, *PrivValidatorSocketServer) { +func TestSocketClientWait(t *testing.T) { + var ( + assert, _ = assert.New(t), require.New(t) + logger = log.TestingLogger() + privKey = crypto.GenPrivKeyEd25519() + sc = NewSocketClient( + logger, + "127.0.0.1:0", + &privKey, + ) + ) + defer sc.Stop() + + SocketClientConnWait(time.Millisecond)(sc) + + assert.EqualError(sc.Start(), ErrConnWaitTimeout.Error()) +} + +func TestRemoteSignerRetry(t *testing.T) { + var ( + assert, _ = assert.New(t), require.New(t) + privKey = crypto.GenPrivKeyEd25519() + rs = NewRemoteSigner( + log.TestingLogger(), + cmn.RandStr(12), + "127.0.0.1:0", + NewTestPrivValidator(types.GenSigner()), + &privKey, + ) + ) + defer rs.Stop() + + RemoteSignerConnDeadline(time.Millisecond)(rs) + RemoteSignerConnRetries(2)(rs) + + assert.EqualError(rs.Start(), ErrDialRetryMax.Error()) +} + +func testSetupSocketPair( + t *testing.T, + chainID string, +) (*SocketClient, *RemoteSigner) { var ( assert, require = assert.New(t), require.New(t) logger = log.TestingLogger() signer = types.GenSigner() clientPrivKey = crypto.GenPrivKeyEd25519() - serverPrivKey = crypto.GenPrivKeyEd25519() + remotePrivKey = crypto.GenPrivKeyEd25519() privVal = NewTestPrivValidator(signer) - pvss = NewPrivValidatorSocketServer( + readyc = make(chan struct{}) + sc = NewSocketClient( logger, - chainID, "127.0.0.1:0", - 1, - privVal, - &serverPrivKey, + &clientPrivKey, ) ) - err := pvss.Start() - require.NoError(err) - assert.True(pvss.IsRunning()) + require.NoError(sc.listen()) - sc := NewSocketClient( + go func(sc *SocketClient) { + require.NoError(sc.Start()) + assert.True(sc.IsRunning()) + + readyc <- struct{}{} + }(sc) + + rs := NewRemoteSigner( logger, - pvss.listener.Addr().String(), - &clientPrivKey, + chainID, + sc.listener.Addr().String(), + privVal, + &remotePrivKey, ) + require.NoError(rs.Start()) + assert.True(rs.IsRunning()) - err = sc.Start() - require.NoError(err) - assert.True(sc.IsRunning()) + <-readyc - return sc, pvss + return sc, rs }