diff --git a/src/@types/handshake.ts b/src/@types/handshake.ts new file mode 100644 index 0000000..265f769 --- /dev/null +++ b/src/@types/handshake.ts @@ -0,0 +1,39 @@ +import {bytes, bytes32, uint32, uint64} from "./basic"; +import {KeyPair} from "./libp2p"; + +export type Hkdf = [bytes, bytes, bytes]; + +export interface MessageBuffer { + ne: bytes32; + ns: bytes; + ciphertext: bytes; +} + +export type CipherState = { + k: bytes32; + n: uint32; +} + +export type SymmetricState = { + cs: CipherState; + ck: bytes32; // chaining key + h: bytes32; // handshake hash +} + +export type HandshakeState = { + ss: SymmetricState; + s: KeyPair; + e?: KeyPair; + rs: bytes32; + re: bytes32; + psk: bytes32; +} + +export type NoiseSession = { + hs: HandshakeState; + h?: bytes32; + cs1?: CipherState; + cs2?: CipherState; + mc: uint64; + i: boolean; +} diff --git a/src/encoder.ts b/src/encoder.ts index 6c4cf75..e4e2c30 100644 --- a/src/encoder.ts +++ b/src/encoder.ts @@ -1,18 +1,18 @@ import {Buffer} from "buffer"; import {bytes} from "./@types/basic"; -import {MessageBuffer} from "./xx"; +import {MessageBuffer} from "./@types/handshake"; -export const int16BEEncode = (value, target, offset) => { +export const uint16BEEncode = (value, target, offset) => { target = target || Buffer.allocUnsafe(2); return target.writeUInt16BE(value, offset); }; -int16BEEncode.bytes = 2; +uint16BEEncode.bytes = 2; -export const int16BEDecode = data => { +export const uint16BEDecode = data => { if (data.length < 2) throw RangeError('Could not decode int16BE'); return data.readUInt16BE(0); }; -int16BEDecode.bytes = 2; +uint16BEDecode.bytes = 2; export function encodeMessageBuffer(message: MessageBuffer): bytes { return Buffer.concat([message.ne, message.ns, message.ciphertext]); diff --git a/src/handshake.ts b/src/handshake.ts index cb76381..b2d96dc 100644 --- a/src/handshake.ts +++ b/src/handshake.ts @@ -1,8 +1,9 @@ import { Buffer } from "buffer"; -import { bytes, bytes32 } from "./@types/basic"; -import { NoiseSession, XXHandshake } from "./xx"; +import { XXHandshake } from "./handshakes/xx"; import { KeyPair, PeerId } from "./@types/libp2p"; +import { bytes, bytes32 } from "./@types/basic"; +import { NoiseSession } from "./@types/handshake"; import { createHandshakePayload, getHandshakePayload, diff --git a/src/handshakes/abstract-handshake.ts b/src/handshakes/abstract-handshake.ts new file mode 100644 index 0000000..486dd96 --- /dev/null +++ b/src/handshakes/abstract-handshake.ts @@ -0,0 +1,171 @@ +import {Buffer} from "buffer"; +import { AEAD, x25519, SHA256 } from 'bcrypto'; + +import {bytes, bytes32, uint32} from "../@types/basic"; +import {CipherState, MessageBuffer, SymmetricState} from "../@types/handshake"; +import {getHkdf} from "../utils"; + +export const MIN_NONCE = 0; + +export abstract class AbstractHandshake { + public encryptWithAd(cs: CipherState, ad: bytes, plaintext: bytes): bytes { + const e = this.encrypt(cs.k, cs.n, ad, plaintext); + this.setNonce(cs, this.incrementNonce(cs.n)); + + return e; + } + + public decryptWithAd(cs: CipherState, ad: bytes, ciphertext: bytes): bytes { + const plaintext = this.decrypt(cs.k, cs.n, ad, ciphertext); + this.setNonce(cs, this.incrementNonce(cs.n)); + + return plaintext; + } + + + // Cipher state related + protected hasKey(cs: CipherState): boolean { + return !this.isEmptyKey(cs.k); + } + + protected setNonce(cs: CipherState, nonce: uint32): void { + cs.n = nonce; + } + + protected createEmptyKey(): bytes32 { + return Buffer.alloc(32); + } + + protected isEmptyKey(k: bytes32): boolean { + const emptyKey = this.createEmptyKey(); + return emptyKey.equals(k); + } + + protected incrementNonce(n: uint32): uint32 { + return n + 1; + } + + protected nonceToBytes(n: uint32): bytes { + const nonce = Buffer.alloc(12); + nonce.writeUInt32LE(n, 4); + + return nonce; + } + + protected encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes { + const nonce = this.nonceToBytes(n); + const ctx = new AEAD(); + + ctx.init(k, nonce); + ctx.aad(ad); + ctx.encrypt(plaintext); + + // Encryption is done on the sent reference + return plaintext; + } + + protected encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes { + let ciphertext; + if (this.hasKey(ss.cs)) { + ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext); + } else { + ciphertext = plaintext; + } + + this.mixHash(ss, ciphertext); + return ciphertext; + } + + protected decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes { + const nonce = this.nonceToBytes(n); + const ctx = new AEAD(); + + ctx.init(k, nonce); + ctx.aad(ad); + ctx.decrypt(ciphertext); + + // Decryption is done on the sent reference + return ciphertext; + } + + protected decryptAndHash(ss: SymmetricState, ciphertext: bytes): bytes { + let plaintext; + if (this.hasKey(ss.cs)) { + plaintext = this.decryptWithAd(ss.cs, ss.h, ciphertext); + } else { + plaintext = ciphertext; + } + + this.mixHash(ss, ciphertext); + return plaintext; + } + + protected dh(privateKey: bytes32, publicKey: bytes32): bytes32 { + const derived = x25519.derive(publicKey, privateKey); + const result = Buffer.alloc(32); + derived.copy(result); + return result; + } + + protected mixHash(ss: SymmetricState, data: bytes): void { + ss.h = this.getHash(ss.h, data); + } + + protected getHash(a: bytes, b: bytes): bytes32 { + return SHA256.digest(Buffer.from([...a, ...b])); + } + + protected mixKey(ss: SymmetricState, ikm: bytes32): void { + const [ ck, tempK ] = getHkdf(ss.ck, ikm); + ss.cs = this.initializeKey(tempK) as CipherState; + ss.ck = ck; + } + + protected initializeKey(k: bytes32): CipherState { + const n = MIN_NONCE; + return { k, n }; + } + + // Symmetric state related + + protected initializeSymmetric(protocolName: string): SymmetricState { + const protocolNameBytes: bytes = Buffer.from(protocolName, 'utf-8'); + const h = this.hashProtocolName(protocolNameBytes); + + const ck = h; + const key = this.createEmptyKey(); + const cs: CipherState = this.initializeKey(key); + + return { cs, ck, h }; + } + + protected hashProtocolName(protocolName: bytes): bytes32 { + if (protocolName.length <= 32) { + const h = Buffer.alloc(32); + protocolName.copy(h); + return h; + } else { + return this.getHash(protocolName, Buffer.alloc(0)); + } + } + + protected split(ss: SymmetricState) { + const [ tempk1, tempk2 ] = getHkdf(ss.ck, Buffer.alloc(0)); + const cs1 = this.initializeKey(tempk1); + const cs2 = this.initializeKey(tempk2); + + return { cs1, cs2 }; + } + + protected writeMessageRegular(cs: CipherState, payload: bytes): MessageBuffer { + const ciphertext = this.encryptWithAd(cs, Buffer.alloc(0), payload); + const ne = this.createEmptyKey(); + const ns = Buffer.alloc(0); + + return { ne, ns, ciphertext }; + } + + protected readMessageRegular(cs: CipherState, message: MessageBuffer): bytes { + return this.decryptWithAd(cs, Buffer.alloc(0), message.ciphertext); + } +} diff --git a/src/handshakes/ik.ts b/src/handshakes/ik.ts new file mode 100644 index 0000000..f7c04ff --- /dev/null +++ b/src/handshakes/ik.ts @@ -0,0 +1,172 @@ +import {Buffer} from "buffer"; +import {BN} from "bn.js"; + +import {HandshakeState, MessageBuffer, NoiseSession} from "../@types/handshake"; +import {bytes, bytes32} from "../@types/basic"; +import {generateKeypair, getHkdf, isValidPublicKey} from "../utils"; +import {AbstractHandshake} from "./abstract-handshake"; +import {KeyPair} from "../@types/libp2p"; + + +export class IKHandshake extends AbstractHandshake { + public initSession(initiator: boolean, prologue: bytes32, s: KeyPair, rs: bytes32): NoiseSession { + const psk = this.createEmptyKey(); + + let hs; + if (initiator) { + hs = this.initializeInitiator(prologue, s, rs, psk); + } else { + hs = this.initializeResponder(prologue, s, rs, psk); + } + + return { + hs, + i: initiator, + mc: new BN(0), + }; + } + + public sendMessage(session: NoiseSession, message: bytes): MessageBuffer { + let messageBuffer: MessageBuffer; + if (session.mc.eqn(0)) { + messageBuffer = this.writeMessageA(session.hs, message); + } else if (session.mc.eqn(1)) { + const { messageBuffer: mb, h, cs1, cs2 } = this.writeMessageB(session.hs, message); + messageBuffer = mb; + session.h = h; + session.cs1 = cs1; + session.cs2 = cs2; + } else if (session.mc.gtn(1)) { + if (session.i) { + if (!session.cs1) { + throw new Error("CS1 (cipher state) is not defined") + } + + messageBuffer = this.writeMessageRegular(session.cs1, message); + } else { + if (!session.cs2) { + throw new Error("CS2 (cipher state) is not defined") + } + + messageBuffer = this.writeMessageRegular(session.cs2, message); + } + } else { + throw new Error("Session invalid.") + } + + session.mc = session.mc.add(new BN(1)); + return messageBuffer; + } + + public recvMessage(session: NoiseSession, message: MessageBuffer): bytes { + let plaintext: bytes; + if (session.mc.eqn(0)) { + plaintext = this.readMessageA(session.hs, message); + } else if (session.mc.eqn(1)) { + const { plaintext: pt, h, cs1, cs2 } = this.readMessageB(session.hs, message); + plaintext = pt; + session.h = h; + session.cs1 = cs1; + session.cs2 = cs2; + delete session.hs; + } else if (session.mc.gtn(1)) { + if (session.i) { + if (!session.cs2) { + throw new Error("CS1 (cipher state) is not defined") + } + plaintext = this.readMessageRegular(session.cs2, message); + } else { + if (!session.cs1) { + throw new Error("CS1 (cipher state) is not defined") + } + plaintext = this.readMessageRegular(session.cs1, message); + } + } else { + throw new Error("Session invalid."); + } + + session.mc = session.mc.add(new BN(1)); + return plaintext; + } + + private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { + hs.e = generateKeypair(); + const ne = hs.e.publicKey; + this.mixHash(hs.ss, ne); + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); + const spk = Buffer.from(hs.s.publicKey); + const ns = this.encryptAndHash(hs.ss, spk); + + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.rs)); + const ciphertext = this.encryptAndHash(hs.ss, payload); + + return { ne, ns, ciphertext }; + } + + private writeMessageB(hs: HandshakeState, payload: bytes) { + hs.e = generateKeypair(); + const ne = hs.e.publicKey; + this.mixHash(hs.ss, ne); + + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); + const ciphertext = this.encryptAndHash(hs.ss, payload); + const ns = this.createEmptyKey(); + const messageBuffer: MessageBuffer = {ne, ns, ciphertext}; + const { cs1, cs2 } = this.split(hs.ss); + + return { messageBuffer, cs1, cs2, h: hs.ss.h } + } + + private readMessageA(hs: HandshakeState, message: MessageBuffer): bytes { + if (isValidPublicKey(message.ne)) { + hs.re = message.ne; + } + + this.mixHash(hs.ss, hs.re); + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); + const ns = this.decryptAndHash(hs.ss, message.ns); + if (ns.length === 32 && isValidPublicKey(message.ns)) { + hs.rs = ns; + } + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.rs)); + return this.decryptAndHash(hs.ss, message.ciphertext); + } + + private readMessageB(hs: HandshakeState, message: MessageBuffer) { + if (isValidPublicKey(message.ne)) { + hs.re = message.ne; + } + + this.mixHash(hs.ss, hs.re); + if (!hs.e) { + throw new Error("Handshake state should contain ephemeral key by now."); + } + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); + const plaintext = this.decryptAndHash(hs.ss, message.ciphertext); + const { cs1, cs2 } = this.split(hs.ss); + + return { h: hs.ss.h, plaintext, cs1, cs2 }; + } + + private initializeInitiator(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState { + const name = "Noise_IK_25519_ChaChaPoly_SHA256"; + const ss = this.initializeSymmetric(name); + this.mixHash(ss, prologue); + this.mixHash(ss, rs); + const re = Buffer.alloc(32); + + return { ss, s, rs, re, psk }; + } + + private initializeResponder(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState { + const name = "Noise_IK_25519_ChaChaPoly_SHA256"; + const ss = this.initializeSymmetric(name); + this.mixHash(ss, prologue); + this.mixHash(ss, s.publicKey); + const re = Buffer.alloc(32); + + return { ss, s, rs, re, psk }; + } +} diff --git a/src/xx.ts b/src/handshakes/xx.ts similarity index 50% rename from src/xx.ts rename to src/handshakes/xx.ts index d4b479a..3b8b0a9 100644 --- a/src/xx.ts +++ b/src/handshakes/xx.ts @@ -1,54 +1,14 @@ import { Buffer } from 'buffer'; -import { AEAD, x25519, HKDF, SHA256 } from 'bcrypto'; import { BN } from 'bn.js'; -import { bytes32, uint32, uint64, bytes } from './@types/basic' -import { KeyPair } from './@types/libp2p' -import { generateKeypair } from './utils'; +import { bytes32, bytes } from '../@types/basic' +import { KeyPair } from '../@types/libp2p' +import {generateKeypair, getHkdf, isValidPublicKey} from '../utils'; +import { HandshakeState, MessageBuffer, NoiseSession } from "../@types/handshake"; +import {AbstractHandshake} from "./abstract-handshake"; -export interface MessageBuffer { - ne: bytes32; - ns: bytes; - ciphertext: bytes; -} - -type CipherState = { - k: bytes32; - n: uint32; -} - -type SymmetricState = { - cs: CipherState; - ck: bytes32; // chaining key - h: bytes32; // handshake hash -} - -type HandshakeState = { - ss: SymmetricState; - s: KeyPair; - e?: KeyPair; - rs: bytes32; - re: bytes32; - psk: bytes32; -} - -export type NoiseSession = { - hs: HandshakeState; - h?: bytes32; - cs1?: CipherState; - cs2?: CipherState; - mc: uint64; - i: boolean; -} -export type Hkdf = [bytes, bytes, bytes]; - -const minNonce = 0; - -export class XXHandshake { - private createEmptyKey(): bytes32 { - return Buffer.alloc(32); - } +export class XXHandshake extends AbstractHandshake { private initializeInitiator(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState { const name = "Noise_XX_25519_ChaChaPoly_SHA256"; const ss = this.initializeSymmetric(name); @@ -67,162 +27,6 @@ export class XXHandshake { return { ss, s, rs, psk, re }; } - private incrementNonce(n: uint32): uint32 { - return n + 1; - } - - private dh(privateKey: bytes32, publicKey: bytes32): bytes32 { - const derived = x25519.derive(publicKey, privateKey); - const result = Buffer.alloc(32); - derived.copy(result); - return result; - } - - private nonceToBytes(n: uint32): bytes { - const nonce = Buffer.alloc(12); - nonce.writeUInt32LE(n, 4); - - return nonce; - } - - private encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes { - const nonce = this.nonceToBytes(n); - const ctx = new AEAD(); - - ctx.init(k, nonce); - ctx.aad(ad); - ctx.encrypt(plaintext); - - // Encryption is done on the sent reference - return plaintext; - } - - private decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes { - const nonce = this.nonceToBytes(n); - const ctx = new AEAD(); - - ctx.init(k, nonce); - ctx.aad(ad); - ctx.decrypt(ciphertext); - - // Decryption is done on the sent reference - return ciphertext; - } - - private isEmptyKey(k: bytes32): boolean { - const emptyKey = this.createEmptyKey(); - return emptyKey.equals(k); - } - - // Cipher state related - private initializeKey(k: bytes32): CipherState { - const n = minNonce; - return { k, n }; - } - - private hasKey(cs: CipherState): boolean { - return !this.isEmptyKey(cs.k); - } - - private setNonce(cs: CipherState, nonce: uint32): void { - cs.n = nonce; - } - - public encryptWithAd(cs: CipherState, ad: bytes, plaintext: bytes): bytes { - const e = this.encrypt(cs.k, cs.n, ad, plaintext); - this.setNonce(cs, this.incrementNonce(cs.n)); - - return e; - } - - public decryptWithAd(cs: CipherState, ad: bytes, ciphertext: bytes): bytes { - const plaintext = this.decrypt(cs.k, cs.n, ad, ciphertext); - this.setNonce(cs, this.incrementNonce(cs.n)); - - return plaintext; - } - - // Symmetric state related - - private initializeSymmetric(protocolName: string): SymmetricState { - const protocolNameBytes: bytes = Buffer.from(protocolName, 'utf-8'); - const h = this.hashProtocolName(protocolNameBytes); - - const ck = h; - const key = this.createEmptyKey(); - const cs: CipherState = this.initializeKey(key); - - return { cs, ck, h }; - } - - private mixKey(ss: SymmetricState, ikm: bytes32): void { - const [ ck, tempK ] = this.getHkdf(ss.ck, ikm); - ss.cs = this.initializeKey(tempK) as CipherState; - ss.ck = ck; - } - - private hashProtocolName(protocolName: bytes): bytes32 { - if (protocolName.length <= 32) { - const h = Buffer.alloc(32); - protocolName.copy(h); - return h; - } else { - return this.getHash(protocolName, Buffer.alloc(0)); - } - } - - public getHkdf(ck: bytes32, ikm: bytes): Hkdf { - const info = Buffer.alloc(0); - const prk = HKDF.extract(SHA256, ikm, ck); - const okm = HKDF.expand(SHA256, prk, info, 96); - - const k1 = okm.slice(0, 32); - const k2 = okm.slice(32, 64); - const k3 = okm.slice(64, 96); - - return [ k1, k2, k3 ]; - } - - private mixHash(ss: SymmetricState, data: bytes) { - ss.h = this.getHash(ss.h, data); - } - - private getHash(a: bytes, b: bytes): bytes32 { - return SHA256.digest(Buffer.from([...a, ...b])); - } - - private encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes { - let ciphertext; - if (this.hasKey(ss.cs)) { - ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext); - } else { - ciphertext = plaintext; - } - - this.mixHash(ss, ciphertext); - return ciphertext; - } - - private decryptAndHash(ss: SymmetricState, ciphertext: bytes): bytes { - let plaintext; - if (this.hasKey(ss.cs)) { - plaintext = this.decryptWithAd(ss.cs, ss.h, ciphertext); - } else { - plaintext = ciphertext; - } - - this.mixHash(ss, ciphertext); - return plaintext; - } - - private split (ss: SymmetricState) { - const [ tempk1, tempk2 ] = this.getHkdf(ss.ck, Buffer.alloc(0)); - const cs1 = this.initializeKey(tempk1); - const cs2 = this.initializeKey(tempk2); - - return { cs1, cs2 }; - } - private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { const ns = Buffer.alloc(0); hs.e = generateKeypair(); @@ -262,16 +66,8 @@ export class XXHandshake { return { h: hs.ss.h, messageBuffer, cs1, cs2 }; } - private writeMessageRegular(cs: CipherState, payload: bytes): MessageBuffer { - const ciphertext = this.encryptWithAd(cs, Buffer.alloc(0), payload); - const ne = this.createEmptyKey(); - const ns = Buffer.alloc(0); - - return { ne, ns, ciphertext }; - } - private readMessageA(hs: HandshakeState, message: MessageBuffer): bytes { - if (x25519.publicKeyVerify(message.ne)) { + if (isValidPublicKey(message.ne)) { hs.re = message.ne; } @@ -280,7 +76,7 @@ export class XXHandshake { } private readMessageB(hs: HandshakeState, message: MessageBuffer): bytes { - if (x25519.publicKeyVerify(message.ne)) { + if (isValidPublicKey(message.ne)) { hs.re = message.ne; } @@ -290,7 +86,7 @@ export class XXHandshake { } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); const ns = this.decryptAndHash(hs.ss, message.ns); - if (ns.length === 32 && x25519.publicKeyVerify(message.ns)) { + if (ns.length === 32 && isValidPublicKey(message.ns)) { hs.rs = ns; } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); @@ -299,7 +95,7 @@ export class XXHandshake { private readMessageC(hs: HandshakeState, message: MessageBuffer) { const ns = this.decryptAndHash(hs.ss, message.ns); - if (ns.length === 32 && x25519.publicKeyVerify(message.ns)) { + if (ns.length === 32 && isValidPublicKey(message.ns)) { hs.rs = ns; } @@ -314,10 +110,6 @@ export class XXHandshake { return { h: hs.ss.h, plaintext, cs1, cs2 }; } - private readMessageRegular(cs: CipherState, message: MessageBuffer): bytes { - return this.decryptWithAd(cs, Buffer.alloc(0), message.ciphertext); - } - public initSession(initiator: boolean, prologue: bytes32, s: KeyPair): NoiseSession { const psk = this.createEmptyKey(); const rs = Buffer.alloc(32); // no static key yet @@ -348,6 +140,7 @@ export class XXHandshake { session.h = h; session.cs1 = cs1; session.cs2 = cs2; + delete session.hs; } else if (session.mc.gtn(2)) { if (session.i) { if (!session.cs1) { diff --git a/src/noise.ts b/src/noise.ts index 83b313e..068ade3 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -8,7 +8,7 @@ import lp from 'it-length-prefixed'; import { Handshake } from "./handshake"; import { generateKeypair } from "./utils"; -import { int16BEDecode, int16BEEncode } from "./encoder"; +import { uint16BEDecode, uint16BEEncode } from "./encoder"; import { decryptStream, encryptStream } from "./crypto"; import { bytes } from "./@types/basic"; import { NoiseConnection, PeerId, KeyPair, SecureOutbound } from "./@types/libp2p"; @@ -108,9 +108,9 @@ export class Noise implements NoiseConnection { secure, // write to wrapper ensureBuffer, // ensure any type of data is converted to buffer encryptStream(handshake), // data is encrypted - lp.encode({ lengthEncoder: int16BEEncode }), // prefix with message length + lp.encode({ lengthEncoder: uint16BEEncode }), // prefix with message length network, // send to the remote peer - lp.decode({ lengthDecoder: int16BEDecode }), // read message length prefix + lp.decode({ lengthDecoder: uint16BEDecode }), // read message length prefix ensureBuffer, // ensure any type of data is converted to buffer decryptStream(handshake), // decrypt the incoming data secure // pipe to the wrapper diff --git a/src/utils.ts b/src/utils.ts index 65b7939..005f826 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,11 +1,12 @@ -import { x25519, ed25519 } from 'bcrypto'; +import { x25519, ed25519, HKDF, SHA256 } from 'bcrypto'; import protobuf from "protobufjs"; import { Buffer } from "buffer"; import PeerId from "peer-id"; import * as crypto from 'libp2p-crypto'; import { KeyPair } from "./@types/libp2p"; -import { bytes } from "./@types/basic"; +import {bytes, bytes32} from "./@types/basic"; +import {Hkdf} from "./@types/handshake"; export async function loadPayloadProto () { const payloadProtoBuf = await protobuf.load("protos/payload.proto"); @@ -88,3 +89,19 @@ export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: byte throw new Error("Static key doesn't match to peer that signed payload!"); } } + +export function getHkdf(ck: bytes32, ikm: bytes): Hkdf { + const info = Buffer.alloc(0); + const prk = HKDF.extract(SHA256, ikm, ck); + const okm = HKDF.expand(SHA256, prk, info, 96); + + const k1 = okm.slice(0, 32); + const k2 = okm.slice(32, 64); + const k3 = okm.slice(64, 96); + + return [ k1, k2, k3 ]; +} + +export function isValidPublicKey(pk: bytes): boolean { + return x25519.publicKeyVerify(pk); +} diff --git a/test/handshakes/ik.test.ts b/test/handshakes/ik.test.ts new file mode 100644 index 0000000..6a4bf8e --- /dev/null +++ b/test/handshakes/ik.test.ts @@ -0,0 +1,67 @@ +import {Buffer} from "buffer"; +import {IKHandshake} from "../../src/handshakes/ik"; +import {KeyPair} from "../../src/@types/libp2p"; +import {createHandshakePayload, generateKeypair, getHandshakePayload} from "../../src/utils"; +import {assert, expect} from "chai"; +import {generateEd25519Keys} from "../utils"; + +describe("Index", () => { + const prologue = Buffer.from("/noise", "utf-8"); + + it("Test complete IK handshake", async () => { + try { + const ikI = new IKHandshake(); + const ikR = new IKHandshake(); + + // Generate static noise keys + const kpInitiator: KeyPair = await generateKeypair(); + const kpResponder: KeyPair = await generateKeypair(); + + // Generate libp2p keys + const libp2pInitKeys = await generateEd25519Keys(); + const libp2pRespKeys = await generateEd25519Keys(); + + // Create sessions + const initiatorSession = await ikI.initSession(true, prologue, kpInitiator, kpResponder.publicKey); + const responderSession = await ikR.initSession(false, prologue, kpResponder, Buffer.alloc(32)); + + /* Stage 0 */ + + // initiator creates payload + const initSignedPayload = await libp2pInitKeys.sign(getHandshakePayload(kpInitiator.publicKey)); + const libp2pInitPrivKey = libp2pInitKeys.marshal().slice(0, 32); + const libp2pInitPubKey = libp2pInitKeys.marshal().slice(32, 64); + const payloadInitEnc = await createHandshakePayload(libp2pInitPubKey, libp2pInitPrivKey, initSignedPayload); + + // initiator sends message + const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]); + const messageBuffer = ikI.sendMessage(initiatorSession, message); + + expect(messageBuffer.ne.length).not.equal(0); + + // responder receives message + const plaintext = ikR.recvMessage(responderSession, messageBuffer); + + /* Stage 1 */ + + // responder creates payload + const libp2pRespPrivKey = libp2pRespKeys.marshal().slice(0, 32); + const libp2pRespPubKey = libp2pRespKeys.marshal().slice(32, 64); + const respSignedPayload = await libp2pRespKeys.sign(getHandshakePayload(kpResponder.publicKey)); + const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, libp2pRespPrivKey, respSignedPayload); + + const message1 = Buffer.concat([message, payloadRespEnc]); + const messageBuffer2 = ikR.sendMessage(responderSession, message1); + + // initiator receives message + const plaintext2 = ikI.recvMessage(initiatorSession, messageBuffer2); + + assert(initiatorSession.cs1.k.equals(responderSession.cs1.k)); + assert(initiatorSession.cs2.k.equals(responderSession.cs2.k)); + + } catch (e) { + console.log(e); + assert(false, e.message); + } + }); +}); diff --git a/test/xx.test.ts b/test/handshakes/xx.test.ts similarity index 95% rename from test/xx.test.ts rename to test/handshakes/xx.test.ts index 1eca6b4..b6d93f4 100644 --- a/test/xx.test.ts +++ b/test/handshakes/xx.test.ts @@ -1,10 +1,10 @@ import { expect, assert } from "chai"; import { Buffer } from 'buffer'; -import { XXHandshake } from "../src/xx"; -import { KeyPair } from "../src/@types/libp2p"; -import { generateEd25519Keys } from "./utils"; -import {createHandshakePayload, generateKeypair, getHandshakePayload} from "../src/utils"; +import { XXHandshake } from "../../src/handshakes/xx"; +import { KeyPair } from "../../src/@types/libp2p"; +import { generateEd25519Keys } from "../utils"; +import {createHandshakePayload, generateKeypair, getHandshakePayload, getHkdf} from "../../src/utils"; describe("Index", () => { const prologue = Buffer.from("/noise", "utf-8"); @@ -29,7 +29,7 @@ describe("Index", () => { const ck = Buffer.alloc(32); ckBytes.copy(ck); - const [k1, k2, k3] = xx.getHkdf(ck, ikm); + const [k1, k2, k3] = getHkdf(ck, ikm); expect(k1.toString('hex')).to.equal('cc5659adff12714982f806e2477a8d5ddd071def4c29bb38777b7e37046f6914'); expect(k2.toString('hex')).to.equal('a16ada915e551ab623f38be674bb4ef15d428ae9d80688899c9ef9b62ef208fa'); expect(k3.toString('hex')).to.equal('ff67bf9727e31b06efc203907e6786667d2c7a74ac412b4d31a80ba3fd766f68'); diff --git a/test/noise.test.ts b/test/noise.test.ts index d6a8dd2..6594473 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -13,7 +13,7 @@ import { signPayload } from "../src/utils"; import { decodeMessageBuffer, encodeMessageBuffer } from "../src/encoder"; -import {XXHandshake} from "../src/xx"; +import {XXHandshake} from "../src/handshakes/xx"; import {Buffer} from "buffer"; import {getKeyPairFromPeerId} from "./utils";