diff --git a/README.md b/README.md index 519c887..b7d6a7e 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,10 @@ [![](https://img.shields.io/badge/project-libp2p-yellow.svg?style=flat-square)](https://libp2p.io/) ![](https://img.shields.io/github/issues-raw/nodefactoryio/js-libp2p-noise) ![](https://img.shields.io/github/license/nodefactoryio/js-libp2p-noise) +[![Build Status](https://travis-ci.com/NodeFactoryIo/js-libp2p-noise.svg?branch=master)](https://travis-ci.com/NodeFactoryIo/js-libp2p-noise) +![](https://img.shields.io/badge/yarn-%3E%3D1.17.0-orange.svg?style=flat-square) +![](https://img.shields.io/badge/Node.js-%3E%3D12.4.0-orange.svg?style=flat-square) +[![Discourse posts](https://img.shields.io/discourse/https/discuss.libp2p.io/posts.svg)](https://discuss.libp2p.io) > Noise libp2p handshake for js-libp2p @@ -10,8 +14,39 @@ This repository contains TypeScript implementation of noise protocol, an encrypt ## Usage -TBD +When published, package should be imported as: `import { Noise } from 'libp2p-noise'`. + +Example of instantiating noise and passing it to the libp2p config: +``` +const NOISE = new Noise(privateKey); + +const libp2p = new Libp2p({ + modules: { + connEncryption: [NOISE], + }, +}); +``` + +Where parameters for Noise constructor are: + - *private key* - required parameter (32 bytes libp2p peer private key) + - *static Noise key* - (optional) existing private Noise static key + - *early data* - (optional) an early data payload to be sent in handshake messages + + ## API -TBD +This module exposes a crypto interface, as defined in the repository [js-interfaces](https://github.com/libp2p/js-interfaces). + +[ยป API Docs](https://github.com/libp2p/js-interfaces/tree/master/src/crypto#api) + + +## Contribute + +Feel free to join in. All welcome. Open an issue! + +[![](https://cdn.rawgit.com/jbenet/contribute-ipfs-gif/master/img/contribute.gif)](https://github.com/ipfs/community/blob/master/contributing.md) + +## License + +[MIT](LICENSE) diff --git a/package.json b/package.json index 62758b8..2babf57 100644 --- a/package.json +++ b/package.json @@ -5,6 +5,11 @@ "repository": "git@github.com:NodeFactoryIo/js-libp2p-noise.git", "author": "NodeFactory ", "license": "MIT", + "keywords": [ + "libp2p", + "noise", + "crypto" + ], "scripts": { "prebuild": "rm -rf lib", "build": "babel src -x .ts -d lib --source-maps", @@ -13,27 +18,6 @@ "pretest": "yarn check-types", "test": "DEBUG=libp2p:noise mocha -r ./babel-register.js \"test/**/*.test.ts\"" }, - "devDependencies": { - "@babel/cli": "^7.6.4", - "@babel/core": "^7.6.4", - "@babel/plugin-proposal-async-generator-functions": "^7.7.0", - "@babel/plugin-proposal-object-rest-spread": "^7.6.2", - "@babel/preset-env": "^7.6.3", - "@babel/preset-typescript": "^7.6.0", - "@babel/register": "^7.6.2", - "@babel/runtime": "^7.6.3", - "@types/chai": "^4.2.4", - "@types/mocha": "^5.2.7", - "@typescript-eslint/eslint-plugin": "^2.6.0", - "@typescript-eslint/parser": "^2.6.0", - "bn.js-typings": "^1.0.1", - "chai": "^4.2.0", - "eslint": "^6.6.0", - "libp2p-crypto": "^0.17.1", - "mocha": "^6.2.2", - "peer-id": "^0.13.5", - "typescript": "^3.6.4" - }, "babel": { "presets": [ [ @@ -51,6 +35,25 @@ "@babel/plugin-proposal-async-generator-functions" ] }, + "devDependencies": { + "@babel/cli": "^7.6.4", + "@babel/core": "^7.6.4", + "@babel/plugin-proposal-async-generator-functions": "^7.7.0", + "@babel/plugin-proposal-object-rest-spread": "^7.6.2", + "@babel/preset-env": "^7.6.3", + "@babel/preset-typescript": "^7.6.0", + "@babel/register": "^7.6.2", + "@babel/runtime": "^7.6.3", + "@types/chai": "^4.2.4", + "@types/mocha": "^5.2.7", + "@typescript-eslint/eslint-plugin": "^2.6.0", + "@typescript-eslint/parser": "^2.6.0", + "bn.js-typings": "^1.0.1", + "chai": "^4.2.0", + "eslint": "^6.6.0", + "mocha": "^6.2.2", + "typescript": "^3.6.4" + }, "dependencies": { "bcrypto": "^4.2.3", "bn.js": "^5.0.0", @@ -61,6 +64,8 @@ "it-pair": "^1.0.0", "it-pb-rpc": "^0.1.3", "it-pipe": "^1.1.0", + "libp2p-crypto": "^0.17.1", + "peer-id": "^0.13.5", "protobufjs": "~6.8.8" } } diff --git a/src/@types/libp2p.ts b/src/@types/libp2p.ts index 46bf95f..58ecd87 100644 --- a/src/@types/libp2p.ts +++ b/src/@types/libp2p.ts @@ -14,6 +14,8 @@ export type PeerId = { pubKey: { marshal(): bytes; }; + marshalPubKey(): bytes; + marshalPrivKey(): bytes; }; export interface NoiseConnection { diff --git a/src/crypto.ts b/src/crypto.ts index 86ee648..5d583d2 100644 --- a/src/crypto.ts +++ b/src/crypto.ts @@ -1,4 +1,3 @@ -import { Duplex } from "it-pair"; import { Handshake } from "./handshake"; import { Buffer } from "buffer"; diff --git a/src/encoder.ts b/src/encoder.ts new file mode 100644 index 0000000..ea946ef --- /dev/null +++ b/src/encoder.ts @@ -0,0 +1,27 @@ +import {Buffer} from "buffer"; +import {bytes} from "./@types/basic"; +import {MessageBuffer} from "./xx"; + +export const int16BEEncode = (value, target, offset) => { + target = target || Buffer.allocUnsafe(2); + return target.writeInt16BE(value, offset); +}; +int16BEEncode.bytes = 2; + +export const int16BEDecode = data => { + if (data.length < 2) throw RangeError('Could not decode int16BE'); + return data.readInt16BE(0); +}; +int16BEDecode.bytes = 2; + +export function encodeMessageBuffer(message: MessageBuffer): bytes { + return Buffer.concat([message.ne, message.ns, message.ciphertext]); +} + +export function decodeMessageBuffer(message: bytes): MessageBuffer { + return { + ne: message.slice(0, 32), + ns: message.slice(32, 64), + ciphertext: message.slice(64, message.length), + } +} diff --git a/src/handshake.ts b/src/handshake.ts index 5cbe50f..cb76381 100644 --- a/src/handshake.ts +++ b/src/handshake.ts @@ -2,15 +2,16 @@ import { Buffer } from "buffer"; import { bytes, bytes32 } from "./@types/basic"; import { NoiseSession, XXHandshake } from "./xx"; -import { KeyPair } from "./@types/libp2p"; +import { KeyPair, PeerId } from "./@types/libp2p"; import { createHandshakePayload, - decodeMessageBuffer, - encodeMessageBuffer, getHandshakePayload, - logger, signEarlyDataPayload, + signEarlyDataPayload, signPayload, + verifySignedPayload, } from "./utils"; +import { logger } from "./logger"; +import { decodeMessageBuffer, encodeMessageBuffer } from "./encoder"; import { WrappedConnection } from "./noise"; export class Handshake { @@ -22,6 +23,7 @@ export class Handshake { private prologue: bytes32; private staticKeys: KeyPair; private connection: WrappedConnection; + private remotePeer: PeerId; private xx: XXHandshake; constructor( @@ -31,6 +33,7 @@ export class Handshake { prologue: bytes32, staticKeys: KeyPair, connection: WrappedConnection, + remotePeer: PeerId, handshake?: XXHandshake, ) { this.isInitiator = isInitiator; @@ -39,32 +42,23 @@ export class Handshake { this.prologue = prologue; this.staticKeys = staticKeys; this.connection = connection; + this.remotePeer = remotePeer; this.xx = handshake || new XXHandshake(); this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeys); } // stage 0 - async propose(earlyData?: bytes): Promise { + async propose(): Promise { if (this.isInitiator) { logger("Stage 0 - Initiator starting to send first message."); - const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeys.publicKey)); - const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, earlyData || Buffer.alloc(0)); - const handshakePayload = await createHandshakePayload( - this.libp2pPublicKey, - this.libp2pPrivateKey, - signedPayload, - signedEarlyDataPayload - ); - const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, Buffer.alloc(0)); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); - - logger("Stage 0 - Initiator finished proposing, sent signed NoiseHandshake payload and static public key."); + logger("Stage 0 - Initiator finished sending first message."); } else { logger("Stage 0 - Responder waiting to receive first message..."); const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); - // TODO: Verify payload + this.xx.recvMessage(this.session, receivedMessageBuffer); logger("Stage 0 - Responder received first message."); } } @@ -74,36 +68,59 @@ export class Handshake { if (this.isInitiator) { logger('Stage 1 - Initiator waiting to receive first message from responder...'); const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); - // TODO: Verify payload + const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); + + logger("Initiator going to check remote's signature..."); + try { + await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + } catch (e) { + throw new Error(`Error occurred while verifying signed payload: ${e.message}`); + } + logger("All good with the signature!"); } else { logger('Stage 1 - Responder sending out first message with signed payload and static key.'); const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeys.publicKey)); + const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, Buffer.alloc(0)); const handshakePayload = await createHandshakePayload( this.libp2pPublicKey, this.libp2pPrivateKey, signedPayload, + signedEarlyDataPayload, ); - const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); - logger('Stage 1 - Responder sent the second handshake message.') + logger('Stage 1 - Responder sent the second handshake message with signed payload.') } } // stage 2 - async finish(): Promise { + async finish(earlyData?: bytes): Promise { if (this.isInitiator) { logger('Stage 2 - Initiator sending third handshake message.'); - const messageBuffer = await this.xx.sendMessage(this.session, Buffer.alloc(0)); + const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeys.publicKey)); + const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, earlyData || Buffer.alloc(0)); + const handshakePayload = await createHandshakePayload( + this.libp2pPublicKey, + this.libp2pPrivateKey, + signedPayload, + signedEarlyDataPayload + ); + const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); - logger('Stage 2 - Initiator sent message.'); + logger('Stage 2 - Initiator sent message with signed payload.'); } else { logger('Stage 2 - Responder waiting for third handshake message...'); - const receivedMessageBuffer = (await this.connection.readLP()).slice(); - const plaintext = await this.xx.recvMessage(this.session, decodeMessageBuffer(receivedMessageBuffer)); + const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); + + try { + await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + } catch (e) { + throw new Error(`Error occurred while verifying signed payload: ${e.message}`); + } } } diff --git a/src/logger.ts b/src/logger.ts new file mode 100644 index 0000000..150e061 --- /dev/null +++ b/src/logger.ts @@ -0,0 +1,2 @@ +import debug from "debug"; +export const logger = debug('libp2p:noise'); diff --git a/src/noise.ts b/src/noise.ts index fe33aed..83b313e 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -7,7 +7,8 @@ import pipe from 'it-pipe'; import lp from 'it-length-prefixed'; import { Handshake } from "./handshake"; -import { generateKeypair, int16BEDecode, int16BEEncode } from "./utils"; +import { generateKeypair } from "./utils"; +import { int16BEDecode, int16BEEncode } from "./encoder"; import { decryptStream, encryptStream } from "./crypto"; import { bytes } from "./@types/basic"; import { NoiseConnection, PeerId, KeyPair, SecureOutbound } from "./@types/libp2p"; @@ -46,8 +47,8 @@ export class Noise implements NoiseConnection { */ public async secureOutbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { const wrappedConnection = Wrap(connection); - const libp2pPublicKey = localPeer.pubKey.marshal(); - const handshake = await this.performHandshake(wrappedConnection, true, libp2pPublicKey); + const libp2pPublicKey = localPeer.marshalPubKey(); + const handshake = await this.performHandshake(wrappedConnection, true, libp2pPublicKey, remotePeer); const conn = await this.createSecureConnection(wrappedConnection, handshake); return { @@ -65,8 +66,8 @@ export class Noise implements NoiseConnection { */ public async secureInbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { const wrappedConnection = Wrap(connection); - const libp2pPublicKey = localPeer.pubKey.marshal(); - const handshake = await this.performHandshake(wrappedConnection, false, libp2pPublicKey); + const libp2pPublicKey = localPeer.marshalPubKey(); + const handshake = await this.performHandshake(wrappedConnection, false, libp2pPublicKey, remotePeer); const conn = await this.createSecureConnection(wrappedConnection, handshake); return { @@ -79,13 +80,18 @@ export class Noise implements NoiseConnection { connection: WrappedConnection, isInitiator: boolean, libp2pPublicKey: bytes, + remotePeer: PeerId, ): Promise { const prologue = Buffer.from(this.protocol); - const handshake = new Handshake(isInitiator, this.privateKey, libp2pPublicKey, prologue, this.staticKeys, connection); + const handshake = new Handshake(isInitiator, this.privateKey, libp2pPublicKey, prologue, this.staticKeys, connection, remotePeer); - await handshake.propose(this.earlyData); - await handshake.exchange(); - await handshake.finish(); + try { + await handshake.propose(); + await handshake.exchange(); + await handshake.finish(this.earlyData); + } catch (e) { + throw new Error(`Error occurred during handshake: ${e.message}`); + } return handshake; } diff --git a/src/utils.ts b/src/utils.ts index f65da02..65b7939 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,13 +1,11 @@ import { x25519, ed25519 } from 'bcrypto'; import protobuf from "protobufjs"; import { Buffer } from "buffer"; -import debug from "debug"; +import PeerId from "peer-id"; +import * as crypto from 'libp2p-crypto'; import { KeyPair } from "./@types/libp2p"; import { bytes } from "./@types/basic"; -import { MessageBuffer } from "./xx"; - -export const logger = debug('libp2p:noise'); export async function loadPayloadProto () { const payloadProtoBuf = await protobuf.load("protos/payload.proto"); @@ -70,26 +68,23 @@ export const getHandshakePayload = (publicKey: bytes ) => Buffer.concat([Buffer. export const getEarlyDataPayload = (earlyData: bytes) => Buffer.concat([Buffer.from("noise-libp2p-early-data:"), earlyData]); -export function encodeMessageBuffer(message: MessageBuffer): bytes { - return Buffer.concat([message.ne, message.ns, message.ciphertext]); +async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) { + const generatedPeerId = await PeerId.createFromPubKey(publicKeyProtobuf); + return generatedPeerId.id.equals(peerId); } -export function decodeMessageBuffer(message: bytes): MessageBuffer { - return { - ne: message.slice(0, 32), - ns: message.slice(32, 64), - ciphertext: message.slice(64, message.length), +export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: bytes, peerId: bytes) { + const NoiseHandshakePayload = await loadPayloadProto(); + const receivedPayload = NoiseHandshakePayload.toObject(NoiseHandshakePayload.decode(plaintext)); + + if (!(await isValidPeerId(peerId, receivedPayload.libp2pKey)) ) { + throw new Error("Peer ID doesn't match libp2p public key."); + } + + const generatedPayload = getHandshakePayload(noiseStaticKey); + // Unmarshaling from PublicKey protobuf and taking key buffer only. + const publicKey = crypto.keys.unmarshalPublicKey(receivedPayload.libp2pKey).marshal(); + if (!ed25519.verify(generatedPayload, receivedPayload.noiseStaticKeySignature, publicKey)) { + throw new Error("Static key doesn't match to peer that signed payload!"); } } - -export const int16BEEncode = (value, target, offset) => { - target = target || Buffer.allocUnsafe(2); - return target.writeInt16BE(value, offset); -}; -int16BEEncode.bytes = 2; - -export const int16BEDecode = data => { - if (data.length < 2) throw RangeError('Could not decode int16BE'); - return data.readInt16BE(0); -}; -int16BEDecode.bytes = 2; diff --git a/src/xx.ts b/src/xx.ts index 95aa619..d4b479a 100644 --- a/src/xx.ts +++ b/src/xx.ts @@ -191,7 +191,7 @@ export class XXHandshake { return SHA256.digest(Buffer.from([...a, ...b])); } - private async encryptAndHash(ss: SymmetricState, plaintext: bytes): Promise { + private encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes { let ciphertext; if (this.hasKey(ss.cs)) { ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext); @@ -203,7 +203,7 @@ export class XXHandshake { return ciphertext; } - private async decryptAndHash(ss: SymmetricState, ciphertext: bytes): Promise { + private decryptAndHash(ss: SymmetricState, ciphertext: bytes): bytes { let plaintext; if (this.hasKey(ss.cs)) { plaintext = this.decryptWithAd(ss.cs, ss.h, ciphertext); @@ -223,38 +223,38 @@ export class XXHandshake { return { cs1, cs2 }; } - private async writeMessageA(hs: HandshakeState, payload: bytes): Promise { + private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { const ns = Buffer.alloc(0); hs.e = generateKeypair(); const ne = hs.e.publicKey; this.mixHash(hs.ss, ne); - const ciphertext = await this.encryptAndHash(hs.ss, payload); + const ciphertext = this.encryptAndHash(hs.ss, payload); return {ne, ns, ciphertext}; } - private async writeMessageB(hs: HandshakeState, payload: bytes): Promise { + private writeMessageB(hs: HandshakeState, payload: bytes): MessageBuffer { hs.e = generateKeypair(); const ne = hs.e.publicKey; this.mixHash(hs.ss, ne); this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); const spk = Buffer.from(hs.s.publicKey); - const ns = await this.encryptAndHash(hs.ss, spk); + const ns = this.encryptAndHash(hs.ss, spk); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); - const ciphertext = await this.encryptAndHash(hs.ss, payload); + const ciphertext = this.encryptAndHash(hs.ss, payload); return { ne, ns, ciphertext }; } - private async writeMessageC(hs: HandshakeState, payload: bytes) { + private writeMessageC(hs: HandshakeState, payload: bytes) { const spk = Buffer.from(hs.s.publicKey); - const ns = await this.encryptAndHash(hs.ss, spk); + const ns = this.encryptAndHash(hs.ss, spk); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); - const ciphertext = await this.encryptAndHash(hs.ss, payload); + const ciphertext = this.encryptAndHash(hs.ss, payload); const ne = this.createEmptyKey(); const messageBuffer: MessageBuffer = {ne, ns, ciphertext}; const { cs1, cs2 } = this.split(hs.ss); @@ -262,7 +262,7 @@ export class XXHandshake { return { h: hs.ss.h, messageBuffer, cs1, cs2 }; } - private async writeMessageRegular(cs: CipherState, payload: bytes): Promise { + private writeMessageRegular(cs: CipherState, payload: bytes): MessageBuffer { const ciphertext = this.encryptWithAd(cs, Buffer.alloc(0), payload); const ne = this.createEmptyKey(); const ns = Buffer.alloc(0); @@ -270,16 +270,16 @@ export class XXHandshake { return { ne, ns, ciphertext }; } - private async readMessageA(hs: HandshakeState, message: MessageBuffer): Promise { + private readMessageA(hs: HandshakeState, message: MessageBuffer): bytes { if (x25519.publicKeyVerify(message.ne)) { hs.re = message.ne; } this.mixHash(hs.ss, hs.re); - return await this.decryptAndHash(hs.ss, message.ciphertext); + return this.decryptAndHash(hs.ss, message.ciphertext); } - private async readMessageB(hs: HandshakeState, message: MessageBuffer): Promise { + private readMessageB(hs: HandshakeState, message: MessageBuffer): bytes { if (x25519.publicKeyVerify(message.ne)) { hs.re = message.ne; } @@ -289,16 +289,16 @@ export class XXHandshake { throw new Error("Handshake state `e` param is missing."); } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); - const ns = await this.decryptAndHash(hs.ss, message.ns); + const ns = this.decryptAndHash(hs.ss, message.ns); if (ns.length === 32 && x25519.publicKeyVerify(message.ns)) { hs.rs = ns; } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); - return await this.decryptAndHash(hs.ss, message.ciphertext); + return this.decryptAndHash(hs.ss, message.ciphertext); } - private async readMessageC(hs: HandshakeState, message: MessageBuffer) { - const ns = await this.decryptAndHash(hs.ss, message.ns); + private readMessageC(hs: HandshakeState, message: MessageBuffer) { + const ns = this.decryptAndHash(hs.ss, message.ns); if (ns.length === 32 && x25519.publicKeyVerify(message.ns)) { hs.rs = ns; } @@ -308,7 +308,7 @@ export class XXHandshake { } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); - const plaintext = await this.decryptAndHash(hs.ss, message.ciphertext); + const plaintext = this.decryptAndHash(hs.ss, message.ciphertext); const { cs1, cs2 } = this.split(hs.ss); return { h: hs.ss.h, plaintext, cs1, cs2 }; @@ -336,14 +336,14 @@ export class XXHandshake { }; } - public async sendMessage(session: NoiseSession, message: bytes): Promise { + public sendMessage(session: NoiseSession, message: bytes): MessageBuffer { let messageBuffer: MessageBuffer; if (session.mc.eqn(0)) { - messageBuffer = await this.writeMessageA(session.hs, message); + messageBuffer = this.writeMessageA(session.hs, message); } else if (session.mc.eqn(1)) { - messageBuffer = await this.writeMessageB(session.hs, message); + messageBuffer = this.writeMessageB(session.hs, message); } else if (session.mc.eqn(2)) { - const { h, messageBuffer: resultingBuffer, cs1, cs2 } = await this.writeMessageC(session.hs, message); + const { h, messageBuffer: resultingBuffer, cs1, cs2 } = this.writeMessageC(session.hs, message); messageBuffer = resultingBuffer; session.h = h; session.cs1 = cs1; @@ -354,13 +354,13 @@ export class XXHandshake { throw new Error("CS1 (cipher state) is not defined") } - messageBuffer = await this.writeMessageRegular(session.cs1, message); + messageBuffer = this.writeMessageRegular(session.cs1, message); } else { if (!session.cs2) { throw new Error("CS2 (cipher state) is not defined") } - messageBuffer = await this.writeMessageRegular(session.cs2, message); + messageBuffer = this.writeMessageRegular(session.cs2, message); } } else { throw new Error("Session invalid.") @@ -370,14 +370,14 @@ export class XXHandshake { return messageBuffer; } - public async recvMessage(session: NoiseSession, message: MessageBuffer): Promise { + public recvMessage(session: NoiseSession, message: MessageBuffer): bytes { let plaintext: bytes; if (session.mc.eqn(0)) { - plaintext = await this.readMessageA(session.hs, message); + plaintext = this.readMessageA(session.hs, message); } else if (session.mc.eqn(1)) { - plaintext = await this.readMessageB(session.hs, message); + plaintext = this.readMessageB(session.hs, message); } else if (session.mc.eqn(2)) { - const { h, plaintext: resultingPlaintext, cs1, cs2 } = await this.readMessageC(session.hs, message); + const { h, plaintext: resultingPlaintext, cs1, cs2 } = this.readMessageC(session.hs, message); plaintext = resultingPlaintext; session.h = h; session.cs1 = cs1; @@ -387,12 +387,12 @@ export class XXHandshake { if (!session.cs2) { throw new Error("CS1 (cipher state) is not defined") } - plaintext = await this.readMessageRegular(session.cs2, message); + plaintext = this.readMessageRegular(session.cs2, message); } else { if (!session.cs1) { throw new Error("CS1 (cipher state) is not defined") } - plaintext = await this.readMessageRegular(session.cs1, message); + plaintext = this.readMessageRegular(session.cs1, message); } } else { throw new Error("Session invalid."); diff --git a/test/handshake.test.ts b/test/handshake.test.ts index 7166f49..6019711 100644 --- a/test/handshake.test.ts +++ b/test/handshake.test.ts @@ -1,55 +1,122 @@ -import {assert} from "chai"; +import {assert, expect} from "chai"; import Duplex from 'it-pair/duplex'; import {Buffer} from "buffer"; import Wrap from "it-pb-rpc"; import {Handshake} from "../src/handshake"; import {generateKeypair} from "../src/utils"; -import {createPeerIds} from "./fixtures/peer"; +import {createPeerIdsFromFixtures} from "./fixtures/peer"; +import {getKeyPairFromPeerId} from "./utils"; describe("Handshake", () => { + let peerA, peerB, fakePeer; + + before(async () => { + [peerA, peerB, fakePeer] = await createPeerIdsFromFixtures(3); + }); + it("should propose, exchange and finish handshake", async() => { - const duplex = Duplex(); - const connectionFrom = Wrap(duplex[0]); - const connectionTo = Wrap(duplex[1]); + try { + const duplex = Duplex(); + const connectionFrom = Wrap(duplex[0]); + const connectionTo = Wrap(duplex[1]); - const prologue = Buffer.from('/noise'); - const staticKeysInitiator = generateKeypair(); - const staticKeysResponder = generateKeypair(); - const [peerA, peerB] = await createPeerIds(2); + const prologue = Buffer.from('/noise'); + const staticKeysInitiator = generateKeypair(); + const staticKeysResponder = generateKeypair(); - const initiatorPrivKey = peerA.privKey.marshal().slice(0, 32); - const initiatorPubKey = peerA.pubKey.marshal(); - const handshakeInitator = new Handshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom); + const { privateKey: initiatorPrivKey, publicKey: initiatorPubKey } = getKeyPairFromPeerId(peerA); + const handshakeInitator = new Handshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, peerB); - const responderPrivKey = peerB.privKey.marshal().slice(0, 32); - const responderPubKey = peerB.pubKey.marshal(); - const handshakeResponder = new Handshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo); + const { privateKey: responderPrivKey, publicKey: responderPubKey } = getKeyPairFromPeerId(peerB); + const handshakeResponder = new Handshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, peerA); - await handshakeInitator.propose(); - await handshakeResponder.propose(); + await handshakeInitator.propose(); + await handshakeResponder.propose(); - await handshakeResponder.exchange(); - await handshakeInitator.exchange(); + await handshakeResponder.exchange(); + await handshakeInitator.exchange(); - await handshakeInitator.finish(); - await handshakeResponder.finish(); + await handshakeInitator.finish(); + await handshakeResponder.finish(); - const sessionInitator = handshakeInitator.session; - const sessionResponder = handshakeResponder.session; + const sessionInitator = handshakeInitator.session; + const sessionResponder = handshakeResponder.session; - // Test shared key - if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { - assert(sessionInitator.cs1.k.equals(sessionResponder.cs1.k)); - assert(sessionInitator.cs2.k.equals(sessionResponder.cs2.k)); - } else { - assert(false); + // Test shared key + if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { + assert(sessionInitator.cs1.k.equals(sessionResponder.cs1.k)); + assert(sessionInitator.cs2.k.equals(sessionResponder.cs2.k)); + } else { + assert(false); + } + + // Test encryption and decryption + const encrypted = handshakeInitator.encrypt(Buffer.from("encryptthis"), handshakeInitator.session); + const decrypted = handshakeResponder.decrypt(encrypted, handshakeResponder.session); + assert(decrypted.equals(Buffer.from("encryptthis"))); + } catch (e) { + assert(false, e.message); } + }); - // Test encryption and decryption - const encrypted = handshakeInitator.encrypt(Buffer.from("encryptthis"), handshakeInitator.session); - const decrypted = handshakeResponder.decrypt(encrypted, handshakeResponder.session); - assert(decrypted.equals(Buffer.from("encryptthis"))); + it("Initiator should fail to exchange handshake if given wrong public key in payload", async() => { + try { + const duplex = Duplex(); + const connectionFrom = Wrap(duplex[0]); + const connectionTo = Wrap(duplex[1]); + + const prologue = Buffer.from('/noise'); + const staticKeysInitiator = generateKeypair(); + const staticKeysResponder = generateKeypair(); + + const { privateKey: initiatorPrivKey, publicKey: initiatorPubKey } = getKeyPairFromPeerId(peerA); + const handshakeInitator = new Handshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, fakePeer); + + const { privateKey: responderPrivKey, publicKey: responderPubKey } = getKeyPairFromPeerId(peerB); + const handshakeResponder = new Handshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, peerA); + + await handshakeInitator.propose(); + await handshakeResponder.propose(); + + await handshakeResponder.exchange(); + await handshakeInitator.exchange(); + + assert(false, "Should throw exception"); + } catch (e) { + expect(e.message).equals("Error occurred while verifying signed payload: Peer ID doesn't match libp2p public key.") + } + }); + + it("Responder should fail to exchange handshake if given wrong public key in payload", async() => { + try { + const duplex = Duplex(); + const connectionFrom = Wrap(duplex[0]); + const connectionTo = Wrap(duplex[1]); + + const prologue = Buffer.from('/noise'); + const staticKeysInitiator = generateKeypair(); + const staticKeysResponder = generateKeypair(); + + const { privateKey: initiatorPrivKey, publicKey: initiatorPubKey } = getKeyPairFromPeerId(peerA); + const handshakeInitator = new Handshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, peerB); + + const { privateKey: responderPrivKey, publicKey: responderPubKey } = getKeyPairFromPeerId(peerB); + const handshakeResponder = new Handshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, fakePeer); + + await handshakeInitator.propose(); + await handshakeResponder.propose(); + + await handshakeResponder.exchange(); + await handshakeInitator.exchange(); + + await handshakeInitator.finish(); + await handshakeResponder.finish(); + + assert(false, "Should throw exception"); + } catch (e) { + expect(e.message).equals("Error occurred while verifying signed payload: Peer ID doesn't match libp2p public key.") + } }); }); diff --git a/test/noise.test.ts b/test/noise.test.ts index 22f5578..6d42cd5 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -2,51 +2,53 @@ import { expect, assert } from "chai"; import DuplexPair from 'it-pair/duplex'; import { Noise } from "../src"; -import { generateEd25519Keys } from "./utils"; -import {createPeerIds, createPeerIdsFromFixtures} from "./fixtures/peer"; +import {createPeerIdsFromFixtures} from "./fixtures/peer"; import Wrap from "it-pb-rpc"; import {Handshake} from "../src/handshake"; import { createHandshakePayload, - decodeMessageBuffer, - encodeMessageBuffer, generateKeypair, getHandshakePayload, signPayload } from "../src/utils"; +import { decodeMessageBuffer, encodeMessageBuffer } from "../src/encoder"; import {XXHandshake} from "../src/xx"; import {Buffer} from "buffer"; +import {getKeyPairFromPeerId} from "./utils"; describe("Noise", () => { let remotePeer, localPeer; before(async () => { - [localPeer, remotePeer] = await createPeerIds(2); + [localPeer, remotePeer] = await createPeerIdsFromFixtures(2); }); it("should communicate through encrypted streams", async() => { - const libp2pInitPrivKey = localPeer.privKey.marshal().slice(0, 32); - const libp2pRespPrivKey = remotePeer.privKey.marshal().slice(0, 32); + try { + const { privateKey: libp2pInitPrivKey } = getKeyPairFromPeerId(localPeer); + const { privateKey: libp2pRespPrivKey } = getKeyPairFromPeerId(remotePeer); + const noiseInit = new Noise(libp2pInitPrivKey); + const noiseResp = new Noise(libp2pRespPrivKey); - const noiseInit = new Noise(libp2pInitPrivKey); - const noiseResp = new Noise(libp2pRespPrivKey); + const [inboundConnection, outboundConnection] = DuplexPair(); + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), + ]); + const wrappedInbound = Wrap(inbound.conn); + const wrappedOutbound = Wrap(outbound.conn); - const [inboundConnection, outboundConnection] = DuplexPair(); - const [outbound, inbound] = await Promise.all([ - noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), - noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), - ]); - const wrappedInbound = Wrap(inbound.conn); - const wrappedOutbound = Wrap(outbound.conn); - - wrappedOutbound.writeLP(Buffer.from("test")); - const response = await wrappedInbound.readLP(); - expect(response.toString()).equal("test"); + wrappedOutbound.writeLP(Buffer.from("test")); + const response = await wrappedInbound.readLP(); + expect(response.toString()).equal("test"); + } catch (e) { + assert(false, e.message); + } }); it("should test that secureOutbound is spec compliant", async() => { - const libp2pPrivKey = localPeer.privKey.marshal().slice(0, 32); - const noiseInit = new Noise(libp2pPrivKey); + const { privateKey: libp2pInitPrivKey } = getKeyPairFromPeerId(localPeer); + const noiseInit = new Noise(libp2pInitPrivKey); const [inboundConnection, outboundConnection] = DuplexPair(); const [outbound, { wrapped, handshake }] = await Promise.all([ @@ -56,37 +58,42 @@ describe("Noise", () => { const prologue = Buffer.from('/noise'); const staticKeys = generateKeypair(); const xx = new XXHandshake(); - const libp2pPubKey = remotePeer.pubKey.marshal().slice(32, 64); - const handshake = new Handshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, xx); + const { privateKey: libp2pPrivKey, publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); + + const handshake = new Handshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, localPeer, xx); let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key expect(receivedMessageBuffer.ne.length).equal(32); - await xx.recvMessage(handshake.session, receivedMessageBuffer); + xx.recvMessage(handshake.session, receivedMessageBuffer); // Stage 1 const signedPayload = signPayload(libp2pPrivKey, getHandshakePayload(staticKeys.publicKey)); const handshakePayload = await createHandshakePayload(libp2pPubKey, libp2pPrivKey, signedPayload); - const messageBuffer = await xx.sendMessage(handshake.session, handshakePayload); + const messageBuffer = xx.sendMessage(handshake.session, handshakePayload); wrapped.writeLP(encodeMessageBuffer(messageBuffer)); // Stage 2 - finish handshake receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); - await xx.recvMessage(handshake.session, receivedMessageBuffer); - return { wrapped, handshake }; + xx.recvMessage(handshake.session, receivedMessageBuffer); + return {wrapped, handshake}; })(), ]); - const wrappedOutbound = Wrap(outbound.conn); - wrappedOutbound.write(Buffer.from("test")); + try { + const wrappedOutbound = Wrap(outbound.conn); + wrappedOutbound.write(Buffer.from("test")); - // Check that noise message is prefixed with 16-bit big-endian unsigned integer - const receivedEncryptedPayload = (await wrapped.read()).slice(); - const dataLength = receivedEncryptedPayload.readInt16BE(0); - const data = receivedEncryptedPayload.slice(2, dataLength + 2); - const decrypted = handshake.decrypt(data, handshake.session); - // Decrypted data should match - assert(decrypted.equals(Buffer.from("test"))); + // Check that noise message is prefixed with 16-bit big-endian unsigned integer + const receivedEncryptedPayload = (await wrapped.read()).slice(); + const dataLength = receivedEncryptedPayload.readInt16BE(0); + const data = receivedEncryptedPayload.slice(2, dataLength + 2); + const decrypted = handshake.decrypt(data, handshake.session); + // Decrypted data should match + assert(decrypted.equals(Buffer.from("test"))); + } catch (e) { + assert(false, e.message); + } }) }); diff --git a/test/utils.ts b/test/utils.ts index 6a6580a..2ac5b7b 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -1,5 +1,13 @@ import * as crypto from 'libp2p-crypto'; +import {KeyPair, PeerId} from "../src/@types/libp2p"; export async function generateEd25519Keys() { return await crypto.keys.generateKeyPair('ed25519'); } + +export function getKeyPairFromPeerId(peerId: PeerId): KeyPair { + return { + privateKey: peerId.privKey.marshal().slice(0, 32), + publicKey: peerId.marshalPubKey(), + } +} diff --git a/test/xx.test.ts b/test/xx.test.ts index 9ff5c6b..1eca6b4 100644 --- a/test/xx.test.ts +++ b/test/xx.test.ts @@ -48,25 +48,26 @@ describe("Index", () => { const respSignedPayload = await libp2pRespKeys.sign(getHandshakePayload(kpResp.publicKey)); // initiator: new XX noise session - const nsInit = await xx.initSession(true, prologue, kpInit); + const nsInit = xx.initSession(true, prologue, kpInit); // responder: new XX noise session - const nsResp = await xx.initSession(false, prologue, kpResp); + const nsResp = xx.initSession(false, prologue, kpResp); /* STAGE 0 */ // initiator creates payload const libp2pInitPrivKey = libp2pInitKeys.marshal().slice(0, 32); const libp2pInitPubKey = libp2pInitKeys.marshal().slice(32, 64); + const payloadInitEnc = await createHandshakePayload(libp2pInitPubKey, libp2pInitPrivKey, initSignedPayload); // initiator sends message const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]); - const messageBuffer = await xx.sendMessage(nsInit, message); + const messageBuffer = xx.sendMessage(nsInit, message); expect(messageBuffer.ne.length).not.equal(0); // responder receives message - const plaintext = await xx.recvMessage(nsResp, messageBuffer); + const plaintext = xx.recvMessage(nsResp, messageBuffer); console.log("Stage 0 responder payload: ", plaintext); /* STAGE 1 */ @@ -77,22 +78,22 @@ describe("Index", () => { const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, libp2pRespPrivKey, respSignedPayload); const message1 = Buffer.concat([message, payloadRespEnc]); - const messageBuffer2 = await xx.sendMessage(nsResp, message1); + const messageBuffer2 = xx.sendMessage(nsResp, message1); expect(messageBuffer2.ne.length).not.equal(0); expect(messageBuffer2.ns.length).not.equal(0); // initiator receive payload - const plaintext2 = await xx.recvMessage(nsInit, messageBuffer2); + const plaintext2 = xx.recvMessage(nsInit, messageBuffer2); console.log("Stage 1 responder payload: ", plaintext2); /* STAGE 2 */ // initiator send message - const messageBuffer3 = await xx.sendMessage(nsInit, Buffer.alloc(0)); + const messageBuffer3 = xx.sendMessage(nsInit, Buffer.alloc(0)); // responder receive message - const plaintext3 = await xx.recvMessage(nsResp, messageBuffer3); + const plaintext3 = xx.recvMessage(nsResp, messageBuffer3); console.log("Stage 2 responder payload: ", plaintext3); assert(nsInit.cs1.k.equals(nsResp.cs1.k));