diff --git a/src/@types/it-pb-rpc/index.d.ts b/src/@types/it-pb-rpc/index.d.ts index 012bc39..c8741f3 100644 --- a/src/@types/it-pb-rpc/index.d.ts +++ b/src/@types/it-pb-rpc/index.d.ts @@ -2,7 +2,7 @@ declare module "it-pb-rpc" { import { Buffer } from "buffer"; import { Duplex } from "it-pair"; type WrappedDuplex = { - read(bytes: number): Promise, + read(bytes?: number): Promise, readLP(): Promise, write(input: Buffer): void, writeLP(input: Buffer): void, diff --git a/src/handshake.ts b/src/handshake.ts index 4fc9a0e..29581ac 100644 --- a/src/handshake.ts +++ b/src/handshake.ts @@ -32,6 +32,7 @@ export class Handshake { prologue: bytes32, staticKeys: KeyPair, connection: WrappedConnection, + handshake?: XXHandshake, ) { this.type = type; this.isInitiator = isInitiator; @@ -40,7 +41,7 @@ export class Handshake { this.staticKeys = staticKeys; this.connection = connection; - this.xx = new XXHandshake(); + this.xx = handshake || new XXHandshake(); } // stage 0 @@ -55,8 +56,7 @@ export class Handshake { earlyData, this.staticKeys.privateKey ); - const message = Buffer.concat([Buffer.alloc(0), handshakePayload]); - const messageBuffer = await this.xx.sendMessage(ns, message); + const messageBuffer = await this.xx.sendMessage(ns, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); logger("Stage 0 - Initiator finished proposing"); @@ -80,8 +80,7 @@ export class Handshake { const signedPayload = signPayload(this.staticKeys.privateKey, getHandshakePayload(this.staticKeys.publicKey)); const handshakePayload = await createHandshakePayload(this.remotePublicKey, signedPayload); - const message = Buffer.concat([Buffer.alloc(0), handshakePayload]); - const messageBuffer = await this.xx.sendMessage(session, message); + const messageBuffer = await this.xx.sendMessage(session, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); logger('Stage 1 - Responder sent the message.') } diff --git a/src/noise.ts b/src/noise.ts index 7cb06f4..2795c4c 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -5,10 +5,9 @@ import DuplexPair from 'it-pair/duplex'; import ensureBuffer from 'it-buffer'; import pipe from 'it-pipe'; import lp from 'it-length-prefixed'; -const { int16BEEncode, int16BEDecode } = lp; import { Handshake } from "./handshake"; -import { generateKeypair } from "./utils"; +import { generateKeypair, int16BEDecode, int16BEEncode } from "./utils"; import { decryptStream, encryptStream } from "./crypto"; import { bytes } from "./@types/basic"; import { NoiseConnection, PeerId, KeyPair, SecureOutbound } from "./@types/libp2p"; diff --git a/src/utils.ts b/src/utils.ts index 31c03ce..f87609e 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -74,3 +74,13 @@ export function decodeMessageBuffer(message: bytes) : MessageBuffer { } } +export const int16BEEncode = (value, target, offset) => { + target = target || Buffer.allocUnsafe(2); + return target.writeInt16BE(value, offset); +}; +int16BEEncode.bytes = 2; + +export const int16BEDecode = data => { + if (data.length < 2) throw RangeError('Could not decode int16BE'); + return data.readInt16BE(0);} +int16BEDecode.bytes = 2; diff --git a/src/xx.ts b/src/xx.ts index fa30fae..851013f 100644 --- a/src/xx.ts +++ b/src/xx.ts @@ -308,6 +308,7 @@ export class XXHandshake { throw new Error("Handshake state `e` param is missing."); } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); + const plaintext = await this.decryptAndHash(hs.ss, message.ciphertext); const { cs1, cs2 } = this.split(hs.ss); diff --git a/test/noise.test.ts b/test/noise.test.ts index ce4208f..25eaa98 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -1,4 +1,4 @@ -import { expect } from "chai"; +import { expect, assert } from "chai"; import DuplexPair from 'it-pair/duplex'; import { Noise } from "../src"; @@ -63,7 +63,7 @@ describe("Noise", () => { const noiseInit = new Noise(libp2pKeys._key, localPeer.privKey.bytes); const [inboundConnection, outboundConnection] = DuplexPair(); - const [outbound] = await Promise.all([ + const [outbound, { wrapped, ns, handshake }] = await Promise.all([ noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), (async () => { const wrapped = Wrap(inboundConnection); @@ -72,16 +72,38 @@ describe("Noise", () => { privateKey: remotePeer.privKey.bytes, publicKey: remotePeer.pubKey.bytes, }; - const handshake = new Handshake('XX', false, localPeer.pubKey.bytes, prologue, staticKeys, wrapped); + const xx = new XXHandshake(); + const handshake = new Handshake('XX', false, localPeer.pubKey.bytes, prologue, staticKeys, wrapped, xx); + const ns = await xx.initSession(false, prologue, staticKeys, localPeer.pubKey.bytes); - // Finish handshake - const sessionResponder = await handshake.propose(Buffer.alloc(0)); - await handshake.exchange(sessionResponder); - await handshake.finish(sessionResponder); + let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); + // The first handshake message contains the initiator's ephemeral public key + expect(receivedMessageBuffer.ne.length).equal(32); + await xx.recvMessage(ns, receivedMessageBuffer); - // Create the encrypted streams - console.log(sessionResponder); + // Stage 1 + const signedPayload = signPayload(staticKeys.privateKey, getHandshakePayload(staticKeys.publicKey)); + const handshakePayload = await createHandshakePayload(localPeer.pubKey.bytes, signedPayload); + + const messageBuffer = await xx.sendMessage(ns, handshakePayload); + wrapped.writeLP(encodeMessageBuffer(messageBuffer)); + + // Stage 2 - finish handshake + receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); + await xx.recvMessage(ns, receivedMessageBuffer); + return { wrapped, ns, handshake }; })(), ]); + + const wrappedOutbound = Wrap(outbound.conn); + wrappedOutbound.write(Buffer.from("test")); + + // Check that noise message is prefixed with 16-bit big-endian unsigned integer + const receivedEncryptedPayload = (await wrapped.read()).slice(); + const dataLength = receivedEncryptedPayload.readInt16BE(0); + const data = receivedEncryptedPayload.slice(2, dataLength + 2); + const decrypted = handshake.decrypt(data, ns); + // Decrypted data should match + assert(decrypted.equals(Buffer.from("test"))); }) });