require that the validator be explicitly passed in

Note: this does mean that the DHT won't work with peer keys by default and that
the constructor signature changes. Given all the changes that'll come with the
libp2p refactor, I don't feel too bad about this.
This commit is contained in:
Steven Allen 2018-05-03 16:36:48 -07:00
parent cad57471f5
commit 3befc403d7
6 changed files with 66 additions and 67 deletions

21
dht.go
View File

@ -54,8 +54,7 @@ type IpfsDHT struct {
birth time.Time // When this peer started up
Validator record.Validator // record validator funcs
Selector record.Selector // record selection funcs
Validator record.Validator
ctx context.Context
proc goprocess.Process
@ -69,8 +68,8 @@ type IpfsDHT struct {
// NewDHT creates a new DHT object with the given peer as the 'local' host.
// IpfsDHT's initialized with this function will respond to DHT requests,
// whereas IpfsDHT's initialized with NewDHTClient will not.
func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
dht := NewDHTClient(ctx, h, dstore)
func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching, validator record.Validator) *IpfsDHT {
dht := NewDHTClient(ctx, h, dstore, validator)
h.SetStreamHandler(ProtocolDHT, dht.handleNewStream)
h.SetStreamHandler(ProtocolDHTOld, dht.handleNewStream)
@ -81,7 +80,8 @@ func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
// NewDHTClient creates a new DHT object with the given peer as the 'local'
// host. IpfsDHT clients initialized with this function will not respond to DHT
// requests. If you need a peer to respond to DHT requests, use NewDHT instead.
func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
// NewDHTClient creates a new DHT object with the given peer as the 'local' host
func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching, validator record.Validator) *IpfsDHT {
dht := makeDHT(ctx, h, dstore)
// register for network notifs.
@ -94,9 +94,7 @@ func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT
})
dht.proc.AddChild(dht.providers.Process())
dht.Validator["pk"] = record.PublicKeyValidator
dht.Selector["pk"] = record.PublicKeySelector
dht.Validator = validator
return dht
}
@ -122,9 +120,6 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
providers: providers.NewProviderManager(ctx, h.ID(), dstore),
birth: time.Now(),
routingTable: rt,
Validator: make(record.Validator),
Selector: make(record.Selector),
}
}
@ -176,7 +171,7 @@ func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string)
log.Debug("getValueOrPeers: got value")
// make sure record is valid.
err = dht.Validator.VerifyRecord(record)
err = dht.Validator.Validate(record.GetKey(), record.GetValue())
if err != nil {
log.Info("Received invalid record! (discarded)")
// return a sentinal to signify an invalid record was received
@ -239,7 +234,7 @@ func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) {
return nil, err
}
err = dht.Validator.VerifyRecord(rec)
err = dht.Validator.Validate(rec.GetKey(), rec.GetValue())
if err != nil {
log.Debugf("local record verify failed: %s (discarded)", err)
return nil, err

View File

@ -41,19 +41,50 @@ func init() {
}
}
type blankValidator struct{}
func (blankValidator) Validate(_ string, _ []byte) error { return nil }
func (blankValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
type testValidator struct{}
func (testValidator) Select(_ string, bs [][]byte) (int, error) {
index := -1
for i, b := range bs {
if bytes.Compare(b, []byte("newer")) == 0 {
index = i
} else if bytes.Compare(b, []byte("valid")) == 0 {
if index == -1 {
index = i
}
}
}
if index == -1 {
return -1, errors.New("no rec found")
}
return index, nil
}
func (testValidator) Validate(_ string, b []byte) error {
if bytes.Compare(b, []byte("expired")) == 0 {
return errors.New("expired")
}
return nil
}
func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
h := bhost.New(netutil.GenSwarmNetwork(t, ctx))
dss := dssync.MutexWrap(ds.NewMapDatastore())
var d *IpfsDHT
if client {
d = NewDHTClient(ctx, h, dss)
} else {
d = NewDHT(ctx, h, dss)
validator := record.NamespacedValidator{
"v": blankValidator{},
"pk": record.PublicKeyValidator{},
}
if client {
d = NewDHTClient(ctx, h, dss, validator)
} else {
d = NewDHT(ctx, h, dss, validator)
}
d.Validator["v"] = func(*record.ValidationRecord) error { return nil }
d.Selector["v"] = func(_ string, bs [][]byte) (int, error) { return 0, nil }
return d
}
@ -148,14 +179,6 @@ func TestValueGetSet(t *testing.T) {
defer dhtA.host.Close()
defer dhtB.host.Close()
vf := func(*record.ValidationRecord) error { return nil }
nulsel := func(_ string, bs [][]byte) (int, error) { return 0, nil }
dhtA.Validator["v"] = vf
dhtB.Validator["v"] = vf
dhtA.Selector["v"] = nulsel
dhtB.Selector["v"] = nulsel
connect(t, ctx, dhtA, dhtB)
log.Debug("adding value on: ", dhtA.self)
@ -203,33 +226,8 @@ func TestValueSetInvalid(t *testing.T) {
defer dhtA.host.Close()
defer dhtB.host.Close()
vf := func(r *record.ValidationRecord) error {
if bytes.Compare(r.Value, []byte("expired")) == 0 {
return errors.New("expired")
}
return nil
}
nulsel := func(k string, bs [][]byte) (int, error) {
index := -1
for i, b := range bs {
if bytes.Compare(b, []byte("newer")) == 0 {
index = i
} else if bytes.Compare(b, []byte("valid")) == 0 {
if index == -1 {
index = i
}
}
}
if index == -1 {
return -1, errors.New("no rec found")
}
return index, nil
}
dhtA.Validator["v"] = vf
dhtB.Validator["v"] = vf
dhtA.Selector["v"] = nulsel
dhtB.Selector["v"] = nulsel
dhtA.Validator.(record.NamespacedValidator)["v"] = testValidator{}
dhtB.Validator.(record.NamespacedValidator)["v"] = testValidator{}
connect(t, ctx, dhtA, dhtB)

View File

@ -32,7 +32,7 @@ func TestGetFailures(t *testing.T) {
hosts := mn.Hosts()
tsds := dssync.MutexWrap(ds.NewMapDatastore())
d := NewDHT(ctx, hosts[0], tsds)
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
d.Update(ctx, hosts[1].ID())
// Reply with failures to every message
@ -149,7 +149,7 @@ func TestNotFound(t *testing.T) {
}
hosts := mn.Hosts()
tsds := dssync.MutexWrap(ds.NewMapDatastore())
d := NewDHT(ctx, hosts[0], tsds)
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
for _, p := range hosts {
d.Update(ctx, p.ID())
@ -226,7 +226,7 @@ func TestLessThanKResponses(t *testing.T) {
hosts := mn.Hosts()
tsds := dssync.MutexWrap(ds.NewMapDatastore())
d := NewDHT(ctx, hosts[0], tsds)
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
for i := 1; i < 5; i++ {
d.Update(ctx, hosts[i].ID())
@ -293,7 +293,7 @@ func TestMultipleQueries(t *testing.T) {
}
hosts := mn.Hosts()
tsds := dssync.MutexWrap(ds.NewMapDatastore())
d := NewDHT(ctx, hosts[0], tsds)
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
d.Update(ctx, hosts[1].ID())

View File

@ -163,21 +163,26 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess
eip.Done()
}()
dskey := convertToDsKey(pmes.GetKey())
rec := pmes.GetRecord()
if rec == nil {
log.Infof("Got nil record from: %s", p.Pretty())
return nil, errors.New("nil record")
}
if pmes.GetKey() != rec.GetKey() {
return nil, errors.New("put key doesn't match record key")
}
cleanRecord(rec)
// Make sure the record is valid (not expired, valid signature etc)
if err = dht.Validator.VerifyRecord(rec); err != nil {
if err = dht.Validator.Validate(rec.GetKey(), rec.GetValue()); err != nil {
log.Warningf("Bad dht record in PUT from: %s. %s", p.Pretty(), err)
return nil, err
}
dskey := convertToDsKey(rec.GetKey())
// Make sure the new record is "better" than the record we have locally.
// This prevents a record with for example a lower sequence number from
// overwriting a record with a higher sequence number.
@ -188,7 +193,7 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess
if existing != nil {
recs := [][]byte{rec.GetValue(), existing.GetValue()}
i, err := dht.Selector.BestRecord(pmes.GetKey(), recs)
i, err := dht.Validator.Select(rec.GetKey(), recs)
if err != nil {
log.Warningf("Bad dht record in PUT from %s: %s", p.Pretty(), err)
return nil, err
@ -237,7 +242,7 @@ func (dht *IpfsDHT) getRecordFromDatastore(dskey ds.Key) (*recpb.Record, error)
return nil, nil
}
err = dht.Validator.VerifyRecord(rec)
err = dht.Validator.Validate(rec.GetKey(), rec.GetValue())
if err != nil {
// Invalid record in datastore, probably expired but don't return an error,
// we'll just overwrite it

View File

@ -72,9 +72,9 @@
},
{
"author": "whyrusleeping",
"hash": "QmZ9V14gpwKsTUG7y5mHZDnHSF4Fa4rKsXNx7jSTEQ4JWs",
"hash": "QmTUyK82BVPA6LmSzEJpfEunk9uBaQzWtMsNP917tVj4sT",
"name": "go-libp2p-record",
"version": "4.0.1"
"version": "4.1.0"
},
{
"author": "whyrusleeping",
@ -168,3 +168,4 @@
"releaseCmd": "git commit -a -m \"gx publish $VERSION\"",
"version": "4.0.4"
}

View File

@ -125,7 +125,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...ropts.Opti
return nil, routing.ErrNotFound
}
i, err := dht.Selector.BestRecord(key, recs)
i, err := dht.Validator.Select(key, recs)
if err != nil {
return nil, err
}