better ik flow

This commit is contained in:
Marin Petrunić
2020-02-07 20:21:27 +01:00
parent 43c901c1b8
commit 082b0c560a
10 changed files with 80 additions and 43 deletions

View File

@ -1,8 +1,10 @@
import {bytes} from "./basic"; import {bytes} from "./basic";
import {NoiseSession} from "./handshake"; import {NoiseSession} from "./handshake";
import PeerId from "peer-id";
export interface IHandshake { export interface IHandshake {
session: NoiseSession; session: NoiseSession;
remotePeer: PeerId;
encrypt(plaintext: bytes, session: NoiseSession): bytes; encrypt(plaintext: bytes, session: NoiseSession): bytes;
decrypt(ciphertext: bytes, session: NoiseSession): bytes; decrypt(ciphertext: bytes, session: NoiseSession): bytes;
} }

View File

@ -6,7 +6,7 @@ import {KeyPair} from "./@types/libp2p";
import {IHandshake} from "./@types/handshake-interface"; import {IHandshake} from "./@types/handshake-interface";
import {Buffer} from "buffer"; import {Buffer} from "buffer";
import {decode0, decode1, encode0, encode1} from "./encoder"; import {decode0, decode1, encode0, encode1} from "./encoder";
import {verifySignedPayload} from "./utils"; import {getPeerIdFromPayload, verifySignedPayload} from "./utils";
import {FailedIKError} from "./errors"; import {FailedIKError} from "./errors";
import {logger} from "./logger"; import {logger} from "./logger";
import PeerId from "peer-id"; import PeerId from "peer-id";
@ -14,12 +14,12 @@ import PeerId from "peer-id";
export class IKHandshake implements IHandshake { export class IKHandshake implements IHandshake {
public isInitiator: boolean; public isInitiator: boolean;
public session: NoiseSession; public session: NoiseSession;
public remotePeer!: PeerId;
private payload: bytes; private payload: bytes;
private prologue: bytes32; private prologue: bytes32;
private staticKeypair: KeyPair; private staticKeypair: KeyPair;
private connection: WrappedConnection; private connection: WrappedConnection;
private remotePeer: PeerId;
private ik: IK; private ik: IK;
constructor( constructor(
@ -28,8 +28,8 @@ export class IKHandshake implements IHandshake {
prologue: bytes32, prologue: bytes32,
staticKeypair: KeyPair, staticKeypair: KeyPair,
connection: WrappedConnection, connection: WrappedConnection,
remotePeer: PeerId,
remoteStaticKey: bytes, remoteStaticKey: bytes,
remotePeer?: PeerId,
handshake?: IK, handshake?: IK,
) { ) {
this.isInitiator = isInitiator; this.isInitiator = isInitiator;
@ -37,8 +37,9 @@ export class IKHandshake implements IHandshake {
this.prologue = prologue; this.prologue = prologue;
this.staticKeypair = staticKeypair; this.staticKeypair = staticKeypair;
this.connection = connection; this.connection = connection;
this.remotePeer = remotePeer; if(remotePeer) {
this.remotePeer = remotePeer;
}
this.ik = handshake || new IK(); this.ik = handshake || new IK();
this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKey); this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKey);
} }
@ -55,12 +56,13 @@ export class IKHandshake implements IHandshake {
try { try {
const receivedMessageBuffer = decode1(receivedMsg); const receivedMessageBuffer = decode1(receivedMsg);
const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer);
this.remotePeer = await getPeerIdFromPayload(plaintext);
logger("IK Stage 0 - Responder got message, going to verify payload."); logger("IK Stage 0 - Responder got message, going to verify payload.");
await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id);
logger("IK Stage 0 - Responder successfully verified payload!"); logger("IK Stage 0 - Responder successfully verified payload!");
} catch (e) { } catch (e) {
logger("Responder breaking up with IK handshake in stage 0."); logger("Responder breaking up with IK handshake in stage 0.");
throw new FailedIKError(receivedMsg, `Error occurred while verifying initiator's signed payload: ${e.message}`); throw new FailedIKError(receivedMsg, `Error occurred while verifying initiator's signed payload: ${e.message}`);
} }
} }

View File

