diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index da001bc..e5b8aa9 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -1,10 +1,12 @@ +import {WrappedConnection} from "./noise"; +import {IKHandshake} from "./handshakes/ik"; import {NoiseSession} from "./@types/handshake"; import {bytes, bytes32} from "./@types/basic"; import {KeyPair, PeerId} from "./@types/libp2p"; -import {WrappedConnection} from "./noise"; -import {IKHandshake} from "./handshakes/ik"; +import {HandshakeInterface} from "./@types/handshake-interface"; +import {Buffer} from "buffer"; -export class Handshake { // implements HandshakeHandler +export class Handshake implements HandshakeInterface { public isInitiator: boolean; public session: NoiseSession; @@ -41,4 +43,30 @@ export class Handshake { // implements HandshakeHandler const remoteStaticKeys = this.staticKeys; this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeys, remoteStaticKeys.publicKey); } + + public decrypt(ciphertext: Buffer, session: NoiseSession): Buffer { + const cs = this.getCS(session, false); + return this.ik.decryptWithAd(cs, Buffer.alloc(0), ciphertext); + } + + public encrypt(plaintext: Buffer, session: NoiseSession): Buffer { + const cs = this.getCS(session); + return this.ik.encryptWithAd(cs, Buffer.alloc(0), plaintext); + } + + public getRemoteEphemeralKeys(): KeyPair | undefined { + return this.session.hs.e; + } + + private getCS(session: NoiseSession, encryption = true) { + if (!session.cs1 || !session.cs2) { + throw new Error("Handshake not completed properly, cipher state does not exist."); + } + + if (this.isInitiator) { + return encryption ? session.cs1 : session.cs2; + } else { + return encryption ? session.cs2 : session.cs1; + } + } } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 3500077..19e06a3 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -36,6 +36,7 @@ export class Handshake implements HandshakeInterface { staticKeys: KeyPair, connection: WrappedConnection, remotePeer: PeerId, + ephemeralKeys?: KeyPair, handshake?: XXHandshake, ) { this.isInitiator = isInitiator; diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts index 3b8b0a9..7445f6a 100644 --- a/src/handshakes/xx.ts +++ b/src/handshakes/xx.ts @@ -27,9 +27,14 @@ export class XXHandshake extends AbstractHandshake { return { ss, s, rs, psk, re }; } - private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { + private writeMessageA(hs: HandshakeState, payload: bytes, e?: KeyPair): MessageBuffer { const ns = Buffer.alloc(0); - hs.e = generateKeypair(); + + if (e) { + hs.e = e; + } else { + hs.e = generateKeypair(); + } const ne = hs.e.publicKey; @@ -128,10 +133,10 @@ export class XXHandshake extends AbstractHandshake { }; } - public sendMessage(session: NoiseSession, message: bytes): MessageBuffer { + public sendMessage(session: NoiseSession, message: bytes, ephemeral?: KeyPair): MessageBuffer { let messageBuffer: MessageBuffer; if (session.mc.eqn(0)) { - messageBuffer = this.writeMessageA(session.hs, message); + messageBuffer = this.writeMessageA(session.hs, message, ephemeral); } else if (session.mc.eqn(1)) { messageBuffer = this.writeMessageB(session.hs, message); } else if (session.mc.eqn(2)) { diff --git a/src/noise.ts b/src/noise.ts index 9e06718..2451539 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -7,6 +7,7 @@ import pipe from 'it-pipe'; import lp from 'it-length-prefixed'; import { Handshake as XX } from "./handshake-xx"; +import { Handshake as IK } from "./handshake-ik"; import { generateKeypair } from "./utils"; import { uint16BEDecode, uint16BEEncode } from "./encoder"; import { decryptStream, encryptStream } from "./crypto"; @@ -21,6 +22,7 @@ export type WrappedConnection = ReturnType; export class Noise implements NoiseConnection { public protocol = "/noise"; + private readonly prologue = Buffer.from(this.protocol); private readonly privateKey: bytes; private readonly staticKeys: KeyPair; private readonly earlyData?: bytes; @@ -92,21 +94,29 @@ export class Noise implements NoiseConnection { libp2pPublicKey: bytes, remotePeer: PeerId, ): Promise { + // TODO: Implement noise pipes - if (false) { + const IKhandshake = new IK(isInitiator, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); + if(true) { + // XX fallback + const ephemeralKeys = IKhandshake.getRemoteEphemeralKeys(); + return await this.performXXHandshake(connection, isInitiator, libp2pPublicKey, remotePeer, ephemeralKeys); } else { - const prologue = Buffer.from(this.protocol); - const handshake = new XX(isInitiator, this.privateKey, libp2pPublicKey, prologue, this.staticKeys, connection, remotePeer); - - return await this.performXXHandshake(handshake); + return await this.performXXHandshake(connection, isInitiator, libp2pPublicKey, remotePeer); } } private async performXXHandshake( - handshake: XX + connection: WrappedConnection, + isInitiator: boolean, + libp2pPublicKey: bytes, + remotePeer: PeerId, + ephemeralKeys?: KeyPair, ): Promise { + const handshake = new XX(isInitiator, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer, ephemeralKeys); + try { await handshake.propose(); await handshake.exchange(); diff --git a/test/noise.test.ts b/test/noise.test.ts index 9e8ded3..230c7e8 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -61,7 +61,7 @@ describe("Noise", () => { const xx = new XXHandshake(); const { privateKey: libp2pPrivKey, publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); - const handshake = new Handshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, localPeer, xx); + const handshake = new Handshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, localPeer, undefined, xx); let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key