diff --git a/dht_net.go b/dht_net.go index a91e0f5..cfbc812 100644 --- a/dht_net.go +++ b/dht_net.go @@ -7,6 +7,7 @@ import ( inet "github.com/jbenet/go-ipfs/net" peer "github.com/jbenet/go-ipfs/peer" pb "github.com/jbenet/go-ipfs/routing/dht/pb" + ctxutil "github.com/jbenet/go-ipfs/util/ctx" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ggio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/io" @@ -21,8 +22,10 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { defer s.Close() ctx := dht.Context() - r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) - w := ggio.NewDelimitedWriter(s) + cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func + cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func + r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) + w := ggio.NewDelimitedWriter(cw) mPeer := s.Conn().RemotePeer() // receive msg @@ -76,8 +79,10 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message } defer s.Close() - r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) - w := ggio.NewDelimitedWriter(s) + cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func + cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func + r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) + w := ggio.NewDelimitedWriter(cw) start := time.Now() @@ -113,7 +118,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message } defer s.Close() - w := ggio.NewDelimitedWriter(s) + cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func + w := ggio.NewDelimitedWriter(cw) log.Debugf("%s writing", dht.self) if err := w.WriteMsg(pmes); err != nil { diff --git a/query.go b/query.go index c45fa23..3d3f940 100644 --- a/query.go +++ b/query.go @@ -12,6 +12,7 @@ import ( todoctr "github.com/jbenet/go-ipfs/util/todocounter" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ctxgroup "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-ctxgroup" ) var maxQueryConcurrency = AlphaValue @@ -78,9 +79,8 @@ type dhtQueryRunner struct { // peersRemaining is a counter of peers remaining (toQuery + processing) peersRemaining todoctr.Counter - // context - ctx context.Context - cancel context.CancelFunc + // context group + cg ctxgroup.ContextGroup // result result *dhtQueryResult @@ -93,16 +93,13 @@ type dhtQueryRunner struct { } func newQueryRunner(ctx context.Context, q *dhtQuery) *dhtQueryRunner { - ctx, cancel := context.WithCancel(ctx) - return &dhtQueryRunner{ - ctx: ctx, - cancel: cancel, query: q, peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(q.key)), peersRemaining: todoctr.NewSyncCounter(), peersSeen: peer.Set{}, rateLimit: make(chan struct{}, q.concurrency), + cg: ctxgroup.WithContext(ctx), } } @@ -120,11 +117,13 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { // add all the peers we got first. for _, p := range peers { - r.addPeerToQuery(p, "") // don't have access to self here... + r.addPeerToQuery(r.cg.Context(), p, "") // don't have access to self here... } // go do this thing. - go r.spawnWorkers() + // do it as a child func to make sure Run exits + // ONLY AFTER spawn workers has exited. + r.cg.AddChildFunc(r.spawnWorkers) // so workers are working. @@ -133,7 +132,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { select { case <-r.peersRemaining.Done(): - r.cancel() // ran all and nothing. cancel all outstanding workers. + r.cg.Close() r.RLock() defer r.RUnlock() @@ -141,10 +140,10 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { err = r.errs[0] } - case <-r.ctx.Done(): + case <-r.cg.Closed(): r.RLock() defer r.RUnlock() - err = r.ctx.Err() + err = r.cg.Context().Err() // collect the error. } if r.result != nil && r.result.success { @@ -154,7 +153,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { return nil, err } -func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) { +func (r *dhtQueryRunner) addPeerToQuery(ctx context.Context, next peer.ID, benchmark peer.ID) { // if new peer is ourselves... if next == r.query.dialer.LocalPeer() { return @@ -186,37 +185,42 @@ func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) { r.peersRemaining.Increment(1) select { case r.peersToQuery.EnqChan <- next: - case <-r.ctx.Done(): + case <-ctx.Done(): } } -func (r *dhtQueryRunner) spawnWorkers() { +func (r *dhtQueryRunner) spawnWorkers(parent ctxgroup.ContextGroup) { for { select { case <-r.peersRemaining.Done(): return - case <-r.ctx.Done(): + case <-r.cg.Closing(): return case p, more := <-r.peersToQuery.DeqChan: if !more { return // channel closed. } - log.Debugf("spawning worker for: %v\n", p) - go r.queryPeer(p) + log.Debugf("spawning worker for: %v", p) + + // do it as a child func to make sure Run exits + // ONLY AFTER spawn workers has exited. + parent.AddChildFunc(func(cg ctxgroup.ContextGroup) { + r.queryPeer(cg, p) + }) } } } -func (r *dhtQueryRunner) queryPeer(p peer.ID) { +func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { log.Debugf("spawned worker for: %v", p) // make sure we rate limit concurrency. select { case <-r.rateLimit: - case <-r.ctx.Done(): + case <-cg.Closing(): r.peersRemaining.Decrement(1) return } @@ -233,7 +237,7 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { }() // make sure we're connected to the peer. - err := r.query.dialer.DialPeer(r.ctx, p) + err := r.query.dialer.DialPeer(cg.Context(), p) if err != nil { log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err) r.Lock() @@ -243,7 +247,7 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { } // finally, run the query against this peer - res, err := r.query.qfunc(r.ctx, p) + res, err := r.query.qfunc(cg.Context(), p) if err != nil { log.Debugf("ERROR worker for: %v %v", p, err) @@ -256,14 +260,20 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { r.Lock() r.result = res r.Unlock() - r.cancel() // signal to everyone that we're done. + go r.cg.Close() // signal to everyone that we're done. + // must be async, as we're one of the children, and Close blocks. } else if len(res.closerPeers) > 0 { log.Debugf("PEERS CLOSER -- worker for: %v (%d closer peers)", p, len(res.closerPeers)) for _, next := range res.closerPeers { // add their addresses to the dialer's peerstore + conns := r.query.dialer.ConnsToPeer(next.ID) + if len(conns) == 0 { + log.Infof("PEERS CLOSER -- worker for %v FOUND NEW PEER: %s %s", p, next.ID, next.Addrs) + } + r.query.dialer.Peerstore().AddAddresses(next.ID, next.Addrs) - r.addPeerToQuery(next.ID, p) + r.addPeerToQuery(cg.Context(), next.ID, p) log.Debugf("PEERS CLOSER -- worker for: %v added %v (%v)", p, next.ID, next.Addrs) } } else {