Fix random distribution in bitArray.PickRandom (#2534)

* Fix random distribution in bitArray.PickRandom

Previously it was very biased. 63 "_" followed by a single "x" had
much greater odds of being chosen. Additionally, the last element was
skewed. This fixes that by first preproccessing the set of all true
indices, and then randomly selecting a single element from there.

This commit also makes the code here significantly simpler, and
improves test cases.

* unlock mtx right after we select true indices
This commit is contained in:
Dev Ojha 2018-10-05 00:00:50 -07:00 committed by Anton Kaliaev
parent 5b120d788a
commit c648c93807
3 changed files with 60 additions and 42 deletions

View File

@ -45,3 +45,4 @@ timeoutPrecommit before starting next round
- [evidence] \#2515 fix db iter leak (@goolAdapter) - [evidence] \#2515 fix db iter leak (@goolAdapter)
- [common/bit_array] Fixed a bug in the `Or` function - [common/bit_array] Fixed a bug in the `Or` function
- [common/bit_array] Fixed a bug in the `Sub` function (@bradyjoestar) - [common/bit_array] Fixed a bug in the `Sub` function (@bradyjoestar)
- [common] \#2534 make bit array's PickRandom choose uniformly from true bits

View File

@ -234,49 +234,53 @@ func (bA *BitArray) IsFull() bool {
return (lastElem+1)&((uint64(1)<<uint(lastElemBits))-1) == 0 return (lastElem+1)&((uint64(1)<<uint(lastElemBits))-1) == 0
} }
// PickRandom returns a random index in the bit array, and its value. // PickRandom returns a random index for a set bit in the bit array.
// If there is no such value, it returns 0, false.
// It uses the global randomness in `random.go` to get this index. // It uses the global randomness in `random.go` to get this index.
func (bA *BitArray) PickRandom() (int, bool) { func (bA *BitArray) PickRandom() (int, bool) {
if bA == nil { if bA == nil {
return 0, false return 0, false
} }
bA.mtx.Lock()
defer bA.mtx.Unlock()
length := len(bA.Elems) bA.mtx.Lock()
if length == 0 { trueIndices := bA.getTrueIndices()
bA.mtx.Unlock()
if len(trueIndices) == 0 { // no bits set to true
return 0, false return 0, false
} }
randElemStart := RandIntn(length)
for i := 0; i < length; i++ { return trueIndices[RandIntn(len(trueIndices))], true
elemIdx := ((i + randElemStart) % length) }
if elemIdx < length-1 {
if bA.Elems[elemIdx] > 0 { func (bA *BitArray) getTrueIndices() []int {
randBitStart := RandIntn(64) trueIndices := make([]int, 0, bA.Bits)
curBit := 0
numElems := len(bA.Elems)
// set all true indices
for i := 0; i < numElems-1; i++ {
elem := bA.Elems[i]
if elem == 0 {
curBit += 64
continue
}
for j := 0; j < 64; j++ { for j := 0; j < 64; j++ {
bitIdx := ((j + randBitStart) % 64) if (elem & (uint64(1) << uint64(j))) > 0 {
if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { trueIndices = append(trueIndices, curBit)
return 64*elemIdx + bitIdx, true }
curBit++
} }
} }
PanicSanity("should not happen") // handle last element
lastElem := bA.Elems[numElems-1]
numFinalBits := bA.Bits - curBit
for i := 0; i < numFinalBits; i++ {
if (lastElem & (uint64(1) << uint64(i))) > 0 {
trueIndices = append(trueIndices, curBit)
} }
} else { curBit++
// Special case for last elem, to ignore straggler bits
elemBits := bA.Bits % 64
if elemBits == 0 {
elemBits = 64
} }
randBitStart := RandIntn(elemBits) return trueIndices
for j := 0; j < elemBits; j++ {
bitIdx := ((j + randBitStart) % elemBits)
if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 {
return 64*elemIdx + bitIdx, true
}
}
}
}
return 0, false
} }
// String returns a string representation of BitArray: BA{<bit-string>}, // String returns a string representation of BitArray: BA{<bit-string>},

View File

@ -107,16 +107,29 @@ func TestSub(t *testing.T) {
} }
func TestPickRandom(t *testing.T) { func TestPickRandom(t *testing.T) {
for idx := 0; idx < 123; idx++ { empty16Bits := "________________"
bA1 := NewBitArray(123) empty64Bits := empty16Bits + empty16Bits + empty16Bits + empty16Bits
bA1.SetIndex(idx, true) testCases := []struct {
index, ok := bA1.PickRandom() bA string
if !ok { ok bool
t.Fatal("Expected to pick element but got none") }{
} {`null`, false},
if index != idx { {`"x"`, true},
t.Fatalf("Expected to pick element at %v but got wrong index", idx) {`"` + empty16Bits + `"`, false},
{`"x` + empty16Bits + `"`, true},
{`"` + empty16Bits + `x"`, true},
{`"x` + empty16Bits + `x"`, true},
{`"` + empty64Bits + `"`, false},
{`"x` + empty64Bits + `"`, true},
{`"` + empty64Bits + `x"`, true},
{`"x` + empty64Bits + `x"`, true},
} }
for _, tc := range testCases {
var bitArr *BitArray
err := json.Unmarshal([]byte(tc.bA), &bitArr)
require.NoError(t, err)
_, ok := bitArr.PickRandom()
require.Equal(t, tc.ok, ok, "PickRandom got an unexpected result on input %s", tc.bA)
} }
} }