mirror of
https://github.com/fluencelabs/go-libp2p-kad-dht
synced 2025-04-24 14:22:13 +00:00
fix: obey the context when sending messages to peers
Related to #453 but not a fix. This will cause us to actually return early when we start blocking on sending to some peers, but it won't really _unblock_ those peers. For that, we need to write with a context.
This commit is contained in:
parent
5552b3ff8d
commit
0b029388bd
28
ctx_mutex.go
Normal file
28
ctx_mutex.go
Normal file
@ -0,0 +1,28 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ctxMutex chan struct{}
|
||||
|
||||
func newCtxMutex() ctxMutex {
|
||||
return make(ctxMutex, 1)
|
||||
}
|
||||
|
||||
func (m ctxMutex) Lock(ctx context.Context) error {
|
||||
select {
|
||||
case m <- struct{}{}:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (m ctxMutex) Unlock() {
|
||||
select {
|
||||
case <-m:
|
||||
default:
|
||||
panic("not locked")
|
||||
}
|
||||
}
|
19
dht_net.go
19
dht_net.go
@ -234,7 +234,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
|
||||
dht.smlk.Unlock()
|
||||
return ms, nil
|
||||
}
|
||||
ms = &messageSender{p: p, dht: dht}
|
||||
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
|
||||
dht.strmap[p] = ms
|
||||
dht.smlk.Unlock()
|
||||
|
||||
@ -262,7 +262,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
|
||||
type messageSender struct {
|
||||
s network.Stream
|
||||
r msgio.ReadCloser
|
||||
lk sync.Mutex
|
||||
lk ctxMutex
|
||||
p peer.ID
|
||||
dht *IpfsDHT
|
||||
|
||||
@ -282,8 +282,11 @@ func (ms *messageSender) invalidate() {
|
||||
}
|
||||
|
||||
func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
|
||||
ms.lk.Lock()
|
||||
if err := ms.lk.Lock(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer ms.lk.Unlock()
|
||||
|
||||
if err := ms.prep(ctx); err != nil {
|
||||
ms.invalidate()
|
||||
return err
|
||||
@ -316,8 +319,11 @@ func (ms *messageSender) prep(ctx context.Context) error {
|
||||
const streamReuseTries = 3
|
||||
|
||||
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
|
||||
ms.lk.Lock()
|
||||
if err := ms.lk.Lock(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer ms.lk.Unlock()
|
||||
|
||||
retry := false
|
||||
for {
|
||||
if err := ms.prep(ctx); err != nil {
|
||||
@ -351,8 +357,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro
|
||||
}
|
||||
|
||||
func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
|
||||
ms.lk.Lock()
|
||||
if err := ms.lk.Lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ms.lk.Unlock()
|
||||
|
||||
retry := false
|
||||
for {
|
||||
if err := ms.prep(ctx); err != nil {
|
||||
|
43
ext_test.go
43
ext_test.go
@ -18,6 +18,49 @@ import (
|
||||
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
|
||||
)
|
||||
|
||||
func TestHang(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mn, err := mocknet.FullMeshConnected(ctx, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hosts := mn.Hosts()
|
||||
|
||||
os := []opts.Option{opts.DisableAutoRefresh()}
|
||||
d, err := New(ctx, hosts[0], os...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Hang on every request.
|
||||
hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) {
|
||||
defer s.Reset()
|
||||
<-ctx.Done()
|
||||
})
|
||||
d.Update(ctx, hosts[1].ID())
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel2()
|
||||
_ = d.Provide(ctx2, testCaseCids[0], true)
|
||||
if ctx2.Err() != context.DeadlineExceeded {
|
||||
t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err())
|
||||
}
|
||||
select {
|
||||
case <-peers:
|
||||
t.Error("GetClosestPeers should not have returned yet")
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetFailures(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
|
4
notif.go
4
notif.go
@ -1,6 +1,8 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/helpers"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
|
||||
@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {
|
||||
|
||||
// Do this asynchronously as ms.lk can block for a while.
|
||||
go func() {
|
||||
ms.lk.Lock()
|
||||
ms.lk.Lock(context.Background())
|
||||
defer ms.lk.Unlock()
|
||||
ms.invalidate()
|
||||
}()
|
||||
|
Loading…
x
Reference in New Issue
Block a user