diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index 30afa1a..f13e711 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -10,8 +10,7 @@ export class IKHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; - private libp2pPrivateKey: bytes; - private libp2pPublicKey: bytes; + private payload: bytes; private prologue: bytes32; private staticKeypair: KeyPair; private connection: WrappedConnection; @@ -20,8 +19,7 @@ export class IKHandshake implements IHandshake { constructor( isInitiator: boolean, - libp2pPrivateKey: bytes, - libp2pPublicKey: bytes, + payload: bytes, prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, @@ -29,8 +27,7 @@ export class IKHandshake implements IHandshake { handshake?: IK, ) { this.isInitiator = isInitiator; - this.libp2pPrivateKey = libp2pPrivateKey; - this.libp2pPublicKey = libp2pPublicKey; + this.payload = payload; this.prologue = prologue; this.staticKeypair = staticKeypair; this.connection = connection; diff --git a/src/handshake-xx-fallback.ts b/src/handshake-xx-fallback.ts index 1db6f2b..223ee2b 100644 --- a/src/handshake-xx-fallback.ts +++ b/src/handshake-xx-fallback.ts @@ -21,8 +21,7 @@ export class XXFallbackHandshake extends XXHandshake { constructor( isInitiator: boolean, - libp2pPrivateKey: bytes, - libp2pPublicKey: bytes, + payload: bytes, prologue: bytes32, staticKeypair: KeyPair, connection: WrappedConnection, @@ -31,7 +30,7 @@ export class XXFallbackHandshake extends XXHandshake { ephemeralKeys?: KeyPair, handshake?: XX, ) { - super(isInitiator, libp2pPrivateKey, libp2pPublicKey, prologue, staticKeypair, connection, remotePeer, handshake); + super(isInitiator, payload, prologue, staticKeypair, connection, remotePeer, handshake); if (ephemeralKeys) { this.ephemeralKeys = ephemeralKeys; } @@ -60,7 +59,6 @@ export class XXFallbackHandshake extends XXHandshake { if (this.isInitiator) { logger('XX Fallback Stage 1 - Initiator waiting to receive first message from responder...'); const receivedMessageBuffer = decode1(this.initialMsg); - logger("Initiator receivedMessageBuffer in stage 1", receivedMessageBuffer); const plaintext = this.xx.recvMessage(this.session, receivedMessageBuffer); logger('XX Fallback Stage 1 - Initiator received the message. Got remote\'s static key.'); @@ -73,16 +71,7 @@ export class XXFallbackHandshake extends XXHandshake { logger("All good with the signature!"); } else { logger('XX Fallback Stage 1 - Responder sending out first message with signed payload and static key.'); - const signedPayload = signPayload(this.libp2pPrivateKey, getHandshakePayload(this.staticKeypair.publicKey)); - const signedEarlyDataPayload = signEarlyDataPayload(this.libp2pPrivateKey, Buffer.alloc(0)); - const handshakePayload = await createHandshakePayload( - this.libp2pPublicKey, - this.libp2pPrivateKey, - signedPayload, - signedEarlyDataPayload, - ); - - const messageBuffer = this.xx.sendMessage(this.session, handshakePayload); + const messageBuffer = this.xx.sendMessage(this.session, this.payload); this.connection.writeLP(encode1(messageBuffer)); logger('XX Fallback Stage 1 - Responder sent the second handshake message with signed payload.') } diff --git a/src/handshake-xx.ts b/src/handshake-xx.ts index fff6698..7f6f1bf 100644 --- a/src/handshake-xx.ts +++ b/src/handshake-xx.ts @@ -16,15 +16,13 @@ export class XXHandshake implements IHandshake { public isInitiator: boolean; public session: NoiseSession; + protected payload: bytes; protected connection: WrappedConnection; protected xx: XX; - protected libp2pPrivateKey: bytes; - protected libp2pPublicKey: bytes; protected staticKeypair: KeyPair; protected remotePeer: PeerId; private prologue: bytes32; - private payload: bytes; constructor( isInitiator: boolean, diff --git a/src/noise.ts b/src/noise.ts index 5c904f0..6f99fab 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -22,7 +22,7 @@ export type WrappedConnection = ReturnType; type HandshakeParams = { connection: WrappedConnection; isInitiator: boolean; - libp2pPublicKey: bytes; + localPeer: PeerId; remotePeer: PeerId; }; @@ -30,7 +30,6 @@ export class Noise implements NoiseConnection { public protocol = "/noise"; private readonly prologue = Buffer.from(this.protocol); - private readonly privateKey: bytes; private readonly staticKeys: KeyPair; private readonly earlyData?: bytes; @@ -127,9 +126,9 @@ export class Noise implements NoiseConnection { ephemeralKeys: KeyPair, initialMsg: bytes, ): Promise { - const { isInitiator, libp2pPublicKey, remotePeer, connection } = params; + const { isInitiator, remotePeer, connection } = params; const handshake = - new XXFallbackHandshake(isInitiator, payload, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); + new XXFallbackHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, initialMsg, ephemeralKeys); try { await handshake.propose(); @@ -146,8 +145,8 @@ export class Noise implements NoiseConnection { params: HandshakeParams, payload: bytes, ): Promise { - const { isInitiator, libp2pPublicKey, remotePeer, connection } = params; - const handshake = new XXHandshake(isInitiator, payload, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); + const { isInitiator, remotePeer, connection } = params; + const handshake = new XXHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer); try { await handshake.propose(); @@ -164,8 +163,8 @@ export class Noise implements NoiseConnection { params: HandshakeParams, payload: bytes, ): Promise { - const { isInitiator, libp2pPublicKey, remotePeer, connection } = params; - const handshake = new IKHandshake(isInitiator, payload, this.privateKey, libp2pPublicKey, this.prologue, this.staticKeys, connection, remotePeer); + const { isInitiator, localPeer, remotePeer, connection } = params; + const handshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer); // TODO diff --git a/test/noise.test.ts b/test/noise.test.ts index 7a94809..dcf790f 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -57,9 +57,8 @@ describe("Noise", () => { const staticKeys = generateKeypair(); const xx = new XX(); - const handshake = new XXHandshake(false, libp2pPrivKey, libp2pPubKey, prologue, staticKeys, wrapped, localPeer, xx); const payload = await getPayload(remotePeer, staticKeys.publicKey); - const handshake = new Handshake(false, payload, prologue, staticKeys, wrapped, localPeer, xx); + const handshake = new XXHandshake(false, payload, prologue, staticKeys, wrapped, localPeer, xx); let receivedMessageBuffer = decode0((await wrapped.readLP()).slice()); // The first handshake message contains the initiator's ephemeral public key @@ -67,9 +66,9 @@ describe("Noise", () => { xx.recvMessage(handshake.session, receivedMessageBuffer); // Stage 1 - const { privateKey: libp2pPrivKey, publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); + const { publicKey: libp2pPubKey } = getKeyPairFromPeerId(remotePeer); const signedPayload = await signPayload(remotePeer, getHandshakePayload(staticKeys.publicKey)); - const handshakePayload = await createHandshakePayload(libp2pPubKey, libp2pPrivKey, signedPayload); + const handshakePayload = await createHandshakePayload(libp2pPubKey, signedPayload); const messageBuffer = xx.sendMessage(handshake.session, handshakePayload); wrapped.writeLP(encode1(messageBuffer)); diff --git a/test/xx-fallback-handshake.test.ts b/test/xx-fallback-handshake.test.ts index 028f425..18bfeba 100644 --- a/test/xx-fallback-handshake.test.ts +++ b/test/xx-fallback-handshake.test.ts @@ -3,17 +3,13 @@ import {Buffer} from "buffer"; import Duplex from 'it-pair/duplex'; import { - createHandshakePayload, generateKeypair, - getHandshakePayload, - signPayload + getPayload, } from "../src/utils"; -import {generateEd25519Keys, getKeyPairFromPeerId} from "./utils"; import {XXFallbackHandshake} from "../src/handshake-xx-fallback"; import {createPeerIdsFromFixtures} from "./fixtures/peer"; import {assert} from "chai"; import {decode1, encode0, encode1} from "../src/encoder"; -import {XX} from "../src/handshakes/xx"; describe("XX Fallback Handshake", () => { let peerA, peerB, fakePeer; @@ -33,24 +29,17 @@ describe("XX Fallback Handshake", () => { const staticKeysResponder = generateKeypair(); const ephemeralKeys = generateKeypair(); - const {privateKey: initiatorPrivKey, publicKey: initiatorPubKey} = getKeyPairFromPeerId(peerA); - const {privateKey: responderPrivKey, publicKey: responderPubKey} = getKeyPairFromPeerId(peerB); - // Initial msg for responder is IK first message from initiator - const signedPayload = signPayload(initiatorPrivKey, getHandshakePayload(staticKeysInitiator.publicKey)); - const handshakePayload = await createHandshakePayload( - initiatorPubKey, - initiatorPrivKey, - signedPayload, - ); + const handshakePayload = await getPayload(peerA, staticKeysInitiator.publicKey); const initialMsgR = encode0({ ne: ephemeralKeys.publicKey, ns: Buffer.alloc(0), ciphertext: handshakePayload, }); + const respPayload = await getPayload(peerB, staticKeysResponder.publicKey); const handshakeResp = - new XXFallbackHandshake(false, responderPrivKey, responderPubKey, prologue, staticKeysResponder, connectionTo, peerA, initialMsgR); + new XXFallbackHandshake(false, respPayload, prologue, staticKeysResponder, connectionTo, peerA, initialMsgR); await handshakeResp.propose(); await handshakeResp.exchange(); @@ -59,7 +48,7 @@ describe("XX Fallback Handshake", () => { // This is the point where initiator falls back from IK const initialMsgI = await connectionFrom.readLP(); const handshakeInit = - new XXFallbackHandshake(true, initiatorPrivKey, initiatorPubKey, prologue, staticKeysInitiator, connectionFrom, peerB, initialMsgI, ephemeralKeys); + new XXFallbackHandshake(true, handshakePayload, prologue, staticKeysInitiator, connectionFrom, peerB, initialMsgI, ephemeralKeys); await handshakeInit.propose(); await handshakeInit.exchange(); diff --git a/test/xx-handshake.test.ts b/test/xx-handshake.test.ts index 569c35b..08851a4 100644 --- a/test/xx-handshake.test.ts +++ b/test/xx-handshake.test.ts @@ -4,9 +4,8 @@ import {Buffer} from "buffer"; import Wrap from "it-pb-rpc"; import {XXHandshake} from "../src/handshake-xx"; -import {generateKeypair} from "../src/utils"; +import {generateKeypair, getPayload} from "../src/utils"; import {createPeerIdsFromFixtures} from "./fixtures/peer"; -import {getKeyPairFromPeerId} from "./utils"; describe("XX Handshake", () => {