mirror of
https://github.com/fluencelabs/go-libp2p-kad-dht
synced 2025-04-24 14:22:13 +00:00
Merge pull request #462 from libp2p/fix/observe-context-in-message-sender
fix: obey the context when sending messages to peers
This commit is contained in:
commit
dbb3d2c0a2
28
ctx_mutex.go
Normal file
28
ctx_mutex.go
Normal file
@ -0,0 +1,28 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ctxMutex chan struct{}
|
||||
|
||||
func newCtxMutex() ctxMutex {
|
||||
return make(ctxMutex, 1)
|
||||
}
|
||||
|
||||
func (m ctxMutex) Lock(ctx context.Context) error {
|
||||
select {
|
||||
case m <- struct{}{}:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (m ctxMutex) Unlock() {
|
||||
select {
|
||||
case <-m:
|
||||
default:
|
||||
panic("not locked")
|
||||
}
|
||||
}
|
19
dht_net.go
19
dht_net.go
@ -246,7 +246,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
|
||||
dht.smlk.Unlock()
|
||||
return ms, nil
|
||||
}
|
||||
ms = &messageSender{p: p, dht: dht}
|
||||
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
|
||||
dht.strmap[p] = ms
|
||||
dht.smlk.Unlock()
|
||||
|
||||
@ -274,7 +274,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
|
||||
type messageSender struct {
|
||||
s network.Stream
|
||||
r msgio.ReadCloser
|
||||
lk sync.Mutex
|
||||
lk ctxMutex
|
||||
p peer.ID
|
||||
dht *IpfsDHT
|
||||
|
||||
@ -294,8 +294,11 @@ func (ms *messageSender) invalidate() {
|
||||
}
|
||||
|
||||
func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
|
||||
ms.lk.Lock()
|
||||
if err := ms.lk.Lock(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer ms.lk.Unlock()
|
||||
|
||||
if err := ms.prep(ctx); err != nil {
|
||||
ms.invalidate()
|
||||
return err
|
||||
@ -328,8 +331,11 @@ func (ms *messageSender) prep(ctx context.Context) error {
|
||||
const streamReuseTries = 3
|
||||
|
||||
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
|
||||
ms.lk.Lock()
|
||||
if err := ms.lk.Lock(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
defer ms.lk.Unlock()
|
||||
|
||||
retry := false
|
||||
for {
|
||||
if err := ms.prep(ctx); err != nil {
|
||||
@ -363,8 +369,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro
|
||||
}
|
||||
|
||||
func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
|
||||
ms.lk.Lock()
|
||||
if err := ms.lk.Lock(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ms.lk.Unlock()
|
||||
|
||||
retry := false
|
||||
for {
|
||||
if err := ms.prep(ctx); err != nil {
|
||||
|
43
ext_test.go
43
ext_test.go
@ -18,6 +18,49 @@ import (
|
||||
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
|
||||
)
|
||||
|
||||
func TestHang(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mn, err := mocknet.FullMeshConnected(ctx, 2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hosts := mn.Hosts()
|
||||
|
||||
os := []opts.Option{opts.DisableAutoRefresh()}
|
||||
d, err := New(ctx, hosts[0], os...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Hang on every request.
|
||||
hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) {
|
||||
defer s.Reset()
|
||||
<-ctx.Done()
|
||||
})
|
||||
d.Update(ctx, hosts[1].ID())
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel2()
|
||||
_ = d.Provide(ctx2, testCaseCids[0], true)
|
||||
if ctx2.Err() != context.DeadlineExceeded {
|
||||
t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err())
|
||||
}
|
||||
select {
|
||||
case <-peers:
|
||||
t.Error("GetClosestPeers should not have returned yet")
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetFailures(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
|
4
notif.go
4
notif.go
@ -1,6 +1,8 @@
|
||||
package dht
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/helpers"
|
||||
"github.com/libp2p/go-libp2p-core/network"
|
||||
|
||||
@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {
|
||||
|
||||
// Do this asynchronously as ms.lk can block for a while.
|
||||
go func() {
|
||||
ms.lk.Lock()
|
||||
ms.lk.Lock(context.Background())
|
||||
defer ms.lk.Unlock()
|
||||
ms.invalidate()
|
||||
}()
|
||||
|
Loading…
x
Reference in New Issue
Block a user