Merge pull request #37 from NodeFactoryIo/mpetrunic/fix-aead

handle aead auth tag
This commit is contained in:
Marin Petrunić 2020-02-20 22:19:37 +01:00 committed by GitHub
commit 6e0d1a84ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 40 additions and 42 deletions

View File

@ -1,2 +1,3 @@
export const NOISE_MSG_MAX_LENGTH_BYTES = 65535;
export const NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG = NOISE_MSG_MAX_LENGTH_BYTES - 16;

View File

@ -1,20 +1,19 @@
import { Buffer } from "buffer";
import {IHandshake} from "./@types/handshake-interface";
import {NOISE_MSG_MAX_LENGTH_BYTES, NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG} from "./constants";
interface IReturnEncryptionWrapper {
(source: Iterable<Uint8Array>): AsyncIterableIterator<Uint8Array>;
}
const maxPlaintextLength = 65519;
// Returns generator that encrypts payload from the user
export function encryptStream(handshake: IHandshake): IReturnEncryptionWrapper {
return async function * (source) {
for await (const chunk of source) {
const chunkBuffer = Buffer.from(chunk.buffer, chunk.byteOffset, chunk.length);
for (let i = 0; i < chunkBuffer.length; i += maxPlaintextLength) {
let end = i + maxPlaintextLength;
for (let i = 0; i < chunkBuffer.length; i += NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG) {
let end = i + NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG;
if (end > chunkBuffer.length) {
end = chunkBuffer.length;
}
@ -33,8 +32,8 @@ export function decryptStream(handshake: IHandshake): IReturnEncryptionWrapper {
for await (const chunk of source) {
const chunkBuffer = Buffer.from(chunk.buffer, chunk.byteOffset, chunk.length);
for (let i = 0; i < chunkBuffer.length; i += maxPlaintextLength) {
let end = i + maxPlaintextLength;
for (let i = 0; i < chunkBuffer.length; i += NOISE_MSG_MAX_LENGTH_BYTES) {
let end = i + NOISE_MSG_MAX_LENGTH_BYTES;
if (end > chunkBuffer.length) {
end = chunkBuffer.length;
}

View File

@ -38,13 +38,13 @@ export function decode0(input: bytes): MessageBuffer {
}
export function decode1(input: bytes): MessageBuffer {
if (input.length < 96) {
if (input.length < 80) {
throw new Error("Cannot decode stage 0 MessageBuffer: length less than 96 bytes.");
}
return {
ne: input.slice(0, 32),
ns: input.slice(32, 64),
ciphertext: input.slice(64, input.length),
ns: input.slice(32, 80),
ciphertext: input.slice(80, input.length),
}
}

View File

@ -59,7 +59,7 @@ export class IKHandshake implements IHandshake {
logger("IK Stage 0 - Responder got message, going to verify payload.");
const decodedPayload = await decodePayload(plaintext);
this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload);
await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer);
await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer);
logger("IK Stage 0 - Responder successfully verified payload!");
} catch (e) {
logger("Responder breaking up with IK handshake in stage 0.");
@ -80,7 +80,7 @@ export class IKHandshake implements IHandshake {
try {
const decodedPayload = await decodePayload(plaintext);
this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload);
await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer);
await verifySignedPayload(receivedMessageBuffer.ns.slice(0, 32), decodedPayload, this.remotePeer);
logger("IK Stage 1 - Initiator successfully verified payload!");
} catch (e) {
logger("Initiator breaking up with IK handshake in stage 1.");

View File

@ -59,7 +59,7 @@ export class XXFallbackHandshake extends XXHandshake {
try {
const decodedPayload = await decodePayload(plaintext);
this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload);
await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer);
await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer);
} catch (e) {
throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`);
}

View File

@ -104,7 +104,7 @@ export class XXHandshake implements IHandshake {
try {
const decodedPayload = await decodePayload(plaintext);
this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload);
await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer);
await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer);
} catch (e) {
throw new Error(`Error occurred while verifying signed payload: ${e.message}`);
}

View File

@ -55,13 +55,13 @@ export abstract class AbstractHandshake {
protected encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes {
const nonce = this.nonceToBytes(n);
const ctx = new AEAD();
plaintext = Buffer.from(plaintext);
ctx.init(k, nonce);
ctx.aad(ad);
ctx.encrypt(plaintext);
// Encryption is done on the sent reference
return plaintext;
return Buffer.concat([plaintext, ctx.final()]);
}
protected encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes {
@ -79,7 +79,8 @@ export abstract class AbstractHandshake {
protected decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes {
const nonce = this.nonceToBytes(n);
const ctx = new AEAD();
ciphertext = Buffer.from(ciphertext);
ciphertext = ciphertext.slice(0, ciphertext.length - 16);
ctx.init(k, nonce);
ctx.aad(ad);
ctx.decrypt(ciphertext);

View File

@ -125,7 +125,7 @@ export class IK extends AbstractHandshake {
this.mixHash(hs.ss, hs.re);
this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re));
const ns = this.decryptAndHash(hs.ss, message.ns);
if (ns.length === 32 && isValidPublicKey(message.ns)) {
if (ns.length === 32 && isValidPublicKey(ns)) {
hs.rs = ns;
}
this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.rs));

View File

@ -91,7 +91,7 @@ export class XX extends AbstractHandshake {
}
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re));
const ns = this.decryptAndHash(hs.ss, message.ns);
if (ns.length === 32 && isValidPublicKey(message.ns)) {
if (ns.length === 32 && isValidPublicKey(ns)) {
hs.rs = ns;
}
this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs));
@ -100,7 +100,7 @@ export class XX extends AbstractHandshake {
private readMessageC(hs: HandshakeState, message: MessageBuffer) {
const ns = this.decryptAndHash(hs.ss, message.ns);
if (ns.length === 32 && isValidPublicKey(message.ns)) {
if (ns.length === 32 && isValidPublicKey(ns)) {
hs.rs = ns;
}

View File

@ -230,7 +230,7 @@ export class Noise implements INoiseConnection {
encryptStream(handshake), // data is encrypted
encode({ lengthEncoder: uint16BEEncode }), // prefix with message length
network, // send to the remote peer
decode({ lengthDecoder: uint16BEDecode, maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES }), // read message length prefix
decode({ lengthDecoder: uint16BEDecode}), // read message length prefix
ensureBuffer, // ensure any type of data is converted to buffer
decryptStream(handshake), // decrypt the incoming data
secure // pipe to the wrapper

View File

@ -114,5 +114,5 @@ export function getHkdf(ck: bytes32, ikm: bytes): Hkdf {
}
export function isValidPublicKey(pk: bytes): boolean {
return x25519.publicKeyVerify(pk);
return x25519.publicKeyVerify(pk.slice(0, 32));
}

View File

@ -114,9 +114,9 @@ describe("XX Handshake", () => {
const ad = Buffer.from("authenticated");
const message = Buffer.from("HelloCrypto");
xx.encryptWithAd(nsInit.cs1, ad, message);
assert(!Buffer.from("HelloCrypto").equals(message), "Encrypted message should not be same as plaintext.");
const decrypted = xx.decryptWithAd(nsResp.cs1, ad, message);
const ciphertext = xx.encryptWithAd(nsInit.cs1, ad, message);
assert(!Buffer.from("HelloCrypto").equals(ciphertext), "Encrypted message should not be same as plaintext.");
const decrypted = xx.decryptWithAd(nsResp.cs1, ad, ciphertext);
assert(Buffer.from("HelloCrypto").equals(decrypted), "Decrypted text not equal to original message.");
} catch (e) {
@ -125,22 +125,18 @@ describe("XX Handshake", () => {
});
it("Test multiple messages encryption and decryption", async () => {
try {
const xx = new XX();
const { nsInit, nsResp } = await doHandshake(xx);
const ad = Buffer.from("authenticated");
const message = Buffer.from("ethereum1");
xx.encryptWithAd(nsInit.cs1, ad, message);
const decrypted = xx.decryptWithAd(nsResp.cs1, ad, message);
assert(Buffer.from("ethereum1").equals(decrypted), "Decrypted text not equal to original message.");
const encrypted = xx.encryptWithAd(nsInit.cs1, ad, message);
const decrypted = xx.decryptWithAd(nsResp.cs1, ad, encrypted);
assert.equal("ethereum1", decrypted.toString("utf8"), "Decrypted text not equal to original message.");
const message2 = Buffer.from("ethereum2");
xx.encryptWithAd(nsInit.cs1, ad, message2);
const decrypted2 = xx.decryptWithAd(nsResp.cs1, ad, message2);
assert(Buffer.from("ethereum2").equals(decrypted2), "Decrypted text not equal to original message.");
} catch (e) {
assert(false, e.message);
}
const encrypted2 = xx.encryptWithAd(nsInit.cs1, ad, message2);
const decrypted2 = xx.decryptWithAd(nsResp.cs1, ad, encrypted2);
assert.equal("ethereum2", decrypted2.toString("utf-8"), "Decrypted text not equal to original message.");
});
});

View File

@ -113,7 +113,8 @@ describe("Noise", () => {
})
it("should test large payloads", async() => {
it("should test large payloads", async function() {
this.timeout(10000);
try {
const noiseInit = new Noise(undefined, undefined, false);
const noiseResp = new Noise(undefined, undefined, false);
@ -128,11 +129,11 @@ describe("Noise", () => {
const largePlaintext = random.randomBytes(100000);
wrappedOutbound.writeLP(largePlaintext);
const response = await wrappedInbound.readLP();
const response = await wrappedInbound.read(100000);
expect(response.length).equals(largePlaintext.length);
} catch (e) {
console.error(e);
console.log(e);
assert(false, e.message);
}
});