diff --git a/src/dialer/index.js b/src/dialer/index.js index fc02b357..606a85e5 100644 --- a/src/dialer/index.js +++ b/src/dialer/index.js @@ -74,7 +74,21 @@ class Dialer { if (!dialTarget.addrs.length) { throw errCode(new Error('The dial request has no addresses'), codes.ERR_NO_VALID_ADDRESSES) } - const pendingDial = this._pendingDials.get(dialTarget.id) || this._createPendingDial(dialTarget, options) + + // Used for subsequent dials pending + let subsequentDialAborted = false + const onAbort = () => { + subsequentDialAborted = true + pendingDial.controller.abort() + } + + let pendingDial = this._pendingDials.get(dialTarget.id) + if (!pendingDial) { + pendingDial = this._createPendingDial(dialTarget, options) + } else { + // track subsequent dial abort + options.signal && options.signal.addEventListener('abort', onAbort) + } try { const connection = await pendingDial.promise @@ -82,12 +96,16 @@ class Dialer { return connection } catch (err) { // Error is a timeout - if (pendingDial.controller.signal.aborted) { + if (pendingDial.controller.signal.aborted && !subsequentDialAborted) { err.code = codes.ERR_TIMEOUT + // Error is a subsequent dial abort + } else if (subsequentDialAborted) { + err.code = codes.ERR_SUBSEQUENT_DIAL_ABORT } log.error(err) throw err } finally { + options.signal && options.signal.removeEventListener('abort', onAbort) pendingDial.destroy() } } diff --git a/src/errors.js b/src/errors.js index 18e600c6..49b4d042 100644 --- a/src/errors.js +++ b/src/errors.js @@ -26,6 +26,7 @@ exports.codes = { ERR_INVALID_PEER: 'ERR_INVALID_PEER', ERR_MUXER_UNAVAILABLE: 'ERR_MUXER_UNAVAILABLE', ERR_TIMEOUT: 'ERR_TIMEOUT', + ERR_SUBSEQUENT_DIAL_ABORT: 'ERR_SUBSEQUENT_DIAL_ABORT', ERR_TRANSPORT_UNAVAILABLE: 'ERR_TRANSPORT_UNAVAILABLE', ERR_TRANSPORT_DIAL_FAILED: 'ERR_TRANSPORT_DIAL_FAILED', ERR_UNSUPPORTED_PROTOCOL: 'ERR_UNSUPPORTED_PROTOCOL', diff --git a/test/dialing/direct.spec.js b/test/dialing/direct.spec.js index e3d60e59..0836ad25 100644 --- a/test/dialing/direct.spec.js +++ b/test/dialing/direct.spec.js @@ -14,6 +14,7 @@ const Muxer = require('libp2p-mplex') const { NOISE: Crypto } = require('libp2p-noise') const multiaddr = require('multiaddr') const AggregateError = require('aggregate-error') +const AbortController = require('abort-controller') const { AbortError } = require('libp2p-interfaces/src/transport/errors') const { codes: ErrorCodes } = require('../../src/errors') @@ -179,6 +180,47 @@ describe('Dialing (direct, WebSockets)', () => { .and.to.have.property('code', ErrorCodes.ERR_TIMEOUT) }) + it('should abort subsequent dials', async () => { + const dialer = new Dialer({ + transportManager: localTM, + timeout: 1000, + peerStore: { + addressBook: { + add: () => { }, + getMultiaddrsForPeer: () => [remoteAddr] + } + } + }) + const controller = new AbortController() + + const deferredDial = pDefer() + const deferredAbort = pDefer() + + sinon.stub(localTM, 'dial').callsFake(async (_, options) => { + deferredDial.resolve() + expect(options.signal).to.exist() + expect(options.signal.aborted).to.equal(false) + await deferredAbort.promise + expect(options.signal.aborted).to.equal(true) + throw new AbortError() + }) + + const dialPromise1 = dialer.connectToPeer(peerId) + const dialPromise2 = dialer.connectToPeer(peerId, { signal: controller.signal }) + + await deferredDial.promise + controller.abort() + + deferredAbort.resolve() + + await expect(dialPromise1) + .to.eventually.be.rejected() + .and.to.have.property('code', ErrorCodes.ERR_SUBSEQUENT_DIAL_ABORT) + await expect(dialPromise2) + .to.eventually.be.rejected() + .and.to.have.property('code', ErrorCodes.ERR_SUBSEQUENT_DIAL_ABORT) + }) + it('should dial to the max concurrency', async () => { const dialer = new Dialer({ transportManager: localTM,