From 8481c49c824e2d71f9c2d00ff5a8d1ee7ad045d0 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Thu, 9 Nov 2017 17:42:32 -0500 Subject: [PATCH] CacheDB (#67) * Add CacheDB & SimpleMap * Generic memBatch; Fix cLevelDB tests * CacheWrap() for CacheDB and MemDB * Change Iterator to match LeviGo Iterator * Fixes from review * cacheWrapWriteMutex and some race fixes * Use tmlibs/common * NewCWWMutex is exposed. DB can be CacheWrap'd * Remove GetOK, not needed * Fsdb (#72) * Add FSDB * Review fixes from Anton * Review changes * Fixes from review --- .gitignore | 2 +- Makefile | 19 ++-- db/backend_test.go | 43 +++++++ db/c_level_db.go | 103 +++++++++++++---- db/c_level_db_test.go | 8 +- db/cache_db.go | 230 +++++++++++++++++++++++++++++++++++++ db/cache_db_test.go | 83 ++++++++++++++ db/common_test.go | 172 ++++++++++++++++++++++++++++ db/db.go | 59 +++++++++- db/fsdb.go | 231 ++++++++++++++++++++++++++++++++++++++ db/go_level_db.go | 127 ++++++++++++++------- db/go_level_db_test.go | 8 +- db/mem_batch.go | 50 +++++++++ db/mem_db.go | 182 +++++++++++++++++------------- db/mem_db_test.go | 2 +- db/stats.go | 7 ++ db/util.go | 82 ++++++++++++++ db/util_test.go | 209 ++++++++++++++++++++++++++++++++++ merkle/kvpairs.go | 48 ++++++++ merkle/simple_map.go | 26 +++++ merkle/simple_map_test.go | 47 ++++++++ merkle/simple_proof.go | 131 +++++++++++++++++++++ merkle/simple_tree.go | 184 ------------------------------ 23 files changed, 1699 insertions(+), 354 deletions(-) create mode 100644 db/backend_test.go create mode 100644 db/cache_db.go create mode 100644 db/cache_db_test.go create mode 100644 db/common_test.go create mode 100644 db/fsdb.go create mode 100644 db/mem_batch.go create mode 100644 db/stats.go create mode 100644 db/util.go create mode 100644 db/util_test.go create mode 100644 merkle/kvpairs.go create mode 100644 merkle/simple_map.go create mode 100644 merkle/simple_map_test.go create mode 100644 merkle/simple_proof.go diff --git a/.gitignore b/.gitignore index e0a06eaf..a2ebfde2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -*.swp +*.sw[opqr] vendor .glide diff --git a/Makefile b/Makefile index 25773ed3..a24306f3 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ all: test NOVENDOR = go list github.com/tendermint/tmlibs/... | grep -v /vendor/ test: - go test `glide novendor` + go test -tags gcc `glide novendor` get_vendor_deps: ensure_tools @rm -rf vendor/ @@ -32,20 +32,19 @@ metalinter_test: ensure_tools --enable=gas \ --enable=goconst \ --enable=gosimple \ - --enable=ineffassign \ - --enable=interfacer \ + --enable=ineffassign \ + --enable=interfacer \ --enable=megacheck \ - --enable=misspell \ - --enable=staticcheck \ + --enable=misspell \ + --enable=staticcheck \ --enable=safesql \ - --enable=structcheck \ - --enable=unconvert \ + --enable=structcheck \ + --enable=unconvert \ --enable=unused \ - --enable=varcheck \ + --enable=varcheck \ --enable=vetshadow \ --enable=vet \ ./... - #--enable=aligncheck \ #--enable=dupl \ #--enable=errcheck \ @@ -53,4 +52,4 @@ metalinter_test: ensure_tools #--enable=goimports \ #--enable=golint \ <== comments on anything exported #--enable=gotype \ - #--enable=unparam \ + #--enable=unparam \ diff --git a/db/backend_test.go b/db/backend_test.go new file mode 100644 index 00000000..b4ffecdc --- /dev/null +++ b/db/backend_test.go @@ -0,0 +1,43 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tmlibs/common" +) + +func testBackend(t *testing.T, backend string) { + // Default + dir, dirname := cmn.Tempdir(fmt.Sprintf("test_backend_%s_", backend)) + defer dir.Close() + db := NewDB("testdb", backend, dirname) + require.Nil(t, db.Get([]byte(""))) + require.Nil(t, db.Get(nil)) + + // Set empty ("") + db.Set([]byte(""), []byte("")) + require.NotNil(t, db.Get([]byte(""))) + require.NotNil(t, db.Get(nil)) + require.Empty(t, db.Get([]byte(""))) + require.Empty(t, db.Get(nil)) + + // Set empty (nil) + db.Set([]byte(""), nil) + require.NotNil(t, db.Get([]byte(""))) + require.NotNil(t, db.Get(nil)) + require.Empty(t, db.Get([]byte(""))) + require.Empty(t, db.Get(nil)) + + // Delete + db.Delete([]byte("")) + require.Nil(t, db.Get([]byte(""))) + require.Nil(t, db.Get(nil)) +} + +func TestBackends(t *testing.T) { + testBackend(t, CLevelDBBackendStr) + testBackend(t, GoLevelDBBackendStr) + testBackend(t, MemDBBackendStr) +} diff --git a/db/c_level_db.go b/db/c_level_db.go index b1ae49a1..95651c0a 100644 --- a/db/c_level_db.go +++ b/db/c_level_db.go @@ -7,8 +7,6 @@ import ( "path" "github.com/jmhodges/levigo" - - . "github.com/tendermint/tmlibs/common" ) func init() { @@ -24,6 +22,8 @@ type CLevelDB struct { ro *levigo.ReadOptions wo *levigo.WriteOptions woSync *levigo.WriteOptions + + cwwMutex } func NewCLevelDB(name string, dir string) (*CLevelDB, error) { @@ -45,6 +45,8 @@ func NewCLevelDB(name string, dir string) (*CLevelDB, error) { ro: ro, wo: wo, woSync: woSync, + + cwwMutex: NewCWWMutex(), } return database, nil } @@ -52,7 +54,7 @@ func NewCLevelDB(name string, dir string) (*CLevelDB, error) { func (db *CLevelDB) Get(key []byte) []byte { res, err := db.db.Get(db.ro, key) if err != nil { - PanicCrisis(err) + panic(err) } return res } @@ -60,28 +62,28 @@ func (db *CLevelDB) Get(key []byte) []byte { func (db *CLevelDB) Set(key []byte, value []byte) { err := db.db.Put(db.wo, key, value) if err != nil { - PanicCrisis(err) + panic(err) } } func (db *CLevelDB) SetSync(key []byte, value []byte) { err := db.db.Put(db.woSync, key, value) if err != nil { - PanicCrisis(err) + panic(err) } } func (db *CLevelDB) Delete(key []byte) { err := db.db.Delete(db.wo, key) if err != nil { - PanicCrisis(err) + panic(err) } } func (db *CLevelDB) DeleteSync(key []byte) { err := db.db.Delete(db.woSync, key) if err != nil { - PanicCrisis(err) + panic(err) } } @@ -97,11 +99,11 @@ func (db *CLevelDB) Close() { } func (db *CLevelDB) Print() { - iter := db.db.NewIterator(db.ro) - defer iter.Close() - for iter.Seek(nil); iter.Valid(); iter.Next() { - key := iter.Key() - value := iter.Value() + itr := db.Iterator() + defer itr.Close() + for itr.Seek(nil); itr.Valid(); itr.Next() { + key := itr.Key() + value := itr.Value() fmt.Printf("[%X]:\t[%X]\n", key, value) } } @@ -112,25 +114,24 @@ func (db *CLevelDB) Stats() map[string]string { stats := make(map[string]string) for _, key := range keys { - str, err := db.db.GetProperty(key) - if err == nil { - stats[key] = str - } + str := db.db.PropertyValue(key) + stats[key] = str } return stats } -func (db *CLevelDB) Iterator() Iterator { - return db.db.NewIterator(nil, nil) +func (db *CLevelDB) CacheWrap() interface{} { + return NewCacheDB(db, db.GetWriteLockVersion()) } +//---------------------------------------- +// Batch + func (db *CLevelDB) NewBatch() Batch { batch := levigo.NewWriteBatch() return &cLevelDBBatch{db, batch} } -//-------------------------------------------------------------------------------- - type cLevelDBBatch struct { db *CLevelDB batch *levigo.WriteBatch @@ -147,6 +148,66 @@ func (mBatch *cLevelDBBatch) Delete(key []byte) { func (mBatch *cLevelDBBatch) Write() { err := mBatch.db.db.Write(mBatch.db.wo, mBatch.batch) if err != nil { - PanicCrisis(err) + panic(err) } } + +//---------------------------------------- +// Iterator + +func (db *CLevelDB) Iterator() Iterator { + itr := db.db.NewIterator(db.ro) + itr.Seek([]byte{0x00}) + return cLevelDBIterator{itr} +} + +type cLevelDBIterator struct { + itr *levigo.Iterator +} + +func (c cLevelDBIterator) Seek(key []byte) { + if key == nil { + key = []byte{0x00} + } + c.itr.Seek(key) +} + +func (c cLevelDBIterator) Valid() bool { + return c.itr.Valid() +} + +func (c cLevelDBIterator) Key() []byte { + if !c.itr.Valid() { + panic("cLevelDBIterator Key() called when invalid") + } + return c.itr.Key() +} + +func (c cLevelDBIterator) Value() []byte { + if !c.itr.Valid() { + panic("cLevelDBIterator Value() called when invalid") + } + return c.itr.Value() +} + +func (c cLevelDBIterator) Next() { + if !c.itr.Valid() { + panic("cLevelDBIterator Next() called when invalid") + } + c.itr.Next() +} + +func (c cLevelDBIterator) Prev() { + if !c.itr.Valid() { + panic("cLevelDBIterator Prev() called when invalid") + } + c.itr.Prev() +} + +func (c cLevelDBIterator) Close() { + c.itr.Close() +} + +func (c cLevelDBIterator) GetError() error { + return c.itr.GetError() +} diff --git a/db/c_level_db_test.go b/db/c_level_db_test.go index e7336cc5..86436233 100644 --- a/db/c_level_db_test.go +++ b/db/c_level_db_test.go @@ -7,7 +7,7 @@ import ( "fmt" "testing" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) func BenchmarkRandomReadsWrites2(b *testing.B) { @@ -18,7 +18,7 @@ func BenchmarkRandomReadsWrites2(b *testing.B) { for i := 0; i < int(numItems); i++ { internal[int64(i)] = int64(0) } - db, err := NewCLevelDB(Fmt("test_%x", RandStr(12)), "") + db, err := NewCLevelDB(cmn.Fmt("test_%x", cmn.RandStr(12)), "") if err != nil { b.Fatal(err.Error()) return @@ -30,7 +30,7 @@ func BenchmarkRandomReadsWrites2(b *testing.B) { for i := 0; i < b.N; i++ { // Write something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) internal[idx] += 1 val := internal[idx] idxBytes := int642Bytes(int64(idx)) @@ -43,7 +43,7 @@ func BenchmarkRandomReadsWrites2(b *testing.B) { } // Read something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) val := internal[idx] idxBytes := int642Bytes(int64(idx)) valBytes := db.Get(idxBytes) diff --git a/db/cache_db.go b/db/cache_db.go new file mode 100644 index 00000000..a41680c1 --- /dev/null +++ b/db/cache_db.go @@ -0,0 +1,230 @@ +package db + +import ( + "fmt" + "sort" + "sync" + "sync/atomic" +) + +// If value is nil but deleted is false, +// it means the parent doesn't have the key. +// (No need to delete upon Write()) +type cDBValue struct { + value []byte + deleted bool + dirty bool +} + +// CacheDB wraps an in-memory cache around an underlying DB. +type CacheDB struct { + mtx sync.Mutex + cache map[string]cDBValue + parent DB + lockVersion interface{} + + cwwMutex +} + +// Needed by MultiStore.CacheWrap(). +var _ atomicSetDeleter = (*CacheDB)(nil) + +// Users should typically not be required to call NewCacheDB directly, as the +// DB implementations here provide a .CacheWrap() function already. +// `lockVersion` is typically provided by parent.GetWriteLockVersion(). +func NewCacheDB(parent DB, lockVersion interface{}) *CacheDB { + db := &CacheDB{ + cache: make(map[string]cDBValue), + parent: parent, + lockVersion: lockVersion, + cwwMutex: NewCWWMutex(), + } + return db +} + +func (db *CacheDB) Get(key []byte) []byte { + db.mtx.Lock() + defer db.mtx.Unlock() + + dbValue, ok := db.cache[string(key)] + if !ok { + data := db.parent.Get(key) + dbValue = cDBValue{value: data, deleted: false, dirty: false} + db.cache[string(key)] = dbValue + } + return dbValue.value +} + +func (db *CacheDB) Set(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +func (db *CacheDB) SetSync(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +func (db *CacheDB) SetNoLock(key []byte, value []byte) { + db.cache[string(key)] = cDBValue{value: value, deleted: false, dirty: true} +} + +func (db *CacheDB) Delete(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +func (db *CacheDB) DeleteSync(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +func (db *CacheDB) DeleteNoLock(key []byte) { + db.cache[string(key)] = cDBValue{value: nil, deleted: true, dirty: true} +} + +func (db *CacheDB) Close() { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.parent.Close() +} + +func (db *CacheDB) Print() { + db.mtx.Lock() + defer db.mtx.Unlock() + + fmt.Println("CacheDB\ncache:") + for key, value := range db.cache { + fmt.Printf("[%X]:\t[%v]\n", []byte(key), value) + } + fmt.Println("\nparent:") + db.parent.Print() +} + +func (db *CacheDB) Stats() map[string]string { + db.mtx.Lock() + defer db.mtx.Unlock() + + stats := make(map[string]string) + stats["cache.size"] = fmt.Sprintf("%d", len(db.cache)) + stats["cache.lock_version"] = fmt.Sprintf("%v", db.lockVersion) + mergeStats(db.parent.Stats(), stats, "parent.") + return stats +} + +func (db *CacheDB) Iterator() Iterator { + panic("CacheDB.Iterator() not yet supported") +} + +func (db *CacheDB) NewBatch() Batch { + return &memBatch{db, nil} +} + +// Implements `atomicSetDeleter` for Batch support. +func (db *CacheDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + +// Write writes pending updates to the parent database and clears the cache. +func (db *CacheDB) Write() { + db.mtx.Lock() + defer db.mtx.Unlock() + + // Optional sanity check to ensure that CacheDB is valid + if parent, ok := db.parent.(WriteLocker); ok { + if parent.TryWriteLock(db.lockVersion) { + // All good! + } else { + panic("CacheDB.Write() failed. Did this CacheDB expire?") + } + } + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + keys := make([]string, 0, len(db.cache)) + for key, dbValue := range db.cache { + if dbValue.dirty { + keys = append(keys, key) + } + } + sort.Strings(keys) + + batch := db.parent.NewBatch() + for _, key := range keys { + dbValue := db.cache[key] + if dbValue.deleted { + batch.Delete([]byte(key)) + } else if dbValue.value == nil { + // Skip, it already doesn't exist in parent. + } else { + batch.Set([]byte(key), dbValue.value) + } + } + batch.Write() + + // Clear the cache + db.cache = make(map[string]cDBValue) +} + +//---------------------------------------- +// To CacheWrap this CacheDB further. + +func (db *CacheDB) CacheWrap() interface{} { + return NewCacheDB(db, db.GetWriteLockVersion()) +} + +// If the parent parent DB implements this, (e.g. such as a CacheDB parent to a +// CacheDB child), CacheDB will call `parent.TryWriteLock()` before attempting +// to write. +type WriteLocker interface { + GetWriteLockVersion() (lockVersion interface{}) + TryWriteLock(lockVersion interface{}) bool +} + +// Implements TryWriteLocker. Embed this in DB structs if desired. +type cwwMutex struct { + mtx sync.Mutex + // CONTRACT: reading/writing to `*written` should use `atomic.*`. + // CONTRACT: replacing `written` with another *int32 should use `.mtx`. + written *int32 +} + +func NewCWWMutex() cwwMutex { + return cwwMutex{ + written: new(int32), + } +} + +func (cww *cwwMutex) GetWriteLockVersion() interface{} { + cww.mtx.Lock() + defer cww.mtx.Unlock() + + // `written` works as a "version" object because it gets replaced upon + // successful TryWriteLock. + return cww.written +} + +func (cww *cwwMutex) TryWriteLock(version interface{}) bool { + cww.mtx.Lock() + defer cww.mtx.Unlock() + + if version != cww.written { + return false // wrong "WriteLockVersion" + } + if !atomic.CompareAndSwapInt32(cww.written, 0, 1) { + return false // already written + } + + // New "WriteLockVersion" + cww.written = new(int32) + return true +} diff --git a/db/cache_db_test.go b/db/cache_db_test.go new file mode 100644 index 00000000..1de08e3f --- /dev/null +++ b/db/cache_db_test.go @@ -0,0 +1,83 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func bz(s string) []byte { return []byte(s) } + +func TestCacheDB(t *testing.T) { + mem := NewMemDB() + cdb := mem.CacheWrap().(*CacheDB) + + require.Empty(t, cdb.Get(bz("key1")), "Expected `key1` to be empty") + + mem.Set(bz("key1"), bz("value1")) + cdb.Set(bz("key1"), bz("value1")) + require.Equal(t, bz("value1"), cdb.Get(bz("key1"))) + + cdb.Set(bz("key1"), bz("value2")) + require.Equal(t, bz("value2"), cdb.Get(bz("key1"))) + require.Equal(t, bz("value1"), mem.Get(bz("key1"))) + + cdb.Write() + require.Equal(t, bz("value2"), mem.Get(bz("key1"))) + + require.Panics(t, func() { cdb.Write() }, "Expected second cdb.Write() to fail") + + cdb = mem.CacheWrap().(*CacheDB) + cdb.Delete(bz("key1")) + require.Empty(t, cdb.Get(bz("key1"))) + require.Equal(t, mem.Get(bz("key1")), bz("value2")) + + cdb.Write() + require.Empty(t, cdb.Get(bz("key1")), "Expected `key1` to be empty") + require.Empty(t, mem.Get(bz("key1")), "Expected `key1` to be empty") +} + +func TestCacheDBWriteLock(t *testing.T) { + mem := NewMemDB() + cdb := mem.CacheWrap().(*CacheDB) + require.NotPanics(t, func() { cdb.Write() }) + require.Panics(t, func() { cdb.Write() }) + cdb = mem.CacheWrap().(*CacheDB) + require.NotPanics(t, func() { cdb.Write() }) + require.Panics(t, func() { cdb.Write() }) +} + +func TestCacheDBWriteLockNested(t *testing.T) { + mem := NewMemDB() + cdb := mem.CacheWrap().(*CacheDB) + cdb2 := cdb.CacheWrap().(*CacheDB) + require.NotPanics(t, func() { cdb2.Write() }) + require.Panics(t, func() { cdb2.Write() }) + cdb2 = cdb.CacheWrap().(*CacheDB) + require.NotPanics(t, func() { cdb2.Write() }) + require.Panics(t, func() { cdb2.Write() }) +} + +func TestCacheDBNested(t *testing.T) { + mem := NewMemDB() + cdb := mem.CacheWrap().(*CacheDB) + cdb.Set(bz("key1"), bz("value1")) + + require.Empty(t, mem.Get(bz("key1"))) + require.Equal(t, bz("value1"), cdb.Get(bz("key1"))) + cdb2 := cdb.CacheWrap().(*CacheDB) + require.Equal(t, bz("value1"), cdb2.Get(bz("key1"))) + + cdb2.Set(bz("key1"), bz("VALUE2")) + require.Equal(t, []byte(nil), mem.Get(bz("key1"))) + require.Equal(t, bz("value1"), cdb.Get(bz("key1"))) + require.Equal(t, bz("VALUE2"), cdb2.Get(bz("key1"))) + + cdb2.Write() + require.Equal(t, []byte(nil), mem.Get(bz("key1"))) + require.Equal(t, bz("VALUE2"), cdb.Get(bz("key1"))) + + cdb.Write() + require.Equal(t, bz("VALUE2"), mem.Get(bz("key1"))) + +} diff --git a/db/common_test.go b/db/common_test.go new file mode 100644 index 00000000..505864c2 --- /dev/null +++ b/db/common_test.go @@ -0,0 +1,172 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + cmn "github.com/tendermint/tmlibs/common" +) + +func checkValid(t *testing.T, itr Iterator, expected bool) { + valid := itr.Valid() + assert.Equal(t, expected, valid) +} + +func checkNext(t *testing.T, itr Iterator, expected bool) { + itr.Next() + valid := itr.Valid() + assert.Equal(t, expected, valid) +} + +func checkNextPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Next() }, "checkNextPanics expected panic but didn't") +} + +func checkPrevPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Prev() }, "checkPrevPanics expected panic but didn't") +} + +func checkPrev(t *testing.T, itr Iterator, expected bool) { + itr.Prev() + valid := itr.Valid() + assert.Equal(t, expected, valid) +} + +func checkItem(t *testing.T, itr Iterator, key []byte, value []byte) { + k, v := itr.Key(), itr.Value() + assert.Exactly(t, key, k) + assert.Exactly(t, value, v) +} + +func checkInvalid(t *testing.T, itr Iterator) { + checkValid(t, itr, false) + checkKeyPanics(t, itr) + checkValuePanics(t, itr) + checkNextPanics(t, itr) + checkPrevPanics(t, itr) +} + +func checkKeyPanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Key() }, "checkKeyPanics expected panic but didn't") +} + +func checkValuePanics(t *testing.T, itr Iterator) { + assert.Panics(t, func() { itr.Key() }, "checkValuePanics expected panic but didn't") +} + +func newTempDB(t *testing.T, backend string) (db DB) { + dir, dirname := cmn.Tempdir("test_go_iterator") + db = NewDB("testdb", backend, dirname) + dir.Close() + return db +} + +func TestDBIteratorSingleKey(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator() + + checkValid(t, itr, true) + checkNext(t, itr, false) + checkValid(t, itr, false) + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorTwoKeys(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + db.SetSync(bz("2"), bz("value_1")) + + { // Fail by calling Next too much + itr := db.Iterator() + checkValid(t, itr, true) + + for i := 0; i < 10; i++ { + checkNext(t, itr, true) + checkValid(t, itr, true) + + checkPrev(t, itr, true) + checkValid(t, itr, true) + } + + checkNext(t, itr, true) + checkValid(t, itr, true) + + checkNext(t, itr, false) + checkValid(t, itr, false) + + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + } + + { // Fail by calling Prev too much + itr := db.Iterator() + checkValid(t, itr, true) + + for i := 0; i < 10; i++ { + checkNext(t, itr, true) + checkValid(t, itr, true) + + checkPrev(t, itr, true) + checkValid(t, itr, true) + } + + checkPrev(t, itr, false) + checkValid(t, itr, false) + + checkPrevPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + } + }) + } +} + +func TestDBIteratorEmpty(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator() + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorEmptySeek(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator() + itr.Seek(bz("1")) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorBadSeek(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator() + itr.Seek(bz("2")) + + checkInvalid(t, itr) + }) + } +} diff --git a/db/db.go b/db/db.go index 8156c1e9..6c8bd480 100644 --- a/db/db.go +++ b/db/db.go @@ -3,7 +3,7 @@ package db import . "github.com/tendermint/tmlibs/common" type DB interface { - Get([]byte) []byte + Get([]byte) []byte // NOTE: returns nil iff never set or deleted. Set([]byte, []byte) SetSync([]byte, []byte) Delete([]byte) @@ -11,11 +11,15 @@ type DB interface { Close() NewBatch() Batch Iterator() Iterator - IteratorPrefix([]byte) Iterator // For debugging Print() + + // Stats returns a map of property values for all keys and the size of the cache. Stats() map[string]string + + // CacheWrap wraps the DB w/ a CacheDB. + CacheWrap() interface{} } type Batch interface { @@ -24,23 +28,66 @@ type Batch interface { Write() } -type Iterator interface { - Next() bool +/* + Usage: + for itr.Seek(mykey); itr.Valid(); itr.Next() { + k, v := itr.Key(); itr.Value() + .... + } +*/ +type Iterator interface { + + // Seek moves the iterator the position of the key given or, if the key + // doesn't exist, the next key that does exist in the database. If the key + // doesn't exist, and there is no next key, the Iterator becomes invalid. + Seek(key []byte) + + // Valid returns false only when an Iterator has iterated past either the + // first or the last key in the database. + Valid() bool + + // Next moves the iterator to the next sequential key in the database, as + // defined by the Comparator in the ReadOptions used to create this Iterator. + // + // If Valid returns false, this method will panic. + Next() + + // Prev moves the iterator to the previous sequential key in the database, as + // defined by the Comparator in the ReadOptions used to create this Iterator. + // + // If Valid returns false, this method will panic. + Prev() + + // Key returns the key of the cursor. + // + // If Valid returns false, this method will panic. Key() []byte + + // Value returns the key of the cursor. + // + // If Valid returns false, this method will panic. Value() []byte - Release() - Error() error + // GetError returns an IteratorError from LevelDB if it had one during + // iteration. + // + // This method is safe to call when Valid returns false. + GetError() error + + // Close deallocates the given Iterator. + Close() } //----------------------------------------------------------------------------- +// Main entry const ( LevelDBBackendStr = "leveldb" // legacy, defaults to goleveldb. CLevelDBBackendStr = "cleveldb" GoLevelDBBackendStr = "goleveldb" MemDBBackendStr = "memdb" + FSDBBackendStr = "fsdb" // using the filesystem naively ) type dbCreator func(name string, dir string) (DB, error) diff --git a/db/fsdb.go b/db/fsdb.go new file mode 100644 index 00000000..65ac3c38 --- /dev/null +++ b/db/fsdb.go @@ -0,0 +1,231 @@ +package db + +import ( + "fmt" + "io/ioutil" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "sync" + + "github.com/pkg/errors" +) + +const ( + keyPerm = os.FileMode(0600) + dirPerm = os.FileMode(0700) +) + +func init() { + registerDBCreator(FSDBBackendStr, func(name string, dir string) (DB, error) { + dbPath := filepath.Join(dir, name+".db") + return NewFSDB(dbPath), nil + }, false) +} + +// It's slow. +type FSDB struct { + mtx sync.Mutex + dir string + + cwwMutex +} + +func NewFSDB(dir string) *FSDB { + err := os.MkdirAll(dir, dirPerm) + if err != nil { + panic(errors.Wrap(err, "Creating FSDB dir "+dir)) + } + database := &FSDB{ + dir: dir, + cwwMutex: NewCWWMutex(), + } + return database +} + +func (db *FSDB) Get(key []byte) []byte { + db.mtx.Lock() + defer db.mtx.Unlock() + + path := db.nameToPath(key) + value, err := read(path) + if os.IsNotExist(err) { + return nil + } else if err != nil { + panic(errors.Wrap(err, fmt.Sprintf("Getting key %s (0x%X)", string(key), key))) + } + return value +} + +func (db *FSDB) Set(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +func (db *FSDB) SetSync(key []byte, value []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// NOTE: Implements atomicSetDeleter. +func (db *FSDB) SetNoLock(key []byte, value []byte) { + if value == nil { + value = []byte{} + } + path := db.nameToPath(key) + err := write(path, value) + if err != nil { + panic(errors.Wrap(err, fmt.Sprintf("Setting key %s (0x%X)", string(key), key))) + } +} + +func (db *FSDB) Delete(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +func (db *FSDB) DeleteSync(key []byte) { + db.mtx.Lock() + defer db.mtx.Unlock() + + db.DeleteNoLock(key) +} + +// NOTE: Implements atomicSetDeleter. +func (db *FSDB) DeleteNoLock(key []byte) { + err := remove(string(key)) + if os.IsNotExist(err) { + return + } else if err != nil { + panic(errors.Wrap(err, fmt.Sprintf("Removing key %s (0x%X)", string(key), key))) + } +} + +func (db *FSDB) Close() { + // Nothing to do. +} + +func (db *FSDB) Print() { + db.mtx.Lock() + defer db.mtx.Unlock() + + panic("FSDB.Print not yet implemented") +} + +func (db *FSDB) Stats() map[string]string { + db.mtx.Lock() + defer db.mtx.Unlock() + + panic("FSDB.Stats not yet implemented") +} + +func (db *FSDB) NewBatch() Batch { + db.mtx.Lock() + defer db.mtx.Unlock() + + // Not sure we would ever want to try... + // It doesn't seem easy for general filesystems. + panic("FSDB.NewBatch not yet implemented") +} + +func (db *FSDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + +func (db *FSDB) CacheWrap() interface{} { + return NewCacheDB(db, db.GetWriteLockVersion()) +} + +func (db *FSDB) Iterator() Iterator { + it := newMemDBIterator() + it.db = db + it.cur = 0 + + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + keys, err := list(db.dir) + if err != nil { + panic(errors.Wrap(err, fmt.Sprintf("Listing keys in %s", db.dir))) + } + sort.Strings(keys) + it.keys = keys + return it +} + +func (db *FSDB) nameToPath(name []byte) string { + n := url.PathEscape(string(name)) + return path.Join(db.dir, n) +} + +// Read some bytes to a file. +// CONTRACT: returns os errors directly without wrapping. +func read(path string) ([]byte, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + d, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + return d, nil +} + +// Write some bytes from a file. +// CONTRACT: returns os errors directly without wrapping. +func write(path string, d []byte) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, keyPerm) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(d) + if err != nil { + return err + } + err = f.Sync() + return err +} + +// Remove a file. +// CONTRACT: returns os errors directly without wrapping. +func remove(path string) error { + return os.Remove(path) +} + +// List files of a path. +// Paths will NOT include dir as the prefix. +// CONTRACT: returns os errors directly without wrapping. +func list(dirPath string) (paths []string, err error) { + dir, err := os.Open(dirPath) + if err != nil { + return nil, err + } + defer dir.Close() + + names, err := dir.Readdirnames(0) + if err != nil { + return nil, err + } + for i, name := range names { + n, err := url.PathUnescape(name) + if err != nil { + return nil, fmt.Errorf("Failed to unescape %s while listing", name) + } + names[i] = n + } + return names, nil +} diff --git a/db/go_level_db.go b/db/go_level_db.go index 4abd7611..d9cec519 100644 --- a/db/go_level_db.go +++ b/db/go_level_db.go @@ -8,7 +8,6 @@ import ( "github.com/syndtr/goleveldb/leveldb/errors" "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" - "github.com/syndtr/goleveldb/leveldb/util" . "github.com/tendermint/tmlibs/common" ) @@ -23,6 +22,8 @@ func init() { type GoLevelDB struct { db *leveldb.DB + + cwwMutex } func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) { @@ -31,7 +32,10 @@ func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) { if err != nil { return nil, err } - database := &GoLevelDB{db: db} + database := &GoLevelDB{ + db: db, + cwwMutex: NewCWWMutex(), + } return database, nil } @@ -117,55 +121,18 @@ func (db *GoLevelDB) Stats() map[string]string { return stats } -type goLevelDBIterator struct { - source iterator.Iterator +func (db *GoLevelDB) CacheWrap() interface{} { + return NewCacheDB(db, db.GetWriteLockVersion()) } -// Key returns a copy of the current key. -func (it *goLevelDBIterator) Key() []byte { - key := it.source.Key() - k := make([]byte, len(key)) - copy(k, key) - - return k -} - -// Value returns a copy of the current value. -func (it *goLevelDBIterator) Value() []byte { - val := it.source.Value() - v := make([]byte, len(val)) - copy(v, val) - - return v -} - -func (it *goLevelDBIterator) Error() error { - return it.source.Error() -} - -func (it *goLevelDBIterator) Next() bool { - return it.source.Next() -} - -func (it *goLevelDBIterator) Release() { - it.source.Release() -} - -func (db *GoLevelDB) Iterator() Iterator { - return &goLevelDBIterator{db.db.NewIterator(nil, nil)} -} - -func (db *GoLevelDB) IteratorPrefix(prefix []byte) Iterator { - return &goLevelDBIterator{db.db.NewIterator(util.BytesPrefix(prefix), nil)} -} +//---------------------------------------- +// Batch func (db *GoLevelDB) NewBatch() Batch { batch := new(leveldb.Batch) return &goLevelDBBatch{db, batch} } -//-------------------------------------------------------------------------------- - type goLevelDBBatch struct { db *GoLevelDB batch *leveldb.Batch @@ -185,3 +152,77 @@ func (mBatch *goLevelDBBatch) Write() { PanicCrisis(err) } } + +//---------------------------------------- +// Iterator + +func (db *GoLevelDB) Iterator() Iterator { + itr := &goLevelDBIterator{ + source: db.db.NewIterator(nil, nil), + } + itr.Seek(nil) + return itr +} + +type goLevelDBIterator struct { + source iterator.Iterator + invalid bool +} + +// Key returns a copy of the current key. +func (it *goLevelDBIterator) Key() []byte { + if !it.Valid() { + panic("goLevelDBIterator Key() called when invalid") + } + key := it.source.Key() + k := make([]byte, len(key)) + copy(k, key) + + return k +} + +// Value returns a copy of the current value. +func (it *goLevelDBIterator) Value() []byte { + if !it.Valid() { + panic("goLevelDBIterator Value() called when invalid") + } + val := it.source.Value() + v := make([]byte, len(val)) + copy(v, val) + + return v +} + +func (it *goLevelDBIterator) GetError() error { + return it.source.Error() +} + +func (it *goLevelDBIterator) Seek(key []byte) { + it.source.Seek(key) +} + +func (it *goLevelDBIterator) Valid() bool { + if it.invalid { + return false + } + it.invalid = !it.source.Valid() + return !it.invalid +} + +func (it *goLevelDBIterator) Next() { + if !it.Valid() { + panic("goLevelDBIterator Next() called when invalid") + } + it.source.Next() +} + +func (it *goLevelDBIterator) Prev() { + if !it.Valid() { + panic("goLevelDBIterator Prev() called when invalid") + } + it.source.Prev() +} + +func (it *goLevelDBIterator) Close() { + it.source.Release() +} diff --git a/db/go_level_db_test.go b/db/go_level_db_test.go index 2cd3192c..88b6730f 100644 --- a/db/go_level_db_test.go +++ b/db/go_level_db_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) func BenchmarkRandomReadsWrites(b *testing.B) { @@ -17,7 +17,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { for i := 0; i < int(numItems); i++ { internal[int64(i)] = int64(0) } - db, err := NewGoLevelDB(Fmt("test_%x", RandStr(12)), "") + db, err := NewGoLevelDB(cmn.Fmt("test_%x", cmn.RandStr(12)), "") if err != nil { b.Fatal(err.Error()) return @@ -29,7 +29,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { for i := 0; i < b.N; i++ { // Write something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) internal[idx] += 1 val := internal[idx] idxBytes := int642Bytes(int64(idx)) @@ -42,7 +42,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { } // Read something { - idx := (int64(RandInt()) % numItems) + idx := (int64(cmn.RandInt()) % numItems) val := internal[idx] idxBytes := int642Bytes(int64(idx)) valBytes := db.Get(idxBytes) diff --git a/db/mem_batch.go b/db/mem_batch.go new file mode 100644 index 00000000..7072d931 --- /dev/null +++ b/db/mem_batch.go @@ -0,0 +1,50 @@ +package db + +import "sync" + +type atomicSetDeleter interface { + Mutex() *sync.Mutex + SetNoLock(key, value []byte) + DeleteNoLock(key []byte) +} + +type memBatch struct { + db atomicSetDeleter + ops []operation +} + +type opType int + +const ( + opTypeSet opType = 1 + opTypeDelete opType = 2 +) + +type operation struct { + opType + key []byte + value []byte +} + +func (mBatch *memBatch) Set(key, value []byte) { + mBatch.ops = append(mBatch.ops, operation{opTypeSet, key, value}) +} + +func (mBatch *memBatch) Delete(key []byte) { + mBatch.ops = append(mBatch.ops, operation{opTypeDelete, key, nil}) +} + +func (mBatch *memBatch) Write() { + mtx := mBatch.db.Mutex() + mtx.Lock() + defer mtx.Unlock() + + for _, op := range mBatch.ops { + switch op.opType { + case opTypeSet: + mBatch.db.SetNoLock(op.key, op.value) + case opTypeDelete: + mBatch.db.DeleteNoLock(op.key) + } + } +} diff --git a/db/mem_db.go b/db/mem_db.go index 07742750..30697adc 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -1,8 +1,9 @@ package db import ( + "bytes" "fmt" - "strings" + "sort" "sync" ) @@ -15,40 +16,63 @@ func init() { type MemDB struct { mtx sync.Mutex db map[string][]byte + + cwwMutex } func NewMemDB() *MemDB { - database := &MemDB{db: make(map[string][]byte)} + database := &MemDB{ + db: make(map[string][]byte), + cwwMutex: NewCWWMutex(), + } return database } func (db *MemDB) Get(key []byte) []byte { db.mtx.Lock() defer db.mtx.Unlock() + return db.db[string(key)] } func (db *MemDB) Set(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - db.db[string(key)] = value + + db.SetNoLock(key, value) } func (db *MemDB) SetSync(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() + + db.SetNoLock(key, value) +} + +// NOTE: Implements atomicSetDeleter +func (db *MemDB) SetNoLock(key []byte, value []byte) { + if value == nil { + value = []byte{} + } db.db[string(key)] = value } func (db *MemDB) Delete(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() + delete(db.db, string(key)) } func (db *MemDB) DeleteSync(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() + + delete(db.db, string(key)) +} + +// NOTE: Implements atomicSetDeleter +func (db *MemDB) DeleteNoLock(key []byte) { delete(db.db, string(key)) } @@ -63,115 +87,113 @@ func (db *MemDB) Close() { func (db *MemDB) Print() { db.mtx.Lock() defer db.mtx.Unlock() + for key, value := range db.db { fmt.Printf("[%X]:\t[%X]\n", []byte(key), value) } } func (db *MemDB) Stats() map[string]string { + db.mtx.Lock() + defer db.mtx.Unlock() + stats := make(map[string]string) stats["database.type"] = "memDB" + stats["database.size"] = fmt.Sprintf("%d", len(db.db)) return stats } +func (db *MemDB) NewBatch() Batch { + db.mtx.Lock() + defer db.mtx.Unlock() + + return &memBatch{db, nil} +} + +func (db *MemDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + +func (db *MemDB) CacheWrap() interface{} { + return NewCacheDB(db, db.GetWriteLockVersion()) +} + +//---------------------------------------- + +func (db *MemDB) Iterator() Iterator { + it := newMemDBIterator() + it.db = db + it.cur = 0 + + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + for key, _ := range db.db { + it.keys = append(it.keys, key) + } + sort.Strings(it.keys) + return it +} + type memDBIterator struct { - last int + cur int keys []string - db *MemDB + db DB } func newMemDBIterator() *memDBIterator { return &memDBIterator{} } -func (it *memDBIterator) Next() bool { - if it.last >= len(it.keys)-1 { - return false +func (it *memDBIterator) Seek(key []byte) { + for i, ik := range it.keys { + it.cur = i + if bytes.Compare(key, []byte(ik)) <= 0 { + return + } } - it.last++ - return true + it.cur += 1 // If not found, becomes invalid. +} + +func (it *memDBIterator) Valid() bool { + return 0 <= it.cur && it.cur < len(it.keys) +} + +func (it *memDBIterator) Next() { + if !it.Valid() { + panic("memDBIterator Next() called when invalid") + } + it.cur++ +} + +func (it *memDBIterator) Prev() { + if !it.Valid() { + panic("memDBIterator Next() called when invalid") + } + it.cur-- } func (it *memDBIterator) Key() []byte { - return []byte(it.keys[it.last]) + if !it.Valid() { + panic("memDBIterator Key() called when invalid") + } + return []byte(it.keys[it.cur]) } func (it *memDBIterator) Value() []byte { + if !it.Valid() { + panic("memDBIterator Value() called when invalid") + } return it.db.Get(it.Key()) } -func (it *memDBIterator) Release() { +func (it *memDBIterator) Close() { it.db = nil it.keys = nil } -func (it *memDBIterator) Error() error { +func (it *memDBIterator) GetError() error { return nil } - -func (db *MemDB) Iterator() Iterator { - return db.IteratorPrefix([]byte{}) -} - -func (db *MemDB) IteratorPrefix(prefix []byte) Iterator { - it := newMemDBIterator() - it.db = db - it.last = -1 - - db.mtx.Lock() - defer db.mtx.Unlock() - - // unfortunately we need a copy of all of the keys - for key, _ := range db.db { - if strings.HasPrefix(key, string(prefix)) { - it.keys = append(it.keys, key) - } - } - return it -} - -func (db *MemDB) NewBatch() Batch { - return &memDBBatch{db, nil} -} - -//-------------------------------------------------------------------------------- - -type memDBBatch struct { - db *MemDB - ops []operation -} - -type opType int - -const ( - opTypeSet = 1 - opTypeDelete = 2 -) - -type operation struct { - opType - key []byte - value []byte -} - -func (mBatch *memDBBatch) Set(key, value []byte) { - mBatch.ops = append(mBatch.ops, operation{opTypeSet, key, value}) -} - -func (mBatch *memDBBatch) Delete(key []byte) { - mBatch.ops = append(mBatch.ops, operation{opTypeDelete, key, nil}) -} - -func (mBatch *memDBBatch) Write() { - mBatch.db.mtx.Lock() - defer mBatch.db.mtx.Unlock() - - for _, op := range mBatch.ops { - if op.opType == opTypeSet { - mBatch.db.db[string(op.key)] = op.value - } else if op.opType == opTypeDelete { - delete(mBatch.db.db, string(op.key)) - } - } - -} diff --git a/db/mem_db_test.go b/db/mem_db_test.go index 503e361f..b5c9167c 100644 --- a/db/mem_db_test.go +++ b/db/mem_db_test.go @@ -21,7 +21,7 @@ func TestMemDbIterator(t *testing.T) { iter := db.Iterator() i := 0 - for iter.Next() { + for ; iter.Valid(); iter.Next() { assert.Equal(t, db.Get(iter.Key()), iter.Value(), "values dont match for key") i += 1 } diff --git a/db/stats.go b/db/stats.go new file mode 100644 index 00000000..ef4b0dd0 --- /dev/null +++ b/db/stats.go @@ -0,0 +1,7 @@ +package db + +func mergeStats(src, dest map[string]string, prefix string) { + for key, value := range src { + dest[prefix+key] = value + } +} diff --git a/db/util.go b/db/util.go new file mode 100644 index 00000000..5f381a5b --- /dev/null +++ b/db/util.go @@ -0,0 +1,82 @@ +package db + +import "bytes" + +// A wrapper around itr that tries to keep the iterator +// within the bounds as defined by `prefix` +type prefixIterator struct { + itr Iterator + prefix []byte + invalid bool +} + +func (pi *prefixIterator) Seek(key []byte) { + if !bytes.HasPrefix(key, pi.prefix) { + pi.invalid = true + return + } + pi.itr.Seek(key) + pi.checkInvalid() +} + +func (pi *prefixIterator) checkInvalid() { + if !pi.itr.Valid() { + pi.invalid = true + } +} + +func (pi *prefixIterator) Valid() bool { + if pi.invalid { + return false + } + key := pi.itr.Key() + ok := bytes.HasPrefix(key, pi.prefix) + if !ok { + pi.invalid = true + return false + } + return true +} + +func (pi *prefixIterator) Next() { + if pi.invalid { + panic("prefixIterator Next() called when invalid") + } + pi.itr.Next() + pi.checkInvalid() +} + +func (pi *prefixIterator) Prev() { + if pi.invalid { + panic("prefixIterator Prev() called when invalid") + } + pi.itr.Prev() + pi.checkInvalid() +} + +func (pi *prefixIterator) Key() []byte { + if pi.invalid { + panic("prefixIterator Key() called when invalid") + } + return pi.itr.Key() +} + +func (pi *prefixIterator) Value() []byte { + if pi.invalid { + panic("prefixIterator Value() called when invalid") + } + return pi.itr.Value() +} + +func (pi *prefixIterator) Close() { pi.itr.Close() } +func (pi *prefixIterator) GetError() error { return pi.itr.GetError() } + +func IteratePrefix(db DB, prefix []byte) Iterator { + itr := db.Iterator() + pi := &prefixIterator{ + itr: itr, + prefix: prefix, + } + pi.Seek(prefix) + return pi +} diff --git a/db/util_test.go b/db/util_test.go new file mode 100644 index 00000000..55a41bf5 --- /dev/null +++ b/db/util_test.go @@ -0,0 +1,209 @@ +package db + +import ( + "fmt" + "testing" +) + +func TestPrefixIteratorNoMatchNil(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := IteratePrefix(db, []byte("2")) + + checkInvalid(t, itr) + }) + } +} + +func TestPrefixIteratorNoMatch1(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := IteratePrefix(db, []byte("2")) + db.SetSync(bz("1"), bz("value_1")) + + checkInvalid(t, itr) + }) + } +} + +func TestPrefixIteratorMatch2(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("2"), bz("value_2")) + itr := IteratePrefix(db, []byte("2")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("2"), bz("value_2")) + checkNext(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +func TestPrefixIteratorMatch3(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("3"), bz("value_3")) + itr := IteratePrefix(db, []byte("2")) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Search for a/1, fail by too much Next() +func TestPrefixIteratorMatches1N(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + itr := IteratePrefix(db, []byte("a/")) + itr.Seek(bz("a/1")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + + // Bad! + checkNext(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Search for a/1, fail by too much Prev() +func TestPrefixIteratorMatches1P(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + itr := IteratePrefix(db, []byte("a/")) + itr.Seek(bz("a/1")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + checkPrev(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + + // Bad! + checkPrev(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Search for a/2, fail by too much Next() +func TestPrefixIteratorMatches2N(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + itr := IteratePrefix(db, []byte("a/")) + itr.Seek(bz("a/2")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + checkPrev(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + + // Bad! + checkNext(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Search for a/2, fail by too much Prev() +func TestPrefixIteratorMatches2P(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + itr := IteratePrefix(db, []byte("a/")) + itr.Seek(bz("a/2")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + checkPrev(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + + // Bad! + checkPrev(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Search for a/3, fail by too much Next() +func TestPrefixIteratorMatches3N(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + itr := IteratePrefix(db, []byte("a/")) + itr.Seek(bz("a/3")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + checkPrev(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + + // Bad! + checkNext(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +// Search for a/3, fail by too much Prev() +func TestPrefixIteratorMatches3P(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("a/1"), bz("value_1")) + db.SetSync(bz("a/3"), bz("value_3")) + itr := IteratePrefix(db, []byte("a/")) + itr.Seek(bz("a/3")) + + checkValid(t, itr, true) + checkItem(t, itr, bz("a/3"), bz("value_3")) + checkPrev(t, itr, true) + checkItem(t, itr, bz("a/1"), bz("value_1")) + + // Bad! + checkPrev(t, itr, false) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} diff --git a/merkle/kvpairs.go b/merkle/kvpairs.go new file mode 100644 index 00000000..3d67049f --- /dev/null +++ b/merkle/kvpairs.go @@ -0,0 +1,48 @@ +package merkle + +import ( + "sort" + + wire "github.com/tendermint/go-wire" + "golang.org/x/crypto/ripemd160" +) + +// NOTE: Behavior is undefined with dup keys. +type KVPair struct { + Key string + Value interface{} // Can be Hashable or not. +} + +func (kv KVPair) Hash() []byte { + hasher, n, err := ripemd160.New(), new(int), new(error) + wire.WriteString(kv.Key, hasher, n, err) + if kvH, ok := kv.Value.(Hashable); ok { + wire.WriteByteSlice(kvH.Hash(), hasher, n, err) + } else { + wire.WriteBinary(kv.Value, hasher, n, err) + } + if *err != nil { + panic(*err) + } + return hasher.Sum(nil) +} + +type KVPairs []KVPair + +func (kvps KVPairs) Len() int { return len(kvps) } +func (kvps KVPairs) Less(i, j int) bool { return kvps[i].Key < kvps[j].Key } +func (kvps KVPairs) Swap(i, j int) { kvps[i], kvps[j] = kvps[j], kvps[i] } +func (kvps KVPairs) Sort() { sort.Sort(kvps) } + +func MakeSortedKVPairs(m map[string]interface{}) []Hashable { + kvPairs := make([]KVPair, 0, len(m)) + for k, v := range m { + kvPairs = append(kvPairs, KVPair{k, v}) + } + KVPairs(kvPairs).Sort() + kvPairsH := make([]Hashable, 0, len(kvPairs)) + for _, kvp := range kvPairs { + kvPairsH = append(kvPairsH, kvp) + } + return kvPairsH +} diff --git a/merkle/simple_map.go b/merkle/simple_map.go new file mode 100644 index 00000000..43dce990 --- /dev/null +++ b/merkle/simple_map.go @@ -0,0 +1,26 @@ +package merkle + +type SimpleMap struct { + kvz KVPairs +} + +func NewSimpleMap() *SimpleMap { + return &SimpleMap{ + kvz: nil, + } +} + +func (sm *SimpleMap) Set(k string, o interface{}) { + sm.kvz = append(sm.kvz, KVPair{Key: k, Value: o}) +} + +// Merkle root hash of items sorted by key. +// NOTE: Behavior is undefined when key is duplicate. +func (sm *SimpleMap) Hash() []byte { + sm.kvz.Sort() + kvPairsH := make([]Hashable, 0, len(sm.kvz)) + for _, kvp := range sm.kvz { + kvPairsH = append(kvPairsH, kvp) + } + return SimpleHashFromHashables(kvPairsH) +} diff --git a/merkle/simple_map_test.go b/merkle/simple_map_test.go new file mode 100644 index 00000000..5eb21827 --- /dev/null +++ b/merkle/simple_map_test.go @@ -0,0 +1,47 @@ +package merkle + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSimpleMap(t *testing.T) { + { + db := NewSimpleMap() + db.Set("key1", "value1") + assert.Equal(t, "376bf717ebe3659a34f68edb833dfdcf4a2d3c10", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", "value2") + assert.Equal(t, "72fd3a7224674377952214cb10ef21753ec803eb", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", "value1") + db.Set("key2", "value2") + assert.Equal(t, "23a160bd4eea5b2fcc0755d722f9112a15999abc", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key2", "value2") // NOTE: out of order + db.Set("key1", "value1") + assert.Equal(t, "23a160bd4eea5b2fcc0755d722f9112a15999abc", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key1", "value1") + db.Set("key2", "value2") + db.Set("key3", "value3") + assert.Equal(t, "40df7416429148d03544cfafa86e1080615cd2bc", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } + { + db := NewSimpleMap() + db.Set("key2", "value2") // NOTE: out of order + db.Set("key1", "value1") + db.Set("key3", "value3") + assert.Equal(t, "40df7416429148d03544cfafa86e1080615cd2bc", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + } +} diff --git a/merkle/simple_proof.go b/merkle/simple_proof.go new file mode 100644 index 00000000..f75568fd --- /dev/null +++ b/merkle/simple_proof.go @@ -0,0 +1,131 @@ +package merkle + +import ( + "bytes" + "fmt" +) + +type SimpleProof struct { + Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. +} + +// proofs[0] is the proof for items[0]. +func SimpleProofsFromHashables(items []Hashable) (rootHash []byte, proofs []*SimpleProof) { + trails, rootSPN := trailsFromHashables(items) + rootHash = rootSPN.Hash + proofs = make([]*SimpleProof, len(items)) + for i, trail := range trails { + proofs[i] = &SimpleProof{ + Aunts: trail.FlattenAunts(), + } + } + return +} + +// Verify that leafHash is a leaf hash of the simple-merkle-tree +// which hashes to rootHash. +func (sp *SimpleProof) Verify(index int, total int, leafHash []byte, rootHash []byte) bool { + computedHash := computeHashFromAunts(index, total, leafHash, sp.Aunts) + return computedHash != nil && bytes.Equal(computedHash, rootHash) +} + +func (sp *SimpleProof) String() string { + return sp.StringIndented("") +} + +func (sp *SimpleProof) StringIndented(indent string) string { + return fmt.Sprintf(`SimpleProof{ +%s Aunts: %X +%s}`, + indent, sp.Aunts, + indent) +} + +// Use the leafHash and innerHashes to get the root merkle hash. +// If the length of the innerHashes slice isn't exactly correct, the result is nil. +func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { + // Recursive impl. + if index >= total { + return nil + } + switch total { + case 0: + panic("Cannot call computeHashFromAunts() with 0 total") + case 1: + if len(innerHashes) != 0 { + return nil + } + return leafHash + default: + if len(innerHashes) == 0 { + return nil + } + numLeft := (total + 1) / 2 + if index < numLeft { + leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if leftHash == nil { + return nil + } + return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) + } else { + rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if rightHash == nil { + return nil + } + return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) + } + } +} + +// Helper structure to construct merkle proof. +// The node and the tree is thrown away afterwards. +// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. +// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or +// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. +type SimpleProofNode struct { + Hash []byte + Parent *SimpleProofNode + Left *SimpleProofNode // Left sibling (only one of Left,Right is set) + Right *SimpleProofNode // Right sibling (only one of Left,Right is set) +} + +// Starting from a leaf SimpleProofNode, FlattenAunts() will return +// the inner hashes for the item corresponding to the leaf. +func (spn *SimpleProofNode) FlattenAunts() [][]byte { + // Nonrecursive impl. + innerHashes := [][]byte{} + for spn != nil { + if spn.Left != nil { + innerHashes = append(innerHashes, spn.Left.Hash) + } else if spn.Right != nil { + innerHashes = append(innerHashes, spn.Right.Hash) + } else { + break + } + spn = spn.Parent + } + return innerHashes +} + +// trails[0].Hash is the leaf hash for items[0]. +// trails[i].Parent.Parent....Parent == root for all i. +func trailsFromHashables(items []Hashable) (trails []*SimpleProofNode, root *SimpleProofNode) { + // Recursive impl. + switch len(items) { + case 0: + return nil, nil + case 1: + trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} + return []*SimpleProofNode{trail}, trail + default: + lefts, leftRoot := trailsFromHashables(items[:(len(items)+1)/2]) + rights, rightRoot := trailsFromHashables(items[(len(items)+1)/2:]) + rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) + root := &SimpleProofNode{rootHash, nil, nil, nil} + leftRoot.Parent = root + leftRoot.Right = rightRoot + rightRoot.Parent = root + rightRoot.Left = leftRoot + return append(lefts, rights...), root + } +} diff --git a/merkle/simple_tree.go b/merkle/simple_tree.go index 8106246d..d64082b4 100644 --- a/merkle/simple_tree.go +++ b/merkle/simple_tree.go @@ -25,10 +25,6 @@ For larger datasets, use IAVLTree. package merkle import ( - "bytes" - "fmt" - "sort" - "golang.org/x/crypto/ripemd160" "github.com/tendermint/go-wire" @@ -95,183 +91,3 @@ func SimpleHashFromMap(m map[string]interface{}) []byte { kpPairsH := MakeSortedKVPairs(m) return SimpleHashFromHashables(kpPairsH) } - -//-------------------------------------------------------------------------------- - -/* Convenience struct for key-value pairs. -A list of KVPairs is hashed via `SimpleHashFromHashables`. -NOTE: Each `Value` is encoded for hashing without extra type information, -so the user is presumed to be aware of the Value types. -*/ -type KVPair struct { - Key string - Value interface{} -} - -func (kv KVPair) Hash() []byte { - hasher, n, err := ripemd160.New(), new(int), new(error) - wire.WriteString(kv.Key, hasher, n, err) - if kvH, ok := kv.Value.(Hashable); ok { - wire.WriteByteSlice(kvH.Hash(), hasher, n, err) - } else { - wire.WriteBinary(kv.Value, hasher, n, err) - } - if *err != nil { - PanicSanity(*err) - } - return hasher.Sum(nil) -} - -type KVPairs []KVPair - -func (kvps KVPairs) Len() int { return len(kvps) } -func (kvps KVPairs) Less(i, j int) bool { return kvps[i].Key < kvps[j].Key } -func (kvps KVPairs) Swap(i, j int) { kvps[i], kvps[j] = kvps[j], kvps[i] } -func (kvps KVPairs) Sort() { sort.Sort(kvps) } - -func MakeSortedKVPairs(m map[string]interface{}) []Hashable { - kvPairs := []KVPair{} - for k, v := range m { - kvPairs = append(kvPairs, KVPair{k, v}) - } - KVPairs(kvPairs).Sort() - kvPairsH := []Hashable{} - for _, kvp := range kvPairs { - kvPairsH = append(kvPairsH, kvp) - } - return kvPairsH -} - -//-------------------------------------------------------------------------------- - -type SimpleProof struct { - Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. -} - -// proofs[0] is the proof for items[0]. -func SimpleProofsFromHashables(items []Hashable) (rootHash []byte, proofs []*SimpleProof) { - trails, rootSPN := trailsFromHashables(items) - rootHash = rootSPN.Hash - proofs = make([]*SimpleProof, len(items)) - for i, trail := range trails { - proofs[i] = &SimpleProof{ - Aunts: trail.FlattenAunts(), - } - } - return -} - -// Verify that leafHash is a leaf hash of the simple-merkle-tree -// which hashes to rootHash. -func (sp *SimpleProof) Verify(index int, total int, leafHash []byte, rootHash []byte) bool { - computedHash := computeHashFromAunts(index, total, leafHash, sp.Aunts) - if computedHash == nil { - return false - } - if !bytes.Equal(computedHash, rootHash) { - return false - } - return true -} - -func (sp *SimpleProof) String() string { - return sp.StringIndented("") -} - -func (sp *SimpleProof) StringIndented(indent string) string { - return fmt.Sprintf(`SimpleProof{ -%s Aunts: %X -%s}`, - indent, sp.Aunts, - indent) -} - -// Use the leafHash and innerHashes to get the root merkle hash. -// If the length of the innerHashes slice isn't exactly correct, the result is nil. -func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { - // Recursive impl. - if index >= total { - return nil - } - switch total { - case 0: - PanicSanity("Cannot call computeHashFromAunts() with 0 total") - return nil - case 1: - if len(innerHashes) != 0 { - return nil - } - return leafHash - default: - if len(innerHashes) == 0 { - return nil - } - numLeft := (total + 1) / 2 - if index < numLeft { - leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) - if leftHash == nil { - return nil - } - return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) - } else { - rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) - if rightHash == nil { - return nil - } - return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) - } - } -} - -// Helper structure to construct merkle proof. -// The node and the tree is thrown away afterwards. -// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. -// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or -// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. -type SimpleProofNode struct { - Hash []byte - Parent *SimpleProofNode - Left *SimpleProofNode // Left sibling (only one of Left,Right is set) - Right *SimpleProofNode // Right sibling (only one of Left,Right is set) -} - -// Starting from a leaf SimpleProofNode, FlattenAunts() will return -// the inner hashes for the item corresponding to the leaf. -func (spn *SimpleProofNode) FlattenAunts() [][]byte { - // Nonrecursive impl. - innerHashes := [][]byte{} - for spn != nil { - if spn.Left != nil { - innerHashes = append(innerHashes, spn.Left.Hash) - } else if spn.Right != nil { - innerHashes = append(innerHashes, spn.Right.Hash) - } else { - break - } - spn = spn.Parent - } - return innerHashes -} - -// trails[0].Hash is the leaf hash for items[0]. -// trails[i].Parent.Parent....Parent == root for all i. -func trailsFromHashables(items []Hashable) (trails []*SimpleProofNode, root *SimpleProofNode) { - // Recursive impl. - switch len(items) { - case 0: - return nil, nil - case 1: - trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} - return []*SimpleProofNode{trail}, trail - default: - lefts, leftRoot := trailsFromHashables(items[:(len(items)+1)/2]) - rights, rightRoot := trailsFromHashables(items[(len(items)+1)/2:]) - rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) - root := &SimpleProofNode{rootHash, nil, nil, nil} - leftRoot.Parent = root - leftRoot.Right = rightRoot - rightRoot.Parent = root - rightRoot.Left = leftRoot - return append(lefts, rights...), root - } -}