diff --git a/src/@types/handshake-interface.ts b/src/@types/handshake-interface.ts index 9bd7240..f432bda 100644 --- a/src/@types/handshake-interface.ts +++ b/src/@types/handshake-interface.ts @@ -5,6 +5,7 @@ import PeerId from "peer-id"; export interface IHandshake { session: NoiseSession; remotePeer: PeerId; + remoteEarlyData: Buffer; encrypt(plaintext: bytes, session: NoiseSession): bytes; decrypt(ciphertext: bytes, session: NoiseSession): {plaintext: bytes; valid: boolean}; } diff --git a/src/@types/libp2p.ts b/src/@types/libp2p.ts index 151d0ca..3c02747 100644 --- a/src/@types/libp2p.ts +++ b/src/@types/libp2p.ts @@ -14,6 +14,7 @@ export interface INoiseConnection { export type SecureOutbound = { conn: any; + remoteEarlyData: Buffer; remotePeer: PeerId; } diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index e67c2c7..e34b283 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -22,6 +22,7 @@ export class IKHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; public remotePeer!: PeerId; + public remoteEarlyData: Buffer; private payload: bytes; private prologue: bytes32; @@ -49,6 +50,7 @@ export class IKHandshake implements IHandshake { } this.ik = handshake || new IK(); this.session = this.ik.initSession(this.isInitiator, this.prologue, this.staticKeypair, remoteStaticKey); + this.remoteEarlyData = Buffer.alloc(0) } public async stage0(): Promise { @@ -73,6 +75,7 @@ export class IKHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); + this.setRemoteEarlyData(decodedPayload.data); logger("IK Stage 0 - Responder successfully verified payload!"); logRemoteEphemeralKey(this.session.hs.re) } catch (e) { @@ -97,6 +100,7 @@ export class IKHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(receivedMessageBuffer.ns.slice(0, 32), decodedPayload, this.remotePeer); + this.setRemoteEarlyData(decodedPayload.data); logger("IK Stage 1 - Initiator successfully verified payload!"); logRemoteEphemeralKey(this.session.hs.re) } catch (e) { @@ -142,4 +146,10 @@ export class IKHandshake implements IHandshake { return encryption ? session.cs2 : session.cs1; } } + + private setRemoteEarlyData(data: Uint8Array|null|undefined): void { + if(data){ + this.remoteEarlyData = Buffer.from(data.buffer, data.byteOffset, data.length); + } + } } diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index f56df5c..1d29f70 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -70,6 +70,7 @@ export class XXFallbackHandshake extends XXHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); + this.setRemoteEarlyData(decodedPayload.data) } catch (e) { throw new Error(`Error occurred while verifying signed payload from responder: ${e.message}`); } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index bbcbab1..286acd0 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -26,6 +26,7 @@ export class XXHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; public remotePeer!: PeerId; + public remoteEarlyData: Buffer; protected payload: bytes; protected connection: WrappedConnection; @@ -53,6 +54,7 @@ export class XXHandshake implements IHandshake { } this.xx = handshake || new XX(); this.session = this.xx.initSession(this.isInitiator, this.prologue, this.staticKeypair); + this.remoteEarlyData = Buffer.alloc(0) } // stage 0 @@ -94,6 +96,7 @@ export class XXHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); this.remotePeer = await verifySignedPayload(receivedMessageBuffer.ns, decodedPayload, this.remotePeer); + this.setRemoteEarlyData(decodedPayload.data) } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } @@ -127,6 +130,7 @@ export class XXHandshake implements IHandshake { const decodedPayload = await decodePayload(plaintext); this.remotePeer = this.remotePeer || await getPeerIdFromPayload(decodedPayload); await verifySignedPayload(this.session.hs.rs, decodedPayload, this.remotePeer); + this.setRemoteEarlyData(decodedPayload.data) } catch (e) { throw new Error(`Error occurred while verifying signed payload: ${e.message}`); } @@ -160,4 +164,10 @@ export class XXHandshake implements IHandshake { return encryption ? session.cs2 : session.cs1; } } + + protected setRemoteEarlyData(data: Uint8Array|null|undefined): void { + if(data){ + this.remoteEarlyData = Buffer.from(data.buffer, data.byteOffset, data.length); + } + } } diff --git a/src/noise.ts b/src/noise.ts index 746949f..6f4e7fe 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -85,6 +85,7 @@ export class Noise implements INoiseConnection { return { conn, + remoteEarlyData: handshake.remoteEarlyData, remotePeer: handshake.remotePeer, } } @@ -115,6 +116,7 @@ export class Noise implements INoiseConnection { return { conn, + remoteEarlyData: handshake.remoteEarlyData, remotePeer: handshake.remotePeer }; } diff --git a/test/noise.test.ts b/test/noise.test.ts index d8e7f84..8d6e54c 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -336,4 +336,30 @@ describe("Noise", () => { assert(false, e.message); } }); + + it("should accept and return early data from remote peer", async() => { + try { + const localPeerEarlyData = Buffer.from('early data') + const staticKeysInitiator = generateKeypair(); + const noiseInit = new Noise(staticKeysInitiator.privateKey, localPeerEarlyData); + const staticKeysResponder = generateKeypair(); + const noiseResp = new Noise(staticKeysResponder.privateKey); + + // Prepare key cache for noise pipes + KeyCache.store(localPeer, staticKeysInitiator.publicKey); + KeyCache.store(remotePeer, staticKeysResponder.publicKey); + + const [inboundConnection, outboundConnection] = DuplexPair(); + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection), + ]); + + assert(inbound.remoteEarlyData.equals(localPeerEarlyData)) + assert(outbound.remoteEarlyData.equals(Buffer.alloc(0))) + } catch (e) { + console.error(e); + assert(false, e.message); + } + }); });