Move common function to abstract class

This commit is contained in:
Belma Gutlic
2019-12-24 21:15:38 +01:00
parent dbbf579288
commit 0290df8685
4 changed files with 104 additions and 68 deletions

View File

@ -0,0 +1,60 @@
import {Buffer} from "buffer";
import { AEAD, x25519, HKDF, SHA256 } from 'bcrypto';
import {bytes, bytes32, uint32} from "../@types/basic";
import {CipherState, SymmetricState} from "../@types/handshake";
import {getHkdf} from "../utils";
export class AbstractHandshake {
protected minNonce = 0;
protected incrementNonce(n: uint32): uint32 {
return n + 1;
}
protected nonceToBytes(n: uint32): bytes {
const nonce = Buffer.alloc(12);
nonce.writeUInt32LE(n, 4);
return nonce;
}
protected encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes {
const nonce = this.nonceToBytes(n);
const ctx = new AEAD();
ctx.init(k, nonce);
ctx.aad(ad);
ctx.encrypt(plaintext);
// Encryption is done on the sent reference
return plaintext;
}
protected dh(privateKey: bytes32, publicKey: bytes32): bytes32 {
const derived = x25519.derive(publicKey, privateKey);
const result = Buffer.alloc(32);
derived.copy(result);
return result;
}
protected mixHash(ss: SymmetricState, data: bytes): void {
ss.h = this.getHash(ss.h, data);
}
protected getHash(a: bytes, b: bytes): bytes32 {
return SHA256.digest(Buffer.from([...a, ...b]));
}
protected mixKey(ss: SymmetricState, ikm: bytes32): void {
const [ ck, tempK ] = getHkdf(ss.ck, ikm);
ss.cs = this.initializeKey(tempK) as CipherState;
ss.ck = ck;
}
protected initializeKey(k: bytes32): CipherState {
const n = this.minNonce;
return { k, n };
}
}

View File

@ -0,0 +1,25 @@
import {Buffer} from "buffer";
import {CipherState, HandshakeState, MessageBuffer, SymmetricState} from "../@types/handshake";
import {bytes, bytes32} from "../@types/basic";
import {generateKeypair, getHkdf} from "../utils";
import {AbstractHandshake} from "./abstract-handshake";
export class IKHandshake extends AbstractHandshake {
private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer {
hs.e = generateKeypair();
const ne = hs.e.publicKey;
this.mixHash(hs.ss, ne);
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs));
const spk = Buffer.from(hs.s.publicKey);
const ns = this.encryptAndHash(hs.ss, spk);
this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re));
const ciphertext = this.encryptAndHash(hs.ss, payload);
return { ne, ns, ciphertext };
}
}

View File

@ -4,13 +4,12 @@ import { BN } from 'bn.js';
import { bytes32, uint32, uint64, bytes } from '../@types/basic'
import { KeyPair } from '../@types/libp2p'
import { generateKeypair } from '../utils';
import {generateKeypair, getHkdf} from '../utils';
import { CipherState, HandshakeState, Hkdf, MessageBuffer, NoiseSession, SymmetricState } from "../@types/handshake";
import {AbstractHandshake} from "./abstract-handshake";
const minNonce = 0;
export class XXHandshake {
export class XXHandshake extends AbstractHandshake {
private createEmptyKey(): bytes32 {
return Buffer.alloc(32);
}
@ -33,36 +32,6 @@ export class XXHandshake {
return { ss, s, rs, psk, re };
}
private incrementNonce(n: uint32): uint32 {
return n + 1;
}
private dh(privateKey: bytes32, publicKey: bytes32): bytes32 {
const derived = x25519.derive(publicKey, privateKey);
const result = Buffer.alloc(32);
derived.copy(result);
return result;
}
private nonceToBytes(n: uint32): bytes {
const nonce = Buffer.alloc(12);
nonce.writeUInt32LE(n, 4);
return nonce;
}
private encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes {
const nonce = this.nonceToBytes(n);
const ctx = new AEAD();
ctx.init(k, nonce);
ctx.aad(ad);
ctx.encrypt(plaintext);
// Encryption is done on the sent reference
return plaintext;
}
private decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes {
const nonce = this.nonceToBytes(n);
const ctx = new AEAD();
@ -81,11 +50,6 @@ export class XXHandshake {
}
// Cipher state related
private initializeKey(k: bytes32): CipherState {
const n = minNonce;
return { k, n };
}
private hasKey(cs: CipherState): boolean {
return !this.isEmptyKey(cs.k);
}
@ -121,12 +85,6 @@ export class XXHandshake {
return { cs, ck, h };
}
private mixKey(ss: SymmetricState, ikm: bytes32): void {
const [ ck, tempK ] = this.getHkdf(ss.ck, ikm);
ss.cs = this.initializeKey(tempK) as CipherState;
ss.ck = ck;
}
private hashProtocolName(protocolName: bytes): bytes32 {
if (protocolName.length <= 32) {
const h = Buffer.alloc(32);
@ -137,26 +95,6 @@ export class XXHandshake {
}
}
public getHkdf(ck: bytes32, ikm: bytes): Hkdf {
const info = Buffer.alloc(0);
const prk = HKDF.extract(SHA256, ikm, ck);
const okm = HKDF.expand(SHA256, prk, info, 96);
const k1 = okm.slice(0, 32);
const k2 = okm.slice(32, 64);
const k3 = okm.slice(64, 96);
return [ k1, k2, k3 ];
}
private mixHash(ss: SymmetricState, data: bytes) {
ss.h = this.getHash(ss.h, data);
}
private getHash(a: bytes, b: bytes): bytes32 {
return SHA256.digest(Buffer.from([...a, ...b]));
}
private encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes {
let ciphertext;
if (this.hasKey(ss.cs)) {
@ -182,7 +120,7 @@ export class XXHandshake {
}
private split (ss: SymmetricState) {
const [ tempk1, tempk2 ] = this.getHkdf(ss.ck, Buffer.alloc(0));
const [ tempk1, tempk2 ] = getHkdf(ss.ck, Buffer.alloc(0));
const cs1 = this.initializeKey(tempk1);
const cs2 = this.initializeKey(tempk2);

View File

@ -1,11 +1,12 @@
import { x25519, ed25519 } from 'bcrypto';
import { x25519, ed25519, HKDF, SHA256 } from 'bcrypto';
import protobuf from "protobufjs";
import { Buffer } from "buffer";
import PeerId from "peer-id";
import * as crypto from 'libp2p-crypto';
import { KeyPair } from "./@types/libp2p";
import { bytes } from "./@types/basic";
import {bytes, bytes32} from "./@types/basic";
import {Hkdf} from "./@types/handshake";
export async function loadPayloadProto () {
const payloadProtoBuf = await protobuf.load("protos/payload.proto");
@ -88,3 +89,15 @@ export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: byte
throw new Error("Static key doesn't match to peer that signed payload!");
}
}
export function getHkdf(ck: bytes32, ikm: bytes): Hkdf {
const info = Buffer.alloc(0);
const prk = HKDF.extract(SHA256, ikm, ck);
const okm = HKDF.expand(SHA256, prk, info, 96);
const k1 = okm.slice(0, 32);
const k2 = okm.slice(32, 64);
const k3 = okm.slice(64, 96);
return [ k1, k2, k3 ];
}