diff --git a/src/constants.ts b/src/constants.ts index aedb1db..d17e594 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -1,2 +1,3 @@ export const NOISE_MSG_MAX_LENGTH_BYTES = 65535; +export const NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG = NOISE_MSG_MAX_LENGTH_BYTES - 16; diff --git a/src/crypto.ts b/src/crypto.ts index 83508ec..059c8f8 100644 --- a/src/crypto.ts +++ b/src/crypto.ts @@ -1,20 +1,19 @@ import { Buffer } from "buffer"; import {IHandshake} from "./@types/handshake-interface"; +import {NOISE_MSG_MAX_LENGTH_BYTES, NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG} from "./constants"; interface IReturnEncryptionWrapper { (source: Iterable): AsyncIterableIterator; } -const maxPlaintextLength = 65519; - // Returns generator that encrypts payload from the user export function encryptStream(handshake: IHandshake): IReturnEncryptionWrapper { return async function * (source) { for await (const chunk of source) { const chunkBuffer = Buffer.from(chunk.buffer, chunk.byteOffset, chunk.length); - for (let i = 0; i < chunkBuffer.length; i += maxPlaintextLength) { - let end = i + maxPlaintextLength; + for (let i = 0; i < chunkBuffer.length; i += NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG) { + let end = i + NOISE_MSG_MAX_LENGTH_BYTES_WITHOUT_TAG; if (end > chunkBuffer.length) { end = chunkBuffer.length; } @@ -33,8 +32,8 @@ export function decryptStream(handshake: IHandshake): IReturnEncryptionWrapper { for await (const chunk of source) { const chunkBuffer = Buffer.from(chunk.buffer, chunk.byteOffset, chunk.length); - for (let i = 0; i < chunkBuffer.length; i += maxPlaintextLength) { - let end = i + maxPlaintextLength; + for (let i = 0; i < chunkBuffer.length; i += NOISE_MSG_MAX_LENGTH_BYTES) { + let end = i + NOISE_MSG_MAX_LENGTH_BYTES; if (end > chunkBuffer.length) { end = chunkBuffer.length; } diff --git a/src/encoder.ts b/src/encoder.ts index 477dbb8..60eb1c6 100644 --- a/src/encoder.ts +++ b/src/encoder.ts @@ -38,13 +38,13 @@ export function decode0(input: bytes): MessageBuffer { } export function decode1(input: bytes): MessageBuffer { - if (input.length < 96) { + if (input.length < 80) { throw new Error("Cannot decode stage 0 MessageBuffer: length less than 96 bytes."); } return { ne: input.slice(0, 32), - ns: input.slice(32, 64), - ciphertext: input.slice(64, input.length), + ns: input.slice(32, 80), + ciphertext: input.slice(80, input.length), } } diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index 73b6c9b..1266346 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -59,7 +59,7 @@ export class IKHandshake implements IHandshake { logger("IK Stage 0 - Responder got message, going to verify payload."); const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); - await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); + await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); logger("IK Stage 0 - Responder successfully verified payload!"); } catch (e) { logger("Responder breaking up with IK handshake in stage 0."); @@ -80,7 +80,7 @@ export class IKHandshake implements IHandshake { try { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); - await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); + await verifySignedPayload(receivedMessageBuffer.ns.slice(0, 32), decodedPayload, this.remotePeer); logger("IK Stage 1 - Initiator successfully verified payload!"); } catch (e) { logger("Initiator breaking up with IK handshake in stage 1."); diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index d98d27c..fb1d837 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -59,7 +59,7 @@ export class XXFallbackHandshake extends XXHandshake { try { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); - await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); + await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); } catch (e) { throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index fed9e4c..153950b 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -104,7 +104,7 @@ export class XXHandshake implements IHandshake { try { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); - await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); + await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } diff --git a/src/handshakes/abstract-handshake.ts b/src/handshakes/abstract-handshake.ts index 486dd96..25a630b 100644 --- a/src/handshakes/abstract-handshake.ts +++ b/src/handshakes/abstract-handshake.ts @@ -55,13 +55,13 @@ export abstract class AbstractHandshake { protected encrypt(k: bytes32, n: uint32, ad: bytes, plaintext: bytes): bytes { const nonce = this.nonceToBytes(n); const ctx = new AEAD(); - + plaintext = Buffer.from(plaintext); ctx.init(k, nonce); ctx.aad(ad); ctx.encrypt(plaintext); // Encryption is done on the sent reference - return plaintext; + return Buffer.concat([plaintext, ctx.final()]); } protected encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes { @@ -79,7 +79,8 @@ export abstract class AbstractHandshake { protected decrypt(k: bytes32, n: uint32, ad: bytes, ciphertext: bytes): bytes { const nonce = this.nonceToBytes(n); const ctx = new AEAD(); - + ciphertext = Buffer.from(ciphertext); + ciphertext = ciphertext.slice(0, ciphertext.length - 16); ctx.init(k, nonce); ctx.aad(ad); ctx.decrypt(ciphertext); diff --git a/src/handshakes/ik.ts b/src/handshakes/ik.ts index 793aea3..9ad8c85 100644 --- a/src/handshakes/ik.ts +++ b/src/handshakes/ik.ts @@ -125,7 +125,7 @@ export class IK extends AbstractHandshake { this.mixHash(hs.ss, hs.re); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); const ns = this.decryptAndHash(hs.ss, message.ns); - if (ns.length === 32 && isValidPublicKey(message.ns)) { + if (ns.length === 32 && isValidPublicKey(ns)) { hs.rs = ns; } this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.rs)); diff --git a/src/handshakes/xx.ts b/src/handshakes/xx.ts index 6362bab..7c85b2a 100644 --- a/src/handshakes/xx.ts +++ b/src/handshakes/xx.ts @@ -91,7 +91,7 @@ export class XX extends AbstractHandshake { } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); const ns = this.decryptAndHash(hs.ss, message.ns); - if (ns.length === 32 && isValidPublicKey(message.ns)) { + if (ns.length === 32 && isValidPublicKey(ns)) { hs.rs = ns; } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); @@ -100,7 +100,7 @@ export class XX extends AbstractHandshake { private readMessageC(hs: HandshakeState, message: MessageBuffer) { const ns = this.decryptAndHash(hs.ss, message.ns); - if (ns.length === 32 && isValidPublicKey(message.ns)) { + if (ns.length === 32 && isValidPublicKey(ns)) { hs.rs = ns; } diff --git a/src/noise.ts b/src/noise.ts index 5a9c13c..5d25788 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -230,7 +230,7 @@ export class Noise implements INoiseConnection { encryptStream(handshake), // data is encrypted encode({ lengthEncoder: uint16BEEncode }), // prefix with message length network, // send to the remote peer - decode({ lengthDecoder: uint16BEDecode, maxDataLength: NOISE_MSG_MAX_LENGTH_BYTES }), // read message length prefix + decode({ lengthDecoder: uint16BEDecode}), // read message length prefix ensureBuffer, // ensure any type of data is converted to buffer decryptStream(handshake), // decrypt the incoming data secure // pipe to the wrapper diff --git a/src/utils.ts b/src/utils.ts index 9b8eb36..6bd7bdd 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -114,5 +114,5 @@ export function getHkdf(ck: bytes32, ikm: bytes): Hkdf { } export function isValidPublicKey(pk: bytes): boolean { - return x25519.publicKeyVerify(pk); + return x25519.publicKeyVerify(pk.slice(0, 32)); } diff --git a/test/handshakes/xx.test.ts b/test/handshakes/xx.test.ts index cee897a..bb05e77 100644 --- a/test/handshakes/xx.test.ts +++ b/test/handshakes/xx.test.ts @@ -114,9 +114,9 @@ describe("XX Handshake", () => { const ad = Buffer.from("authenticated"); const message = Buffer.from("HelloCrypto"); - xx.encryptWithAd(nsInit.cs1, ad, message); - assert(!Buffer.from("HelloCrypto").equals(message), "Encrypted message should not be same as plaintext."); - const decrypted = xx.decryptWithAd(nsResp.cs1, ad, message); + const ciphertext = xx.encryptWithAd(nsInit.cs1, ad, message); + assert(!Buffer.from("HelloCrypto").equals(ciphertext), "Encrypted message should not be same as plaintext."); + const decrypted = xx.decryptWithAd(nsResp.cs1, ad, ciphertext); assert(Buffer.from("HelloCrypto").equals(decrypted), "Decrypted text not equal to original message."); } catch (e) { @@ -125,22 +125,18 @@ describe("XX Handshake", () => { }); it("Test multiple messages encryption and decryption", async () => { - try { - const xx = new XX(); - const { nsInit, nsResp } = await doHandshake(xx); - const ad = Buffer.from("authenticated"); - const message = Buffer.from("ethereum1"); + const xx = new XX(); + const { nsInit, nsResp } = await doHandshake(xx); + const ad = Buffer.from("authenticated"); + const message = Buffer.from("ethereum1"); - xx.encryptWithAd(nsInit.cs1, ad, message); - const decrypted = xx.decryptWithAd(nsResp.cs1, ad, message); - assert(Buffer.from("ethereum1").equals(decrypted), "Decrypted text not equal to original message."); + const encrypted = xx.encryptWithAd(nsInit.cs1, ad, message); + const decrypted = xx.decryptWithAd(nsResp.cs1, ad, encrypted); + assert.equal("ethereum1", decrypted.toString("utf8"), "Decrypted text not equal to original message."); - const message2 = Buffer.from("ethereum2"); - xx.encryptWithAd(nsInit.cs1, ad, message2); - const decrypted2 = xx.decryptWithAd(nsResp.cs1, ad, message2); - assert(Buffer.from("ethereum2").equals(decrypted2), "Decrypted text not equal to original message."); - } catch (e) { - assert(false, e.message); - } + const message2 = Buffer.from("ethereum2"); + const encrypted2 = xx.encryptWithAd(nsInit.cs1, ad, message2); + const decrypted2 = xx.decryptWithAd(nsResp.cs1, ad, encrypted2); + assert.equal("ethereum2", decrypted2.toString("utf-8"), "Decrypted text not equal to original message."); }); }); diff --git a/test/noise.test.ts b/test/noise.test.ts index fa63b17..cc40a6c 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -113,7 +113,8 @@ describe("Noise", () => { }) - it("should test large payloads", async() => { + it("should test large payloads", async function() { + this.timeout(10000); try { const noiseInit = new Noise(undefined, undefined, false); const noiseResp = new Noise(undefined, undefined, false); @@ -128,11 +129,11 @@ describe("Noise", () => { const largePlaintext = random.randomBytes(100000); wrappedOutbound.writeLP(largePlaintext); - const response = await wrappedInbound.readLP(); + const response = await wrappedInbound.read(100000); expect(response.length).equals(largePlaintext.length); } catch (e) { - console.error(e); + console.log(e); assert(false, e.message); } });