diff --git a/src/encoder.ts b/src/encoder.ts index 4f869e6..d8c93f7 100644 --- a/src/encoder.ts +++ b/src/encoder.ts @@ -14,6 +14,8 @@ export const uint16BEDecode = data => { }; uint16BEDecode.bytes = 2; +// Note: IK and XX encoder usage is opposite (XX uses in stages encode0 where IK uses encode1) + export function encode0(message: MessageBuffer): bytes { return Buffer.concat([message.ne, message.ciphertext]); } diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index 5683c99..62a02f7 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -8,6 +8,7 @@ import {Buffer} from "buffer"; import {decode0, decode1, encode0, encode1} from "./encoder"; import {verifySignedPayload} from "./utils"; import {FailedIKError} from "./errors"; +import {logger} from "./logger"; export class IKHandshake implements IHandshake { public isInitiator: boolean; @@ -44,15 +45,16 @@ export class IKHandshake implements IHandshake { public async stage0(): Promise { if (this.isInitiator) { const messageBuffer = this.ik.sendMessage(this.session, this.payload); - this.connection.writeLP(encode0(messageBuffer)); + this.connection.writeLP(encode1(messageBuffer)); } else { - const receivedMsg = await this.connection.readLP(); - const receivedMessageBuffer = decode0(receivedMsg); + const receivedMsg = (await this.connection.readLP()).slice(); + const receivedMessageBuffer = decode1(Buffer.from(receivedMsg)); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); try { await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { + logger("Responder breaking up with IK handshake in stage 0."); throw new FailedIKError(receivedMsg, `Error occurred while verifying initiator's signed payload: ${e.message}`); } } @@ -60,18 +62,19 @@ export class IKHandshake implements IHandshake { public async stage1(): Promise { if (this.isInitiator) { - const receivedMsg = await this.connection.readLP(); - const receivedMessageBuffer = decode1(receivedMsg); + const receivedMsg = (await this.connection.readLP()).slice(); + const receivedMessageBuffer = decode0(Buffer.from(receivedMsg)); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); try { await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { + logger("Initiator breaking up with IK handshake in stage 1."); throw new FailedIKError(receivedMsg, `Error occurred while verifying responder's signed payload: ${e.message}`); } } else { const messageBuffer = this.ik.sendMessage(this.session, this.payload); - this.connection.writeLP(encode1(messageBuffer)); + this.connection.writeLP(encode0(messageBuffer)); } } diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index c01539b..ed14f4d 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -39,37 +39,34 @@ export class XXFallbackHandshake extends XXHandshake { this.xx.sendMessage(this.session, Buffer.alloc(0), this.ephemeralKeys); logger("XX Fallback Stage 0 - Initialized state as the first message was sent by initiator."); } else { - logger("XX Fallback Stage 0 - Responder waiting to receive first message..."); const receivedMessageBuffer = decode0(this.initialMsg); this.xx.recvMessage(this.session, { ne: receivedMessageBuffer.ne, ns: Buffer.alloc(0), ciphertext: Buffer.alloc(0), }); - logger("XX Fallback Stage 0 - Responder received first message."); + logger("XX Fallback Stage 0 - Responder used received message from IK."); } } // stage 1 public async exchange(): Promise { if (this.isInitiator) { - logger('XX Fallback Stage 1 - Initiator waiting to receive first message from responder...'); const receivedMessageBuffer = decode1(this.initialMsg); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); - logger('XX Fallback Stage 1 - Initiator received the message. Got remote\'s static key.'); + logger('XX Fallback Stage 1 - Initiator used received message from IK.'); logger("Initiator going to check remote's signature..."); try { await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { - throw new Error(`Error occurred while verifying signed payload: ${e.message}`); + throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); } logger("All good with the signature!"); } else { - logger('XX Fallback Stage 1 - Responder sending out first message with signed payload and static key.'); - const messageBuffer = this.xx.sendMessage(this.session, this.payload); - this.connection.writeLP(encode1(messageBuffer)); - logger('XX Fallback Stage 1 - Responder sent the second handshake message with signed payload.') + logger("XX Fallback Stage 1 - Responder start"); + super.exchange(); + logger("XX Fallback Stage 1 - Responder end"); } } } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 5f42238..ed65093 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -63,7 +63,7 @@ export class XXHandshake implements IHandshake { public async exchange(): Promise { if (this.isInitiator) { logger('Stage 1 - Initiator waiting to receive first message from responder...'); - const receivedMessageBuffer = decode1((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode1(await this.connection.readLP()); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); diff --git a/src/noise.ts b/src/noise.ts index c678042..b0d65ef 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -124,7 +124,7 @@ export class Noise implements INoiseConnection { throw new Error("Remote static key should be initialized."); } - const IKhandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, foundRemoteStaticKey); + const IKhandshake = new IKHandshake(isInitiator, Buffer.from(payload), this.prologue, this.staticKeys, connection, remotePeer, foundRemoteStaticKey); try { return await this.performIKHandshake(IKhandshake); } catch (e) { @@ -189,13 +189,8 @@ export class Noise implements INoiseConnection { handshake: IKHandshake, ): Promise { - try { - await handshake.stage0(); - await handshake.stage1(); - } catch (e) { - console.error("Error in IK handshake: ", e); - throw e; - } + await handshake.stage0(); + await handshake.stage1(); return handshake; } diff --git a/test/noise.test.ts b/test/noise.test.ts index 4281874..da52381 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -168,6 +168,7 @@ describe("Noise", () => { const staticKeysInitiator = generateKeypair(); const noiseInit = new Noise(staticKeysInitiator.privateKey); const noiseResp = new Noise(); + const xxSpy = sandbox.spy(noiseInit, "performXXFallbackHandshake"); // Prepare key cache for noise pipes await KeyCache.store(localPeer, staticKeysInitiator.publicKey); @@ -185,9 +186,46 @@ describe("Noise", () => { wrappedOutbound.writeLP(Buffer.from("test fallback")); const response = await wrappedInbound.readLP(); expect(response.toString()).equal("test fallback"); + + assert(xxSpy.calledOnce, "XX Fallback method was never called."); } catch (e) { console.error(e); assert(false, e.message); } }); + + it("XX fallback with XX as responder has noise pipes disabled", async() => { + try { + const staticKeysInitiator = generateKeypair(); + const noiseInit = new Noise(staticKeysInitiator.privateKey); + const staticKeysResponder = generateKeypair(); + console.log("staticKeysInitiator: ", staticKeysInitiator) + console.log("staticKeysResponder: ", staticKeysResponder) + const noiseResp = new Noise(staticKeysResponder.privateKey, undefined, false); + const xxSpy = sandbox.spy(noiseInit, "performXXFallbackHandshake"); + + // Prepare key cache for noise pipes + await KeyCache.store(localPeer, staticKeysInitiator.publicKey); + await KeyCache.store(remotePeer, staticKeysResponder.publicKey); + + const [inboundConnection, outboundConnection] = DuplexPair(); + + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), + ]); + + const wrappedInbound = Wrap(inbound.conn); + const wrappedOutbound = Wrap(outbound.conn); + + wrappedOutbound.writeLP(Buffer.from("test fallback")); + const response = await wrappedInbound.readLP(); + expect(response.toString()).equal("test fallback"); + + assert(xxSpy.calledOnce, "XX Fallback method was never called."); + } catch (e) { + console.error(e); + assert(false, e.message); + } + }); });