diff --git a/src/xx.ts b/src/xx.ts index 2709f25..6bfcd04 100644 --- a/src/xx.ts +++ b/src/xx.ts @@ -196,6 +196,18 @@ export class XXHandshake { return ciphertext; } + private async decryptAndHash(ss: SymmetricState, ciphertext: bytes) : Promise { + let plaintext; + if (this.hasKey(ss.cs)) { + plaintext = this.decryptWithAd(ss.cs, ss.h, ciphertext); + } else { + plaintext = ciphertext; + } + + await this.mixHash(ss, ciphertext); + return plaintext; + } + private split (ss: SymmetricState) { const [ tempk1, tempk2 ] = this.getHkdf(ss.ck, Buffer.alloc(0)); const cs1 = this.initializeKey(tempk1); @@ -246,6 +258,46 @@ export class XXHandshake { return { ne, ns, ciphertext }; } + private async readMessageA(hs: HandshakeState, message: MessageBuffer) : Promise { + // TODO: validate public key here + + await this.mixHash(hs.ss, hs.re); + return await this.decryptAndHash(hs.ss, message.ciphertext); + } + + private async readMessageB(hs: HandshakeState, message: MessageBuffer) : Promise { + // TODO: validate public key here + + await this.mixHash(hs.ss, hs.re); + if (!hs.e) { + throw new Error("Handshake state `e` param is missing."); + } + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); + const ns = await this.decryptAndHash(hs.ss, message.ns); + // TODO: validate ns here as public key + hs.rs = ns; + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); + return await this.decryptAndHash(hs.ss, message.ciphertext); + } + + private async readMessageC(hs: HandshakeState, message: MessageBuffer) { + const ns = await this.decryptAndHash(hs.ss, message.ns); + // TODO: validate ns here as public key + hs.rs = ns; + if (!hs.e) { + throw new Error("Handshake state `e` param is missing."); + } + this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); + const plaintext = await this.decryptAndHash(hs.ss, message.ciphertext); + const { cs1, cs2 } = this.split(hs.ss); + + return { h: hs.ss.h, plaintext, cs1, cs2 }; + } + + private readMessageRegular(cs: CipherState, message: MessageBuffer) : bytes { + return this.decryptWithAd(cs, Buffer.alloc(0), message.ciphertext); + } + public async generateKeypair() : Promise { return await crypto.keys.generateKeyPair('ed25519'); } @@ -274,7 +326,8 @@ export class XXHandshake { } else if (session.mc === 1) { messageBuffer = await this.writeMessageB(session.hs, message); } else if (session.mc === 2) { - const { h, messageBuffer, cs1, cs2 } = await this.writeMessageC(session.hs, message); + const { h, messageBuffer: resultingBuffer, cs1, cs2 } = await this.writeMessageC(session.hs, message); + messageBuffer = resultingBuffer; session.h = h; session.cs1 = cs1; session.cs2 = cs2; @@ -299,4 +352,36 @@ export class XXHandshake { session.mc++; return messageBuffer; } + + public async RecvMessage(session: NoiseSession, message: MessageBuffer) : Promise { + let plaintext: bytes; + if (session.mc === 0) { + plaintext = await this.readMessageA(session.hs, message); + } else if (session.mc === 1) { + plaintext = await this.readMessageB(session.hs, message); + } else if (session.mc === 2) { + const { h, plaintext: resultingPlaintext, cs1, cs2 } = await this.readMessageC(session.hs, message); + plaintext = resultingPlaintext; + session.h = h; + session.cs1 = cs1; + session.cs2 = cs2; + } else if (session.mc > 2) { + if (session.i) { + if (!session.cs2) { + throw new Error("CS1 (cipher state) is not defined") + } + plaintext = await this.readMessageRegular(session.cs2, message); + } else { + if (!session.cs1) { + throw new Error("CS1 (cipher state) is not defined") + } + plaintext = await this.readMessageRegular(session.cs1, message); + } + } else { + throw new Error("Session invalid."); + } + + session.mc++; + return plaintext; + } }