diff --git a/mempool/mempool.go b/mempool/mempool.go index 65cd5535..5e52989a 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -25,12 +25,12 @@ import ( // PreCheckFunc is an optional filter executed before CheckTx and rejects // transaction if false is returned. An example would be to ensure that a // transaction doesn't exceeded the block size. -type PreCheckFunc func(types.Tx) bool +type PreCheckFunc func(types.Tx) error // PostCheckFunc is an optional filter executed after CheckTx and rejects // transaction if false is returned. An example would be to ensure a // transaction doesn't require more gas than available for the block. -type PostCheckFunc func(types.Tx, *abci.ResponseCheckTx) bool +type PostCheckFunc func(types.Tx, *abci.ResponseCheckTx) error /* @@ -68,24 +68,48 @@ var ( ErrMempoolIsFull = errors.New("Mempool is full") ) +// ErrPreCheck is returned when tx is too big +type ErrPreCheck struct { + Reason error +} + +func (e ErrPreCheck) Error() string { + return e.Reason.Error() +} + +// IsPreCheckError returns true if err is due to pre check failure. +func IsPreCheckError(err error) bool { + _, ok := err.(ErrPreCheck) + return ok +} + // PreCheckAminoMaxBytes checks that the size of the transaction plus the amino // overhead is smaller or equal to the expected maxBytes. func PreCheckAminoMaxBytes(maxBytes int64) PreCheckFunc { - return func(tx types.Tx) bool { + return func(tx types.Tx) error { // We have to account for the amino overhead in the tx size as well aminoOverhead := amino.UvarintSize(uint64(len(tx))) - return int64(len(tx)+aminoOverhead) <= maxBytes + txSize := int64(len(tx) + aminoOverhead) + if txSize > maxBytes { + return fmt.Errorf("Tx size (including amino overhead) is too big: %d, max: %d", + txSize, maxBytes) + } + return nil } } // PostCheckMaxGas checks that the wanted gas is smaller or equal to the passed -// maxGas. Returns true if maxGas is -1. +// maxGas. Returns nil if maxGas is -1. func PostCheckMaxGas(maxGas int64) PostCheckFunc { - return func(tx types.Tx, res *abci.ResponseCheckTx) bool { + return func(tx types.Tx, res *abci.ResponseCheckTx) error { if maxGas == -1 { - return true + return nil } - return res.GasWanted <= maxGas + if res.GasWanted > maxGas { + return fmt.Errorf("gas wanted %d is greater than max gas %d", + res.GasWanted, maxGas) + } + return nil } } @@ -285,8 +309,10 @@ func (mem *Mempool) CheckTx(tx types.Tx, cb func(*abci.Response)) (err error) { return ErrMempoolIsFull } - if mem.preCheck != nil && !mem.preCheck(tx) { - return + if mem.preCheck != nil { + if err := mem.preCheck(tx); err != nil { + return ErrPreCheck{err} + } } // CACHE @@ -346,7 +372,13 @@ func (mem *Mempool) resCbNormal(req *abci.Request, res *abci.Response) { tx: tx, } mem.txs.PushBack(memTx) - mem.logger.Info("Added good transaction", "tx", TxID(tx), "res", r, "total", mem.Size()) + mem.logger.Info("Added good transaction", + "tx", TxID(tx), + "res", r, + "height", memTx.height, + "total", mem.Size(), + "counter", memTx.counter, + ) mem.metrics.TxSizeBytes.Observe(float64(len(tx))) mem.notifyTxsAvailable() } else { @@ -566,7 +598,13 @@ func (mem *Mempool) recheckTxs(goodTxs []types.Tx) { } func (mem *Mempool) isPostCheckPass(tx types.Tx, r *abci.ResponseCheckTx) bool { - return mem.postCheck == nil || mem.postCheck(tx, r) + if mem.postCheck == nil { + return true + } + if err := mem.postCheck(tx, r); err != nil { + return false + } + return true } //-------------------------------------------------------------------------------- diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 5aabd00e..44917afb 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - amino "github.com/tendermint/go-amino" "github.com/tendermint/tendermint/abci/example/counter" "github.com/tendermint/tendermint/abci/example/kvstore" abci "github.com/tendermint/tendermint/abci/types" @@ -66,7 +65,10 @@ func checkTxs(t *testing.T, mempool *Mempool, count int) types.Txs { t.Error(err) } if err := mempool.CheckTx(txBytes, nil); err != nil { - t.Fatalf("Error after CheckTx: %v", err) + if IsPreCheckError(err) { + continue + } + t.Fatalf("CheckTx failed: %v while checking #%d tx", err, i) } } return txs @@ -126,47 +128,29 @@ func TestMempoolFilters(t *testing.T) { mempool := newMempoolWithApp(cc) emptyTxArr := []types.Tx{[]byte{}} - nopPreFilter := func(tx types.Tx) bool { return true } - nopPostFilter := func(tx types.Tx, res *abci.ResponseCheckTx) bool { return true } - - // This is the same filter we expect to be used within node/node.go and state/execution.go - nBytePreFilter := func(n int) func(tx types.Tx) bool { - return func(tx types.Tx) bool { - // We have to account for the amino overhead in the tx size as well - aminoOverhead := amino.UvarintSize(uint64(len(tx))) - return (len(tx) + aminoOverhead) <= n - } - } - - nGasPostFilter := func(n int64) func(tx types.Tx, res *abci.ResponseCheckTx) bool { - return func(tx types.Tx, res *abci.ResponseCheckTx) bool { - if n == -1 { - return true - } - return res.GasWanted <= n - } - } + nopPreFilter := func(tx types.Tx) error { return nil } + nopPostFilter := func(tx types.Tx, res *abci.ResponseCheckTx) error { return nil } // each table driven test creates numTxsToCreate txs with checkTx, and at the end clears all remaining txs. // each tx has 20 bytes + amino overhead = 21 bytes, 1 gas tests := []struct { numTxsToCreate int - preFilter func(tx types.Tx) bool - postFilter func(tx types.Tx, res *abci.ResponseCheckTx) bool + preFilter PreCheckFunc + postFilter PostCheckFunc expectedNumTxs int }{ {10, nopPreFilter, nopPostFilter, 10}, - {10, nBytePreFilter(10), nopPostFilter, 0}, - {10, nBytePreFilter(20), nopPostFilter, 0}, - {10, nBytePreFilter(21), nopPostFilter, 10}, - {10, nopPreFilter, nGasPostFilter(-1), 10}, - {10, nopPreFilter, nGasPostFilter(0), 0}, - {10, nopPreFilter, nGasPostFilter(1), 10}, - {10, nopPreFilter, nGasPostFilter(3000), 10}, - {10, nBytePreFilter(10), nGasPostFilter(20), 0}, - {10, nBytePreFilter(30), nGasPostFilter(20), 10}, - {10, nBytePreFilter(21), nGasPostFilter(1), 10}, - {10, nBytePreFilter(21), nGasPostFilter(0), 0}, + {10, PreCheckAminoMaxBytes(10), nopPostFilter, 0}, + {10, PreCheckAminoMaxBytes(20), nopPostFilter, 0}, + {10, PreCheckAminoMaxBytes(21), nopPostFilter, 10}, + {10, nopPreFilter, PostCheckMaxGas(-1), 10}, + {10, nopPreFilter, PostCheckMaxGas(0), 0}, + {10, nopPreFilter, PostCheckMaxGas(1), 10}, + {10, nopPreFilter, PostCheckMaxGas(3000), 10}, + {10, PreCheckAminoMaxBytes(10), PostCheckMaxGas(20), 0}, + {10, PreCheckAminoMaxBytes(30), PostCheckMaxGas(20), 10}, + {10, PreCheckAminoMaxBytes(21), PostCheckMaxGas(1), 10}, + {10, PreCheckAminoMaxBytes(21), PostCheckMaxGas(0), 0}, } for tcIndex, tt := range tests { mempool.Update(1, emptyTxArr, tt.preFilter, tt.postFilter)