diff --git a/src/@types/handshake-interface.ts b/src/@types/handshake-interface.ts index d92b9ff..6098f23 100644 --- a/src/@types/handshake-interface.ts +++ b/src/@types/handshake-interface.ts @@ -1,8 +1,10 @@ import {bytes} from "./basic"; import {NoiseSession} from "./handshake"; +import PeerId from "peer-id"; export interface IHandshake { session: NoiseSession; + remotePeer: PeerId; encrypt(plaintext: bytes, session: NoiseSession): bytes; decrypt(ciphertext: bytes, session: NoiseSession): bytes; } diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index e2bfb35..ca7ea9a 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -6,7 +6,7 @@ import {KeyPair} from "./@types/libp2p"; import {IHandshake} from "./@types/handshake-interface"; import {Buffer} from "buffer"; import {decode0, decode1, encode0, encode1} from "./encoder"; -import {verifySignedPayload} from "./utils"; +import {getPeerIdFromPayload, verifySignedPayload} from "./utils"; import {FailedIKError} from "./errors"; import {logger} from "./logger"; import PeerId from "peer-id"; @@ -14,12 +14,12 @@ import PeerId from "peer-id"; export class IKHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; + public remotePeer!: PeerId; private payload: bytes; private prologue: bytes32; private staticKeypair: KeyPair; private connection: WrappedConnection; - private remotePeer: PeerId; private ik: IK; constructor( @@ -28,8 +28,8 @@ export class IKHandshake implements IHandshake { prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, - remotePeer: PeerId, remoteStaticKey: bytes, + remotePeer?: PeerId, handshake?: IK, ) { this.isInitiator = isInitiator; @@ -37,8 +37,9 @@ export class IKHandshake implements IHandshake { this.prologue = prologue; this.staticKeypair = staticKeypair; this.connection = connection; - this.remotePeer = remotePeer; - + if(remotePeer) { + this.remotePeer = remotePeer; + } this.ik = handshake || new IK(); this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKey); } @@ -55,12 +56,13 @@ export class IKHandshake implements IHandshake { try { const receivedMessageBuffer = decode1(receivedMsg); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); - + this.remotePeer = await getPeerIdFromPayload(plaintext); logger("IK Stage 0 - Responder got message, going to verify payload."); await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); logger("IK Stage 0 - Responder successfully verified payload!"); } catch (e) { logger("Responder breaking up with IK handshake in stage 0."); + throw new FailedIKError(receivedMsg, `Error occurred while verifying initiator's signed payload: ${e.message}`); } } diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index 301fc3e..21f8272 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -3,7 +3,7 @@ import {XXHandshake} from "./handshake-xx"; import {XX} from "./handshakes/xx"; import {KeyPair} from "./@types/libp2p"; import {bytes, bytes32} from "./@types/basic"; -import {verifySignedPayload,} from "./utils"; +import {getPeerIdFromPayload, verifySignedPayload,} from "./utils"; import {logger} from "./logger"; import {WrappedConnection} from "./noise"; import {decode0, decode1} from "./encoder"; @@ -19,8 +19,8 @@ export class XXFallbackHandshake extends XXHandshake { prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, - remotePeer: PeerId, initialMsg: bytes, + remotePeer?: PeerId, ephemeralKeys?: KeyPair, handshake?: XX, ) { @@ -57,6 +57,7 @@ export class XXFallbackHandshake extends XXHandshake { logger("Initiator going to check remote's signature..."); try { + this.remotePeer = await getPeerIdFromPayload(plaintext); await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index e6d0aff..8dba492 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -6,6 +6,7 @@ import { bytes, bytes32 } from "./@types/basic"; import { NoiseSession } from "./@types/handshake"; import {IHandshake} from "./@types/handshake-interface"; import { + getPeerIdFromPayload, verifySignedPayload, } from "./utils"; import { logger } from "./logger"; @@ -16,12 +17,12 @@ import PeerId from "peer-id"; export class XXHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; + public remotePeer!: PeerId; protected payload: bytes; protected connection: WrappedConnection; protected xx: XX; protected staticKeypair: KeyPair; - protected remotePeer: PeerId; private prologue: bytes32; @@ -31,7 +32,7 @@ export class XXHandshake implements IHandshake { prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, - remotePeer: PeerId, + remotePeer?: PeerId, handshake?: XX, ) { this.isInitiator = isInitiator; @@ -39,8 +40,9 @@ export class XXHandshake implements IHandshake { this.prologue = prologue; this.staticKeypair = staticKeypair; this.connection = connection; - this.remotePeer = remotePeer; - + if(remotePeer) { + this.remotePeer = remotePeer; + } this.xx = handshake || new XX(); this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair); } @@ -70,6 +72,7 @@ export class XXHandshake implements IHandshake { logger("Initiator going to check remote's signature..."); try { + this.remotePeer = await getPeerIdFromPayload(plaintext); await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); @@ -94,6 +97,7 @@ export class XXHandshake implements IHandshake { logger('Stage 2 - Responder waiting for third handshake message...'); const receivedMessageBuffer = decode1(await this.connection.readLP()); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); + this.remotePeer = await getPeerIdFromPayload(plaintext); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); try { diff --git a/src/keycache.ts b/src/keycache.ts index d095d80..4d72e82 100644 --- a/src/keycache.ts +++ b/src/keycache.ts @@ -7,12 +7,16 @@ import PeerId from "peer-id"; class Keycache { private storage = new Map(); - public store(peerId: PeerId, key: bytes32): void { + public store(peerId?: PeerId, key: bytes32): void { + if(!peerId) return; this.storage.set(peerId.id, key); } - public load(peerId: PeerId): bytes32|undefined { - return this.storage.get(peerId.id); + public load(peerId?: PeerId): bytes32 | null { + if(!peerId) { + return null; + } + return this.storage.get(peerId.id) || null; } public resetStorage(): void { diff --git a/src/noise.ts b/src/noise.ts index e089465..04b29fe 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -1,20 +1,20 @@ -import { x25519 } from 'bcrypto'; -import { Buffer } from "buffer"; +import {x25519} from 'bcrypto'; +import {Buffer} from "buffer"; import Wrap from 'it-pb-rpc'; import DuplexPair from 'it-pair/duplex'; import ensureBuffer from 'it-buffer'; import pipe from 'it-pipe'; import lp from 'it-length-prefixed'; -import { XXHandshake } from "./handshake-xx"; -import { IKHandshake } from "./handshake-ik"; -import { XXFallbackHandshake } from "./handshake-xx-fallback"; -import { generateKeypair, getPayload } from "./utils"; -import { uint16BEDecode, uint16BEEncode } from "./encoder"; -import { decryptStream, encryptStream } from "./crypto"; -import { bytes } from "./@types/basic"; -import { INoiseConnection, KeyPair, SecureOutbound } from "./@types/libp2p"; -import { Duplex } from "./@types/it-pair"; +import {XXHandshake} from "./handshake-xx"; +import {IKHandshake} from "./handshake-ik"; +import {XXFallbackHandshake} from "./handshake-xx-fallback"; +import {generateKeypair, getPayload} from "./utils"; +import {uint16BEDecode, uint16BEEncode} from "./encoder"; +import {decryptStream, encryptStream} from "./crypto"; +import {bytes, bytes32} from "./@types/basic"; +import {INoiseConnection, KeyPair, SecureOutbound} from "./@types/libp2p"; +import {Duplex} from "./@types/it-pair"; import {IHandshake} from "./@types/handshake-interface"; import {KeyCache} from "./keycache"; import {logger} from "./logger"; @@ -26,7 +26,7 @@ type HandshakeParams = { connection: WrappedConnection; isInitiator: boolean; localPeer: PeerId; - remotePeer: PeerId; + remotePeer?: PeerId; }; export class Noise implements INoiseConnection { @@ -88,7 +88,7 @@ export class Noise implements INoiseConnection { * @param {PeerId} remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades. * @returns {Promise} */ - public async secureInbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise { + public async secureInbound(localPeer: PeerId, connection: any, remotePeer?: PeerId): Promise { const wrappedConnection = Wrap(connection); const handshake = await this.performHandshake({ connection: wrappedConnection, @@ -100,7 +100,7 @@ export class Noise implements INoiseConnection { return { conn, - remotePeer, + remotePeer: handshake.remotePeer }; } @@ -111,26 +111,38 @@ export class Noise implements INoiseConnection { */ private async performHandshake(params: HandshakeParams): Promise { const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); - - const remoteStaticKey = KeyCache.load(params.remotePeer); + let tryIK = this.useNoisePipes; + if(params.isInitiator && KeyCache.load(params.remotePeer) === null) { + //if we are initiator and remote static key is unknown, don't try IK + tryIK = false; + } // Try IK if acting as responder or initiator that has remote's static key. - if (this.useNoisePipes && remoteStaticKey) { + if (tryIK) { // Try IK first const { remotePeer, connection, isInitiator } = params; - const ikHandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, remoteStaticKey); + const ikHandshake = new IKHandshake( + isInitiator, + payload, + this.prologue, + this.staticKeys, + connection, + //safe to cast as we did checks + KeyCache.load(params.remotePeer) || Buffer.alloc(32), + remotePeer as PeerId, + ); try { return await this.performIKHandshake(ikHandshake); } catch (e) { // IK failed, go to XX fallback let ephemeralKeys; - if (params.isInitiator) { + if (!params.isInitiator) { ephemeralKeys = ikHandshake.getRemoteEphemeralKeys(); } return await this.performXXFallbackHandshake(params, payload, e.initialMsg, ephemeralKeys); } } else { - // Noise pipes not supported, use XX + // run XX handshake return await this.performXXHandshake(params, payload); } } @@ -143,7 +155,7 @@ export class Noise implements INoiseConnection { ): Promise { const { isInitiator, remotePeer, connection } = params; const handshake = - new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); + new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, initialMsg, remotePeer, ephemeralKeys); try { await handshake.propose(); diff --git a/src/utils.ts b/src/utils.ts index b225a14..8b27f15 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -63,6 +63,18 @@ export async function signPayload(peerId: PeerId, payload: bytes): Promise { + const decodedPayload = await decodePayload(payload); + return await PeerId.createFromPubKey(Buffer.from(decodedPayload.identityKey)); +} + +async function decodePayload(payload: bytes){ + const NoiseHandshakePayload = await loadPayloadProto(); + return NoiseHandshakePayload.toObject( + NoiseHandshakePayload.decode(payload) + ); +} + export const getHandshakePayload = (publicKey: bytes ) => Buffer.concat([Buffer.from("noise-libp2p-static-key:"), publicKey]); async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) { diff --git a/test/ik-handshake.test.ts b/test/ik-handshake.test.ts index 32779b9..3060f6a 100644 --- a/test/ik-handshake.test.ts +++ b/test/ik-handshake.test.ts @@ -25,10 +25,10 @@ describe("IK Handshake", () => { const staticKeysResponder = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB, staticKeysResponder.publicKey); + const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, staticKeysResponder.publicKey, peerB); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, staticKeysInitiator.publicKey); + const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, staticKeysInitiator.publicKey); await handshakeInit.stage0(); await handshakeResp.stage0(); @@ -66,10 +66,10 @@ describe("IK Handshake", () => { const oldScammyKeys = generateKeypair(); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); - const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB, oldScammyKeys.publicKey); + const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, oldScammyKeys.publicKey, peerB); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); - const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, staticKeysInitiator.publicKey); + const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, staticKeysInitiator.publicKey); await handshakeInit.stage0(); await handshakeResp.stage0(); diff --git a/test/noise.test.ts b/test/noise.test.ts index 1fe19df..38a23e4 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -275,7 +275,7 @@ describe("Noise", () => { // Prepare key cache for noise pipes KeyCache.resetStorage(); - KeyCache.store(remotePeer, staticKeysResponder.publicKey); + KeyCache.store(localPeer, staticKeysResponder.publicKey); const [inboundConnection, outboundConnection] = DuplexPair(); diff --git a/test/xx-fallback-handshake.test.ts b/test/xx-fallback-handshake.test.ts index ced9bf1..9f3802f 100644 --- a/test/xx-fallback-handshake.test.ts +++ b/test/xx-fallback-handshake.test.ts @@ -39,7 +39,7 @@ describe("XX Fallback Handshake", () => { const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); const handshakeResp = - new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, initialMsgR); + new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, initialMsgR, peerA); await handshakeResp.propose(); await handshakeResp.exchange(); @@ -48,7 +48,7 @@ describe("XX Fallback Handshake", () => { // This is the point where initiator falls back from IK const initialMsgI = await connectionFrom.readLP(); const handshakeInit = - new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, peerB, initialMsgI, ephemeralKeys); + new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, initialMsgI, peerB, ephemeralKeys); await handshakeInit.propose(); await handshakeInit.exchange();