diff --git a/src/handshake.ts b/src/handshake.ts index bd27605..255327e 100644 --- a/src/handshake.ts +++ b/src/handshake.ts @@ -53,13 +53,13 @@ export class Handshake { async propose(): Promise { if (this.isInitiator) { logger("Stage 0 - Initiator starting to send first message."); - const messageBuffer = await this.xx.sendMessage(this.session, Buffer.alloc(0)); + const messageBuffer = this.xx.sendMessage(this.session, Buffer.alloc(0)); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); logger("Stage 0 - Initiator finished sending first message."); } else { logger("Stage 0 - Responder waiting to receive first message..."); const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - await this.xx.recvMessage(this.session, receivedMessageBuffer); + this.xx.recvMessage(this.session, receivedMessageBuffer); logger("Stage 0 - Responder received first message."); } } @@ -69,7 +69,7 @@ export class Handshake { if (this.isInitiator) { logger('Stage 1 - Initiator waiting to receive first message from responder...'); const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); + const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 1 - Initiator received the message. Got remote\'s static key.'); logger("Initiator going to check remote's signature..."); @@ -90,7 +90,7 @@ export class Handshake { signedEarlyDataPayload, ); - const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); logger('Stage 1 - Responder sent the second handshake message with signed payload.') } @@ -108,13 +108,13 @@ export class Handshake { signedPayload, signedEarlyDataPayload ); - const messageBuffer = await this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); this.connection.writeLP(encodeMessageBuffer(messageBuffer)); logger('Stage 2 - Initiator sent message with signed payload.'); } else { logger('Stage 2 - Responder waiting for third handshake message...'); const receivedMessageBuffer = decodeMessageBuffer((await this.connection.readLP()).slice()); - const plaintext = await this.xx.recvMessage(this.session, receivedMessageBuffer); + const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('Stage 2 - Responder received the message, finished handshake. Got remote\'s static key.'); try { diff --git a/src/xx.ts b/src/xx.ts index 95aa619..d4b479a 100644 --- a/src/xx.ts +++ b/src/xx.ts @@ -191,7 +191,7 @@ export class XXHandshake { return SHA256.digest(Buffer.from([...a, ...b])); } - private async encryptAndHash(ss: SymmetricState, plaintext: bytes): Promise { + private encryptAndHash(ss: SymmetricState, plaintext: bytes): bytes { let ciphertext; if (this.hasKey(ss.cs)) { ciphertext = this.encryptWithAd(ss.cs, ss.h, plaintext); @@ -203,7 +203,7 @@ export class XXHandshake { return ciphertext; } - private async decryptAndHash(ss: SymmetricState, ciphertext: bytes): Promise { + private decryptAndHash(ss: SymmetricState, ciphertext: bytes): bytes { let plaintext; if (this.hasKey(ss.cs)) { plaintext = this.decryptWithAd(ss.cs, ss.h, ciphertext); @@ -223,38 +223,38 @@ export class XXHandshake { return { cs1, cs2 }; } - private async writeMessageA(hs: HandshakeState, payload: bytes): Promise { + private writeMessageA(hs: HandshakeState, payload: bytes): MessageBuffer { const ns = Buffer.alloc(0); hs.e = generateKeypair(); const ne = hs.e.publicKey; this.mixHash(hs.ss, ne); - const ciphertext = await this.encryptAndHash(hs.ss, payload); + const ciphertext = this.encryptAndHash(hs.ss, payload); return {ne, ns, ciphertext}; } - private async writeMessageB(hs: HandshakeState, payload: bytes): Promise { + private writeMessageB(hs: HandshakeState, payload: bytes): MessageBuffer { hs.e = generateKeypair(); const ne = hs.e.publicKey; this.mixHash(hs.ss, ne); this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); const spk = Buffer.from(hs.s.publicKey); - const ns = await this.encryptAndHash(hs.ss, spk); + const ns = this.encryptAndHash(hs.ss, spk); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); - const ciphertext = await this.encryptAndHash(hs.ss, payload); + const ciphertext = this.encryptAndHash(hs.ss, payload); return { ne, ns, ciphertext }; } - private async writeMessageC(hs: HandshakeState, payload: bytes) { + private writeMessageC(hs: HandshakeState, payload: bytes) { const spk = Buffer.from(hs.s.publicKey); - const ns = await this.encryptAndHash(hs.ss, spk); + const ns = this.encryptAndHash(hs.ss, spk); this.mixKey(hs.ss, this.dh(hs.s.privateKey, hs.re)); - const ciphertext = await this.encryptAndHash(hs.ss, payload); + const ciphertext = this.encryptAndHash(hs.ss, payload); const ne = this.createEmptyKey(); const messageBuffer: MessageBuffer = {ne, ns, ciphertext}; const { cs1, cs2 } = this.split(hs.ss); @@ -262,7 +262,7 @@ export class XXHandshake { return { h: hs.ss.h, messageBuffer, cs1, cs2 }; } - private async writeMessageRegular(cs: CipherState, payload: bytes): Promise { + private writeMessageRegular(cs: CipherState, payload: bytes): MessageBuffer { const ciphertext = this.encryptWithAd(cs, Buffer.alloc(0), payload); const ne = this.createEmptyKey(); const ns = Buffer.alloc(0); @@ -270,16 +270,16 @@ export class XXHandshake { return { ne, ns, ciphertext }; } - private async readMessageA(hs: HandshakeState, message: MessageBuffer): Promise { + private readMessageA(hs: HandshakeState, message: MessageBuffer): bytes { if (x25519.publicKeyVerify(message.ne)) { hs.re = message.ne; } this.mixHash(hs.ss, hs.re); - return await this.decryptAndHash(hs.ss, message.ciphertext); + return this.decryptAndHash(hs.ss, message.ciphertext); } - private async readMessageB(hs: HandshakeState, message: MessageBuffer): Promise { + private readMessageB(hs: HandshakeState, message: MessageBuffer): bytes { if (x25519.publicKeyVerify(message.ne)) { hs.re = message.ne; } @@ -289,16 +289,16 @@ export class XXHandshake { throw new Error("Handshake state `e` param is missing."); } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.re)); - const ns = await this.decryptAndHash(hs.ss, message.ns); + const ns = this.decryptAndHash(hs.ss, message.ns); if (ns.length === 32 && x25519.publicKeyVerify(message.ns)) { hs.rs = ns; } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); - return await this.decryptAndHash(hs.ss, message.ciphertext); + return this.decryptAndHash(hs.ss, message.ciphertext); } - private async readMessageC(hs: HandshakeState, message: MessageBuffer) { - const ns = await this.decryptAndHash(hs.ss, message.ns); + private readMessageC(hs: HandshakeState, message: MessageBuffer) { + const ns = this.decryptAndHash(hs.ss, message.ns); if (ns.length === 32 && x25519.publicKeyVerify(message.ns)) { hs.rs = ns; } @@ -308,7 +308,7 @@ export class XXHandshake { } this.mixKey(hs.ss, this.dh(hs.e.privateKey, hs.rs)); - const plaintext = await this.decryptAndHash(hs.ss, message.ciphertext); + const plaintext = this.decryptAndHash(hs.ss, message.ciphertext); const { cs1, cs2 } = this.split(hs.ss); return { h: hs.ss.h, plaintext, cs1, cs2 }; @@ -336,14 +336,14 @@ export class XXHandshake { }; } - public async sendMessage(session: NoiseSession, message: bytes): Promise { + public sendMessage(session: NoiseSession, message: bytes): MessageBuffer { let messageBuffer: MessageBuffer; if (session.mc.eqn(0)) { - messageBuffer = await this.writeMessageA(session.hs, message); + messageBuffer = this.writeMessageA(session.hs, message); } else if (session.mc.eqn(1)) { - messageBuffer = await this.writeMessageB(session.hs, message); + messageBuffer = this.writeMessageB(session.hs, message); } else if (session.mc.eqn(2)) { - const { h, messageBuffer: resultingBuffer, cs1, cs2 } = await this.writeMessageC(session.hs, message); + const { h, messageBuffer: resultingBuffer, cs1, cs2 } = this.writeMessageC(session.hs, message); messageBuffer = resultingBuffer; session.h = h; session.cs1 = cs1; @@ -354,13 +354,13 @@ export class XXHandshake { throw new Error("CS1 (cipher state) is not defined") } - messageBuffer = await this.writeMessageRegular(session.cs1, message); + messageBuffer = this.writeMessageRegular(session.cs1, message); } else { if (!session.cs2) { throw new Error("CS2 (cipher state) is not defined") } - messageBuffer = await this.writeMessageRegular(session.cs2, message); + messageBuffer = this.writeMessageRegular(session.cs2, message); } } else { throw new Error("Session invalid.") @@ -370,14 +370,14 @@ export class XXHandshake { return messageBuffer; } - public async recvMessage(session: NoiseSession, message: MessageBuffer): Promise { + public recvMessage(session: NoiseSession, message: MessageBuffer): bytes { let plaintext: bytes; if (session.mc.eqn(0)) { - plaintext = await this.readMessageA(session.hs, message); + plaintext = this.readMessageA(session.hs, message); } else if (session.mc.eqn(1)) { - plaintext = await this.readMessageB(session.hs, message); + plaintext = this.readMessageB(session.hs, message); } else if (session.mc.eqn(2)) { - const { h, plaintext: resultingPlaintext, cs1, cs2 } = await this.readMessageC(session.hs, message); + const { h, plaintext: resultingPlaintext, cs1, cs2 } = this.readMessageC(session.hs, message); plaintext = resultingPlaintext; session.h = h; session.cs1 = cs1; @@ -387,12 +387,12 @@ export class XXHandshake { if (!session.cs2) { throw new Error("CS1 (cipher state) is not defined") } - plaintext = await this.readMessageRegular(session.cs2, message); + plaintext = this.readMessageRegular(session.cs2, message); } else { if (!session.cs1) { throw new Error("CS1 (cipher state) is not defined") } - plaintext = await this.readMessageRegular(session.cs1, message); + plaintext = this.readMessageRegular(session.cs1, message); } } else { throw new Error("Session invalid."); diff --git a/test/noise.test.ts b/test/noise.test.ts index b409568..3acd8e2 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -66,18 +66,18 @@ describe("Noise", () => { let receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key expect(receivedMessageBuffer.ne.length).equal(32); - await xx.recvMessage(handshake.session, receivedMessageBuffer); + xx.recvMessage(handshake.session, receivedMessageBuffer); // Stage 1 const signedPayload = signPayload(libp2pPrivKey, getHandshakePayload(staticKeys.publicKey)); const handshakePayload = await createHandshakePayload(libp2pPubKey, libp2pPrivKey, signedPayload); - const messageBuffer = await xx.sendMessage(handshake.session, handshakePayload); + const messageBuffer = xx.sendMessage(handshake.session, handshakePayload); wrapped.writeLP(encodeMessageBuffer(messageBuffer)); // Stage 2 - finish handshake receivedMessageBuffer = decodeMessageBuffer((await wrapped.readLP()).slice()); - await xx.recvMessage(handshake.session, receivedMessageBuffer); + xx.recvMessage(handshake.session, receivedMessageBuffer); return {wrapped, handshake}; })(), ]); diff --git a/test/xx.test.ts b/test/xx.test.ts index 75c3bc0..1eca6b4 100644 --- a/test/xx.test.ts +++ b/test/xx.test.ts @@ -48,9 +48,9 @@ describe("Index", () => { const respSignedPayload = await libp2pRespKeys.sign(getHandshakePayload(kpResp.publicKey)); // initiator: new XX noise session - const nsInit = await xx.initSession(true, prologue, kpInit); + const nsInit = xx.initSession(true, prologue, kpInit); // responder: new XX noise session - const nsResp = await xx.initSession(false, prologue, kpResp); + const nsResp = xx.initSession(false, prologue, kpResp); /* STAGE 0 */ @@ -62,12 +62,12 @@ describe("Index", () => { // initiator sends message const message = Buffer.concat([Buffer.alloc(0), payloadInitEnc]); - const messageBuffer = await xx.sendMessage(nsInit, message); + const messageBuffer = xx.sendMessage(nsInit, message); expect(messageBuffer.ne.length).not.equal(0); // responder receives message - const plaintext = await xx.recvMessage(nsResp, messageBuffer); + const plaintext = xx.recvMessage(nsResp, messageBuffer); console.log("Stage 0 responder payload: ", plaintext); /* STAGE 1 */ @@ -78,22 +78,22 @@ describe("Index", () => { const payloadRespEnc = await createHandshakePayload(libp2pRespPubKey, libp2pRespPrivKey, respSignedPayload); const message1 = Buffer.concat([message, payloadRespEnc]); - const messageBuffer2 = await xx.sendMessage(nsResp, message1); + const messageBuffer2 = xx.sendMessage(nsResp, message1); expect(messageBuffer2.ne.length).not.equal(0); expect(messageBuffer2.ns.length).not.equal(0); // initiator receive payload - const plaintext2 = await xx.recvMessage(nsInit, messageBuffer2); + const plaintext2 = xx.recvMessage(nsInit, messageBuffer2); console.log("Stage 1 responder payload: ", plaintext2); /* STAGE 2 */ // initiator send message - const messageBuffer3 = await xx.sendMessage(nsInit, Buffer.alloc(0)); + const messageBuffer3 = xx.sendMessage(nsInit, Buffer.alloc(0)); // responder receive message - const plaintext3 = await xx.recvMessage(nsResp, messageBuffer3); + const plaintext3 = xx.recvMessage(nsResp, messageBuffer3); console.log("Stage 2 responder payload: ", plaintext3); assert(nsInit.cs1.k.equals(nsResp.cs1.k));