mirror of
https://github.com/fluencelabs/go-libp2p-kad-dht
synced 2025-04-24 22:32:13 +00:00
require that the validator be explicitly passed in
Note: this does mean that the DHT won't work with peer keys by default and that the constructor signature changes. Given all the changes that'll come with the libp2p refactor, I don't feel too bad about this.
This commit is contained in:
parent
cad57471f5
commit
3befc403d7
21
dht.go
21
dht.go
@ -54,8 +54,7 @@ type IpfsDHT struct {
|
||||
|
||||
birth time.Time // When this peer started up
|
||||
|
||||
Validator record.Validator // record validator funcs
|
||||
Selector record.Selector // record selection funcs
|
||||
Validator record.Validator
|
||||
|
||||
ctx context.Context
|
||||
proc goprocess.Process
|
||||
@ -69,8 +68,8 @@ type IpfsDHT struct {
|
||||
// NewDHT creates a new DHT object with the given peer as the 'local' host.
|
||||
// IpfsDHT's initialized with this function will respond to DHT requests,
|
||||
// whereas IpfsDHT's initialized with NewDHTClient will not.
|
||||
func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
|
||||
dht := NewDHTClient(ctx, h, dstore)
|
||||
func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching, validator record.Validator) *IpfsDHT {
|
||||
dht := NewDHTClient(ctx, h, dstore, validator)
|
||||
|
||||
h.SetStreamHandler(ProtocolDHT, dht.handleNewStream)
|
||||
h.SetStreamHandler(ProtocolDHTOld, dht.handleNewStream)
|
||||
@ -81,7 +80,8 @@ func NewDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
|
||||
// NewDHTClient creates a new DHT object with the given peer as the 'local'
|
||||
// host. IpfsDHT clients initialized with this function will not respond to DHT
|
||||
// requests. If you need a peer to respond to DHT requests, use NewDHT instead.
|
||||
func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
|
||||
// NewDHTClient creates a new DHT object with the given peer as the 'local' host
|
||||
func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching, validator record.Validator) *IpfsDHT {
|
||||
dht := makeDHT(ctx, h, dstore)
|
||||
|
||||
// register for network notifs.
|
||||
@ -94,9 +94,7 @@ func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT
|
||||
})
|
||||
|
||||
dht.proc.AddChild(dht.providers.Process())
|
||||
|
||||
dht.Validator["pk"] = record.PublicKeyValidator
|
||||
dht.Selector["pk"] = record.PublicKeySelector
|
||||
dht.Validator = validator
|
||||
|
||||
return dht
|
||||
}
|
||||
@ -122,9 +120,6 @@ func makeDHT(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT {
|
||||
providers: providers.NewProviderManager(ctx, h.ID(), dstore),
|
||||
birth: time.Now(),
|
||||
routingTable: rt,
|
||||
|
||||
Validator: make(record.Validator),
|
||||
Selector: make(record.Selector),
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,7 +171,7 @@ func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string)
|
||||
log.Debug("getValueOrPeers: got value")
|
||||
|
||||
// make sure record is valid.
|
||||
err = dht.Validator.VerifyRecord(record)
|
||||
err = dht.Validator.Validate(record.GetKey(), record.GetValue())
|
||||
if err != nil {
|
||||
log.Info("Received invalid record! (discarded)")
|
||||
// return a sentinal to signify an invalid record was received
|
||||
@ -239,7 +234,7 @@ func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dht.Validator.VerifyRecord(rec)
|
||||
err = dht.Validator.Validate(rec.GetKey(), rec.GetValue())
|
||||
if err != nil {
|
||||
log.Debugf("local record verify failed: %s (discarded)", err)
|
||||
return nil, err
|
||||
|
82
dht_test.go
82
dht_test.go
@ -41,19 +41,50 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
type blankValidator struct{}
|
||||
|
||||
func (blankValidator) Validate(_ string, _ []byte) error { return nil }
|
||||
func (blankValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
|
||||
|
||||
type testValidator struct{}
|
||||
|
||||
func (testValidator) Select(_ string, bs [][]byte) (int, error) {
|
||||
index := -1
|
||||
for i, b := range bs {
|
||||
if bytes.Compare(b, []byte("newer")) == 0 {
|
||||
index = i
|
||||
} else if bytes.Compare(b, []byte("valid")) == 0 {
|
||||
if index == -1 {
|
||||
index = i
|
||||
}
|
||||
}
|
||||
}
|
||||
if index == -1 {
|
||||
return -1, errors.New("no rec found")
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
func (testValidator) Validate(_ string, b []byte) error {
|
||||
if bytes.Compare(b, []byte("expired")) == 0 {
|
||||
return errors.New("expired")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
|
||||
h := bhost.New(netutil.GenSwarmNetwork(t, ctx))
|
||||
|
||||
dss := dssync.MutexWrap(ds.NewMapDatastore())
|
||||
var d *IpfsDHT
|
||||
if client {
|
||||
d = NewDHTClient(ctx, h, dss)
|
||||
} else {
|
||||
d = NewDHT(ctx, h, dss)
|
||||
validator := record.NamespacedValidator{
|
||||
"v": blankValidator{},
|
||||
"pk": record.PublicKeyValidator{},
|
||||
}
|
||||
if client {
|
||||
d = NewDHTClient(ctx, h, dss, validator)
|
||||
} else {
|
||||
d = NewDHT(ctx, h, dss, validator)
|
||||
}
|
||||
|
||||
d.Validator["v"] = func(*record.ValidationRecord) error { return nil }
|
||||
d.Selector["v"] = func(_ string, bs [][]byte) (int, error) { return 0, nil }
|
||||
return d
|
||||
}
|
||||
|
||||
@ -148,14 +179,6 @@ func TestValueGetSet(t *testing.T) {
|
||||
defer dhtA.host.Close()
|
||||
defer dhtB.host.Close()
|
||||
|
||||
vf := func(*record.ValidationRecord) error { return nil }
|
||||
nulsel := func(_ string, bs [][]byte) (int, error) { return 0, nil }
|
||||
|
||||
dhtA.Validator["v"] = vf
|
||||
dhtB.Validator["v"] = vf
|
||||
dhtA.Selector["v"] = nulsel
|
||||
dhtB.Selector["v"] = nulsel
|
||||
|
||||
connect(t, ctx, dhtA, dhtB)
|
||||
|
||||
log.Debug("adding value on: ", dhtA.self)
|
||||
@ -203,33 +226,8 @@ func TestValueSetInvalid(t *testing.T) {
|
||||
defer dhtA.host.Close()
|
||||
defer dhtB.host.Close()
|
||||
|
||||
vf := func(r *record.ValidationRecord) error {
|
||||
if bytes.Compare(r.Value, []byte("expired")) == 0 {
|
||||
return errors.New("expired")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
nulsel := func(k string, bs [][]byte) (int, error) {
|
||||
index := -1
|
||||
for i, b := range bs {
|
||||
if bytes.Compare(b, []byte("newer")) == 0 {
|
||||
index = i
|
||||
} else if bytes.Compare(b, []byte("valid")) == 0 {
|
||||
if index == -1 {
|
||||
index = i
|
||||
}
|
||||
}
|
||||
}
|
||||
if index == -1 {
|
||||
return -1, errors.New("no rec found")
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
dhtA.Validator["v"] = vf
|
||||
dhtB.Validator["v"] = vf
|
||||
dhtA.Selector["v"] = nulsel
|
||||
dhtB.Selector["v"] = nulsel
|
||||
dhtA.Validator.(record.NamespacedValidator)["v"] = testValidator{}
|
||||
dhtB.Validator.(record.NamespacedValidator)["v"] = testValidator{}
|
||||
|
||||
connect(t, ctx, dhtA, dhtB)
|
||||
|
||||
|
@ -32,7 +32,7 @@ func TestGetFailures(t *testing.T) {
|
||||
hosts := mn.Hosts()
|
||||
|
||||
tsds := dssync.MutexWrap(ds.NewMapDatastore())
|
||||
d := NewDHT(ctx, hosts[0], tsds)
|
||||
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
|
||||
d.Update(ctx, hosts[1].ID())
|
||||
|
||||
// Reply with failures to every message
|
||||
@ -149,7 +149,7 @@ func TestNotFound(t *testing.T) {
|
||||
}
|
||||
hosts := mn.Hosts()
|
||||
tsds := dssync.MutexWrap(ds.NewMapDatastore())
|
||||
d := NewDHT(ctx, hosts[0], tsds)
|
||||
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
|
||||
|
||||
for _, p := range hosts {
|
||||
d.Update(ctx, p.ID())
|
||||
@ -226,7 +226,7 @@ func TestLessThanKResponses(t *testing.T) {
|
||||
hosts := mn.Hosts()
|
||||
|
||||
tsds := dssync.MutexWrap(ds.NewMapDatastore())
|
||||
d := NewDHT(ctx, hosts[0], tsds)
|
||||
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
|
||||
|
||||
for i := 1; i < 5; i++ {
|
||||
d.Update(ctx, hosts[i].ID())
|
||||
@ -293,7 +293,7 @@ func TestMultipleQueries(t *testing.T) {
|
||||
}
|
||||
hosts := mn.Hosts()
|
||||
tsds := dssync.MutexWrap(ds.NewMapDatastore())
|
||||
d := NewDHT(ctx, hosts[0], tsds)
|
||||
d := NewDHT(ctx, hosts[0], tsds, record.NamespacedValidator{})
|
||||
|
||||
d.Update(ctx, hosts[1].ID())
|
||||
|
||||
|
15
handlers.go
15
handlers.go
@ -163,21 +163,26 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess
|
||||
eip.Done()
|
||||
}()
|
||||
|
||||
dskey := convertToDsKey(pmes.GetKey())
|
||||
|
||||
rec := pmes.GetRecord()
|
||||
if rec == nil {
|
||||
log.Infof("Got nil record from: %s", p.Pretty())
|
||||
return nil, errors.New("nil record")
|
||||
}
|
||||
|
||||
if pmes.GetKey() != rec.GetKey() {
|
||||
return nil, errors.New("put key doesn't match record key")
|
||||
}
|
||||
|
||||
cleanRecord(rec)
|
||||
|
||||
// Make sure the record is valid (not expired, valid signature etc)
|
||||
if err = dht.Validator.VerifyRecord(rec); err != nil {
|
||||
if err = dht.Validator.Validate(rec.GetKey(), rec.GetValue()); err != nil {
|
||||
log.Warningf("Bad dht record in PUT from: %s. %s", p.Pretty(), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dskey := convertToDsKey(rec.GetKey())
|
||||
|
||||
// Make sure the new record is "better" than the record we have locally.
|
||||
// This prevents a record with for example a lower sequence number from
|
||||
// overwriting a record with a higher sequence number.
|
||||
@ -188,7 +193,7 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess
|
||||
|
||||
if existing != nil {
|
||||
recs := [][]byte{rec.GetValue(), existing.GetValue()}
|
||||
i, err := dht.Selector.BestRecord(pmes.GetKey(), recs)
|
||||
i, err := dht.Validator.Select(rec.GetKey(), recs)
|
||||
if err != nil {
|
||||
log.Warningf("Bad dht record in PUT from %s: %s", p.Pretty(), err)
|
||||
return nil, err
|
||||
@ -237,7 +242,7 @@ func (dht *IpfsDHT) getRecordFromDatastore(dskey ds.Key) (*recpb.Record, error)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = dht.Validator.VerifyRecord(rec)
|
||||
err = dht.Validator.Validate(rec.GetKey(), rec.GetValue())
|
||||
if err != nil {
|
||||
// Invalid record in datastore, probably expired but don't return an error,
|
||||
// we'll just overwrite it
|
||||
|
@ -72,9 +72,9 @@
|
||||
},
|
||||
{
|
||||
"author": "whyrusleeping",
|
||||
"hash": "QmZ9V14gpwKsTUG7y5mHZDnHSF4Fa4rKsXNx7jSTEQ4JWs",
|
||||
"hash": "QmTUyK82BVPA6LmSzEJpfEunk9uBaQzWtMsNP917tVj4sT",
|
||||
"name": "go-libp2p-record",
|
||||
"version": "4.0.1"
|
||||
"version": "4.1.0"
|
||||
},
|
||||
{
|
||||
"author": "whyrusleeping",
|
||||
@ -168,3 +168,4 @@
|
||||
"releaseCmd": "git commit -a -m \"gx publish $VERSION\"",
|
||||
"version": "4.0.4"
|
||||
}
|
||||
|
||||
|
@ -125,7 +125,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...ropts.Opti
|
||||
return nil, routing.ErrNotFound
|
||||
}
|
||||
|
||||
i, err := dht.Selector.BestRecord(key, recs)
|
||||
i, err := dht.Validator.Select(key, recs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user