diff --git a/src/handshake.ts b/src/handshake.ts index f897029..6684d8a 100644 --- a/src/handshake.ts +++ b/src/handshake.ts @@ -45,9 +45,55 @@ export class Handshake { } // stage 0 - async propose(earlyData?: bytes): Promise { + async propose(): Promise { if (this.isInitiator) { logger("Stage 0 - Initiator starting to send first message."); + const messageBuffer = await this.xx.sendMessage(this.session, Buffer.alloc(0)); + this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + logger("Stage 0 - Initiator finished sending first message."); + } else { + logger("Stage 0 - Responder waiting to receive first message..."); + const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + await this.xx.recvMessage(this.session, receivedMessageBuffer); + logger("Stage 0 - Responder received first message."); + } + } + + // stage 1 + async exchange(): Promise { + if (this.isInitiator) { + logger('Stage 1 - Initiator waiting to receive first message from responder...'); + const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); + logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); + + // if (!libp2pRemotekey) { + // throw new Error("Missing remote's libp2p public key, can't verify peer ID."); + // } + logger("Initiator going to check remote's signature..."); + await verifySignedPayload(receivedMessageBuffer.ns, plaintext); + logger("All good with the signature!"); + } else { + logger('Stage 1 - Responder sending out first message with signed payload and static key.'); + const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeys.publicKey)); + const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, Buffer.alloc(0)); + const handshakePayload = await createHandshakePayload( + this.libp2pPublicKey, + this.libp2pPrivateKey, + signedPayload, + signedEarlyDataPayload, + ); + + const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); + this.connection.writeLP(encodeMessageBuffer(messageBuffer)); + logger('Stage 1 - Responder sent the second handshake message with signed payload.') + } + } + + // stage 2 + async finish(earlyData?: bytes): Promise { + if (this.isInitiator) { + logger('Stage 2 - Initiator sending third handshake message.'); const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeys.publicKey)); const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, earlyData || Buffer.alloc(0)); const handshakePayload = await createHandshakePayload( @@ -58,56 +104,18 @@ export class Handshake { ); const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); - - logger("Stage 0 - Initiator finished proposing, sent signed NoiseHandshake payload and static public key."); - } else { - logger("Stage 0 - Responder waiting to receive first message..."); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); - // TODO: Verify payload - logger("Stage 0 - Responder received first message."); - } - } - - // stage 1 - async exchange(libp2pRemotekey?: bytes): Promise { - if (this.isInitiator) { - logger('Stage 1 - Initiator waiting to receive first message from responder...'); - const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); - logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); - - if (!libp2pRemotekey) { - throw new Error("Missing remote's libp2p public key, can't verify signature."); - } - await verifySignedPayload(receivedMessageBuffer.ns, plaintext, libp2pRemotekey); - } else { - logger('Stage 1 - Responder sending out first message with signed payload and static key.'); - const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeys.publicKey)); - const handshakePayload = await createHandshakePayload( - this.libp2pPublicKey, - this.libp2pPrivateKey, - signedPayload, - ); - - const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); - logger('Stage 1 - Responder sent the second handshake message.') - } - } - - // stage 2 - async finish(): Promise { - if (this.isInitiator) { - logger('Stage 2 - Initiator sending third handshake message.'); - const messageBuffer = await this.xx.sendMessage(this.session, Buffer.alloc(0)); - this.connection.writeLP(encodeMessageBuffer(messageBuffer)); - logger('Stage 2 - Initiator sent message.'); + logger('Stage 2 - Initiator sent message with signed payload.'); } else { logger('Stage 2 - Responder waiting for third handshake message...'); - const receivedMessageBuffer = (await this.connection.readLP()).slice(); - const plaintext = await this.xx.recvMessage(this.session, decodeMessageBuffer(receivedMessageBuffer)); + const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); + const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); + + // if (!libp2pRemotekey) { + // throw new Error("Missing remote's libp2p public key, can't verify signature."); + // } + + await verifySignedPayload(receivedMessageBuffer.ns, plaintext); } } diff --git a/src/noise.ts b/src/noise.ts index 8221a41..ec82433 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -84,9 +84,13 @@ export class Noise implements NoiseConnection { const prologue = Buffer.from(this.protocol); const handshake = new Handshake(isInitiator, this.privateKey, libp2pPublicKey, prologue, this.staticKeys, connection); - await handshake.propose(this.earlyData); - await handshake.exchange(remotePeer.pubKey.marshal()); - await handshake.finish(); + try { + await handshake.propose(); + await handshake.exchange(); + await handshake.finish(this.earlyData); + } catch (e) { + throw new Error(`Error occurred during handshake: ${e.message}`); + } return handshake; } diff --git a/src/utils.ts b/src/utils.ts index 107819d..4936708 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -86,17 +86,17 @@ export function decodeMessageBuffer(message: bytes): MessageBuffer { export async function verifyPeerId(peerId: bytes, publicKey: bytes) { const generatedPeerId = await PeerId.createFromPubKey(publicKey); if (!generatedPeerId.equals(peerId)) { - Promise.reject("Peer ID doesn't match libp2p public key."); + throw new Error("Peer ID doesn't match libp2p public key."); } } -export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: bytes, libp2pPublicKey: bytes) { +export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: bytes) { const NoiseHandshakePayload = await loadPayloadProto(); const receivedPayload = NoiseHandshakePayload.toObject(NoiseHandshakePayload.decode(plaintext)); const generatedPayload = getHandshakePayload(noiseStaticKey); - if (!ed25519.verify(generatedPayload, receivedPayload.noiseStaticKeySignature, libp2pPublicKey)) { - Promise.reject("Static key doesn't match to peer that signed payload!"); + if (!ed25519.verify(generatedPayload, receivedPayload.noiseStaticKeySignature, receivedPayload.libp2pKey)) { + throw new Error("Static key doesn't match to peer that signed payload!"); } } diff --git a/test/handshake.test.ts b/test/handshake.test.ts index 927bc2b..792e0d2 100644 --- a/test/handshake.test.ts +++ b/test/handshake.test.ts @@ -10,46 +10,50 @@ import {createPeerIds} from "./fixtures/peer"; describe("Handshake", () => { it("should propose, exchange and finish handshake", async() => { - const duplex = Duplex(); - const connectionFrom = Wrap(duplex[0]); - const connectionTo = Wrap(duplex[1]); + try { + const duplex = Duplex(); + const connectionFrom = Wrap(duplex[0]); + const connectionTo = Wrap(duplex[1]); - const prologue = Buffer.from('/noise'); - const staticKeysInitiator = generateKeypair(); - const staticKeysResponder = generateKeypair(); - const [peerA, peerB] = await createPeerIds(2); + const prologue = Buffer.from('/noise'); + const staticKeysInitiator = generateKeypair(); + const staticKeysResponder = generateKeypair(); + const [peerA, peerB] = await createPeerIds(2); - const initiatorPrivKey = peerA.privKey.marshal().slice(0, 32); - const initiatorPubKey = peerA.pubKey.marshal(); - const handshakeInitator = new Handshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom); + const initiatorPrivKey = peerA.privKey.marshal().slice(0, 32); + const initiatorPubKey = peerA.pubKey.marshal(); + const handshakeInitator = new Handshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom); - const responderPrivKey = peerB.privKey.marshal().slice(0, 32); - const responderPubKey = peerB.pubKey.marshal(); - const handshakeResponder = new Handshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo); + const responderPrivKey = peerB.privKey.marshal().slice(0, 32); + const responderPubKey = peerB.pubKey.marshal(); + const handshakeResponder = new Handshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo); - await handshakeInitator.propose(); - await handshakeResponder.propose(); + await handshakeInitator.propose(); + await handshakeResponder.propose(); - await handshakeResponder.exchange(); - await handshakeInitator.exchange(peerB.pubKey.marshal()); + await handshakeResponder.exchange(); + await handshakeInitator.exchange(); - await handshakeInitator.finish(); - await handshakeResponder.finish(); + await handshakeInitator.finish(); + await handshakeResponder.finish(); - const sessionInitator = handshakeInitator.session; - const sessionResponder = handshakeResponder.session; + const sessionInitator = handshakeInitator.session; + const sessionResponder = handshakeResponder.session; - // Test shared key - if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { - assert(sessionInitator.cs1.k.equals(sessionResponder.cs1.k)); - assert(sessionInitator.cs2.k.equals(sessionResponder.cs2.k)); - } else { - assert(false); + // Test shared key + if (sessionInitator.cs1 && sessionResponder.cs1 && sessionInitator.cs2 && sessionResponder.cs2) { + assert(sessionInitator.cs1.k.equals(sessionResponder.cs1.k)); + assert(sessionInitator.cs2.k.equals(sessionResponder.cs2.k)); + } else { + assert(false); + } + + // Test encryption and decryption + const encrypted = handshakeInitator.encrypt(Buffer.from("encryptthis"), handshakeInitator.session); + const decrypted = handshakeResponder.decrypt(encrypted, handshakeResponder.session); + assert(decrypted.equals(Buffer.from("encryptthis"))); + } catch (e) { + assert(false, e.message); } - - // Test encryption and decryption - const encrypted = handshakeInitator.encrypt(Buffer.from("encryptthis"), handshakeInitator.session); - const decrypted = handshakeResponder.decrypt(encrypted, handshakeResponder.session); - assert(decrypted.equals(Buffer.from("encryptthis"))); }); }); diff --git a/test/noise.test.ts b/test/noise.test.ts index 22f5578..a1c3bfe 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -2,7 +2,6 @@ import { expect, assert } from "chai"; import DuplexPair from 'it-pair/duplex'; import { Noise } from "../src"; -import { generateEd25519Keys } from "./utils"; import {createPeerIds, createPeerIdsFromFixtures} from "./fixtures/peer"; import Wrap from "it-pb-rpc"; import {Handshake} from "../src/handshake"; @@ -11,7 +10,7 @@ import { decodeMessageBuffer, encodeMessageBuffer, generateKeypair, - getHandshakePayload, + getHandshakePayload, signEarlyDataPayload, signPayload } from "../src/utils"; import {XXHandshake} from "../src/xx"; @@ -25,28 +24,32 @@ describe("Noise", () => { }); it("should communicate through encrypted streams", async() => { - const libp2pInitPrivKey = localPeer.privKey.marshal().slice(0, 32); - const libp2pRespPrivKey = remotePeer.privKey.marshal().slice(0, 32); + try { + const libp2pInitPrivKey = localPeer.privKey.marshal().slice(0, 32); + const libp2pRespPrivKey = remotePeer.privKey.marshal().slice(0, 32); - const noiseInit = new Noise(libp2pInitPrivKey); - const noiseResp = new Noise(libp2pRespPrivKey); + const noiseInit = new Noise(libp2pInitPrivKey); + const noiseResp = new Noise(libp2pRespPrivKey); - const [inboundConnection, outboundConnection] = DuplexPair(); - const [outbound, inbound] = await Promise.all([ - noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), - noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), - ]); - const wrappedInbound = Wrap(inbound.conn); - const wrappedOutbound = Wrap(outbound.conn); + const [inboundConnection, outboundConnection] = DuplexPair(); + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), + ]); + const wrappedInbound = Wrap(inbound.conn); + const wrappedOutbound = Wrap(outbound.conn); - wrappedOutbound.writeLP(Buffer.from("test")); - const response = await wrappedInbound.readLP(); - expect(response.toString()).equal("test"); + wrappedOutbound.writeLP(Buffer.from("test")); + const response = await wrappedInbound.readLP(); + expect(response.toString()).equal("test"); + } catch (e) { + assert(false, e.message); + } }); it("should test that secureOutbound is spec compliant", async() => { - const libp2pPrivKey = localPeer.privKey.marshal().slice(0, 32); - const noiseInit = new Noise(libp2pPrivKey); + const libp2pInitPrivKey = localPeer.privKey.marshal().slice(0, 32); + const noiseInit = new Noise(libp2pInitPrivKey); const [inboundConnection, outboundConnection] = DuplexPair(); const [outbound, { wrapped, handshake }] = await Promise.all([ @@ -56,7 +59,8 @@ describe("Noise", () => { const prologue = Buffer.from('/noise'); const staticKeys = generateKeypair(); const xx = new XXHandshake(); - const libp2pPubKey = remotePeer.pubKey.marshal().slice(32, 64); + const libp2pPubKey = remotePeer.pubKey.marshal(); + const libp2pPrivKey = remotePeer.privKey.marshal().slice(0, 32); const handshake = new Handshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, xx); let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); @@ -74,19 +78,23 @@ describe("Noise", () => { // Stage 2 - finish handshake receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); await xx.recvMessage(handshake.session, receivedMessageBuffer); - return { wrapped, handshake }; + return {wrapped, handshake}; })(), ]); - const wrappedOutbound = Wrap(outbound.conn); - wrappedOutbound.write(Buffer.from("test")); + try { + const wrappedOutbound = Wrap(outbound.conn); + wrappedOutbound.write(Buffer.from("test")); - // Check that noise message is prefixed with 16-bit big-endian unsigned integer - const receivedEncryptedPayload = (await wrapped.read()).slice(); - const dataLength = receivedEncryptedPayload.readInt16BE(0); - const data = receivedEncryptedPayload.slice(2, dataLength + 2); - const decrypted = handshake.decrypt(data, handshake.session); - // Decrypted data should match - assert(decrypted.equals(Buffer.from("test"))); + // Check that noise message is prefixed with 16-bit big-endian unsigned integer + const receivedEncryptedPayload = (await wrapped.read()).slice(); + const dataLength = receivedEncryptedPayload.readInt16BE(0); + const data = receivedEncryptedPayload.slice(2, dataLength + 2); + const decrypted = handshake.decrypt(data, handshake.session); + // Decrypted data should match + assert(decrypted.equals(Buffer.from("test"))); + } catch (e) { + assert(false, e.message); + } }) });