diff --git a/src/@types/handshake-interface.ts b/src/@types/handshake-interface.ts index d92b9ff..6098f23 100644 --- a/src/@types/handshake-interface.ts +++ b/src/@types/handshake-interface.ts @@ -1,8 +1,10 @@ import {bytes} from "./basic"; import {NoiseSession} from "./handshake"; +import PeerId from "peer-id"; export interface IHandshake { session: NoiseSession; + remotePeer: PeerId; encrypt(plaintext: bytes, session: NoiseSession): bytes; decrypt(ciphertext: bytes, session: NoiseSession): bytes; } diff --git a/src/@types/handshake.ts b/src/@types/handshake.ts index cb67eeb..7e823b3 100644 --- a/src/@types/handshake.ts +++ b/src/@types/handshake.ts @@ -37,3 +37,9 @@ export type NoiseSession = { mc: uint64; i: boolean; } + +export interface INoisePayload { + identityKey: bytes; + identitySig: bytes; + data: bytes; +} diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index e2bfb35..ea29a4c 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -6,7 +6,7 @@ import {KeyPair} from "./@types/libp2p"; import {IHandshake} from "./@types/handshake-interface"; import {Buffer} from "buffer"; import {decode0, decode1, encode0, encode1} from "./encoder"; -import {verifySignedPayload} from "./utils"; +import {decodePayload, getPeerIdFromPayload, verifySignedPayload} from "./utils"; import {FailedIKError} from "./errors"; import {logger} from "./logger"; import PeerId from "peer-id"; @@ -14,12 +14,12 @@ import PeerId from "peer-id"; export class IKHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; + public remotePeer!: PeerId; private payload: bytes; private prologue: bytes32; private staticKeypair: KeyPair; private connection: WrappedConnection; - private remotePeer: PeerId; private ik: IK; constructor( @@ -28,8 +28,8 @@ export class IKHandshake implements IHandshake { prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, - remotePeer: PeerId, remoteStaticKey: bytes, + remotePeer?: PeerId, handshake?: IK, ) { this.isInitiator = isInitiator; @@ -37,8 +37,9 @@ export class IKHandshake implements IHandshake { this.prologue = prologue; this.staticKeypair = staticKeypair; this.connection = connection; - this.remotePeer = remotePeer; - + if(remotePeer) { + this.remotePeer = remotePeer; + } this.ik = handshake || new IK(); this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKey); } @@ -55,12 +56,14 @@ export class IKHandshake implements IHandshake { try { const receivedMessageBuffer = decode1(receivedMsg); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); - logger("IK Stage 0 - Responder got message, going to verify payload."); - await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + const decodedPayload = await decodePayload(plaintext); + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); + await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); logger("IK Stage 0 - Responder successfully verified payload!"); } catch (e) { logger("Responder breaking up with IK handshake in stage 0."); + throw new FailedIKError(receivedMsg, `Error occurred while verifying initiator's signed payload: ${e.message}`); } } @@ -75,7 +78,9 @@ export class IKHandshake implements IHandshake { logger("IK Stage 1 - Initiator got message, going to verify payload."); try { - await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + const decodedPayload = await decodePayload(plaintext); + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); + await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); logger("IK Stage 1 - Initiator successfully verified payload!"); } catch (e) { logger("Initiator breaking up with IK handshake in stage 1."); @@ -99,7 +104,7 @@ export class IKHandshake implements IHandshake { return this.ik.encryptWithAd(cs, Buffer.alloc(0), plaintext); } - public getRemoteEphemeralKeys(): KeyPair { + public getLocalEphemeralKeys(): KeyPair { if (!this.session.hs.e) { throw new Error("Ephemeral keys do not exist."); } diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index 301fc3e..d98d27c 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -3,7 +3,7 @@ import {XXHandshake} from "./handshake-xx"; import {XX} from "./handshakes/xx"; import {KeyPair} from "./@types/libp2p"; import {bytes, bytes32} from "./@types/basic"; -import {verifySignedPayload,} from "./utils"; +import {decodePayload, getPeerIdFromPayload, verifySignedPayload,} from "./utils"; import {logger} from "./logger"; import {WrappedConnection} from "./noise"; import {decode0, decode1} from "./encoder"; @@ -19,8 +19,8 @@ export class XXFallbackHandshake extends XXHandshake { prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, - remotePeer: PeerId, initialMsg: bytes, + remotePeer?: PeerId, ephemeralKeys?: KeyPair, handshake?: XX, ) { @@ -57,14 +57,16 @@ export class XXFallbackHandshake extends XXHandshake { logger("Initiator going to check remote's signature..."); try { - await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + const decodedPayload = await decodePayload(plaintext); + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); + await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); } catch (e) { throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); } logger("All good with the signature!"); } else { logger("XX Fallback Stage 1 - Responder start"); - super.exchange(); + await super.exchange(); logger("XX Fallback Stage 1 - Responder end"); } } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index e6d0aff..0085f13 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -6,6 +6,8 @@ import { bytes, bytes32 } from "./@types/basic"; import { NoiseSession } from "./@types/handshake"; import {IHandshake} from "./@types/handshake-interface"; import { + decodePayload, + getPeerIdFromPayload, verifySignedPayload, } from "./utils"; import { logger } from "./logger"; @@ -16,12 +18,12 @@ import PeerId from "peer-id"; export class XXHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; + public remotePeer!: PeerId; protected payload: bytes; protected connection: WrappedConnection; protected xx: XX; protected staticKeypair: KeyPair; - protected remotePeer: PeerId; private prologue: bytes32; @@ -31,7 +33,7 @@ export class XXHandshake implements IHandshake { prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, - remotePeer: PeerId, + remotePeer?: PeerId, handshake?: XX, ) { this.isInitiator = isInitiator; @@ -39,8 +41,9 @@ export class XXHandshake implements IHandshake { this.prologue = prologue; this.staticKeypair = staticKeypair; this.connection = connection; - this.remotePeer = remotePeer; - + if(remotePeer) { + this.remotePeer = remotePeer; + } this.xx = handshake || new XX(); this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair); } @@ -70,7 +73,9 @@ export class XXHandshake implements IHandshake { logger("Initiator going to check remote's signature..."); try { - await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + const decodedPayload = await decodePayload(plaintext); + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); + this.remotePeer = await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } @@ -97,7 +102,9 @@ export class XXHandshake implements IHandshake { logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); try { - await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); + const decodedPayload = await decodePayload(plaintext); + this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); + await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } diff --git a/src/keycache.ts b/src/keycache.ts index d095d80..56bb013 100644 --- a/src/keycache.ts +++ b/src/keycache.ts @@ -11,8 +11,11 @@ class Keycache { this.storage.set(peerId.id, key); } - public load(peerId: PeerId): bytes32|undefined { - return this.storage.get(peerId.id); + public load(peerId?: PeerId): bytes32 | null { + if(!peerId) { + return null; + } + return this.storage.get(peerId.id) || null; } public resetStorage(): void { diff --git a/src/noise.ts b/src/noise.ts index e089465..bf6c989 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -1,20 +1,20 @@ -import { x25519 } from 'bcrypto'; -import { Buffer } from "buffer"; +import {x25519} from 'bcrypto'; +import {Buffer} from "buffer"; import Wrap from 'it-pb-rpc'; import DuplexPair from 'it-pair/duplex'; import ensureBuffer from 'it-buffer'; import pipe from 'it-pipe'; import lp from 'it-length-prefixed'; -import { XXHandshake } from "./handshake-xx"; -import { IKHandshake } from "./handshake-ik"; -import { XXFallbackHandshake } from "./handshake-xx-fallback"; -import { generateKeypair, getPayload } from "./utils"; -import { uint16BEDecode, uint16BEEncode } from "./encoder"; -import { decryptStream, encryptStream } from "./crypto"; -import { bytes } from "./@types/basic"; -import { INoiseConnection, KeyPair, SecureOutbound } from "./@types/libp2p"; -import { Duplex } from "./@types/it-pair"; +import {XXHandshake} from "./handshake-xx"; +import {IKHandshake} from "./handshake-ik"; +import {XXFallbackHandshake} from "./handshake-xx-fallback"; +import {generateKeypair, getPayload} from "./utils"; +import {uint16BEDecode, uint16BEEncode} from "./encoder"; +import {decryptStream, encryptStream} from "./crypto"; +import {bytes} from "./@types/basic"; +import {INoiseConnection, KeyPair, SecureOutbound} from "./@types/libp2p"; +import {Duplex} from "./@types/it-pair"; import {IHandshake} from "./@types/handshake-interface"; import {KeyCache} from "./keycache"; import {logger} from "./logger"; @@ -26,7 +26,7 @@ type HandshakeParams = { connection: WrappedConnection; isInitiator: boolean; localPeer: PeerId; - remotePeer: PeerId; + remotePeer?: PeerId; }; export class Noise implements INoiseConnection { @@ -88,7 +88,7 @@ export class Noise implements INoiseConnection { * @param {PeerId} remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades. * @returns {Promise} */ - public async secureInbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { + public async secureInbound(localPeer: PeerId, connection: any, remotePeer?: PeerId): Promise { const wrappedConnection = Wrap(connection); const handshake = await this.performHandshake({ connection: wrappedConnection, @@ -100,7 +100,7 @@ export class Noise implements INoiseConnection { return { conn, - remotePeer, + remotePeer: handshake.remotePeer }; } @@ -111,13 +111,25 @@ export class Noise implements INoiseConnection { */ private async performHandshake(params: HandshakeParams): Promise { const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); - - const remoteStaticKey = KeyCache.load(params.remotePeer); + let tryIK = this.useNoisePipes; + if(params.isInitiator && KeyCache.load(params.remotePeer) === null) { + //if we are initiator and remote static key is unknown, don't try IK + tryIK = false; + } // Try IK if acting as responder or initiator that has remote's static key. - if (this.useNoisePipes && remoteStaticKey) { + if (tryIK) { // Try IK first const { remotePeer, connection, isInitiator } = params; - const ikHandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, remoteStaticKey); + const ikHandshake = new IKHandshake( + isInitiator, + payload, + this.prologue, + this.staticKeys, + connection, + //safe to cast as we did checks + KeyCache.load(params.remotePeer) || Buffer.alloc(32), + remotePeer as PeerId, + ); try { return await this.performIKHandshake(ikHandshake); @@ -125,12 +137,12 @@ export class Noise implements INoiseConnection { // IK failed, go to XX fallback let ephemeralKeys; if (params.isInitiator) { - ephemeralKeys = ikHandshake.getRemoteEphemeralKeys(); + ephemeralKeys = ikHandshake.getLocalEphemeralKeys(); } return await this.performXXFallbackHandshake(params, payload, e.initialMsg, ephemeralKeys); } } else { - // Noise pipes not supported, use XX + // run XX handshake return await this.performXXHandshake(params, payload); } } @@ -143,7 +155,7 @@ export class Noise implements INoiseConnection { ): Promise { const { isInitiator, remotePeer, connection } = params; const handshake = - new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); + new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, initialMsg, remotePeer, ephemeralKeys); try { await handshake.propose(); @@ -169,8 +181,8 @@ export class Noise implements INoiseConnection { await handshake.exchange(); await handshake.finish(); - if (this.useNoisePipes) { - KeyCache.store(remotePeer, handshake.getRemoteStaticKey()); + if (this.useNoisePipes && handshake.remotePeer) { + KeyCache.store(handshake.remotePeer, handshake.getRemoteStaticKey()); } } catch (e) { throw new Error(`Error occurred during XX handshake: ${e.message}`); diff --git a/src/utils.ts b/src/utils.ts index b225a14..d4e847d 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -5,10 +5,10 @@ import PeerId from "peer-id"; import * as crypto from 'libp2p-crypto'; import { KeyPair } from "./@types/libp2p"; import {bytes, bytes32} from "./@types/basic"; -import {Hkdf} from "./@types/handshake"; +import {Hkdf, INoisePayload} from "./@types/handshake"; import payloadProto from "./proto/payload.json"; -export async function loadPayloadProto () { +async function loadPayloadProto () { const payloadProtoBuf = await protobuf.Root.fromJSON(payloadProto); return payloadProtoBuf.lookupType("NoiseHandshakePayload"); } @@ -63,6 +63,17 @@ export async function signPayload(peerId: PeerId, payload: bytes): Promise { + return await PeerId.createFromPubKey(Buffer.from(payload.identityKey)); +} + +export async function decodePayload(payload: bytes): Promise { + const NoiseHandshakePayload = await loadPayloadProto(); + return NoiseHandshakePayload.toObject( + NoiseHandshakePayload.decode(payload) + ) as INoisePayload; +} + export const getHandshakePayload = (publicKey: bytes ) => Buffer.concat([Buffer.from("noise-libp2p-static-key:"), publicKey]); async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) { @@ -70,32 +81,40 @@ async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) { return generatedPeerId.id.equals(peerId); } -export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: bytes, peerId: bytes) { - let receivedPayload; +/** + * Verifies signed payload, throws on any irregularities. + * @param {bytes} noiseStaticKey - owner's noise static key + * @param {bytes} payload - decoded payload + * @param {PeerId} remotePeer - owner's libp2p peer ID + * @returns {Promise} - peer ID of payload owner + */ +export async function verifySignedPayload( + noiseStaticKey: bytes, + payload: INoisePayload, + remotePeer: PeerId +): Promise { try { - const NoiseHandshakePayload = await loadPayloadProto(); - receivedPayload = NoiseHandshakePayload.toObject( - NoiseHandshakePayload.decode(plaintext) - ); //temporary fix until protobufsjs conversion options starts working //by default it ends up as Uint8Array - receivedPayload.identityKey = Buffer.from(receivedPayload.identityKey); - receivedPayload.identitySig = Buffer.from(receivedPayload.identitySig); + payload.identityKey = Buffer.from(payload.identityKey); + payload.identitySig = Buffer.from(payload.identitySig); } catch (e) { throw new Error("Failed to decode received payload. Reason: " + e.message); } - if (!(await isValidPeerId(peerId, receivedPayload.identityKey)) ) { + if (!(await isValidPeerId(remotePeer.id, payload.identityKey)) ) { throw new Error("Peer ID doesn't match libp2p public key."); } const generatedPayload = getHandshakePayload(noiseStaticKey); // Unmarshaling from PublicKey protobuf - const publicKey = crypto.keys.unmarshalPublicKey(receivedPayload.identityKey); - if (!publicKey.verify(generatedPayload, receivedPayload.identitySig)) { + const publicKey = crypto.keys.unmarshalPublicKey(payload.identityKey); + if (!publicKey.verify(generatedPayload, payload.identitySig)) { throw new Error("Static key doesn't match to peer that signed payload!"); } + + return remotePeer; } export function getHkdf(ck: bytes32, ikm: bytes): Hkdf { diff --git a/test/ik-handshake.test.ts b/test/ik-handshake.test.ts index 32779b9..3060f6a 100644 --- a/test/ik-handshake.test.ts +++ b/test/ik-handshake.test.ts @@ -25,10 +25,10 @@ describe("IK Handshake", () => { const staticKeysResponder = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB, staticKeysResponder.publicKey); + const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, staticKeysResponder.publicKey, peerB); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, staticKeysInitiator.publicKey); + const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, staticKeysInitiator.publicKey); await handshakeInit.stage0(); await handshakeResp.stage0(); @@ -66,10 +66,10 @@ describe("IK Handshake", () => { const oldScammyKeys = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB, oldScammyKeys.publicKey); + const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, oldScammyKeys.publicKey, peerB); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, staticKeysInitiator.publicKey); + const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, staticKeysInitiator.publicKey); await handshakeInit.stage0(); await handshakeResp.stage0(); diff --git a/test/noise.test.ts b/test/noise.test.ts index 1fe19df..dcee669 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -263,15 +263,16 @@ describe("Noise", () => { } }); - it("IK -> XX Fallback: responder has no remote static key", async() => { + it("IK: responder has no remote static key", async() => { try { const staticKeysInitiator = generateKeypair(); const noiseInit = new Noise(staticKeysInitiator.privateKey); const staticKeysResponder = generateKeypair(); const noiseResp = new Noise(staticKeysResponder.privateKey); + const ikInitSpy = sandbox.spy(noiseInit, "performIKHandshake"); const xxFallbackInitSpy = sandbox.spy(noiseInit, "performXXFallbackHandshake"); - const xxRespSpy = sandbox.spy(noiseResp, "performXXHandshake"); + const ikRespSpy = sandbox.spy(noiseResp, "performIKHandshake"); // Prepare key cache for noise pipes KeyCache.resetStorage(); @@ -291,8 +292,40 @@ describe("Noise", () => { const response = await wrappedInbound.readLP(); expect(response.toString()).equal("test fallback"); - assert(xxFallbackInitSpy.calledOnce, "XX Fallback method was not called."); - assert(xxRespSpy.calledOnce, "XX method was not called."); + assert(ikInitSpy.calledOnce, "IK handshake was not called."); + assert(ikRespSpy.calledOnce, "IK handshake was not called."); + assert(xxFallbackInitSpy.notCalled, "XX Fallback method was called."); + } catch (e) { + console.error(e); + assert(false, e.message); + } + }); + + it("should working without remote peer provided in incoming connection", async() => { + try { + const staticKeysInitiator = generateKeypair(); + const noiseInit = new Noise(staticKeysInitiator.privateKey); + const staticKeysResponder = generateKeypair(); + const noiseResp = new Noise(staticKeysResponder.privateKey); + + // Prepare key cache for noise pipes + KeyCache.store(localPeer, staticKeysInitiator.publicKey); + KeyCache.store(remotePeer, staticKeysResponder.publicKey); + + const [inboundConnection, outboundConnection] = DuplexPair(); + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection), + ]); + const wrappedInbound = Wrap(inbound.conn); + const wrappedOutbound = Wrap(outbound.conn); + + wrappedOutbound.writeLP(Buffer.from("test v2")); + const response = await wrappedInbound.readLP(); + expect(response.toString()).equal("test v2"); + + assert(inbound.remotePeer.marshalPubKey().equals(localPeer.marshalPubKey())); + assert(outbound.remotePeer.marshalPubKey().equals(remotePeer.marshalPubKey())); } catch (e) { console.error(e); assert(false, e.message); diff --git a/test/xx-fallback-handshake.test.ts b/test/xx-fallback-handshake.test.ts index ced9bf1..9f3802f 100644 --- a/test/xx-fallback-handshake.test.ts +++ b/test/xx-fallback-handshake.test.ts @@ -39,7 +39,7 @@ describe("XX Fallback Handshake", () => { const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); const handshakeResp = - new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, initialMsgR); + new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, initialMsgR, peerA); await handshakeResp.propose(); await handshakeResp.exchange(); @@ -48,7 +48,7 @@ describe("XX Fallback Handshake", () => { // This is the point where initiator falls back from IK const initialMsgI = await connectionFrom.readLP(); const handshakeInit = - new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, peerB, initialMsgI, ephemeralKeys); + new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, initialMsgI, peerB, ephemeralKeys); await handshakeInit.propose(); await handshakeInit.exchange();