diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 9dd8788..fff6698 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -6,10 +6,6 @@ import { bytes, bytes32 } from "./@types/basic"; import { NoiseSession } from "./@types/handshake"; import {IHandshake} from "./@types/handshake-interface"; import { - createHandshakePayload, - getHandshakePayload, - signEarlyDataPayload, - signPayload, verifySignedPayload, } from "./utils"; import { logger } from "./logger"; @@ -28,11 +24,11 @@ export class XXHandshake implements IHandshake { protected remotePeer: PeerId; private prologue: bytes32; + private payload: bytes; constructor( isInitiator: boolean, - libp2pPrivateKey: bytes, - libp2pPublicKey: bytes, + payload: bytes, prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, @@ -40,8 +36,7 @@ export class XXHandshake implements IHandshake { handshake?: XX, ) { this.isInitiator = isInitiator; - this.libp2pPrivateKey = libp2pPrivateKey; - this.libp2pPublicKey = libp2pPublicKey; + this.payload = payload; this.prologue = prologue; this.staticKeypair = staticKeypair; this.connection = connection; @@ -83,34 +78,17 @@ export class XXHandshake implements IHandshake { logger("All good with the signature!"); } else { logger('Stage 1 - Responder sending out first message with signed payload and static key.'); - const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeypair.publicKey)); - const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, Buffer.alloc(0)); - const handshakePayload = await createHandshakePayload( - this.libp2pPublicKey, - this.libp2pPrivateKey, - signedPayload, - signedEarlyDataPayload, - ); - - const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, this.payload); this.connection.writeLP(encode1(messageBuffer)); logger('Stage 1 - Responder sent the second handshake message with signed payload.') } } // stage 2 - public async finish(earlyData?: bytes): Promise { + public async finish(): Promise { if (this.isInitiator) { logger('Stage 2 - Initiator sending third handshake message.'); - const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeypair.publicKey)); - const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, earlyData || Buffer.alloc(0)); - const handshakePayload = await createHandshakePayload( - this.libp2pPublicKey, - this.libp2pPrivateKey, - signedPayload, - signedEarlyDataPayload - ); - const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, this.payload); this.connection.writeLP(encode1(messageBuffer)); logger('Stage 2 - Initiator sent message with signed payload.'); } else { diff --git a/src/noise.ts b/src/noise.ts index 0b9a3ef..5c904f0 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -9,7 +9,7 @@ import lp from 'it-length-prefixed'; import { XXHandshake } from "./handshake-xx"; import { IKHandshake } from "./handshake-ik"; import { XXFallbackHandshake } from "./handshake-xx-fallback"; -import { generateKeypair } from "./utils"; +import { generateKeypair, getPayload } from "./utils"; import { uint16BEDecode, uint16BEEncode } from "./encoder"; import { decryptStream, encryptStream } from "./crypto"; import { bytes } from "./@types/basic"; @@ -34,8 +34,7 @@ export class Noise implements NoiseConnection { private readonly staticKeys: KeyPair; private readonly earlyData?: bytes; - constructor(privateKey: bytes, staticNoiseKey?: bytes, earlyData?: bytes) { - this.privateKey = privateKey; + constructor(staticNoiseKey?: bytes, earlyData?: bytes) { this.earlyData = earlyData || Buffer.alloc(0); if (staticNoiseKey) { @@ -58,11 +57,10 @@ export class Noise implements NoiseConnection { */ public async secureOutbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { const wrappedConnection = Wrap(connection); - const libp2pPublicKey = localPeer.marshalPubKey(); const handshake = await this.performHandshake({ connection: wrappedConnection, isInitiator: true, - libp2pPublicKey, + localPeer, remotePeer, }); const conn = await this.createSecureConnection(wrappedConnection, handshake); @@ -82,11 +80,10 @@ export class Noise implements NoiseConnection { */ public async secureInbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { const wrappedConnection = Wrap(connection); - const libp2pPublicKey = localPeer.marshalPubKey(); const handshake = await this.performHandshake({ connection: wrappedConnection, isInitiator: false, - libp2pPublicKey, + localPeer, remotePeer }); const conn = await this.createSecureConnection(wrappedConnection, handshake); @@ -107,35 +104,37 @@ export class Noise implements NoiseConnection { */ private async performHandshake(params: HandshakeParams): Promise { // TODO: Implement noise pipes + const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); if (false) { let IKhandshake; try { - IKhandshake = await this.performIKHandshake(params); + IKhandshake = await this.performIKHandshake(params, payload); return IKhandshake; } catch (e) { // XX fallback const ephemeralKeys = IKhandshake.getRemoteEphemeralKeys(); - return await this.performXXFallbackHandshake(params, ephemeralKeys, e.initialMsg); + return await this.performXXFallbackHandshake(params, payload, ephemeralKeys, e.initialMsg); } } else { - return await this.performXXHandshake(params); + return await this.performXXHandshake(params, payload); } } private async performXXFallbackHandshake( params: HandshakeParams, + payload: bytes, ephemeralKeys: KeyPair, initialMsg: bytes, ): Promise { const { isInitiator, libp2pPublicKey, remotePeer, connection } = params; const handshake = - new XXFallbackHandshake(isInitiator, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); + new XXFallbackHandshake(isInitiator, payload, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); try { await handshake.propose(); await handshake.exchange(); - await handshake.finish(this.earlyData); + await handshake.finish(); } catch (e) { throw new Error(`Error occurred during XX Fallback handshake: ${e.message}`); } @@ -145,14 +144,15 @@ export class Noise implements NoiseConnection { private async performXXHandshake( params: HandshakeParams, + payload: bytes, ): Promise { const { isInitiator, libp2pPublicKey, remotePeer, connection } = params; - const handshake = new XXHandshake(isInitiator, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); + const handshake = new XXHandshake(isInitiator, payload, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); try { await handshake.propose(); await handshake.exchange(); - await handshake.finish(this.earlyData); + await handshake.finish(); } catch (e) { throw new Error(`Error occurred during XX handshake: ${e.message}`); } @@ -162,9 +162,10 @@ export class Noise implements NoiseConnection { private async performIKHandshake( params: HandshakeParams, + payload: bytes, ): Promise { const { isInitiator, libp2pPublicKey, remotePeer, connection } = params; - const handshake = new IKHandshake(isInitiator, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); + const handshake = new IKHandshake(isInitiator, payload, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); // TODO diff --git a/src/utils.ts b/src/utils.ts index ccd6271..003159c 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,4 @@ -import { x25519, ed25519, HKDF, SHA256 } from 'bcrypto'; +import { x25519, HKDF, SHA256 } from 'bcrypto'; import protobuf from "protobufjs"; import { Buffer } from "buffer"; import PeerId from "peer-id"; @@ -23,9 +23,23 @@ export function generateKeypair(): KeyPair { } } +export async function getPayload( + localPeer: PeerId, + staticPublicKey: bytes, + earlyData?: bytes, +): Promise { + const signedPayload = await signPayload(localPeer, getHandshakePayload(staticPublicKey)); + const signedEarlyDataPayload = await signEarlyDataPayload(localPeer, earlyData || Buffer.alloc(0)); + + return await createHandshakePayload( + localPeer.marshalPubKey(), + signedPayload, + signedEarlyDataPayload + ); +} + export async function createHandshakePayload( libp2pPublicKey: bytes, - libp2pPrivateKey: bytes, signedPayload: bytes, signedEarlyData?: EarlyDataPayload, ): Promise { @@ -46,8 +60,8 @@ export async function createHandshakePayload( } -export function signPayload(libp2pPrivateKey: bytes, payload: bytes): bytes { - return ed25519.sign(payload, libp2pPrivateKey); +export async function signPayload(peerId: PeerId, payload: bytes): Promise { + return peerId.privKey.sign(payload); } type EarlyDataPayload = { @@ -55,9 +69,9 @@ type EarlyDataPayload = { libp2pDataSignature: bytes; } -export function signEarlyDataPayload(libp2pPrivateKey: bytes, earlyData: bytes): EarlyDataPayload { +export async function signEarlyDataPayload(peerId: PeerId, earlyData: bytes): Promise { const payload = getEarlyDataPayload(earlyData); - const signedPayload = signPayload(libp2pPrivateKey, payload); + const signedPayload = await signPayload(peerId, payload); return { libp2pData: payload, @@ -83,9 +97,10 @@ export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: byte } const generatedPayload = getHandshakePayload(noiseStaticKey); - // Unmarshaling from PublicKey protobuf and taking key buffer only. - const publicKey = crypto.keys.unmarshalPublicKey(receivedPayload.libp2pKey).marshal(); - if (!ed25519.verify(generatedPayload, receivedPayload.noiseStaticKeySignature, publicKey)) { + + // Unmarshaling from PublicKey protobuf + const publicKey = crypto.keys.unmarshalPublicKey(receivedPayload.libp2pKey); + if (!publicKey.verify(generatedPayload, receivedPayload.noiseStaticKeySignature)) { throw new Error("Static key doesn't match to peer that signed payload!"); } } diff --git a/test/handshakes/ik.test.ts b/test/handshakes/ik.test.ts index 4597e0f..1ce50e2 100644 --- a/test/handshakes/ik.test.ts +++ b/test/handshakes/ik.test.ts @@ -31,7 +31,7 @@ describe("Index", () => { const initSignedPayload = await libp2pInitKeys.sign(getHandshakePayload(kpInitiator.publicKey)); const libp2pInitPrivKey = libp2pInitKeys.marshal().slice(0, 32); const libp2pInitPubKey = libp2pInitKeys.marshal().slice(32, 64); - const payloadInitEnc = await createHandshakePayload(libp2pInitPubKey, libp2pInitPrivKey, initSignedPayload); + const payloadInitEnc = await createHandshakePayload(libp2pInitPubKey, initSignedPayload); // initiator sends message const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]); @@ -48,7 +48,7 @@ describe("Index", () => { const libp2pRespPrivKey = libp2pRespKeys.marshal().slice(0, 32); const libp2pRespPubKey = libp2pRespKeys.marshal().slice(32, 64); const respSignedPayload = await libp2pRespKeys.sign(getHandshakePayload(kpResponder.publicKey)); - const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, libp2pRespPrivKey, respSignedPayload); + const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, respSignedPayload); const message1 = Buffer.concat([message, payloadRespEnc]); const messageBuffer2 = ikR.sendMessage(responderSession, message1); diff --git a/test/handshakes/xx.test.ts b/test/handshakes/xx.test.ts index b0713b3..656d21e 100644 --- a/test/handshakes/xx.test.ts +++ b/test/handshakes/xx.test.ts @@ -58,7 +58,7 @@ describe("Index", () => { const libp2pInitPrivKey = libp2pInitKeys.marshal().slice(0, 32); const libp2pInitPubKey = libp2pInitKeys.marshal().slice(32, 64); - const payloadInitEnc = await createHandshakePayload(libp2pInitPubKey, libp2pInitPrivKey, initSignedPayload); + const payloadInitEnc = await createHandshakePayload(libp2pInitPubKey, initSignedPayload); // initiator sends message const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]); @@ -75,7 +75,7 @@ describe("Index", () => { // responder creates payload const libp2pRespPrivKey = libp2pRespKeys.marshal().slice(0, 32); const libp2pRespPubKey = libp2pRespKeys.marshal().slice(32, 64); - const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, libp2pRespPrivKey, respSignedPayload); + const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, respSignedPayload); const message1 = Buffer.concat([message, payloadRespEnc]); const messageBuffer2 = xx.sendMessage(nsResp, message1); diff --git a/test/index.test.ts b/test/index.test.ts index 808094f..f1287cf 100644 --- a/test/index.test.ts +++ b/test/index.test.ts @@ -3,7 +3,7 @@ import { Noise } from "../src"; describe("Index", () => { it("should expose class with tag and required functions", () => { - const noise = new Noise(Buffer.from("privatekey")); + const noise = new Noise(); expect(noise.protocol).to.equal('/noise'); expect(typeof(noise.secureInbound)).to.equal('function'); expect(typeof(noise.secureOutbound)).to.equal('function'); diff --git a/test/noise.test.ts b/test/noise.test.ts index 34617c1..7a94809 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -9,7 +9,7 @@ import {XXHandshake} from "../src/handshake-xx"; import { createHandshakePayload, generateKeypair, - getHandshakePayload, + getHandshakePayload, getPayload, signPayload } from "../src/utils"; import {decode0, decode1, encode1} from "../src/encoder"; @@ -26,10 +26,8 @@ describe("Noise", () => { it("should communicate through encrypted streams", async() => { try { - const { privateKey: libp2pInitPrivKey } = getKeyPairFromPeerId(localPeer); - const { privateKey: libp2pRespPrivKey } = getKeyPairFromPeerId(remotePeer); - const noiseInit = new Noise(libp2pInitPrivKey); - const noiseResp = new Noise(libp2pRespPrivKey); + const noiseInit = new Noise(); + const noiseResp = new Noise(); const [inboundConnection, outboundConnection] = DuplexPair(); const [outbound, inbound] = await Promise.all([ @@ -48,8 +46,7 @@ describe("Noise", () => { }); it("should test that secureOutbound is spec compliant", async() => { - const { privateKey: libp2pInitPrivKey } = getKeyPairFromPeerId(localPeer); - const noiseInit = new Noise(libp2pInitPrivKey); + const noiseInit = new Noise(); const [inboundConnection, outboundConnection] = DuplexPair(); const [outbound, { wrapped, handshake }] = await Promise.all([ @@ -59,9 +56,10 @@ describe("Noise", () => { const prologue = Buffer.from('/noise'); const staticKeys = generateKeypair(); const xx = new XX(); - const { privateKey: libp2pPrivKey, publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); const handshake = new XXHandshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, localPeer, xx); + const payload = await getPayload(remotePeer, staticKeys.publicKey); + const handshake = new Handshake(false, payload, prologue, staticKeys, wrapped, localPeer, xx); let receivedMessageBuffer = decode0((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key @@ -69,7 +67,8 @@ describe("Noise", () => { xx.recvMessage(handshake.session, receivedMessageBuffer); // Stage 1 - const signedPayload = signPayload(libp2pPrivKey, getHandshakePayload(staticKeys.publicKey)); + const { privateKey: libp2pPrivKey, publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); + const signedPayload = await signPayload(remotePeer, getHandshakePayload(staticKeys.publicKey)); const handshakePayload = await createHandshakePayload(libp2pPubKey, libp2pPrivKey, signedPayload); const messageBuffer = xx.sendMessage(handshake.session, handshakePayload); @@ -101,10 +100,8 @@ describe("Noise", () => { it("should test large payloads", async() => { try { - const { privateKey: libp2pInitPrivKey } = getKeyPairFromPeerId(localPeer); - const { privateKey: libp2pRespPrivKey } = getKeyPairFromPeerId(remotePeer); - const noiseInit = new Noise(libp2pInitPrivKey); - const noiseResp = new Noise(libp2pRespPrivKey); + const noiseInit = new Noise(); + const noiseResp = new Noise(); const [inboundConnection, outboundConnection] = DuplexPair(); const [outbound, inbound] = await Promise.all([ diff --git a/test/utils.ts b/test/utils.ts index 5e8b430..2ac5b7b 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -1,6 +1,5 @@ import * as crypto from 'libp2p-crypto'; import {KeyPair, PeerId} from "../src/@types/libp2p"; -import {bytes} from "../src/@types/basic"; export async function generateEd25519Keys() { return await crypto.keys.generateKeyPair('ed25519'); diff --git a/test/xx-handshake.test.ts b/test/xx-handshake.test.ts index 3b651f8..569c35b 100644 --- a/test/xx-handshake.test.ts +++ b/test/xx-handshake.test.ts @@ -26,11 +26,11 @@ describe("XX Handshake", () => { const staticKeysInitiator = generateKeypair(); const staticKeysResponder = generateKeypair(); - const { privateKey: initiatorPrivKey, publicKey: initiatorPubKey } = getKeyPairFromPeerId(peerA); - const handshakeInitator = new XXHandshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, peerB); + const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); + const handshakeInitator = new XXHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB); - const { privateKey: responderPrivKey, publicKey: responderPubKey } = getKeyPairFromPeerId(peerB); - const handshakeResponder = new XXHandshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, peerA); + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); + const handshakeResponder = new XXHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA); await handshakeInitator.propose(); await handshakeResponder.propose(); @@ -71,11 +71,11 @@ describe("XX Handshake", () => { const staticKeysInitiator = generateKeypair(); const staticKeysResponder = generateKeypair(); - const { privateKey: initiatorPrivKey, publicKey: initiatorPubKey } = getKeyPairFromPeerId(peerA); - const handshakeInitator = new XXHandshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, fakePeer); + const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); + const handshakeInitator = new XXHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, fakePeer); - const { privateKey: responderPrivKey, publicKey: responderPubKey } = getKeyPairFromPeerId(peerB); - const handshakeResponder = new XXHandshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, peerA); + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); + const handshakeResponder = new XXHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA); await handshakeInitator.propose(); await handshakeResponder.propose(); @@ -99,11 +99,11 @@ describe("XX Handshake", () => { const staticKeysInitiator = generateKeypair(); const staticKeysResponder = generateKeypair(); - const { privateKey: initiatorPrivKey, publicKey: initiatorPubKey } = getKeyPairFromPeerId(peerA); - const handshakeInitator = new XXHandshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, peerB); + const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); + const handshakeInitator = new XXHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB); - const { privateKey: responderPrivKey, publicKey: responderPubKey } = getKeyPairFromPeerId(peerB); - const handshakeResponder = new XXHandshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, fakePeer); + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); + const handshakeResponder = new XXHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, fakePeer); await handshakeInitator.propose(); await handshakeResponder.propose();