import {bytes32, bytes16, uint32, uint64, bytes} from './types/basic' import { Buffer } from 'buffer'; import * as crypto from 'libp2p-crypto'; import { AEAD, x25519, HKDF } from 'bcrypto'; export interface KeyPair { publicKey: bytes32, privateKey: bytes32, } interface MessageBuffer { ne: bytes32, ns: bytes, ciphertext: bytes } type CipherState = { k: bytes32, n: uint32, } type SymmetricState = { cs: CipherState, ck: bytes32, // chaining key h: bytes32, // handshake hash } type HandshakeState = { ss: SymmetricState, s: KeyPair, e?: KeyPair, rs: bytes32, re?: bytes32, psk: bytes32, } type NoiseSession = { hs: HandshakeState, h?: bytes32, cs1?: CipherState, cs2?: CipherState, mc: uint64, i: boolean, } const minNonce = 0; export class XXHandshake { private createEmptyKey() : bytes32 { return Buffer.alloc(32); } private async initializeInitiator(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32) : Promise { const name = "Noise_XX_25519_ChaChaPoly_SHA256"; const ss = await this.initializeSymmetric(name); await this.mixHash(ss, prologue); return { ss, s, rs, psk }; } private async initializeResponder(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32) : Promise { const name = "Noise_XX_25519_ChaChaPoly_SHA256"; const ss = await this.initializeSymmetric(name); await this.mixHash(ss, prologue); return { ss, s, rs, psk }; } private incrementNonce(n: uint32) : uint32 { return n + 1; } private dh(privateKey: bytes32, publicKey: bytes32) : bytes32 { return x25519.derive(privateKey, publicKey); } private convertNonce(n: uint32) : bytes { const nonce = Buffer.alloc(12); nonce.writeUInt32LE(n, 4); return nonce; } private encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes) : bytes { const nonce = this.convertNonce(n); const ctx = new AEAD(); ctx.init(k, nonce); ctx.aad(ad); ctx.encrypt(plaintext); return ctx.final(); } private decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes) : bytes { const nonce = this.convertNonce(n); const ctx = new AEAD(); ctx.init(k, nonce); ctx.aad(ad); ctx.decrypt(ciphertext); return ctx.final(); } private isEmptyKey(k: bytes32) : boolean { const emptyKey = this.createEmptyKey(); return emptyKey.equals(k); } // Cipher state related private initializeKey(k: bytes32) : CipherState { const n = minNonce; return { k, n }; } private hasKey(cs: CipherState) : boolean { return !this.isEmptyKey(cs.k); } private setNonce(cs: CipherState, nonce: uint32) { cs.n = nonce; } private encryptWithAd(cs: CipherState, ad: bytes, plaintext: bytes) : bytes { const e = this.encrypt(cs.k, cs.n, ad, plaintext); this.setNonce(cs, this.incrementNonce(cs.n)); return e; } private decryptWithAd(cs: CipherState, ad: bytes, ciphertext: bytes) : bytes { const plaintext = this.decrypt(cs.k, cs.n, ad, ciphertext); this.setNonce(cs, this.incrementNonce(cs.n)); return plaintext; } // Symmetric state related private async initializeSymmetric(protocolName: string) : Promise { const protocolNameBytes: bytes = Buffer.from(protocolName, 'utf-8'); const h = await this.hashProtocolName(protocolNameBytes); const ck = h; const key = this.createEmptyKey(); const cs = this.initializeKey(key); return { cs, ck, h }; } private mixKey(ss: SymmetricState, ikm: bytes32) { const [ ck, tempK ] = this.getHkdf(ss.ck, ikm); ss.cs = this.initializeKey(tempK); ss.ck = ck; } private async hashProtocolName(protocolName: bytes) : Promise { if (protocolName.length <= 32) { return new Promise(resolve => { const h = Buffer.alloc(32); protocolName.copy(h); resolve(h) }); } else { return await this.getHash(protocolName, Buffer.from([])); } } private getHkdf(ck: bytes32, ikm: bytes) : Array { const info = Buffer.alloc(0); const prk = HKDF.extract('SHA256', ikm, ck); const okm = HKDF.expand('SHA256', prk, info, ikm.length); const k1 = okm.slice(0, 16); const k2 = okm.slice(16, 32); const k3 = okm.slice(32, 64); return [ k1, k2, k3 ]; } private async mixHash(ss: SymmetricState, data: bytes) { ss.h = await this.getHash(ss.h, data); } private async getHash(a: bytes, b: bytes) : Promise { return await crypto.hmac.create('sha256', Buffer.from([...a, ...b])) } private async encryptAndHash(ss: SymmetricState, plaintext: bytes) : Promise { let ciphertext; if (this.hasKey(ss.cs)) { ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext); } else { ciphertext = plaintext; } await this.mixHash(ss, ciphertext); return ciphertext; } private split (ss: SymmetricState) { const [ tempk1, tempk2 ] = this.getHkdf(ss.ck, Buffer.alloc(0)); const cs1 = this.initializeKey(tempk1); const cs2 = this.initializeKey(tempk2); return { cs1, cs2 }; } private async writeMessageA(hs: HandshakeState, payload: bytes) : Promise { let ns = Buffer.alloc(0); hs.e = await this.generateKeypair(); const ne = hs.e.publicKey; await this.mixHash(hs.ss, ne); const ciphertext = await this.encryptAndHash(hs.ss, payload); return {ne, ns, ciphertext} as MessageBuffer; } private async writeMessageB(hs: HandshakeState, payload: bytes) : Promise { hs.e = await this.generateKeypair(); const ne = hs.e.publicKey; await this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); const spk = Buffer.alloc(hs.s.publicKey.length); const ns = await this.encryptAndHash(hs.ss, spk); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); const ciphertext = await this.encryptAndHash(hs.ss, payload); return { ne, ns, ciphertext }; } private async writeMessageC(hs: HandshakeState, payload: bytes) { const spk = hs.s.publicKey; const ns = await this.encryptAndHash(hs.ss, spk); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); const ciphertext = await this.encryptAndHash(hs.ss, payload); const ne = this.createEmptyKey(); const messageBuffer: MessageBuffer = {ne, ns, ciphertext}; const { cs1, cs2 } = this.split(hs.ss); return { h: hs.ss.h, messageBuffer, cs1, cs2 }; } private async writeMessageRegular(cs: CipherState, payload: bytes) : Promise { const ciphertext = this.encryptWithAd(cs, Buffer.alloc(0), payload); const ne = this.createEmptyKey(); const ns = Buffer.alloc(0); return { ne, ns, ciphertext }; } public async generateKeypair() : Promise { return await crypto.keys.generateKeyPair('ed25519'); } public async initSession(initiator: boolean, prologue: bytes32, s: KeyPair, rs: bytes32) : Promise { const psk = this.createEmptyKey(); let hs; if (initiator) { hs = await this.initializeInitiator(prologue, s, rs, psk); } else { hs = await this.initializeResponder(prologue, s, rs, psk); } return { hs, i: initiator, mc: 0 }; } public async sendMessage(session: NoiseSession, message: bytes) : Promise { let messageBuffer: MessageBuffer = {} as MessageBuffer; if (session.mc === 0) { messageBuffer = await this.writeMessageA(session.hs, message); } else if (session.mc === 1) { messageBuffer = await this.writeMessageB(session.hs, message); } else if (session.mc === 2) { const { h, messageBuffer, cs1, cs2 } = await this.writeMessageC(session.hs, message); session.h = h; session.cs1 = cs1; session.cs2 = cs2; } else if (session.mc > 2) { if (session.i) { messageBuffer = await this.writeMessageRegular(session.cs1, message); } else { messageBuffer = await this.writeMessageRegular(session.cs2, message); } } else { throw new Error("Session invalid.") } session.mc++; return messageBuffer; } }