diff --git a/src/@types/handshake-interface.ts b/src/@types/handshake-interface.ts index 9bd7240..a8de48b 100644 --- a/src/@types/handshake-interface.ts +++ b/src/@types/handshake-interface.ts @@ -5,6 +5,7 @@ import PeerId from "peer-id"; export interface IHandshake { session: NoiseSession; remotePeer: PeerId; + earlyData: Uint8Array; encrypt(plaintext: bytes, session: NoiseSession): bytes; decrypt(ciphertext: bytes, session: NoiseSession): {plaintext: bytes; valid: boolean}; } diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index 983cf2e..d2675a1 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -15,6 +15,7 @@ export class IKHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; public remotePeer!: PeerId; + public earlyData: Uint8Array; private payload: bytes; private prologue: bytes32; @@ -42,6 +43,7 @@ export class IKHandshake implements IHandshake { } this.ik = handshake || new IK(); this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKey); + this.earlyData = Buffer.alloc(0) } public async stage0(): Promise { @@ -63,6 +65,7 @@ export class IKHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); + this.setEarlyData(decodedPayload.data); logger("IK Stage 0 - Responder successfully verified payload!"); } catch (e) { logger("Responder breaking up with IK handshake in stage 0."); @@ -86,6 +89,7 @@ export class IKHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(receivedMessageBuffer.ns.slice(0, 32), decodedPayload, this.remotePeer); + this.setEarlyData(decodedPayload.data); logger("IK Stage 1 - Initiator successfully verified payload!"); } catch (e) { logger("Initiator breaking up with IK handshake in stage 1."); @@ -128,4 +132,10 @@ export class IKHandshake implements IHandshake { return encryption ? session.cs2 : session.cs1; } } + + private setEarlyData(data: Uint8Array|null|undefined): void { + if(data){ + this.earlyData = data; + } + } } diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index 6217bc5..9b29e2b 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -66,6 +66,9 @@ export class XXFallbackHandshake extends XXHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); + if(decodedPayload.data){ + this.earlyData = decodedPayload.data; + } } catch (e) { throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index 41360e5..44baa1f 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -19,6 +19,7 @@ export class XXHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; public remotePeer!: PeerId; + public earlyData: Uint8Array; protected payload: bytes; protected connection: WrappedConnection; @@ -46,6 +47,7 @@ export class XXHandshake implements IHandshake { } this.xx = handshake || new XX(); this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair); + this.earlyData = Buffer.alloc(0) } // stage 0 @@ -82,6 +84,7 @@ export class XXHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); this.remotePeer = await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); + this.setEarlyData(decodedPayload.data) } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } @@ -114,6 +117,7 @@ export class XXHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); + this.setEarlyData(decodedPayload.data) } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } @@ -146,4 +150,10 @@ export class XXHandshake implements IHandshake { return encryption ? session.cs2 : session.cs1; } } + + private setEarlyData(data: Uint8Array|null|undefined): void { + if(data){ + this.earlyData = data + } + } }