diff --git a/src/errors.ts b/src/errors.ts index 09d2b10..46a5c0a 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -6,6 +6,5 @@ export class FailedIKError extends Error { this.initialMsg = initialMsg; this.name = "FailedIKhandshake"; - this.stack = new Error().stack; } }; diff --git a/src/noise.ts b/src/noise.ts index 73c6edd..7d88c73 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -108,23 +108,13 @@ export class Noise implements INoiseConnection { private async performHandshake(params: HandshakeParams): Promise { const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); - let tryIK = this.useNoisePipes; - const foundRemoteStaticKey = KeyCache.load(params.remotePeer); - if (tryIK && params.isInitiator && !foundRemoteStaticKey) { - tryIK = false; - logger(`Static key not found.`) - } - + const remoteStaticKey = KeyCache.load(params.remotePeer); // Try IK if acting as responder or initiator that has remote's static key. - if (tryIK) { + if (this.useNoisePipes && remoteStaticKey) { // Try IK first const { remotePeer, connection, isInitiator } = params; - if (!foundRemoteStaticKey) { - // TODO: Recheck. Possible that responder should not have it here. - throw new Error("Remote static key should be initialized."); - } + const IKhandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, remoteStaticKey); - const IKhandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, foundRemoteStaticKey); try { return await this.performIKHandshake(IKhandshake); } catch (e) { diff --git a/test/noise.test.ts b/test/noise.test.ts index dca824c..f40346e 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -163,7 +163,7 @@ describe("Noise", () => { } }); - it("should switch to XX fallback because of invalid remote static key", async() => { + it("IK -> XX fallback: initiator has invalid remote static key", async() => { try { const staticKeysInitiator = generateKeypair(); const noiseInit = new Noise(staticKeysInitiator.privateKey); @@ -171,6 +171,7 @@ describe("Noise", () => { const xxSpy = sandbox.spy(noiseInit, "performXXFallbackHandshake"); // Prepare key cache for noise pipes + KeyCache.resetStorage(); KeyCache.store(localPeer, staticKeysInitiator.publicKey); KeyCache.store(remotePeer, generateKeypair().publicKey); @@ -194,12 +195,12 @@ describe("Noise", () => { } }); - it("XX fallback with XX as responder has noise pipes disabled", async() => { + it("IK -> XX fallback: responder has disabled noise pipes", async() => { try { const staticKeysInitiator = generateKeypair(); const noiseInit = new Noise(staticKeysInitiator.privateKey); - const staticKeysResponder = generateKeypair(); + const staticKeysResponder = generateKeypair(); const noiseResp = new Noise(staticKeysResponder.privateKey, undefined, false); const xxSpy = sandbox.spy(noiseInit, "performXXFallbackHandshake"); @@ -208,7 +209,6 @@ describe("Noise", () => { KeyCache.store(remotePeer, staticKeysResponder.publicKey); const [inboundConnection, outboundConnection] = DuplexPair(); - const [outbound, inbound] = await Promise.all([ noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), @@ -228,7 +228,7 @@ describe("Noise", () => { } }); - it("Initiator starts with XX (pipes disabled) responder has noise pipes", async() => { + it("Initiator starts with XX (pipes disabled), responder has enabled noise pipes", async() => { try { const staticKeysInitiator = generateKeypair(); const noiseInit = new Noise(staticKeysInitiator.privateKey, undefined, false); @@ -262,4 +262,40 @@ describe("Noise", () => { assert(false, e.message); } }); + + it("IK -> XX Fallback: responder has no remote static key", async() => { + try { + const staticKeysInitiator = generateKeypair(); + const noiseInit = new Noise(staticKeysInitiator.privateKey); + const staticKeysResponder = generateKeypair(); + + const noiseResp = new Noise(staticKeysResponder.privateKey); + const xxFallbackInitSpy = sandbox.spy(noiseInit, "performXXFallbackHandshake"); + const xxRespSpy = sandbox.spy(noiseResp, "performXXHandshake"); + + // Prepare key cache for noise pipes + KeyCache.resetStorage(); + KeyCache.store(remotePeer, staticKeysResponder.publicKey); + + const [inboundConnection, outboundConnection] = DuplexPair(); + + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), + ]); + + const wrappedInbound = Wrap(inbound.conn); + const wrappedOutbound = Wrap(outbound.conn); + + wrappedOutbound.writeLP(Buffer.from("test fallback")); + const response = await wrappedInbound.readLP(); + expect(response.toString()).equal("test fallback"); + + assert(xxFallbackInitSpy.calledOnce, "XX Fallback method was not called."); + assert(xxRespSpy.calledOnce, "XX method was not called."); + } catch (e) { + console.error(e); + assert(false, e.message); + } + }); });