diff --git a/src/errors.ts b/src/errors.ts new file mode 100644 index 0000000..09d2b10 --- /dev/null +++ b/src/errors.ts @@ -0,0 +1,11 @@ +export class FailedIKError extends Error { + public initialMsg; + + constructor(initialMsg, message?: string) { + super(message); + + this.initialMsg = initialMsg; + this.name = "FailedIKhandshake"; + this.stack = new Error().stack; + } +}; diff --git a/src/handshake-ik.ts b/src/handshake-ik.ts index 63a8ffc..5683c99 100644 --- a/src/handshake-ik.ts +++ b/src/handshake-ik.ts @@ -7,6 +7,7 @@ import {IHandshake} from "./@types/handshake-interface"; import {Buffer} from "buffer"; import {decode0, decode1, encode0, encode1} from "./encoder"; import {verifySignedPayload} from "./utils"; +import {FailedIKError} from "./errors"; export class IKHandshake implements IHandshake { public isInitiator: boolean; @@ -45,26 +46,28 @@ export class IKHandshake implements IHandshake { const messageBuffer = this.ik.sendMessage(this.session, this.payload); this.connection.writeLP(encode0(messageBuffer)); } else { - const receivedMessageBuffer = decode0(await this.connection.readLP()); + const receivedMsg = await this.connection.readLP(); + const receivedMessageBuffer = decode0(receivedMsg); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); try { await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { - throw new Error(`Error occurred while verifying initiator's signed payload: ${e.message}`); + throw new FailedIKError(receivedMsg, `Error occurred while verifying initiator's signed payload: ${e.message}`); } } } public async stage1(): Promise { if (this.isInitiator) { - const receivedMessageBuffer = decode1(await this.connection.readLP()); + const receivedMsg = await this.connection.readLP(); + const receivedMessageBuffer = decode1(receivedMsg); const plaintext = this.ik.recvMessage(this.session, receivedMessageBuffer); try { await verifySignedPayload(receivedMessageBuffer.ns, plaintext, this.remotePeer.id); } catch (e) { - throw new Error(`Error occurred while verifying responder's signed payload: ${e.message}`); + throw new FailedIKError(receivedMsg, `Error occurred while verifying responder's signed payload: ${e.message}`); } } else { const messageBuffer = this.ik.sendMessage(this.session, this.payload); diff --git a/src/handshakes/ik.ts b/src/handshakes/ik.ts index 75d646e..793aea3 100644 --- a/src/handshakes/ik.ts +++ b/src/handshakes/ik.ts @@ -68,7 +68,6 @@ export class IK extends AbstractHandshake { session.h = h; session.cs1 = cs1; session.cs2 = cs2; - delete session.hs; } else if (session.mc.gtn(1)) { if (session.i) { if (!session.cs2) { diff --git a/src/keycache.ts b/src/keycache.ts index bf96f5e..fec7ef3 100644 --- a/src/keycache.ts +++ b/src/keycache.ts @@ -11,8 +11,8 @@ class Keycache { this.storage.set(peerId.id, key); } - public async load(peerId: PeerId): Promise { - return this.storage.get(peerId.id) || null; + public async load(peerId: PeerId): Promise { + return this.storage.get(peerId.id); } public resetStorage(): void { diff --git a/src/noise.ts b/src/noise.ts index 4d708dc..c678042 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -108,23 +108,32 @@ export class Noise implements INoiseConnection { private async performHandshake(params: HandshakeParams): Promise { const payload = await getPayload(params.localPeer, this.staticKeys.publicKey, this.earlyData); - let foundRemoteStaticKey: bytes|null = null; - if (this.useNoisePipes && params.isInitiator) { - logger("Initiator using noise pipes. Going to load cached static key..."); - foundRemoteStaticKey = await KeyCache.load(params.remotePeer); - logger(`Static key has been found: ${!!foundRemoteStaticKey}`) + let tryIK = this.useNoisePipes; + const foundRemoteStaticKey = await KeyCache.load(params.remotePeer); + if (tryIK && params.isInitiator && !foundRemoteStaticKey) { + tryIK = false; + logger(`Static key not found.`) } - if (foundRemoteStaticKey) { + // Try IK if acting as responder or initiator that has remote's static key. + if (tryIK) { // Try IK first const { remotePeer, connection, isInitiator } = params; + if (!foundRemoteStaticKey) { + // TODO: Recheck. Possible that responder should not have it here. + throw new Error("Remote static key should be initialized."); + } + const IKhandshake = new IKHandshake(isInitiator, payload, this.prologue, this.staticKeys, connection, remotePeer, foundRemoteStaticKey); try { - return await this.performIKHandshake(IKhandshake, payload); + return await this.performIKHandshake(IKhandshake); } catch (e) { // IK failed, go to XX fallback - const ephemeralKeys = IKhandshake.getRemoteEphemeralKeys(); - return await this.performXXFallbackHandshake(params, payload, ephemeralKeys, e.initialMsg); + let ephemeralKeys; + if (params.isInitiator) { + ephemeralKeys = IKhandshake.getRemoteEphemeralKeys(); + } + return await this.performXXFallbackHandshake(params, payload, e.initialMsg, ephemeralKeys); } } else { // Noise pipes not supported, use XX @@ -135,8 +144,8 @@ export class Noise implements INoiseConnection { private async performXXFallbackHandshake( params: HandshakeParams, payload: bytes, - ephemeralKeys: KeyPair, initialMsg: bytes, + ephemeralKeys?: KeyPair, ): Promise { const { isInitiator, remotePeer, connection } = params; const handshake = @@ -147,6 +156,7 @@ export class Noise implements INoiseConnection { await handshake.exchange(); await handshake.finish(); } catch (e) { + logger(e); throw new Error(`Error occurred during XX Fallback handshake: ${e.message}`); } @@ -179,8 +189,13 @@ export class Noise implements INoiseConnection { handshake: IKHandshake, ): Promise { - await handshake.stage0(); - await handshake.stage1(); + try { + await handshake.stage0(); + await handshake.stage1(); + } catch (e) { + console.error("Error in IK handshake: ", e); + throw e; + } return handshake; } diff --git a/src/utils.ts b/src/utils.ts index 003159c..6469215 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -89,8 +89,13 @@ async function isValidPeerId(peerId: bytes, publicKeyProtobuf: bytes) { } export async function verifySignedPayload(noiseStaticKey: bytes, plaintext: bytes, peerId: bytes) { - const NoiseHandshakePayload = await loadPayloadProto(); - const receivedPayload = NoiseHandshakePayload.toObject(NoiseHandshakePayload.decode(plaintext)); + let receivedPayload; + try { + const NoiseHandshakePayload = await loadPayloadProto(); + receivedPayload = NoiseHandshakePayload.toObject(NoiseHandshakePayload.decode(plaintext)); + } catch (e) { + throw new Error("Failed to decode received payload."); + } if (!(await isValidPeerId(peerId, receivedPayload.libp2pKey)) ) { throw new Error("Peer ID doesn't match libp2p public key."); diff --git a/test/noise.test.ts b/test/noise.test.ts index 872f825..2678d9e 100644 --- a/test/noise.test.ts +++ b/test/noise.test.ts @@ -16,6 +16,7 @@ import {decode0, decode1, encode1} from "../src/encoder"; import {XX} from "../src/handshakes/xx"; import {Buffer} from "buffer"; import {getKeyPairFromPeerId} from "./utils"; +import {KeyCache} from "../src/keycache"; describe("Noise", () => { let remotePeer, localPeer; @@ -24,7 +25,7 @@ describe("Noise", () => { [localPeer, remotePeer] = await createPeerIdsFromFixtures(2); }); - it("should communicate through encrypted streams", async() => { + it("should communicate through encrypted streams without noise pipes", async() => { try { const noiseInit = new Noise(undefined, undefined, false); const noiseResp = new Noise(undefined, undefined, false); @@ -120,4 +121,54 @@ describe("Noise", () => { assert(false, e.message); } }); + + it("should communicate through encrypted streams with noise pipes", async() => { + try { + const staticKeysInitiator = generateKeypair(); + const noiseInit = new Noise(staticKeysInitiator.privateKey); + const staticKeysResponder = generateKeypair(); + const noiseResp = new Noise(staticKeysResponder.privateKey); + + // Prepare key cache for noise pipes + await KeyCache.store(localPeer, staticKeysInitiator.publicKey); + await KeyCache.store(remotePeer, staticKeysResponder.publicKey); + + const [inboundConnection, outboundConnection] = DuplexPair(); + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), + ]); + const wrappedInbound = Wrap(inbound.conn); + const wrappedOutbound = Wrap(outbound.conn); + + wrappedOutbound.writeLP(Buffer.from("test v2")); + const response = await wrappedInbound.readLP(); + expect(response.toString()).equal("test v2"); + } catch (e) { + console.error(e); + assert(false, e.message); + } + }); + + it("should switch to XX fallback because of invalid remote static key", async() => { + try { + const staticKeysInitiator = generateKeypair(); + const noiseInit = new Noise(staticKeysInitiator.privateKey); + const noiseResp = new Noise(); + + // Prepare key cache for noise pipes + await KeyCache.store(localPeer, staticKeysInitiator.publicKey); + await KeyCache.store(remotePeer, generateKeypair().publicKey); + + const [inboundConnection, outboundConnection] = DuplexPair(); + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer), + ]); + assert(false, "Should throw error"); + } catch (e) { + console.error(e); + assert(true, e.message); + } + }); });