diff --git a/src/handshakes/abstract-handshake.ts b/src/handshakes/abstract-handshake.ts new file mode 100644 index 0000000..e34a901 --- /dev/null +++ b/src/handshakes/abstract-handshake.ts @@ -0,0 +1,60 @@ +import {Buffer} from "buffer"; +import { AEAD, x25519, HKDF, SHA256 } from 'bcrypto'; + +import {bytes, bytes32, uint32} from "../@types/basic"; +import {CipherState, SymmetricState} from "../@types/handshake"; +import {getHkdf} from "../utils"; + +export class AbstractHandshake { + protected minNonce = 0; + + protected incrementNonce(n: uint32): uint32 { + return n + 1; + } + + protected nonceToBytes(n: uint32): bytes { + const nonce = Buffer.alloc(12); + nonce.writeUInt32LE(n, 4); + + return nonce; + } + + protected encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes { + const nonce = this.nonceToBytes(n); + const ctx = new AEAD(); + + ctx.init(k, nonce); + ctx.aad(ad); + ctx.encrypt(plaintext); + + // Encryption is done on the sent reference + return plaintext; + } + + + protected dh(privateKey: bytes32, publicKey: bytes32): bytes32 { + const derived = x25519.derive(publicKey, privateKey); + const result = Buffer.alloc(32); + derived.copy(result); + return result; + } + + protected mixHash(ss: SymmetricState, data: bytes): void { + ss.h = this.getHash(ss.h, data); + } + + protected getHash(a: bytes, b: bytes): bytes32 { + return SHA256.digest(Buffer.from([...a, ...b])); + } + + protected mixKey(ss: SymmetricState, ikm: bytes32): void { + const [ ck, tempK ] = getHkdf(ss.ck, ikm); + ss.cs = this.initializeKey(tempK) as CipherState; + ss.ck = ck; + } + + protected initializeKey(k: bytes32): CipherState { + const n = this.minNonce; + return { k, n }; + } +} diff --git a/src/handshakes/ik.ts b/src/handshakes/ik.ts index e69de29..bf46def 100644 --- a/src/handshakes/ik.ts +++ b/src/handshakes/ik.ts @@ -0,0 +1,25 @@ +import {Buffer} from "buffer"; + +import {CipherState, HandshakeState, MessageBuffer, SymmetricState} from "../@types/handshake"; +import {bytes, bytes32} from "../@types/basic"; +import {generateKeypair, getHkdf} from "../utils"; +import {AbstractHandshake} from "./abstract-handshake"; + + +export class IKHandshake extends AbstractHandshake { + private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { + hs.e = generateKeypair(); + const ne = hs.e.publicKey; + this.mixHash(hs.ss, ne); + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); + const spk = Buffer.from(hs.s.publicKey); + const ns = this.encryptAndHash(hs.ss, spk); + + this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); + const ciphertext = this.encryptAndHash(hs.ss, payload); + + return { ne, ns, ciphertext }; + } + + +} diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts index 63f0dad..3273ed6 100644 --- a/src/handshakes/xx.ts +++ b/src/handshakes/xx.ts @@ -4,13 +4,12 @@ import { BN } from 'bn.js'; import { bytes32, uint32, uint64, bytes } from '../@types/basic' import { KeyPair } from '../@types/libp2p' -import { generateKeypair } from '../utils'; +import {generateKeypair, getHkdf} from '../utils'; import { CipherState, HandshakeState, Hkdf, MessageBuffer, NoiseSession, SymmetricState } from "../@types/handshake"; +import {AbstractHandshake} from "./abstract-handshake"; -const minNonce = 0; - -export class XXHandshake { +export class XXHandshake extends AbstractHandshake { private createEmptyKey(): bytes32 { return Buffer.alloc(32); } @@ -33,36 +32,6 @@ export class XXHandshake { return { ss, s, rs, psk, re }; } - private incrementNonce(n: uint32): uint32 { - return n + 1; - } - - private dh(privateKey: bytes32, publicKey: bytes32): bytes32 { - const derived = x25519.derive(publicKey, privateKey); - const result = Buffer.alloc(32); - derived.copy(result); - return result; - } - - private nonceToBytes(n: uint32): bytes { - const nonce = Buffer.alloc(12); - nonce.writeUInt32LE(n, 4); - - return nonce; - } - - private encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes { - const nonce = this.nonceToBytes(n); - const ctx = new AEAD(); - - ctx.init(k, nonce); - ctx.aad(ad); - ctx.encrypt(plaintext); - - // Encryption is done on the sent reference - return plaintext; - } - private decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes { const nonce = this.nonceToBytes(n); const ctx = new AEAD(); @@ -81,11 +50,6 @@ export class XXHandshake { } // Cipher state related - private initializeKey(k: bytes32): CipherState { - const n = minNonce; - return { k, n }; - } - private hasKey(cs: CipherState): boolean { return !this.isEmptyKey(cs.k); } @@ -121,12 +85,6 @@ export class XXHandshake { return { cs, ck, h }; } - private mixKey(ss: SymmetricState, ikm: bytes32): void { - const [ ck, tempK ] = this.getHkdf(ss.ck, ikm); - ss.cs = this.initializeKey(tempK) as CipherState; - ss.ck = ck; - } - private hashProtocolName(protocolName: bytes): bytes32 { if (protocolName.length <= 32) { const h = Buffer.alloc(32); @@ -137,26 +95,6 @@ export class XXHandshake { } } - public getHkdf(ck: bytes32, ikm: bytes): Hkdf { - const info = Buffer.alloc(0); - const prk = HKDF.extract(SHA256, ikm, ck); - const okm = HKDF.expand(SHA256, prk, info, 96); - - const k1 = okm.slice(0, 32); - const k2 = okm.slice(32, 64); - const k3 = okm.slice(64, 96); - - return [ k1, k2, k3 ]; - } - - private mixHash(ss: SymmetricState, data: bytes) { - ss.h = this.getHash(ss.h, data); - } - - private getHash(a: bytes, b: bytes): bytes32 { - return SHA256.digest(Buffer.from([...a, ...b])); - } - private encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes { let ciphertext; if (this.hasKey(ss.cs)) { @@ -182,7 +120,7 @@ export class XXHandshake { } private split (ss: SymmetricState) { - const [ tempk1, tempk2 ] = this.getHkdf(ss.ck, Buffer.alloc(0)); + const [ tempk1, tempk2 ] = getHkdf(ss.ck, Buffer.alloc(0)); const cs1 = this.initializeKey(tempk1); const cs2 = this.initializeKey(tempk2); diff --git a/src/utils.ts b/src/utils.ts index 65b7939..e749040 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,11 +1,12 @@ -import { x25519, ed25519 } from 'bcrypto'; +import { x25519, ed25519, HKDF, SHA256 } from 'bcrypto'; import protobuf from "protobufjs"; import { Buffer } from "buffer"; import PeerId from "peer-id"; import * as crypto from 'libp2p-crypto'; import { KeyPair } from "./@types/libp2p"; -import { bytes } from "./@types/basic"; +import {bytes, bytes32} from "./@types/basic"; +import {Hkdf} from "./@types/handshake"; export async function loadPayloadProto () { const payloadProtoBuf = await protobuf.load("protos/payload.proto"); @@ -88,3 +89,15 @@ export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: byte throw new Error("Static key doesn't match to peer that signed payload!"); } } + +export function getHkdf(ck: bytes32, ikm: bytes): Hkdf { + const info = Buffer.alloc(0); + const prk = HKDF.extract(SHA256, ikm, ck); + const okm = HKDF.expand(SHA256, prk, info, 96); + + const k1 = okm.slice(0, 32); + const k2 = okm.slice(32, 64); + const k3 = okm.slice(64, 96); + + return [ k1, k2, k3 ]; +}