diff --git a/.eslintrc b/.eslintrc index a621e68..b182f12 100644 --- a/.eslintrc +++ b/.eslintrc @@ -15,6 +15,7 @@ "@typescript-eslint/indent": ["error", 2], "@typescript-eslint/no-use-before-define": "off", "@typescript-eslint/no-explicit-any": "off", + "@typescript-eslint/interface-name-prefix": ["error", { "prefixWithI": "always" }], "no-console": "warn" } -} \ No newline at end of file +} diff --git a/src/@types/handshake-interface.ts b/src/@types/handshake-interface.ts new file mode 100644 index 0000000..d92b9ff --- /dev/null +++ b/src/@types/handshake-interface.ts @@ -0,0 +1,8 @@ +import {bytes} from "./basic"; +import {NoiseSession} from "./handshake"; + +export interface IHandshake { + session: NoiseSession; + encrypt(plaintext: bytes, session: NoiseSession): bytes; + decrypt(ciphertext: bytes, session: NoiseSession): bytes; +} diff --git a/src/@types/handshake.ts b/src/@types/handshake.ts index 265f769..cb67eeb 100644 --- a/src/@types/handshake.ts +++ b/src/@types/handshake.ts @@ -3,7 +3,7 @@ import {KeyPair} from "./libp2p"; export type Hkdf = [bytes, bytes, bytes]; -export interface MessageBuffer { +export type MessageBuffer = { ne: bytes32; ns: bytes; ciphertext: bytes; diff --git a/src/@types/libp2p.ts b/src/@types/libp2p.ts index 58ecd87..8f0ac27 100644 --- a/src/@types/libp2p.ts +++ b/src/@types/libp2p.ts @@ -1,7 +1,7 @@ import { bytes, bytes32 } from "./basic"; import { Duplex } from "it-pair"; -export interface KeyPair { +export type KeyPair = { publicKey: bytes32; privateKey: bytes32; } @@ -18,7 +18,7 @@ export type PeerId = { marshalPrivKey(): bytes; }; -export interface NoiseConnection { +export interface INoiseConnection { remoteEarlyData?(): bytes; secureOutbound(localPeer: PeerId, insecure: any, remotePeer: PeerId): Promise; secureInbound(localPeer: PeerId, insecure: any, remotePeer: PeerId): Promise; diff --git a/src/crypto.ts b/src/crypto.ts index 34f8c06..83508ec 100644 --- a/src/crypto.ts +++ b/src/crypto.ts @@ -1,14 +1,14 @@ import { Buffer } from "buffer"; -import { Handshake } from "./handshake"; +import {IHandshake} from "./@types/handshake-interface"; -interface ReturnEncryptionWrapper { +interface IReturnEncryptionWrapper { (source: Iterable): AsyncIterableIterator; } const maxPlaintextLength = 65519; // Returns generator that encrypts payload from the user -export function encryptStream(handshake: Handshake): ReturnEncryptionWrapper { +export function encryptStream(handshake: IHandshake): IReturnEncryptionWrapper { return async function * (source) { for await (const chunk of source) { const chunkBuffer = Buffer.from(chunk.buffer, chunk.byteOffset, chunk.length); @@ -28,7 +28,7 @@ export function encryptStream(handshake: Handshake): ReturnEncryptionWrapper { // Decrypt received payload to the user -export function decryptStream(handshake: Handshake): ReturnEncryptionWrapper { +export function decryptStream(handshake: IHandshake): IReturnEncryptionWrapper { return async function * (source) { for await (const chunk of source) { const chunkBuffer = Buffer.from(chunk.buffer, chunk.byteOffset, chunk.length); diff --git a/src/encoder.ts b/src/encoder.ts index e4e2c30..4f869e6 100644 --- a/src/encoder.ts +++ b/src/encoder.ts @@ -14,14 +14,34 @@ export const uint16BEDecode = data => { }; uint16BEDecode.bytes = 2; -export function encodeMessageBuffer(message: MessageBuffer): bytes { +export function encode0(message: MessageBuffer): bytes { + return Buffer.concat([message.ne, message.ciphertext]); +} + +export function encode1(message: MessageBuffer): bytes { return Buffer.concat([message.ne, message.ns, message.ciphertext]); } -export function decodeMessageBuffer(message: bytes): MessageBuffer { +export function decode0(input: bytes): MessageBuffer { + if (input.length < 32) { + throw new Error("Cannot decode stage 0 MessageBuffer: length less than 32 bytes."); + } + return { - ne: message.slice(0, 32), - ns: message.slice(32, 64), - ciphertext: message.slice(64, message.length), + ne: input.slice(0, 32), + ciphertext: input.slice(32, input.length), + ns: Buffer.alloc(0), + } +} + +export function decode1(input: bytes): MessageBuffer { + if (input.length < 96) { + throw new Error("Cannot decode stage 0 MessageBuffer: length less than 96 bytes."); + } + + return { + ne: input.slice(0, 32), + ns: input.slice(32, 64), + ciphertext: input.slice(64, input.length), } } diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts new file mode 100644 index 0000000..f13e711 --- /dev/null +++ b/src/handshake-ik.ts @@ -0,0 +1,73 @@ +import {WrappedConnection} from "./noise"; +import {IK} from "./handshakes/ik"; +import {NoiseSession} from "./@types/handshake"; +import {bytes, bytes32} from "./@types/basic"; +import {KeyPair, PeerId} from "./@types/libp2p"; +import {IHandshake} from "./@types/handshake-interface"; +import {Buffer} from "buffer"; + +export class IKHandshake implements IHandshake { + public isInitiator: boolean; + public session: NoiseSession; + + private payload: bytes; + private prologue: bytes32; + private staticKeypair: KeyPair; + private connection: WrappedConnection; + private remotePeer: PeerId; + private ik: IK; + + constructor( + isInitiator: boolean, + payload: bytes, + prologue: bytes32, + staticKeypair: KeyPair, + connection: WrappedConnection, + remotePeer: PeerId, + handshake?: IK, + ) { + this.isInitiator = isInitiator; + this.payload = payload; + this.prologue = prologue; + this.staticKeypair = staticKeypair; + this.connection = connection; + this.remotePeer = remotePeer; + + this.ik = handshake || new IK(); + + // Dummy data + // TODO: Load remote static keys if found + const remoteStaticKeys = this.staticKeypair; + this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKeys.publicKey); + } + + public decrypt(ciphertext: Buffer, session: NoiseSession): Buffer { + const cs = this.getCS(session, false); + return this.ik.decryptWithAd(cs, Buffer.alloc(0), ciphertext); + } + + public encrypt(plaintext: Buffer, session: NoiseSession): Buffer { + const cs = this.getCS(session); + return this.ik.encryptWithAd(cs, Buffer.alloc(0), plaintext); + } + + public getRemoteEphemeralKeys(): KeyPair { + if (!this.session.hs.e) { + throw new Error("Ephemeral keys do not exist."); + } + + return this.session.hs.e; + } + + private getCS(session: NoiseSession, encryption = true) { + if (!session.cs1 || !session.cs2) { + throw new Error("Handshake not completed properly, cipher state does not exist."); + } + + if (this.isInitiator) { + return encryption ? session.cs1 : session.cs2; + } else { + return encryption ? session.cs2 : session.cs1; + } + } +} diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts new file mode 100644 index 0000000..c01539b --- /dev/null +++ b/src/handshake-xx-fallback.ts @@ -0,0 +1,75 @@ +import { Buffer } from "buffer"; + +import { XXHandshake } from "./handshake-xx"; +import { XX } from "./handshakes/xx"; +import { KeyPair, PeerId } from "./@types/libp2p"; +import { bytes, bytes32 } from "./@types/basic"; +import { + verifySignedPayload, +} from "./utils"; +import { logger } from "./logger"; +import { WrappedConnection } from "./noise"; +import {decode0, decode1, encode1} from "./encoder"; + +export class XXFallbackHandshake extends XXHandshake { + private ephemeralKeys?: KeyPair; + private initialMsg: bytes; + + constructor( + isInitiator: boolean, + payload: bytes, + prologue: bytes32, + staticKeypair: KeyPair, + connection: WrappedConnection, + remotePeer: PeerId, + initialMsg: bytes, + ephemeralKeys?: KeyPair, + handshake?: XX, + ) { + super(isInitiator, payload, prologue, staticKeypair, connection, remotePeer, handshake); + if (ephemeralKeys) { + this.ephemeralKeys = ephemeralKeys; + } + this.initialMsg = initialMsg; + } + + // stage 0 + public async propose(): Promise { + if (this.isInitiator) { + this.xx.sendMessage(this.session, Buffer.alloc(0), this.ephemeralKeys); + logger("XX Fallback Stage 0 - Initialized state as the first message was sent by initiator."); + } else { + logger("XX Fallback Stage 0 - Responder waiting to receive first message..."); + const receivedMessageBuffer = decode0(this.initialMsg); + this.xx.recvMessage(this.session, { + ne: receivedMessageBuffer.ne, + ns: Buffer.alloc(0), + ciphertext: Buffer.alloc(0), + }); + logger("XX Fallback Stage 0 - Responder received first message."); + } + } + + // stage 1 + public async exchange(): Promise { + if (this.isInitiator) { + logger('XX Fallback Stage 1 - Initiator waiting to receive first message from responder...'); + const receivedMessageBuffer = decode1(this.initialMsg); + const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); + logger('XX Fallback Stage 1 - Initiator received the message. Got remote\'s static key.'); + + logger("Initiator going to check remote's signature..."); + try { + await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + } catch (e) { + throw new Error(`Error occurred while verifying signed payload: ${e.message}`); + } + logger("All good with the signature!"); + } else { + logger('XX Fallback Stage 1 - Responder sending out first message with signed payload and static key.'); + const messageBuffer = this.xx.sendMessage(this.session, this.payload); + this.connection.writeLP(encode1(messageBuffer)); + logger('XX Fallback Stage 1 - Responder sent the second handshake message with signed payload.') + } + } +} diff --git a/src/handshake.ts b/src/handshake-xx.ts similarity index 77% rename from src/handshake.ts rename to src/handshake-xx.ts index 646b07c..7f6f1bf 100644 --- a/src/handshake.ts +++ b/src/handshake-xx.ts @@ -1,67 +1,69 @@ import { Buffer } from "buffer"; -import { XXHandshake } from "./handshakes/xx"; +import { XX } from "./handshakes/xx"; import { KeyPair, PeerId } from "./@types/libp2p"; import { bytes, bytes32 } from "./@types/basic"; import { NoiseSession } from "./@types/handshake"; +import {IHandshake} from "./@types/handshake-interface"; import { verifySignedPayload, } from "./utils"; import { logger } from "./logger"; -import { decodeMessageBuffer, encodeMessageBuffer } from "./encoder"; +import { decode0, decode1, encode0, encode1 } from "./encoder"; import { WrappedConnection } from "./noise"; -export class Handshake { +export class XXHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; - private payload: bytes; + protected payload: bytes; + protected connection: WrappedConnection; + protected xx: XX; + protected staticKeypair: KeyPair; + protected remotePeer: PeerId; + private prologue: bytes32; - private staticKeys: KeyPair; - private connection: WrappedConnection; - private remotePeer: PeerId; - private xx: XXHandshake; constructor( isInitiator: boolean, payload: bytes, prologue: bytes32, - staticKeys: KeyPair, + staticKeypair: KeyPair, connection: WrappedConnection, remotePeer: PeerId, - handshake?: XXHandshake, + handshake?: XX, ) { this.isInitiator = isInitiator; this.payload = payload; this.prologue = prologue; - this.staticKeys = staticKeys; + this.staticKeypair = staticKeypair; this.connection = connection; this.remotePeer = remotePeer; - this.xx = handshake || new XXHandshake(); - this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeys); + this.xx = handshake || new XX(); + this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair); } // stage 0 - async propose(): Promise { + public async propose(): Promise { if (this.isInitiator) { logger("Stage 0 - Initiator starting to send first message."); const messageBuffer = this.xx.sendMessage(this.session, Buffer.alloc(0)); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode0(messageBuffer)); logger("Stage 0 - Initiator finished sending first message."); } else { logger("Stage 0 - Responder waiting to receive first message..."); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode0((await this.connection.readLP()).slice()); this.xx.recvMessage(this.session, receivedMessageBuffer); logger("Stage 0 - Responder received first message."); } } // stage 1 - async exchange(): Promise { + public async exchange(): Promise { if (this.isInitiator) { logger('Stage 1 - Initiator waiting to receive first message from responder...'); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode1((await this.connection.readLP()).slice()); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); @@ -75,21 +77,21 @@ export class Handshake { } else { logger('Stage 1 - Responder sending out first message with signed payload and static key.'); const messageBuffer = this.xx.sendMessage(this.session, this.payload); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode1(messageBuffer)); logger('Stage 1 - Responder sent the second handshake message with signed payload.') } } // stage 2 - async finish(): Promise { + public async finish(): Promise { if (this.isInitiator) { logger('Stage 2 - Initiator sending third handshake message.'); const messageBuffer = this.xx.sendMessage(this.session, this.payload); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + this.connection.writeLP(encode1(messageBuffer)); logger('Stage 2 - Initiator sent message with signed payload.'); } else { logger('Stage 2 - Responder waiting for third handshake message...'); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const receivedMessageBuffer = decode1((await this.connection.readLP()).slice()); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); diff --git a/src/handshakes/ik.ts b/src/handshakes/ik.ts index f7c04ff..8d41e49 100644 --- a/src/handshakes/ik.ts +++ b/src/handshakes/ik.ts @@ -8,7 +8,7 @@ import {AbstractHandshake} from "./abstract-handshake"; import {KeyPair} from "../@types/libp2p"; -export class IKHandshake extends AbstractHandshake { +export class IK extends AbstractHandshake { public initSession(initiator: boolean, prologue: bytes32, s: KeyPair, rs: bytes32): NoiseSession { const psk = this.createEmptyKey(); diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts index 694d8b8..110dc7e 100644 --- a/src/handshakes/xx.ts +++ b/src/handshakes/xx.ts @@ -8,7 +8,7 @@ import { HandshakeState, MessageBuffer, NoiseSession } from "../@types/handshake import {AbstractHandshake} from "./abstract-handshake"; -export class XXHandshake extends AbstractHandshake { +export class XX extends AbstractHandshake { private initializeInitiator(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState { const name = "Noise_XX_25519_ChaChaPoly_SHA256"; const ss = this.initializeSymmetric(name); @@ -27,9 +27,14 @@ export class XXHandshake extends AbstractHandshake { return { ss, s, rs, psk, re }; } - private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { + private writeMessageA(hs: HandshakeState, payload: bytes, e?: KeyPair): MessageBuffer { const ns = Buffer.alloc(0); - hs.e = generateKeypair(); + + if (e) { + hs.e = e; + } else { + hs.e = generateKeypair(); + } const ne = hs.e.publicKey; @@ -128,10 +133,10 @@ export class XXHandshake extends AbstractHandshake { }; } - public sendMessage(session: NoiseSession, message: bytes): MessageBuffer { + public sendMessage(session: NoiseSession, message: bytes, ephemeral?: KeyPair): MessageBuffer { let messageBuffer: MessageBuffer; if (session.mc.eqn(0)) { - messageBuffer = this.writeMessageA(session.hs, message); + messageBuffer = this.writeMessageA(session.hs, message, ephemeral); } else if (session.mc.eqn(1)) { messageBuffer = this.writeMessageB(session.hs, message); } else if (session.mc.eqn(2)) { diff --git a/src/noise.ts b/src/noise.ts index 79dc158..676beb3 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -6,22 +6,30 @@ import ensureBuffer from 'it-buffer'; import pipe from 'it-pipe'; import lp from 'it-length-prefixed'; -import { Handshake } from "./handshake"; -import { - generateKeypair, - getPayload, -} from "./utils"; +import { XXHandshake } from "./handshake-xx"; +import { IKHandshake } from "./handshake-ik"; +import { XXFallbackHandshake } from "./handshake-xx-fallback"; +import { generateKeypair, getPayload } from "./utils"; import { uint16BEDecode, uint16BEEncode } from "./encoder"; import { decryptStream, encryptStream } from "./crypto"; import { bytes } from "./@types/basic"; -import { NoiseConnection, PeerId, KeyPair, SecureOutbound } from "./@types/libp2p"; +import { INoiseConnection, PeerId, KeyPair, SecureOutbound } from "./@types/libp2p"; import { Duplex } from "./@types/it-pair"; +import {IHandshake} from "./@types/handshake-interface"; export type WrappedConnection = ReturnType; -export class Noise implements NoiseConnection { +type HandshakeParams = { + connection: WrappedConnection; + isInitiator: boolean; + localPeer: PeerId; + remotePeer: PeerId; +}; + +export class Noise implements INoiseConnection { public protocol = "/noise"; + private readonly prologue = Buffer.from(this.protocol); private readonly staticKeys: KeyPair; private readonly earlyData?: bytes; @@ -48,7 +56,12 @@ export class Noise implements NoiseConnection { */ public async secureOutbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { const wrappedConnection = Wrap(connection); - const handshake = await this.performHandshake(wrappedConnection, true, localPeer, remotePeer); + const handshake = await this.performHandshake({ + connection: wrappedConnection, + isInitiator: true, + localPeer, + remotePeer, + }); const conn = await this.createSecureConnection(wrappedConnection, handshake); return { @@ -66,7 +79,12 @@ export class Noise implements NoiseConnection { */ public async secureInbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { const wrappedConnection = Wrap(connection); - const handshake = await this.performHandshake(wrappedConnection, false, localPeer, remotePeer); + const handshake = await this.performHandshake({ + connection: wrappedConnection, + isInitiator: false, + localPeer, + remotePeer + }); const conn = await this.createSecureConnection(wrappedConnection, handshake); return { @@ -75,30 +93,87 @@ export class Noise implements NoiseConnection { }; } - private async performHandshake( - connection: WrappedConnection, - isInitiator: boolean, - localPeer: PeerId, - remotePeer: PeerId, - ): Promise { - const prologue = Buffer.from(this.protocol); - const payload = await getPayload(localPeer, this.staticKeys.publicKey, this.earlyData); - const handshake = new Handshake(isInitiator, payload, prologue, this.staticKeys, connection, remotePeer); + /** + * If Noise pipes supported, tries IK handshake first with XX as fallback if it fails. + * If remote peer static key is unknown, use XX. + * @param connection + * @param isInitiator + * @param libp2pPublicKey + * @param remotePeer + */ + private async performHandshake(params: HandshakeParams): Promise { + // TODO: Implement noise pipes + const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); + + if (false) { + let IKhandshake; + try { + IKhandshake = await this.performIKHandshake(params, payload); + return IKhandshake; + } catch (e) { + // XX fallback + const ephemeralKeys = IKhandshake.getRemoteEphemeralKeys(); + return await this.performXXFallbackHandshake(params, payload, ephemeralKeys, e.initialMsg); + } + } else { + return await this.performXXHandshake(params, payload); + } + } + + private async performXXFallbackHandshake( + params: HandshakeParams, + payload: bytes, + ephemeralKeys: KeyPair, + initialMsg: bytes, + ): Promise { + const { isInitiator, remotePeer, connection } = params; + const handshake = + new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); try { await handshake.propose(); await handshake.exchange(); await handshake.finish(); } catch (e) { - throw new Error(`Error occurred during handshake: ${e.message}`); + throw new Error(`Error occurred during XX Fallback handshake: ${e.message}`); } return handshake; } + private async performXXHandshake( + params: HandshakeParams, + payload: bytes, + ): Promise { + const { isInitiator, remotePeer, connection } = params; + const handshake = new XXHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer); + + try { + await handshake.propose(); + await handshake.exchange(); + await handshake.finish(); + } catch (e) { + throw new Error(`Error occurred during XX handshake: ${e.message}`); + } + + return handshake; + } + + private async performIKHandshake( + params: HandshakeParams, + payload: bytes, + ): Promise { + const { isInitiator, remotePeer, connection } = params; + const handshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer); + + // TODO + + return handshake; + } + private async createSecureConnection( connection: WrappedConnection, - handshake: Handshake, + handshake: IHandshake, ): Promise { // Create encryption box/unbox wrapper const [secure, user] = DuplexPair(); @@ -119,4 +194,5 @@ export class Noise implements NoiseConnection { return user; } + } diff --git a/test/handshakes/ik.test.ts b/test/handshakes/ik.test.ts index 69a674f..1ce50e2 100644 --- a/test/handshakes/ik.test.ts +++ b/test/handshakes/ik.test.ts @@ -1,5 +1,5 @@ import {Buffer} from "buffer"; -import {IKHandshake} from "../../src/handshakes/ik"; +import {IK} from "../../src/handshakes/ik"; import {KeyPair} from "../../src/@types/libp2p"; import {createHandshakePayload, generateKeypair, getHandshakePayload} from "../../src/utils"; import {assert, expect} from "chai"; @@ -10,8 +10,8 @@ describe("Index", () => { it("Test complete IK handshake", async () => { try { - const ikI = new IKHandshake(); - const ikR = new IKHandshake(); + const ikI = new IK(); + const ikR = new IK(); // Generate static noise keys const kpInitiator: KeyPair = await generateKeypair(); diff --git a/test/handshakes/xx.test.ts b/test/handshakes/xx.test.ts index ddd9137..656d21e 100644 --- a/test/handshakes/xx.test.ts +++ b/test/handshakes/xx.test.ts @@ -1,7 +1,7 @@ import { expect, assert } from "chai"; import { Buffer } from 'buffer'; -import { XXHandshake } from "../../src/handshakes/xx"; +import { XX } from "../../src/handshakes/xx"; import { KeyPair } from "../../src/@types/libp2p"; import { generateEd25519Keys } from "../utils"; import {createHandshakePayload, generateKeypair, getHandshakePayload, getHkdf} from "../../src/utils"; @@ -11,7 +11,7 @@ describe("Index", () => { it("Test creating new XX session", async () => { try { - const xx = new XXHandshake(); + const xx = new XX(); const kpInitiator: KeyPair = await generateKeypair(); const kpResponder: KeyPair = await generateKeypair(); @@ -23,7 +23,7 @@ describe("Index", () => { }); it("Test get HKDF", async () => { - const xx = new XXHandshake(); + const xx = new XX(); const ckBytes = Buffer.from('4e6f6973655f58585f32353531395f58436861436861506f6c795f53484132353600000000000000000000000000000000000000000000000000000000000000', 'hex'); const ikm = Buffer.from('a3eae50ea37a47e8a7aa0c7cd8e16528670536dcd538cebfd724fb68ce44f1910ad898860666227d4e8dd50d22a9a64d1c0a6f47ace092510161e9e442953da3', 'hex'); const ck = Buffer.alloc(32); @@ -104,7 +104,7 @@ describe("Index", () => { it("Test handshake", async () => { try { - const xx = new XXHandshake(); + const xx = new XX(); await doHandshake(xx); } catch (e) { assert(false, e.message); @@ -113,7 +113,7 @@ describe("Index", () => { it("Test symmetric encrypt and decrypt", async () => { try { - const xx = new XXHandshake(); + const xx = new XX(); const { nsInit, nsResp } = await doHandshake(xx); const ad = Buffer.from("authenticated"); const message = Buffer.from("HelloCrypto"); @@ -130,7 +130,7 @@ describe("Index", () => { it("Test multiple messages encryption and decryption", async () => { try { - const xx = new XXHandshake(); + const xx = new XX(); const { nsInit, nsResp } = await doHandshake(xx); const ad = Buffer.from("authenticated"); const message = Buffer.from("ethereum1"); diff --git a/test/noise.test.ts b/test/noise.test.ts index f10b0b1..dcf790f 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -5,15 +5,15 @@ import { Noise } from "../src"; import {createPeerIdsFromFixtures} from "./fixtures/peer"; import Wrap from "it-pb-rpc"; import { random } from "bcrypto"; -import {Handshake} from "../src/handshake"; +import {XXHandshake} from "../src/handshake-xx"; import { createHandshakePayload, generateKeypair, getHandshakePayload, getPayload, signPayload } from "../src/utils"; -import { decodeMessageBuffer, encodeMessageBuffer } from "../src/encoder"; -import {XXHandshake} from "../src/handshakes/xx"; +import {decode0, decode1, encode1} from "../src/encoder"; +import {XX} from "../src/handshakes/xx"; import {Buffer} from "buffer"; import {getKeyPairFromPeerId} from "./utils"; @@ -55,26 +55,26 @@ describe("Noise", () => { const wrapped = Wrap(inboundConnection); const prologue = Buffer.from('/noise'); const staticKeys = generateKeypair(); - const xx = new XXHandshake(); + const xx = new XX(); const payload = await getPayload(remotePeer, staticKeys.publicKey); - const handshake = new Handshake(false, payload, prologue, staticKeys, wrapped, localPeer, xx); + const handshake = new XXHandshake(false, payload, prologue, staticKeys, wrapped, localPeer, xx); - let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); + let receivedMessageBuffer = decode0((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key expect(receivedMessageBuffer.ne.length).equal(32); xx.recvMessage(handshake.session, receivedMessageBuffer); // Stage 1 - const { privateKey: libp2pPrivKey, publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); + const { publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); const signedPayload = await signPayload(remotePeer, getHandshakePayload(staticKeys.publicKey)); - const handshakePayload = await createHandshakePayload(libp2pPubKey, libp2pPrivKey, signedPayload); + const handshakePayload = await createHandshakePayload(libp2pPubKey, signedPayload); const messageBuffer = xx.sendMessage(handshake.session, handshakePayload); - wrapped.writeLP(encodeMessageBuffer(messageBuffer)); + wrapped.writeLP(encode1(messageBuffer)); // Stage 2 - finish handshake - receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); + receivedMessageBuffer = decode1((await wrapped.readLP()).slice()); xx.recvMessage(handshake.session, receivedMessageBuffer); return {wrapped, handshake}; })(), diff --git a/test/xx-fallback-handshake.test.ts b/test/xx-fallback-handshake.test.ts new file mode 100644 index 0000000..18bfeba --- /dev/null +++ b/test/xx-fallback-handshake.test.ts @@ -0,0 +1,74 @@ +import Wrap from "it-pb-rpc"; +import {Buffer} from "buffer"; +import Duplex from 'it-pair/duplex'; + +import { + generateKeypair, + getPayload, +} from "../src/utils"; +import {XXFallbackHandshake} from "../src/handshake-xx-fallback"; +import {createPeerIdsFromFixtures} from "./fixtures/peer"; +import {assert} from "chai"; +import {decode1, encode0, encode1} from "../src/encoder"; + +describe("XX Fallback Handshake", () => { + let peerA, peerB, fakePeer; + + before(async () => { + [peerA, peerB] = await createPeerIdsFromFixtures(2); + }); + + it("should test that both parties can fallback to XX and finish handshake", async () => { + try { + const duplex = Duplex(); + const connectionFrom = Wrap(duplex[0]); + const connectionTo = Wrap(duplex[1]); + + const prologue = Buffer.from('/noise'); + const staticKeysInitiator = generateKeypair(); + const staticKeysResponder = generateKeypair(); + const ephemeralKeys = generateKeypair(); + + // Initial msg for responder is IK first message from initiator + const handshakePayload = await getPayload(peerA, staticKeysInitiator.publicKey); + const initialMsgR = encode0({ + ne: ephemeralKeys.publicKey, + ns: Buffer.alloc(0), + ciphertext: handshakePayload, + }); + + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); + const handshakeResp = + new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, initialMsgR); + + await handshakeResp.propose(); + await handshakeResp.exchange(); + + // Initial message for initiator is XX Message B from responder + // This is the point where initiator falls back from IK + const initialMsgI = await connectionFrom.readLP(); + const handshakeInit = + new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, peerB, initialMsgI, ephemeralKeys); + + await handshakeInit.propose(); + await handshakeInit.exchange(); + + await handshakeInit.finish(); + await handshakeResp.finish(); + + const sessionInitator = handshakeInit.session; + const sessionResponder = handshakeResp.session; + + // Test shared key + if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { + assert(sessionInitator.cs1.k.equals(sessionResponder.cs1.k)); + assert(sessionInitator.cs2.k.equals(sessionResponder.cs2.k)); + } else { + assert(false); + } + } catch (e) { + console.error(e); + assert(false, e.message); + } + }); +}) diff --git a/test/handshake.test.ts b/test/xx-handshake.test.ts similarity index 82% rename from test/handshake.test.ts rename to test/xx-handshake.test.ts index 1c93e84..08851a4 100644 --- a/test/handshake.test.ts +++ b/test/xx-handshake.test.ts @@ -3,13 +3,12 @@ import Duplex from 'it-pair/duplex'; import {Buffer} from "buffer"; import Wrap from "it-pb-rpc"; -import {Handshake} from "../src/handshake"; +import {XXHandshake} from "../src/handshake-xx"; import {generateKeypair, getPayload} from "../src/utils"; import {createPeerIdsFromFixtures} from "./fixtures/peer"; -import {getKeyPairFromPeerId} from "./utils"; -describe("Handshake", () => { +describe("XX Handshake", () => { let peerA, peerB, fakePeer; before(async () => { @@ -27,10 +26,10 @@ describe("Handshake", () => { const staticKeysResponder = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInitator = new Handshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB); + const handshakeInitator = new XXHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResponder = new Handshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA); + const handshakeResponder = new XXHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA); await handshakeInitator.propose(); await handshakeResponder.propose(); @@ -72,10 +71,10 @@ describe("Handshake", () => { const staticKeysResponder = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInitator = new Handshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, fakePeer); + const handshakeInitator = new XXHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, fakePeer); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResponder = new Handshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA); + const handshakeResponder = new XXHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA); await handshakeInitator.propose(); await handshakeResponder.propose(); @@ -100,10 +99,10 @@ describe("Handshake", () => { const staticKeysResponder = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInitator = new Handshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB); + const handshakeInitator = new XXHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResponder = new Handshake(false, respPayload, prologue, staticKeysResponder, connectionTo, fakePeer); + const handshakeResponder = new XXHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, fakePeer); await handshakeInitator.propose(); await handshakeResponder.propose();