@ -3,7 +3,7 @@ import {XXHandshake} from "./handshake-xx";
import {XX} from "./handshakes/xx"; import {XX} from "./handshakes/xx";
import {KeyPair} from "./@types/libp2p"; import {KeyPair} from "./@types/libp2p";
import {bytes, bytes32} from "./@types/basic"; import {bytes, bytes32} from "./@types/basic";
import {verifySignedPayload,} from "./utils"; import {getPeerIdFromPayload, verifySignedPayload,} from "./utils";
import {logger} from "./logger"; import {logger} from "./logger";
import {WrappedConnection} from "./noise"; import {WrappedConnection} from "./noise";
import {decode0, decode1} from "./encoder"; import {decode0, decode1} from "./encoder";
@ -19,8 +19,8 @@ export class XXFallbackHandshake extends XXHandshake {
prologue: bytes32, prologue: bytes32,
staticKeypair: KeyPair, staticKeypair: KeyPair,
connection: WrappedConnection, connection: WrappedConnection,
remotePeer: PeerId,
initialMsg: bytes, initialMsg: bytes,
remotePeer?: PeerId,
ephemeralKeys?: KeyPair, ephemeralKeys?: KeyPair,
handshake?: XX, handshake?: XX,
) { ) {
@ -57,6 +57,7 @@ export class XXFallbackHandshake extends XXHandshake {
logger("Initiator going to check remote's signature..."); logger("Initiator going to check remote's signature...");
try { try {
this.remotePeer = await getPeerIdFromPayload(plaintext);
await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id);
} catch (e) { } catch (e) {
throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`);

View File

@ -6,6 +6,7 @@ import { bytes, bytes32 } from "./@types/basic";
import { NoiseSession } from "./@types/handshake"; import { NoiseSession } from "./@types/handshake";
import {IHandshake} from "./@types/handshake-interface"; import {IHandshake} from "./@types/handshake-interface";
import { import {
getPeerIdFromPayload,
verifySignedPayload, verifySignedPayload,
} from "./utils"; } from "./utils";
import { logger } from "./logger"; import { logger } from "./logger";
@ -16,12 +17,12 @@ import PeerId from "peer-id";
export class XXHandshake implements IHandshake { export class XXHandshake implements IHandshake {
public isInitiator: boolean; public isInitiator: boolean;
public session: NoiseSession; public session: NoiseSession;
public remotePeer!: PeerId;
protected payload: bytes; protected payload: bytes;
protected connection: WrappedConnection; protected connection: WrappedConnection;
protected xx: XX; protected xx: XX;
protected staticKeypair: KeyPair; protected staticKeypair: KeyPair;
protected remotePeer: PeerId;
private prologue: bytes32; private prologue: bytes32;
@ -31,7 +32,7 @@ export class XXHandshake implements IHandshake {
prologue: bytes32, prologue: bytes32,
staticKeypair: KeyPair, staticKeypair: KeyPair,
connection: WrappedConnection, connection: WrappedConnection,
remotePeer: PeerId, remotePeer?: PeerId,
handshake?: XX, handshake?: XX,
) { ) {
this.isInitiator = isInitiator; this.isInitiator = isInitiator;
@ -39,8 +40,9 @@ export class XXHandshake implements IHandshake {
this.prologue = prologue; this.prologue = prologue;
this.staticKeypair = staticKeypair; this.staticKeypair = staticKeypair;
this.connection = connection; this.connection = connection;
this.remotePeer = remotePeer; if(remotePeer) {
this.remotePeer = remotePeer;
}
this.xx = handshake || new XX(); this.xx = handshake || new XX();
this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair); this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair);
} }
@ -70,6 +72,7 @@ export class XXHandshake implements IHandshake {
logger("Initiator going to check remote's signature..."); logger("Initiator going to check remote's signature...");
try { try {
this.remotePeer = await getPeerIdFromPayload(plaintext);
await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id);
} catch (e) { } catch (e) {
throw new Error(`Error occurred while verifying signed payload: ${e.message}`); throw new Error(`Error occurred while verifying signed payload: ${e.message}`);
@ -94,6 +97,7 @@ export class XXHandshake implements IHandshake {
logger('Stage 2 - Responder waiting for third handshake message...'); logger('Stage 2 - Responder waiting for third handshake message...');
const receivedMessageBuffer = decode1(await this.connection.readLP()); const receivedMessageBuffer = decode1(await this.connection.readLP());
const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer);
this.remotePeer = await getPeerIdFromPayload(plaintext);
logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.');
try { try {

View File

@ -7,12 +7,16 @@ import PeerId from "peer-id";
class Keycache { class Keycache {
private storage = new Map<bytes, bytes32>(); private storage = new Map<bytes, bytes32>();
public store(peerId: PeerId, key: bytes32): void { public store(peerId?: PeerId, key: bytes32): void {
if(!peerId) return;
this.storage.set(peerId.id, key); this.storage.set(peerId.id, key);
} }
public load(peerId: PeerId): bytes32|undefined { public load(peerId?: PeerId): bytes32 | null {
return this.storage.get(peerId.id); if(!peerId) {
return null;
}
return this.storage.get(peerId.id) || null;
} }
public resetStorage(): void { public resetStorage(): void {

View File

@ -1,20 +1,20 @@
import { x25519 } from 'bcrypto'; import {x25519} from 'bcrypto';
import { Buffer } from "buffer"; import {Buffer} from "buffer";
import Wrap from 'it-pb-rpc'; import Wrap from 'it-pb-rpc';
import DuplexPair from 'it-pair/duplex'; import DuplexPair from 'it-pair/duplex';
import ensureBuffer from 'it-buffer'; import ensureBuffer from 'it-buffer';
import pipe from 'it-pipe'; import pipe from 'it-pipe';
import lp from 'it-length-prefixed'; import lp from 'it-length-prefixed';
import { XXHandshake } from "./handshake-xx"; import {XXHandshake} from "./handshake-xx";
import { IKHandshake } from "./handshake-ik"; import {IKHandshake} from "./handshake-ik";
import { XXFallbackHandshake } from "./handshake-xx-fallback"; import {XXFallbackHandshake} from "./handshake-xx-fallback";
import { generateKeypair, getPayload } from "./utils"; import {generateKeypair, getPayload} from "./utils";
import { uint16BEDecode, uint16BEEncode } from "./encoder"; import {uint16BEDecode, uint16BEEncode} from "./encoder";
import { decryptStream, encryptStream } from "./crypto"; import {decryptStream, encryptStream} from "./crypto";
import { bytes } from "./@types/basic"; import {bytes, bytes32} from "./@types/basic";
import { INoiseConnection, KeyPair, SecureOutbound } from "./@types/libp2p"; import {INoiseConnection, KeyPair, SecureOutbound} from "./@types/libp2p";
import { Duplex } from "./@types/it-pair"; import {Duplex} from "./@types/it-pair";
import {IHandshake} from "./@types/handshake-interface"; import {IHandshake} from "./@types/handshake-interface";
import {KeyCache} from "./keycache"; import {KeyCache} from "./keycache";
import {logger} from "./logger"; import {logger} from "./logger";
@ -26,7 +26,7 @@ type HandshakeParams = {
connection: WrappedConnection; connection: WrappedConnection;
isInitiator: boolean; isInitiator: boolean;
localPeer: PeerId; localPeer: PeerId;
remotePeer: PeerId; remotePeer?: PeerId;
}; };
export class Noise implements INoiseConnection { export class Noise implements INoiseConnection {
@ -88,7 +88,7 @@ export class Noise implements INoiseConnection {
* @param {PeerId} remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades. * @param {PeerId} remotePeer - optional PeerId of the initiating peer, if known. This may only exist during transport upgrades.
* @returns {Promise<SecureOutbound>} * @returns {Promise<SecureOutbound>}
*/ */
public async secureInbound(localPeer: PeerId, connection: any, remotePeer: PeerId): Promise<SecureOutbound> { public async secureInbound(localPeer: PeerId, connection: any, remotePeer?: PeerId): Promise<SecureOutbound> {
const wrappedConnection = Wrap(connection); const wrappedConnection = Wrap(connection);
const handshake = await this.performHandshake({ const handshake = await this.performHandshake({
connection: wrappedConnection, connection: wrappedConnection,
@ -100,7 +100,7 @@ export class Noise implements INoiseConnection {
return { return {
conn, conn,
remotePeer, remotePeer: handshake.remotePeer
}; };
} }
@ -111,26 +111,38 @@ export class Noise implements INoiseConnection {
*/ */
private async performHandshake(params: HandshakeParams): Promise<IHandshake> { private async performHandshake(params: HandshakeParams): Promise<IHandshake> {
const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData);
let tryIK = this.useNoisePipes;
const remoteStaticKey = KeyCache.load(params.remotePeer); if(params.isInitiator && KeyCache.load(params.remotePeer) === null) {
//if we are initiator and remote static key is unknown, don't try IK
tryIK = false;
}
// Try IK if acting as responder or initiator that has remote's static key. // Try IK if acting as responder or initiator that has remote's static key.
if (this.useNoisePipes && remoteStaticKey) { if (tryIK) {
// Try IK first // Try IK first
const { remotePeer, connection, isInitiator } = params; const { remotePeer, connection, isInitiator } = params;
const ikHandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, remoteStaticKey); const ikHandshake = new IKHandshake(
isInitiator,
payload,
this.prologue,
this.staticKeys,
connection,
//safe to cast as we did checks
KeyCache.load(params.remotePeer) || Buffer.alloc(32),
remotePeer as PeerId,
);
try { try {
return await this.performIKHandshake(ikHandshake); return await this.performIKHandshake(ikHandshake);
} catch (e) { } catch (e) {
// IK failed, go to XX fallback // IK failed, go to XX fallback
let ephemeralKeys; let ephemeralKeys;
if (params.isInitiator) { if (!params.isInitiator) {
ephemeralKeys = ikHandshake.getRemoteEphemeralKeys(); ephemeralKeys = ikHandshake.getRemoteEphemeralKeys();
} }
return await this.performXXFallbackHandshake(params, payload, e.initialMsg, ephemeralKeys); return await this.performXXFallbackHandshake(params, payload, e.initialMsg, ephemeralKeys);
} }
} else { } else {
// Noise pipes not supported, use XX // run XX handshake
return await this.performXXHandshake(params, payload); return await this.performXXHandshake(params, payload);
} }
} }
@ -143,7 +155,7 @@ export class Noise implements INoiseConnection {
): Promise<XXFallbackHandshake> { ): Promise<XXFallbackHandshake> {
const { isInitiator, remotePeer, connection } = params; const { isInitiator, remotePeer, connection } = params;
const handshake = const handshake =
new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, initialMsg, remotePeer, ephemeralKeys);
try { try {
await handshake.propose(); await handshake.propose();

View File

@ -63,6 +63,18 @@ export async function signPayload(peerId: PeerId, payload: bytes): Promise<bytes
return peerId.privKey.sign(payload); return peerId.privKey.sign(payload);
} }
export async function getPeerIdFromPayload(payload: bytes): Promise<PeerId> {
const decodedPayload = await decodePayload(payload);
return await PeerId.createFromPubKey(Buffer.from(decodedPayload.identityKey));
}
async function decodePayload(payload: bytes){
const NoiseHandshakePayload = await loadPayloadProto();
return NoiseHandshakePayload.toObject(
NoiseHandshakePayload.decode(payload)
);
}
export const getHandshakePayload = (publicKey: bytes ) => Buffer.concat([Buffer.from("noise-libp2p-static-key:"), publicKey]); export const getHandshakePayload = (publicKey: bytes ) => Buffer.concat([Buffer.from("noise-libp2p-static-key:"), publicKey]);
async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) { async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) {

View File

@ -25,10 +25,10 @@ describe("IK Handshake", () => {
const staticKeysResponder = generateKeypair(); const staticKeysResponder = generateKeypair();
const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey);
const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB, staticKeysResponder.publicKey); const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, staticKeysResponder.publicKey, peerB);
const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey);
const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, staticKeysInitiator.publicKey); const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, staticKeysInitiator.publicKey);
await handshakeInit.stage0(); await handshakeInit.stage0();
await handshakeResp.stage0(); await handshakeResp.stage0();
@ -66,10 +66,10 @@ describe("IK Handshake", () => {
const oldScammyKeys = generateKeypair(); const oldScammyKeys = generateKeypair();
const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey); const initPayload = await getPayload(peerA, staticKeysInitiator.publicKey);
const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, peerB, oldScammyKeys.publicKey); const handshakeInit = new IKHandshake(true, initPayload, prologue, staticKeysInitiator, connectionFrom, oldScammyKeys.publicKey, peerB);
const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey);
const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, staticKeysInitiator.publicKey); const handshakeResp = new IKHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, staticKeysInitiator.publicKey);
await handshakeInit.stage0(); await handshakeInit.stage0();
await handshakeResp.stage0(); await handshakeResp.stage0();

View File

@ -275,7 +275,7 @@ describe("Noise", () => {
// Prepare key cache for noise pipes // Prepare key cache for noise pipes
KeyCache.resetStorage(); KeyCache.resetStorage();
KeyCache.store(remotePeer, staticKeysResponder.publicKey); KeyCache.store(localPeer, staticKeysResponder.publicKey);
const [inboundConnection, outboundConnection] = DuplexPair(); const [inboundConnection, outboundConnection] = DuplexPair();

View File

@ -39,7 +39,7 @@ describe("XX Fallback Handshake", () => {
const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); const respPayload = await getPayload(peerB, staticKeysResponder.publicKey);
const handshakeResp = const handshakeResp =
new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, initialMsgR); new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, initialMsgR, peerA);
await handshakeResp.propose(); await handshakeResp.propose();
await handshakeResp.exchange(); await handshakeResp.exchange();
@ -48,7 +48,7 @@ describe("XX Fallback Handshake", () => {
// This is the point where initiator falls back from IK // This is the point where initiator falls back from IK
const initialMsgI = await connectionFrom.readLP(); const initialMsgI = await connectionFrom.readLP();
const handshakeInit = const handshakeInit =
new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, peerB, initialMsgI, ephemeralKeys); new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, initialMsgI, peerB, ephemeralKeys);
await handshakeInit.propose(); await handshakeInit.propose();
await handshakeInit.exchange(); await handshakeInit.exchange();