validate aead decryption

This commit is contained in:
Marin Petrunić
2020-03-01 19:05:53 +01:00
parent b93a50c8b0
commit 638b0773e5
14 changed files with 115 additions and 115 deletions

View File

@ -6,5 +6,5 @@ export interface IHandshake {
session: NoiseSession;
remotePeer: PeerId;
encrypt(plaintext: bytes, session: NoiseSession): bytes;
decrypt(ciphertext: bytes, session: NoiseSession): bytes;
decrypt(ciphertext: bytes, session: NoiseSession): {plaintext: bytes; valid: boolean};
}

View File

@ -39,7 +39,10 @@ export function decryptStream(handshake: IHandshake): IReturnEncryptionWrapper {
}
const chunk = chunkBuffer.slice(i, end);
const decrypted = await handshake.decrypt(chunk, handshake.session);
const {plaintext: decrypted, valid} = await handshake.decrypt(chunk, handshake.session);
if(!valid) {
throw new Error("Failed to validate decrypted chunk");
}
yield decrypted;
}
}

View File

@ -55,7 +55,10 @@ export class IKHandshake implements IHandshake {
const receivedMsg = await this.connection.readLP();
try {
const receivedMessageBuffer = decode1(receivedMsg.slice());
const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer);
const {plaintext, valid} = this.ik.recvMessage(this.session, receivedMessageBuffer);
if(!valid) {
throw new Error("ik handshake stage 0 decryption validation fail");
}
logger("IK Stage 0 - Responder got message, going to verify payload.");
const decodedPayload = await decodePayload(plaintext);
this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload);
@ -74,10 +77,12 @@ export class IKHandshake implements IHandshake {
logger("IK Stage 1 - Initiator receiving message...");
const receivedMsg = (await this.connection.readLP()).slice();
const receivedMessageBuffer = decode0(Buffer.from(receivedMsg));
const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer);
const {plaintext, valid} = this.ik.recvMessage(this.session, receivedMessageBuffer);
logger("IK Stage 1 - Initiator got message, going to verify payload.");
try {
if(!valid) {
throw new Error("ik stage 1 decryption validation fail");
}
const decodedPayload = await decodePayload(plaintext);
this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload);
await verifySignedPayload(receivedMessageBuffer.ns.slice(0, 32), decodedPayload, this.remotePeer);
@ -94,7 +99,7 @@ export class IKHandshake implements IHandshake {
}
}
public decrypt(ciphertext: Buffer, session: NoiseSession): Buffer {
public decrypt(ciphertext: bytes, session: NoiseSession): {plaintext: bytes, valid: boolean} {
const cs = this.getCS(session, false);
return this.ik.decryptWithAd(cs, Buffer.alloc(0), ciphertext);
}

View File

@ -37,13 +37,16 @@ export class XXFallbackHandshake extends XXHandshake {
this.xx.sendMessage(this.session, Buffer.alloc(0), this.ephemeralKeys);
logger("XX Fallback Stage 0 - Initialized state as the first message was sent by initiator.");
} else {
logger("XX Fallback Stage 0 - Responder decoding initial msg from IK.")
logger("XX Fallback Stage 0 - Responder decoding initial msg from IK.");
const receivedMessageBuffer = decode0(this.initialMsg);
this.xx.recvMessage(this.session, {
const {valid} = this.xx.recvMessage(this.session, {
ne: receivedMessageBuffer.ne,
ns: Buffer.alloc(0),
ciphertext: Buffer.alloc(0),
});
if(!valid) {
throw new Error("xx fallback stage 0 decryption validation fail");
}
logger("XX Fallback Stage 0 - Responder used received message from IK.");
}
}
@ -52,7 +55,10 @@ export class XXFallbackHandshake extends XXHandshake {
public async exchange(): Promise<void> {
if (this.isInitiator) {
const receivedMessageBuffer = decode1(this.initialMsg);
const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer);
const {plaintext, valid} = this.xx.recvMessage(this.session, receivedMessageBuffer);
if(!valid) {
throw new Error("xx fallback stage 1 decryption validation fail");
}
logger('XX Fallback Stage 1 - Initiator used received message from IK.');
logger("Initiator going to check remote's signature...");

View File

@ -58,7 +58,10 @@ export class XXHandshake implements IHandshake {
} else {
logger("Stage 0 - Responder waiting to receive first message...");
const receivedMessageBuffer = decode0((await this.connection.readLP()).slice());
this.xx.recvMessage(this.session, receivedMessageBuffer);
const {valid} = this.xx.recvMessage(this.session, receivedMessageBuffer);
if(!valid) {
throw new Error("xx handshake stage 0 validation fail");
}
logger("Stage 0 - Responder received first message.");
}
}
@ -68,8 +71,11 @@ export class XXHandshake implements IHandshake {
if (this.isInitiator) {
logger('Stage 1 - Initiator waiting to receive first message from responder...');
const receivedMessageBuffer = decode1((await this.connection.readLP()).slice());
const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer);
logger('Stage 1 - Initiator received the message. Got remote\'s static key.');
const {plaintext, valid} = this.xx.recvMessage(this.session, receivedMessageBuffer);
if(!valid) {
throw new Error("xx handshake stage 1 validation fail");
}
logger('Stage 1 - Initiator received the message.');
logger("Initiator going to check remote's signature...");
try {
@ -98,8 +104,11 @@ export class XXHandshake implements IHandshake {
} else {
logger('Stage 2 - Responder waiting for third handshake message...');
const receivedMessageBuffer = decode1((await this.connection.readLP()).slice());
const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer);
logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.');
const {plaintext, valid} = this.xx.recvMessage(this.session, receivedMessageBuffer);
if(!valid) {
throw new Error("xx handshake stage 2 validation fail");
}
logger('Stage 2 - Responder received the message, finished handshake.');
try {
const decodedPayload = await decodePayload(plaintext);
@ -117,7 +126,7 @@ export class XXHandshake implements IHandshake {
return this.xx.encryptWithAd(cs, Buffer.alloc(0), plaintext);
}
public decrypt(ciphertext: bytes, session: NoiseSession): bytes {
public decrypt(ciphertext: bytes, session: NoiseSession): {plaintext: bytes; valid: boolean} {
const cs = this.getCS(session, false);
return this.xx.decryptWithAd(cs, Buffer.alloc(0), ciphertext);
}

View File

@ -4,6 +4,7 @@ import { AEAD, x25519, SHA256 } from 'bcrypto';
import {bytes, bytes32, uint32} from "../@types/basic";
import {CipherState, MessageBuffer, SymmetricState} from "../@types/handshake";
import {getHkdf} from "../utils";
import {logger} from "../logger";
export const MIN_NONCE = 0;
@ -15,11 +16,11 @@ export abstract class AbstractHandshake {
return e;
}
public decryptWithAd(cs: CipherState, ad: bytes, ciphertext: bytes): bytes {
const plaintext = this.decrypt(cs.k, cs.n, ad, ciphertext);
public decryptWithAd(cs: CipherState, ad: bytes, ciphertext: bytes): {plaintext: bytes; valid: boolean} {
const {plaintext, valid} = this.decrypt(cs.k, cs.n, ad, ciphertext);
this.setNonce(cs, this.incrementNonce(cs.n));
return plaintext;
return {plaintext, valid};
}
@ -76,36 +77,41 @@ export abstract class AbstractHandshake {
return ciphertext;
}
protected decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes {
protected decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): {plaintext: bytes; valid: boolean} {
const nonce = this.nonceToBytes(n);
const ctx = new AEAD();
ciphertext = Buffer.from(ciphertext);
const tag = ciphertext.slice(ciphertext.length - 16);
ciphertext = ciphertext.slice(0, ciphertext.length - 16);
ctx.init(k, nonce);
ctx.aad(ad);
ctx.decrypt(ciphertext);
// Decryption is done on the sent reference
return ciphertext;
return {plaintext: ciphertext, valid: ctx.verify(tag)};
}
protected decryptAndHash(ss: SymmetricState, ciphertext: bytes): bytes {
let plaintext;
protected decryptAndHash(ss: SymmetricState, ciphertext: bytes): {plaintext: bytes; valid: boolean} {
let plaintext: bytes, valid = true;
if (this.hasKey(ss.cs)) {
plaintext = this.decryptWithAd(ss.cs, ss.h, ciphertext);
({plaintext, valid} = this.decryptWithAd(ss.cs, ss.h, ciphertext));
} else {
plaintext = ciphertext;
}
this.mixHash(ss, ciphertext);
return plaintext;
return {plaintext, valid};
}
protected dh(privateKey: bytes32, publicKey: bytes32): bytes32 {
const derived = x25519.derive(publicKey, privateKey);
const result = Buffer.alloc(32);
derived.copy(result);
return result;
try {
const derived = x25519.derive(publicKey, privateKey);
const result = Buffer.alloc(32);
derived.copy(result);
return result;
} catch (e) {
logger(e.message);
return Buffer.alloc(32);
}
}
protected mixHash(ss: SymmetricState, data: bytes): void {
@ -166,7 +172,7 @@ export abstract class AbstractHandshake {
return { ne, ns, ciphertext };
}
protected readMessageRegular(cs: CipherState, message: MessageBuffer): bytes {
protected readMessageRegular(cs: CipherState, message: MessageBuffer): {plaintext: bytes; valid: boolean} {
return this.decryptWithAd(cs, Buffer.alloc(0), message.ciphertext);
}
}

View File

@ -1,7 +1,7 @@
import {Buffer} from "buffer";
import {BN} from "bn.js";
import {HandshakeState, MessageBuffer, NoiseSession} from "../@types/handshake";
import {CipherState, HandshakeState, MessageBuffer, NoiseSession} from "../@types/handshake";
import {bytes, bytes32} from "../@types/basic";
import {generateKeypair, isValidPublicKey} from "../utils";
import {AbstractHandshake} from "./abstract-handshake";
@ -58,34 +58,21 @@ export class IK extends AbstractHandshake {
return messageBuffer;
}
public recvMessage(session: NoiseSession, message: MessageBuffer): bytes {
let plaintext: bytes;
public recvMessage(session: NoiseSession, message: MessageBuffer): {plaintext: bytes; valid: boolean} {
let plaintext = Buffer.alloc(0), valid = false;
if (session.mc.eqn(0)) {
plaintext = this.readMessageA(session.hs, message);
} else if (session.mc.eqn(1)) {
const { plaintext: pt, h, cs1, cs2 } = this.readMessageB(session.hs, message);
({plaintext, valid} = this.readMessageA(session.hs, message));
}
if (session.mc.eqn(1)) {
const { plaintext: pt, valid: v, h, cs1, cs2 } = this.readMessageB(session.hs, message);
plaintext = pt;
valid = v;
session.h = h;
session.cs1 = cs1;
session.cs2 = cs2;
} else if (session.mc.gtn(1)) {
if (session.i) {
if (!session.cs2) {
throw new Error("CS1 (cipher state) is not defined")
}
plaintext = this.readMessageRegular(session.cs2, message);
} else {
if (!session.cs1) {
throw new Error("CS1 (cipher state) is not defined")
}
plaintext = this.readMessageRegular(session.cs1, message);
}
} else {
throw new Error("Session invalid.");
}
session.mc = session.mc.add(new BN(1));
return plaintext;
return {plaintext, valid};
}
private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer {
@ -117,22 +104,23 @@ export class IK extends AbstractHandshake {
return { messageBuffer, cs1, cs2, h: hs.ss.h }
}
private readMessageA(hs: HandshakeState, message: MessageBuffer): bytes {
private readMessageA(hs: HandshakeState, message: MessageBuffer): {plaintext: bytes; valid: boolean} {
if (isValidPublicKey(message.ne)) {
hs.re = message.ne;
}
this.mixHash(hs.ss, hs.re);
this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re));
const ns = this.decryptAndHash(hs.ss, message.ns);
if (ns.length === 32 && isValidPublicKey(ns)) {
const {plaintext: ns, valid: valid1} = this.decryptAndHash(hs.ss, message.ns);
if (valid1 && ns.length === 32 && isValidPublicKey(ns)) {
hs.rs = ns;
}
this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.rs));
return this.decryptAndHash(hs.ss, message.ciphertext);
const {plaintext, valid: valid2} = this.decryptAndHash(hs.ss, message.ciphertext);
return {plaintext, valid: (valid1 && valid2)};
}
private readMessageB(hs: HandshakeState, message: MessageBuffer) {
private readMessageB(hs: HandshakeState, message: MessageBuffer): {h: bytes; plaintext: bytes; valid: boolean; cs1: CipherState; cs2: CipherState} {
if (isValidPublicKey(message.ne)) {
hs.re = message.ne;
}
@ -143,10 +131,10 @@ export class IK extends AbstractHandshake {
}
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re));
this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re));
const plaintext = this.decryptAndHash(hs.ss, message.ciphertext);
const {plaintext, valid} = this.decryptAndHash(hs.ss, message.ciphertext);
const { cs1, cs2 } = this.split(hs.ss);
return { h: hs.ss.h, plaintext, cs1, cs2 };
return { h: hs.ss.h, valid, plaintext, cs1, cs2 };
}
private initializeInitiator(prologue: bytes32, s: KeyPair, rs: bytes32, psk: bytes32): HandshakeState {

View File

@ -4,7 +4,7 @@ import { BN } from 'bn.js';
import { bytes32, bytes } from '../@types/basic'
import { KeyPair } from '../@types/libp2p'
import {generateKeypair, isValidPublicKey} from '../utils';
import { HandshakeState, MessageBuffer, NoiseSession } from "../@types/handshake";
import {CipherState, HandshakeState, MessageBuffer, NoiseSession} from "../@types/handshake";
import {AbstractHandshake} from "./abstract-handshake";
@ -71,7 +71,7 @@ export class XX extends AbstractHandshake {
return { h: hs.ss.h, messageBuffer, cs1, cs2 };
}
private readMessageA(hs: HandshakeState, message: MessageBuffer): bytes {
private readMessageA(hs: HandshakeState, message: MessageBuffer): {plaintext: bytes; valid: boolean} {
if (isValidPublicKey(message.ne)) {
hs.re = message.ne;
}
@ -80,7 +80,7 @@ export class XX extends AbstractHandshake {
return this.decryptAndHash(hs.ss, message.ciphertext);
}
private readMessageB(hs: HandshakeState, message: MessageBuffer): bytes {
private readMessageB(hs: HandshakeState, message: MessageBuffer): {plaintext: bytes; valid: boolean} {
if (isValidPublicKey(message.ne)) {
hs.re = message.ne;
}
@ -90,29 +90,29 @@ export class XX extends AbstractHandshake {
throw new Error("Handshake state `e` param is missing.");
}
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re));
const ns = this.decryptAndHash(hs.ss, message.ns);
if (ns.length === 32 && isValidPublicKey(ns)) {
const {plaintext: ns, valid: valid1} = this.decryptAndHash(hs.ss, message.ns);
if (valid1 && ns.length === 32 && isValidPublicKey(ns)) {
hs.rs = ns;
}
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs));
return this.decryptAndHash(hs.ss, message.ciphertext);
const {plaintext, valid: valid2} = this.decryptAndHash(hs.ss, message.ciphertext);
return {plaintext, valid: (valid1 && valid2)};
}
private readMessageC(hs: HandshakeState, message: MessageBuffer) {
const ns = this.decryptAndHash(hs.ss, message.ns);
if (ns.length === 32 && isValidPublicKey(ns)) {
private readMessageC(hs: HandshakeState, message: MessageBuffer): {h: bytes; plaintext: bytes; valid: boolean; cs1: CipherState; cs2: CipherState} {
const {plaintext: ns, valid: valid1} = this.decryptAndHash(hs.ss, message.ns);
if (valid1 && ns.length === 32 && isValidPublicKey(ns)) {
hs.rs = ns;
}
if (!hs.e) {
throw new Error("Handshake state `e` param is missing.");
}
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs));
const plaintext = this.decryptAndHash(hs.ss, message.ciphertext);
const {plaintext, valid: valid2} = this.decryptAndHash(hs.ss, message.ciphertext);
const { cs1, cs2 } = this.split(hs.ss);
return { h: hs.ss.h, plaintext, cs1, cs2 };
return { h: hs.ss.h, plaintext, valid: (valid1 && valid2), cs1, cs2 };
}
public initSession(initiator: boolean, prologue: bytes32, s: KeyPair): NoiseSession {
@ -167,35 +167,22 @@ export class XX extends AbstractHandshake {
return messageBuffer;
}
public recvMessage(session: NoiseSession, message: MessageBuffer): bytes {
let plaintext: bytes;
public recvMessage(session: NoiseSession, message: MessageBuffer): {plaintext: bytes; valid: boolean} {
let plaintext: bytes = Buffer.alloc(0);
let valid = false;
if (session.mc.eqn(0)) {
plaintext = this.readMessageA(session.hs, message);
({plaintext, valid} = this.readMessageA(session.hs, message));
} else if (session.mc.eqn(1)) {
plaintext = this.readMessageB(session.hs, message);
({plaintext, valid} = this.readMessageB(session.hs, message));
} else if (session.mc.eqn(2)) {
const { h, plaintext: resultingPlaintext, cs1, cs2 } = this.readMessageC(session.hs, message);
const { h, plaintext: resultingPlaintext, valid: resultingValid, cs1, cs2 } = this.readMessageC(session.hs, message);
plaintext = resultingPlaintext;
valid = resultingValid;
session.h = h;
session.cs1 = cs1;
session.cs2 = cs2;
} else if (session.mc.gtn(2)) {
if (session.i) {
if (!session.cs2) {
throw new Error("CS1 (cipher state) is not defined")
}
plaintext = this.readMessageRegular(session.cs2, message);
} else {
if (!session.cs1) {
throw new Error("CS1 (cipher state) is not defined")
}
plaintext = this.readMessageRegular(session.cs1, message);
}
} else {
throw new Error("Session invalid.");
}
session.mc = session.mc.add(new BN(1));
return plaintext;
return {plaintext, valid};
}
}