mirror of
https://github.com/fluencelabs/go-libp2p-kad-dht
synced 2025-04-24 22:32:13 +00:00
Merge pull request #400 from libp2p/feat/disable-providers
feat: allow disabling value and provider storage/messages
This commit is contained in:
commit
2e6adb8c2b
7
dht.go
7
dht.go
@ -75,6 +75,11 @@ type IpfsDHT struct {
|
||||
triggerRtRefresh chan chan<- error
|
||||
|
||||
maxRecordAge time.Duration
|
||||
|
||||
// Allows disabling dht subsystems. These should _only_ be set on
|
||||
// "forked" DHTs (e.g., DHTs with custom protocols and/or private
|
||||
// networks).
|
||||
enableProviders, enableValues bool
|
||||
}
|
||||
|
||||
// Assert that IPFS assumptions about interfaces aren't broken. These aren't a
|
||||
@ -100,6 +105,8 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
|
||||
dht.rtRefreshQueryTimeout = cfg.RoutingTable.RefreshQueryTimeout
|
||||
|
||||
dht.maxRecordAge = cfg.MaxRecordAge
|
||||
dht.enableProviders = cfg.EnableProviders
|
||||
dht.enableValues = cfg.EnableValues
|
||||
|
||||
// register for network notifs.
|
||||
dht.host.Network().Notify((*netNotifiee)(dht))
|
||||
|
76
dht_test.go
76
dht_test.go
@ -107,13 +107,15 @@ func (testAtomicPutValidator) Select(_ string, bs [][]byte) (int, error) {
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
|
||||
func setupDHT(ctx context.Context, t *testing.T, client bool, options ...opts.Option) *IpfsDHT {
|
||||
d, err := New(
|
||||
ctx,
|
||||
bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)),
|
||||
opts.Client(client),
|
||||
opts.NamespacedValidator("v", blankValidator{}),
|
||||
opts.DisableAutoRefresh(),
|
||||
append([]opts.Option{
|
||||
opts.Client(client),
|
||||
opts.NamespacedValidator("v", blankValidator{}),
|
||||
opts.DisableAutoRefresh(),
|
||||
}, options...)...,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -1500,6 +1502,72 @@ func TestFindClosestPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvideDisabled(t *testing.T) {
|
||||
k := testCaseCids[0]
|
||||
for i := 0; i < 3; i++ {
|
||||
enabledA := (i & 0x1) > 0
|
||||
enabledB := (i & 0x2) > 0
|
||||
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
optsA, optsB []opts.Option
|
||||
)
|
||||
if !enabledA {
|
||||
optsA = append(optsA, opts.DisableProviders())
|
||||
}
|
||||
if !enabledB {
|
||||
optsB = append(optsB, opts.DisableProviders())
|
||||
}
|
||||
|
||||
dhtA := setupDHT(ctx, t, false, optsA...)
|
||||
dhtB := setupDHT(ctx, t, false, optsB...)
|
||||
|
||||
defer dhtA.Close()
|
||||
defer dhtB.Close()
|
||||
defer dhtA.host.Close()
|
||||
defer dhtB.host.Close()
|
||||
|
||||
connect(t, ctx, dhtA, dhtB)
|
||||
|
||||
err := dhtB.Provide(ctx, k, true)
|
||||
if enabledB {
|
||||
if err != nil {
|
||||
t.Fatal("put should have succeeded on node B", err)
|
||||
}
|
||||
} else {
|
||||
if err != routing.ErrNotSupported {
|
||||
t.Fatal("should not have put the value to node B", err)
|
||||
}
|
||||
_, err = dhtB.FindProviders(ctx, k)
|
||||
if err != routing.ErrNotSupported {
|
||||
t.Fatal("get should have failed on node B")
|
||||
}
|
||||
provs := dhtB.providers.GetProviders(ctx, k)
|
||||
if len(provs) != 0 {
|
||||
t.Fatal("node B should not have found local providers")
|
||||
}
|
||||
}
|
||||
|
||||
provs, err := dhtA.FindProviders(ctx, k)
|
||||
if enabledA {
|
||||
if len(provs) != 0 {
|
||||
t.Fatal("node A should not have found providers")
|
||||
}
|
||||
} else {
|
||||
if err != routing.ErrNotSupported {
|
||||
t.Fatal("node A should not have found providers")
|
||||
}
|
||||
}
|
||||
provAddrs := dhtA.providers.GetProviders(ctx, k)
|
||||
if len(provAddrs) != 0 {
|
||||
t.Fatal("node A should not have found local providers")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSetPluggedProtocol(t *testing.T) {
|
||||
t.Run("PutValue/GetValue - same protocol", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
30
handlers.go
30
handlers.go
@ -26,21 +26,31 @@ type dhtHandler func(context.Context, peer.ID, *pb.Message) (*pb.Message, error)
|
||||
|
||||
func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler {
|
||||
switch t {
|
||||
case pb.Message_GET_VALUE:
|
||||
return dht.handleGetValue
|
||||
case pb.Message_PUT_VALUE:
|
||||
return dht.handlePutValue
|
||||
case pb.Message_FIND_NODE:
|
||||
return dht.handleFindPeer
|
||||
case pb.Message_ADD_PROVIDER:
|
||||
return dht.handleAddProvider
|
||||
case pb.Message_GET_PROVIDERS:
|
||||
return dht.handleGetProviders
|
||||
case pb.Message_PING:
|
||||
return dht.handlePing
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
if dht.enableValues {
|
||||
switch t {
|
||||
case pb.Message_GET_VALUE:
|
||||
return dht.handleGetValue
|
||||
case pb.Message_PUT_VALUE:
|
||||
return dht.handlePutValue
|
||||
}
|
||||
}
|
||||
|
||||
if dht.enableProviders {
|
||||
switch t {
|
||||
case pb.Message_ADD_PROVIDER:
|
||||
return dht.handleAddProvider
|
||||
case pb.Message_GET_PROVIDERS:
|
||||
return dht.handleGetProviders
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
|
||||
|
@ -21,12 +21,14 @@ var (
|
||||
|
||||
// Options is a structure containing all the options that can be used when constructing a DHT.
|
||||
type Options struct {
|
||||
Datastore ds.Batching
|
||||
Validator record.Validator
|
||||
Client bool
|
||||
Protocols []protocol.ID
|
||||
BucketSize int
|
||||
MaxRecordAge time.Duration
|
||||
Datastore ds.Batching
|
||||
Validator record.Validator
|
||||
Client bool
|
||||
Protocols []protocol.ID
|
||||
BucketSize int
|
||||
MaxRecordAge time.Duration
|
||||
EnableProviders bool
|
||||
EnableValues bool
|
||||
|
||||
RoutingTable struct {
|
||||
RefreshQueryTimeout time.Duration
|
||||
@ -56,6 +58,8 @@ var Defaults = func(o *Options) error {
|
||||
}
|
||||
o.Datastore = dssync.MutexWrap(ds.NewMapDatastore())
|
||||
o.Protocols = DefaultProtocols
|
||||
o.EnableProviders = true
|
||||
o.EnableValues = true
|
||||
|
||||
o.RoutingTable.RefreshQueryTimeout = 10 * time.Second
|
||||
o.RoutingTable.RefreshPeriod = 1 * time.Hour
|
||||
@ -177,3 +181,30 @@ func DisableAutoRefresh() Option {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DisableProviders disables storing and retrieving provider records.
|
||||
//
|
||||
// Defaults to enabled.
|
||||
//
|
||||
// WARNING: do not change this unless you're using a forked DHT (i.e., a private
|
||||
// network and/or distinct DHT protocols with the `Protocols` option).
|
||||
func DisableProviders() Option {
|
||||
return func(o *Options) error {
|
||||
o.EnableProviders = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// DisableProviders disables storing and retrieving value records (including
|
||||
// public keys).
|
||||
//
|
||||
// Defaults to enabled.
|
||||
//
|
||||
// WARNING: do not change this unless you're using a forked DHT (i.e., a private
|
||||
// network and/or distinct DHT protocols with the `Protocols` option).
|
||||
func DisableValues() Option {
|
||||
return func(o *Options) error {
|
||||
o.EnableValues = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package dht
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"github.com/libp2p/go-libp2p-core/test"
|
||||
"testing"
|
||||
"time"
|
||||
@ -13,6 +14,8 @@ import (
|
||||
"github.com/libp2p/go-libp2p-core/routing"
|
||||
record "github.com/libp2p/go-libp2p-record"
|
||||
tnet "github.com/libp2p/go-libp2p-testing/net"
|
||||
|
||||
dhtopt "github.com/libp2p/go-libp2p-kad-dht/opts"
|
||||
)
|
||||
|
||||
// Check that GetPublicKey() correctly extracts a public key
|
||||
@ -305,3 +308,75 @@ func TestPubkeyGoodKeyFromDHTGoodKeyDirect(t *testing.T) {
|
||||
t.Fatal("got incorrect public key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesDisabled(t *testing.T) {
|
||||
for i := 0; i < 3; i++ {
|
||||
enabledA := (i & 0x1) > 0
|
||||
enabledB := (i & 0x2) > 0
|
||||
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
optsA, optsB []dhtopt.Option
|
||||
)
|
||||
if !enabledA {
|
||||
optsA = append(optsA, dhtopt.DisableValues())
|
||||
}
|
||||
if !enabledB {
|
||||
optsB = append(optsB, dhtopt.DisableValues())
|
||||
}
|
||||
|
||||
dhtA := setupDHT(ctx, t, false, optsA...)
|
||||
dhtB := setupDHT(ctx, t, false, optsB...)
|
||||
|
||||
defer dhtA.Close()
|
||||
defer dhtB.Close()
|
||||
defer dhtA.host.Close()
|
||||
defer dhtB.host.Close()
|
||||
|
||||
connect(t, ctx, dhtA, dhtB)
|
||||
|
||||
pubk := dhtB.peerstore.PubKey(dhtB.self)
|
||||
pkbytes, err := pubk.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pkkey := routing.KeyForPublicKey(dhtB.self)
|
||||
err = dhtB.PutValue(ctx, pkkey, pkbytes)
|
||||
if enabledB {
|
||||
if err != nil {
|
||||
t.Fatal("put should have succeeded on node B", err)
|
||||
}
|
||||
} else {
|
||||
if err != routing.ErrNotSupported {
|
||||
t.Fatal("should not have put the value to node B", err)
|
||||
}
|
||||
_, err = dhtB.GetValue(ctx, pkkey)
|
||||
if err != routing.ErrNotSupported {
|
||||
t.Fatal("get should have failed on node B")
|
||||
}
|
||||
rec, _ := dhtB.getLocal(pkkey)
|
||||
if rec != nil {
|
||||
t.Fatal("node B should not have found the value locally")
|
||||
}
|
||||
}
|
||||
|
||||
_, err = dhtA.GetValue(ctx, pkkey)
|
||||
if enabledA {
|
||||
if err != routing.ErrNotFound {
|
||||
t.Fatal("node A should not have found the value")
|
||||
}
|
||||
} else {
|
||||
if err != routing.ErrNotSupported {
|
||||
t.Fatal("node A should not have found the value")
|
||||
}
|
||||
}
|
||||
rec, _ := dhtA.getLocal(pkkey)
|
||||
if rec != nil {
|
||||
t.Fatal("node A should not have found the value locally")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
30
routing.go
30
routing.go
@ -34,6 +34,10 @@ var asyncQueryBuffer = 10
|
||||
// PutValue adds value corresponding to given Key.
|
||||
// This is the top level "Store" operation of the DHT
|
||||
func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts ...routing.Option) (err error) {
|
||||
if !dht.enableValues {
|
||||
return routing.ErrNotSupported
|
||||
}
|
||||
|
||||
eip := logger.EventBegin(ctx, "PutValue")
|
||||
defer func() {
|
||||
eip.Append(loggableKey(key))
|
||||
@ -110,6 +114,10 @@ type RecvdVal struct {
|
||||
|
||||
// GetValue searches for the value corresponding to given Key.
|
||||
func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Option) (_ []byte, err error) {
|
||||
if !dht.enableValues {
|
||||
return nil, routing.ErrNotSupported
|
||||
}
|
||||
|
||||
eip := logger.EventBegin(ctx, "GetValue")
|
||||
defer func() {
|
||||
eip.Append(loggableKey(key))
|
||||
@ -148,6 +156,10 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op
|
||||
}
|
||||
|
||||
func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) {
|
||||
if !dht.enableValues {
|
||||
return nil, routing.ErrNotSupported
|
||||
}
|
||||
|
||||
var cfg routing.Options
|
||||
if err := cfg.Apply(opts...); err != nil {
|
||||
return nil, err
|
||||
@ -250,8 +262,11 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing
|
||||
|
||||
// GetValues gets nvals values corresponding to the given key.
|
||||
func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) {
|
||||
eip := logger.EventBegin(ctx, "GetValues")
|
||||
if !dht.enableValues {
|
||||
return nil, routing.ErrNotSupported
|
||||
}
|
||||
|
||||
eip := logger.EventBegin(ctx, "GetValues")
|
||||
eip.Append(loggableKey(key))
|
||||
defer eip.Done()
|
||||
|
||||
@ -398,6 +413,9 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-cha
|
||||
|
||||
// Provide makes this node announce that it can provide a value for the given key
|
||||
func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) {
|
||||
if !dht.enableProviders {
|
||||
return routing.ErrNotSupported
|
||||
}
|
||||
eip := logger.EventBegin(ctx, "Provide", key, logging.LoggableMap{"broadcast": brdcst})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@ -477,6 +495,9 @@ func (dht *IpfsDHT) makeProvRecord(skey cid.Cid) (*pb.Message, error) {
|
||||
|
||||
// FindProviders searches until the context expires.
|
||||
func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
|
||||
if !dht.enableProviders {
|
||||
return nil, routing.ErrNotSupported
|
||||
}
|
||||
var providers []peer.AddrInfo
|
||||
for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) {
|
||||
providers = append(providers, p)
|
||||
@ -488,8 +509,13 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
|
||||
// Peers will be returned on the channel as soon as they are found, even before
|
||||
// the search query completes.
|
||||
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
|
||||
logger.Event(ctx, "findProviders", key)
|
||||
peerOut := make(chan peer.AddrInfo, count)
|
||||
if !dht.enableProviders {
|
||||
close(peerOut)
|
||||
return peerOut
|
||||
}
|
||||
|
||||
logger.Event(ctx, "findProviders", key)
|
||||
|
||||
go dht.findProvidersAsyncRoutine(ctx, key, count, peerOut)
|
||||
return peerOut
|
||||
|
Loading…
x
Reference in New Issue
Block a user