From 29471d75cb50eb4cea5878b8bd1be25e8150564c Mon Sep 17 00:00:00 2001 From: Emmanuel Odeke Date: Wed, 13 Dec 2017 22:53:02 -0700 Subject: [PATCH] common: no more relying on math/rand.DefaultSource Fixes https://github.com/tendermint/tmlibs/issues/99 Updates https://github.com/tendermint/tendermint/issues/973 Removed usages of math/rand.DefaultSource in favour of our own source that's seeded with a completely random source and is safe for use in concurrent in multiple goroutines. Also extend some functionality that the stdlib exposes such as * RandPerm * RandIntn * RandInt31 * RandInt63 Also added an integration test whose purpose is to be run as a consistency check to ensure that our results never repeat hence that our internal PRNG is uniquely seeded each time. This integration test can be triggered by setting environment variable: `TENDERMINT_INTEGRATION_TESTS=true` for example ```shell TENDERMINT_INTEGRATION_TESTS=true go test ``` --- common/bit_array.go | 7 ++- common/random.go | 89 +++++++++++++++++++++++++--------- common/random_test.go | 108 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 26 deletions(-) create mode 100644 common/random_test.go diff --git a/common/bit_array.go b/common/bit_array.go index 5590fe61..848763b4 100644 --- a/common/bit_array.go +++ b/common/bit_array.go @@ -3,7 +3,6 @@ package common import ( "encoding/binary" "fmt" - "math/rand" "strings" "sync" ) @@ -212,12 +211,12 @@ func (bA *BitArray) PickRandom() (int, bool) { if length == 0 { return 0, false } - randElemStart := rand.Intn(length) + randElemStart := RandIntn(length) for i := 0; i < length; i++ { elemIdx := ((i + randElemStart) % length) if elemIdx < length-1 { if bA.Elems[elemIdx] > 0 { - randBitStart := rand.Intn(64) + randBitStart := RandIntn(64) for j := 0; j < 64; j++ { bitIdx := ((j + randBitStart) % 64) if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { @@ -232,7 +231,7 @@ func (bA *BitArray) PickRandom() (int, bool) { if elemBits == 0 { elemBits = 64 } - randBitStart := rand.Intn(elemBits) + randBitStart := RandIntn(elemBits) for j := 0; j < elemBits; j++ { bitIdx := ((j + randBitStart) % elemBits) if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { diff --git a/common/random.go b/common/random.go index 73bd1635..f0d169e0 100644 --- a/common/random.go +++ b/common/random.go @@ -3,6 +3,7 @@ package common import ( crand "crypto/rand" "math/rand" + "sync" "time" ) @@ -10,6 +11,11 @@ const ( strChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" // 62 characters ) +var rng struct { + sync.Mutex + *rand.Rand +} + func init() { b := cRandBytes(8) var seed uint64 @@ -17,7 +23,7 @@ func init() { seed |= uint64(b[i]) seed <<= 8 } - rand.Seed(int64(seed)) + rng.Rand = rand.New(rand.NewSource(int64(seed))) } // Constructs an alphanumeric string of given length. @@ -25,7 +31,7 @@ func RandStr(length int) string { chars := []byte{} MAIN_LOOP: for { - val := rand.Int63() + val := rng.Int63() for i := 0; i < 10; i++ { v := int(val & 0x3f) // rightmost 6 bits if v >= 62 { // only 62 characters in strChars @@ -45,72 +51,98 @@ MAIN_LOOP: } func RandUint16() uint16 { - return uint16(rand.Uint32() & (1<<16 - 1)) + return uint16(RandUint32() & (1<<16 - 1)) } func RandUint32() uint32 { - return rand.Uint32() + rng.Lock() + u32 := rng.Uint32() + rng.Unlock() + return u32 } func RandUint64() uint64 { - return uint64(rand.Uint32())<<32 + uint64(rand.Uint32()) + return uint64(RandUint32())<<32 + uint64(RandUint32()) } func RandUint() uint { - return uint(rand.Int()) + rng.Lock() + i := rng.Int() + rng.Unlock() + return uint(i) } func RandInt16() int16 { - return int16(rand.Uint32() & (1<<16 - 1)) + return int16(RandUint32() & (1<<16 - 1)) } func RandInt32() int32 { - return int32(rand.Uint32()) + return int32(RandUint32()) } func RandInt64() int64 { - return int64(rand.Uint32())<<32 + int64(rand.Uint32()) + return int64(RandUint64()) } func RandInt() int { - return rand.Int() + rng.Lock() + i := rng.Int() + rng.Unlock() + return i +} + +func RandInt31() int32 { + rng.Lock() + i31 := rng.Int31() + rng.Unlock() + return i31 +} + +func RandInt63() int64 { + rng.Lock() + i63 := rng.Int63() + rng.Unlock() + return i63 } // Distributed pseudo-exponentially to test for various cases func RandUint16Exp() uint16 { - bits := rand.Uint32() % 16 + bits := RandUint32() % 16 if bits == 0 { return 0 } n := uint16(1 << (bits - 1)) - n += uint16(rand.Int31()) & ((1 << (bits - 1)) - 1) + n += uint16(RandInt31()) & ((1 << (bits - 1)) - 1) return n } // Distributed pseudo-exponentially to test for various cases func RandUint32Exp() uint32 { - bits := rand.Uint32() % 32 + bits := RandUint32() % 32 if bits == 0 { return 0 } n := uint32(1 << (bits - 1)) - n += uint32(rand.Int31()) & ((1 << (bits - 1)) - 1) + n += uint32(RandInt31()) & ((1 << (bits - 1)) - 1) return n } // Distributed pseudo-exponentially to test for various cases func RandUint64Exp() uint64 { - bits := rand.Uint32() % 64 + bits := RandUint32() % 64 if bits == 0 { return 0 } n := uint64(1 << (bits - 1)) - n += uint64(rand.Int63()) & ((1 << (bits - 1)) - 1) + n += uint64(RandInt63()) & ((1 << (bits - 1)) - 1) return n } func RandFloat32() float32 { - return rand.Float32() + rng.Lock() + f32 := rng.Float32() + rng.Unlock() + return f32 } func RandTime() time.Time { @@ -118,11 +150,24 @@ func RandTime() time.Time { } func RandBytes(n int) []byte { - bs := make([]byte, n) - for i := 0; i < n; i++ { - bs[i] = byte(rand.Intn(256)) - } - return bs + return cRandBytes(n) +} + +// RandIntn returns, as an int, a non-negative pseudo-random number in [0, n). +// It panics if n <= 0 +func RandIntn(n int) int { + rng.Lock() + i := rng.Intn(n) + rng.Unlock() + return i +} + +// RandPerm returns a pseudo-random permutation of n integers in [0, n). +func RandPerm(n int) []int { + rng.Lock() + perm := rng.Perm(n) + rng.Unlock() + return perm } // NOTE: This relies on the os's random number generator. diff --git a/common/random_test.go b/common/random_test.go new file mode 100644 index 00000000..dd803b3f --- /dev/null +++ b/common/random_test.go @@ -0,0 +1,108 @@ +package common_test + +import ( + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/tendermint/tmlibs/common" +) + +// It is essential that these tests run and never repeat their outputs +// lest we've been pwned and the behavior of our randomness is controlled. +// See Issues: +// * https://github.com/tendermint/tmlibs/issues/99 +// * https://github.com/tendermint/tendermint/issues/973 +func TestUniqueRng(t *testing.T) { + if os.Getenv("TENDERMINT_INTEGRATION_TESTS") == "" { + t.Skipf("Can only be run as an integration test") + } + + // The goal of this test is to invoke the + // Rand* tests externally with no repeating results, booted up. + // Any repeated results indicate that the seed is the same or that + // perhaps we are using math/rand directly. + tmpDir, err := ioutil.TempDir("", "rng-tests") + if err != nil { + t.Fatalf("Creating tempDir: %v", err) + } + defer os.RemoveAll(tmpDir) + + outpath := filepath.Join(tmpDir, "main.go") + f, err := os.Create(outpath) + if err != nil { + t.Fatalf("Setting up %q err: %v", outpath, err) + } + f.Write([]byte(integrationTestProgram)) + if err := f.Close(); err != nil { + t.Fatalf("Closing: %v", err) + } + + outputs := make(map[string][]int) + for i := 0; i < 100; i++ { + cmd := exec.Command("go", "run", outpath) + bOutput, err := cmd.CombinedOutput() + if err != nil { + t.Errorf("Run #%d: err: %v output: %s", i, err, bOutput) + continue + } + output := string(bOutput) + runs, seen := outputs[output] + if seen { + t.Errorf("Run #%d's output was already seen in previous runs: %v", i, runs) + } + outputs[output] = append(outputs[output], i) + } +} + +const integrationTestProgram = ` +package main + +import ( + "encoding/json" + "fmt" + "math/rand" + + "github.com/tendermint/tmlibs/common" +) + +func main() { + // Set math/rand's Seed so that any direct invocations + // of math/rand will reveal themselves. + rand.Seed(1) + perm := common.RandPerm(10) + blob, _ := json.Marshal(perm) + fmt.Printf("perm: %s\n", blob) + + fmt.Printf("randInt: %d\n", common.RandInt()) + fmt.Printf("randUint: %d\n", common.RandUint()) + fmt.Printf("randIntn: %d\n", common.RandIntn(97)) + fmt.Printf("randInt31: %d\n", common.RandInt31()) + fmt.Printf("randInt32: %d\n", common.RandInt32()) + fmt.Printf("randInt63: %d\n", common.RandInt63()) + fmt.Printf("randInt64: %d\n", common.RandInt64()) + fmt.Printf("randUint32: %d\n", common.RandUint32()) + fmt.Printf("randUint64: %d\n", common.RandUint64()) + fmt.Printf("randUint16Exp: %d\n", common.RandUint16Exp()) + fmt.Printf("randUint32Exp: %d\n", common.RandUint32Exp()) + fmt.Printf("randUint64Exp: %d\n", common.RandUint64Exp()) +}` + +func TestRngConcurrencySafety(t *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + _ = common.RandUint64() + <-time.After(time.Millisecond * time.Duration(common.RandIntn(100))) + _ = common.RandPerm(3) + }() + } + wg.Wait() +}