SearchValues: more review addressing

This commit is contained in:
Łukasz Magiera 2018-08-07 12:43:52 +02:00
parent d72432caf1
commit 6ca5dd7bf4
2 changed files with 27 additions and 25 deletions

View File

@ -11,9 +11,9 @@ import (
pb "github.com/libp2p/go-libp2p-kad-dht/pb" pb "github.com/libp2p/go-libp2p-kad-dht/pb"
inet "github.com/libp2p/go-libp2p-net" inet "github.com/libp2p/go-libp2p-net"
pstore "github.com/libp2p/go-libp2p-peerstore" pstore "github.com/libp2p/go-libp2p-peerstore"
"github.com/libp2p/go-libp2p-record" record "github.com/libp2p/go-libp2p-record"
"github.com/libp2p/go-libp2p-routing" routing "github.com/libp2p/go-libp2p-routing"
"github.com/libp2p/go-libp2p/p2p/net/mock" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
) )
func TestGetFailures(t *testing.T) { func TestGetFailures(t *testing.T) {

View File

@ -173,17 +173,20 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...ropts.O
maxVals = defaultQuorum * 4 // we want some upper bound on how maxVals = defaultQuorum * 4 // we want some upper bound on how
// much correctional entries we will send // much correctional entries we will send
} }
// vals is used collect entries we got so far and send corrections to peers
// when we exit this function
vals := make([]RecvdVal, 0, maxVals) vals := make([]RecvdVal, 0, maxVals)
best := -1 var best *RecvdVal
defer func() { defer func() {
if len(vals) <= 1 || best < 0 { if len(vals) <= 1 || best == nil {
return return
} }
fixupRec := record.MakePutRecord(key, vals[best].Val) fixupRec := record.MakePutRecord(key, best.Val)
for _, v := range vals { for _, v := range vals {
// if someone sent us a different 'less-valid' record, lets correct them // if someone sent us a different 'less-valid' record, lets correct them
if !bytes.Equal(v.Val, vals[best].Val) { if !bytes.Equal(v.Val, best.Val) {
go func(v RecvdVal) { go func(v RecvdVal) {
if v.From == dht.self { if v.From == dht.self {
err := dht.putLocal(key, fixupRec) err := dht.putLocal(key, fixupRec)
@ -210,32 +213,28 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...ropts.O
return return
} }
i := len(vals)
if len(vals) < maxVals { if len(vals) < maxVals {
vals = append(vals, v) vals = append(vals, v)
} else {
i = (best + 1) % maxVals
vals[i] = v
} }
if v.Val == nil { if v.Val == nil {
continue continue
} }
// Select best value // Select best value
if best > -1 { if best != nil {
sel, err := dht.Validator.Select(key, [][]byte{vals[best].Val, v.Val}) sel, err := dht.Validator.Select(key, [][]byte{best.Val, v.Val})
if err != nil { if err != nil {
log.Warning("Failed to select dht key: ", err) log.Warning("Failed to select dht key: ", err)
continue continue
} }
if sel == 1 && !bytes.Equal(v.Val, vals[best].Val) { if sel == 1 && !bytes.Equal(v.Val, best.Val) {
best = i best = &v
out <- v.Val out <- v.Val
} }
} else { } else {
// Output first valid value // Output first valid value
if err := dht.Validator.Validate(key, v.Val); err == nil { if err := dht.Validator.Validate(key, v.Val); err == nil {
best = len(vals) - 1 best = &v
out <- v.Val out <- v.Val
} }
} }
@ -250,8 +249,14 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...ropts.O
// GetValues gets nvals values corresponding to the given key. // GetValues gets nvals values corresponding to the given key.
func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) { func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) {
eip := log.EventBegin(ctx, "GetValues")
eip.Append(loggableKey(key))
defer eip.Done()
valCh, err := dht.getValues(ctx, key, nvals) valCh, err := dht.getValues(ctx, key, nvals)
if err != nil { if err != nil {
eip.SetError(err)
return nil, err return nil, err
} }
@ -264,18 +269,10 @@ func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []R
} }
func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-chan RecvdVal, error) { func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-chan RecvdVal, error) {
eip := log.EventBegin(ctx, "GetValues")
vals := make(chan RecvdVal, 1) vals := make(chan RecvdVal, 1)
done := func(err error) (<-chan RecvdVal, error) { done := func(err error) (<-chan RecvdVal, error) {
defer close(vals) defer close(vals)
eip.Append(loggableKey(key))
if err != nil {
eip.SetError(err)
}
eip.Done()
return vals, err return vals, err
} }
@ -348,7 +345,12 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-cha
From: p, From: p,
} }
valslock.Lock() valslock.Lock()
vals <- rv select {
case vals <- rv:
case <-ctx.Done():
valslock.Unlock()
return nil, ctx.Err()
}
got++ got++
// If we have collected enough records, we're done // If we have collected enough records, we're done