diff --git a/src/encoder.ts b/src/encoder.ts index e4e2c30..4f869e6 100644 --- a/src/encoder.ts +++ b/src/encoder.ts @@ -14,14 +14,34 @@ export const uint16BEDecode = data => { }; uint16BEDecode.bytes = 2; -export function encodeMessageBuffer(message: MessageBuffer): bytes { +export function encode0(message: MessageBuffer): bytes { + return Buffer.concat([message.ne, message.ciphertext]); +} + +export function encode1(message: MessageBuffer): bytes { return Buffer.concat([message.ne, message.ns, message.ciphertext]); } -export function decodeMessageBuffer(message: bytes): MessageBuffer { +export function decode0(input: bytes): MessageBuffer { + if (input.length < 32) { + throw new Error("Cannot decode stage 0 MessageBuffer: length less than 32 bytes."); + } + return { - ne: message.slice(0, 32), - ns: message.slice(32, 64), - ciphertext: message.slice(64, message.length), + ne: input.slice(0, 32), + ciphertext: input.slice(32, input.length), + ns: Buffer.alloc(0), + } +} + +export function decode1(input: bytes): MessageBuffer { + if (input.length < 96) { + throw new Error("Cannot decode stage 0 MessageBuffer: length less than 96 bytes."); + } + + return { + ne: input.slice(0, 32), + ns: input.slice(32, 64), + ciphertext: input.slice(64, input.length), } } diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index 594f6e4..d25a7cf 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -12,8 +12,8 @@ import { verifySignedPayload, } from "./utils"; import { logger } from "./logger"; -import { decodeMessageBuffer, encodeMessageBuffer } from "./encoder"; import { WrappedConnection } from "./noise"; +import {decode0, decode1, encode1} from "./encoder"; export class Handshake extends XXHandshake { private ephemeralKeys: KeyPair; @@ -43,7 +43,7 @@ export class Handshake extends XXHandshake { logger("XX Fallback Stage 0 - Initialized state as the first message was sent by initiator."); } else { logger("XX Fallback Stage 0 - Responder waiting to receive first message..."); - const receivedMessageBuffer = decodeMessageBuffer(this.initialMsg); + const receivedMessageBuffer = decode0(this.initialMsg); console.log("receivedMessageBuffer: ", receivedMessageBuffer) this.xx.recvMessage(this.session, { ne: receivedMessageBuffer.ne, @@ -58,7 +58,7 @@ export class Handshake extends XXHandshake { public async exchange(): Promise { if (this.isInitiator) { logger('XX Fallback Stage 1 - Initiator waiting to receive first message from responder...'); - const receivedMessageBuffer = decodeMessageBuffer(this.initialMsg); + const receivedMessageBuffer = decode1(this.initialMsg); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('XX Fallback Stage 1 - Initiator received the message. Got remote\'s static key.'); @@ -81,7 +81,7 @@ export class Handshake extends XXHandshake { ); const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode1(messageBuffer)); logger('Stage 1 - Responder sent the second handshake message with signed payload.') } } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 61611cc..f5fd353 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -13,7 +13,7 @@ import { verifySignedPayload, } from "./utils"; import { logger } from "./logger"; -import { decodeMessageBuffer, encodeMessageBuffer } from "./encoder"; +import { decode0, decode1, encode0, encode1 } from "./encoder"; import { WrappedConnection } from "./noise"; export class Handshake implements HandshakeInterface { @@ -56,11 +56,11 @@ export class Handshake implements HandshakeInterface { if (this.isInitiator) { logger("Stage 0 - Initiator starting to send first message."); const messageBuffer = this.xx.sendMessage(this.session, Buffer.alloc(0)); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode0(messageBuffer)); logger("Stage 0 - Initiator finished sending first message."); } else { logger("Stage 0 - Responder waiting to receive first message..."); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode0((await this.connection.readLP()).slice()); this.xx.recvMessage(this.session, receivedMessageBuffer); logger("Stage 0 - Responder received first message."); } @@ -70,7 +70,7 @@ export class Handshake implements HandshakeInterface { public async exchange(): Promise { if (this.isInitiator) { logger('Stage 1 - Initiator waiting to receive first message from responder...'); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode1((await this.connection.readLP()).slice()); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); @@ -93,7 +93,7 @@ export class Handshake implements HandshakeInterface { ); const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode1(messageBuffer)); logger('Stage 1 - Responder sent the second handshake message with signed payload.') } } @@ -111,11 +111,11 @@ export class Handshake implements HandshakeInterface { signedEarlyDataPayload ); const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode1(messageBuffer)); logger('Stage 2 - Initiator sent message with signed payload.'); } else { logger('Stage 2 - Responder waiting for third handshake message...'); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode1((await this.connection.readLP()).slice()); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts index 857a9b5..110dc7e 100644 --- a/src/handshakes/xx.ts +++ b/src/handshakes/xx.ts @@ -199,28 +199,4 @@ export class XX extends AbstractHandshake { session.mc = session.mc.add(new BN(1)); return plaintext; } - - public decode0(input: bytes): MessageBuffer { - if (input.length < 32) { - throw new Error("Cannot decode stage 0 MessageBuffer: length less than 32 bytes."); - } - - return { - ne: input.slice(0, 32), - ciphertext: input.slice(32, input.length), - ns: Buffer.alloc(0), - } - } - - public decode1(input: bytes): MessageBuffer { - if (input.length < 96) { - throw new Error("Cannot decode stage 0 MessageBuffer: length less than 96 bytes."); - } - - return { - ne: input.slice(0, 32), - ns: input.slice(32, 80), - ciphertext: input.slice(80, input.length), - } - } } diff --git a/test/noise.test.ts b/test/noise.test.ts index 8f78b76..00297a0 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -12,7 +12,7 @@ import { getHandshakePayload, signPayload } from "../src/utils"; -import { decodeMessageBuffer, encodeMessageBuffer } from "../src/encoder"; +import {decode0, decode1, encode1} from "../src/encoder"; import {XX} from "../src/handshakes/xx"; import {Buffer} from "buffer"; import {getKeyPairFromPeerId} from "./utils"; @@ -63,7 +63,7 @@ describe("Noise", () => { const handshake = new Handshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, localPeer, xx); - let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); + let receivedMessageBuffer = decode0((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key expect(receivedMessageBuffer.ne.length).equal(32); xx.recvMessage(handshake.session, receivedMessageBuffer); @@ -73,10 +73,10 @@ describe("Noise", () => { const handshakePayload = await createHandshakePayload(libp2pPubKey, libp2pPrivKey, signedPayload); const messageBuffer = xx.sendMessage(handshake.session, handshakePayload); - wrapped.writeLP(encodeMessageBuffer(messageBuffer)); + wrapped.writeLP(encode1(messageBuffer)); // Stage 2 - finish handshake - receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); + receivedMessageBuffer = decode1((await wrapped.readLP()).slice()); xx.recvMessage(handshake.session, receivedMessageBuffer); return {wrapped, handshake}; })(), diff --git a/test/xx-fallback-handshake.test.ts b/test/xx-fallback-handshake.test.ts index 9eb7b70..8c742ce 100644 --- a/test/xx-fallback-handshake.test.ts +++ b/test/xx-fallback-handshake.test.ts @@ -13,7 +13,7 @@ import {generateEd25519Keys, getKeyPairFromPeerId} from "./utils"; import {Handshake} from "../src/handshake-xx-fallback"; import {createPeerIdsFromFixtures} from "./fixtures/peer"; import {assert} from "chai"; -import {encodeMessageBuffer} from "../src/encoder"; +import {encode0} from "../src/encoder"; describe("XX Fallback Handshake", () => { let peerA, peerB, fakePeer; @@ -43,7 +43,7 @@ describe("XX Fallback Handshake", () => { signedPayload, signedEarlyDataPayload, ); - const initialMsg = encodeMessageBuffer({ + const initialMsg = encode0({ ne: staticKeysInitiator.publicKey, ns: Buffer.alloc(32), ciphertext: handshakePayload,