From b04fa64365729244a9c50a6b54b12e9bcc9e55d0 Mon Sep 17 00:00:00 2001 From: Damien Arrachequesne Date: Sat, 21 Sep 2024 08:11:15 +0200 Subject: [PATCH] fix(sio): allow to join a room in a middleware (uws) Related: - https://github.com/socketio/socket.io/issues/4810 - https://github.com/socketio/socket.io/issues/5139 --- packages/socket.io/lib/namespace.ts | 23 ++++++++++++++------ packages/socket.io/lib/uws.ts | 6 ++++-- packages/socket.io/test/middleware.ts | 9 ++++++++ packages/socket.io/test/uws.ts | 31 ++++++++++++++++++++++++--- 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/packages/socket.io/lib/namespace.ts b/packages/socket.io/lib/namespace.ts index f7a5942829..d7ae1308fe 100644 --- a/packages/socket.io/lib/namespace.ts +++ b/packages/socket.io/lib/namespace.ts @@ -10,7 +10,6 @@ import { AllButLast, Last, DecorateAcknowledgementsWithMultipleResponses, - DecorateAcknowledgements, RemoveAcknowledgements, EventNamesWithAck, FirstNonErrorArg, @@ -135,11 +134,23 @@ export class Namespace< > > { public readonly name: string; + + /** + * A map of currently connected sockets. + */ public readonly sockets: Map< SocketId, Socket > = new Map(); + /** + * A map of currently connecting sockets. + */ + private _preConnectSockets: Map< + SocketId, + Socket + > = new Map(); + public adapter: Adapter; /** @private */ @@ -327,6 +338,8 @@ export class Namespace< debug("adding socket to nsp %s", this.name); const socket = await this._createSocket(client, auth); + this._preConnectSockets.set(socket.id, socket); + if ( // @ts-ignore this.server.opts.connectionStateRecovery?.skipMiddlewares && @@ -394,7 +407,7 @@ export class Namespace< socket: Socket, ) => void, ) { - // track socket + this._preConnectSockets.delete(socket.id); this.sockets.set(socket.id, socket); // it's paramount that the internal `onconnect` logic @@ -417,11 +430,7 @@ export class Namespace< _remove( socket: Socket, ): void { - if (this.sockets.has(socket.id)) { - this.sockets.delete(socket.id); - } else { - debug("ignoring remove for %s", socket.id); - } + this.sockets.delete(socket.id) || this._preConnectSockets.delete(socket.id); } /** diff --git a/packages/socket.io/lib/uws.ts b/packages/socket.io/lib/uws.ts index 14bcc0aebb..8a2621230e 100644 --- a/packages/socket.io/lib/uws.ts +++ b/packages/socket.io/lib/uws.ts @@ -14,7 +14,8 @@ export function patchAdapter(app /* : TemplatedApp */) { Adapter.prototype.addAll = function (id, rooms) { const isNew = !this.sids.has(id); addAll.call(this, id, rooms); - const socket: Socket = this.nsp.sockets.get(id); + const socket: Socket = + this.nsp.sockets.get(id) || this.nsp._preConnectSockets.get(id); if (!socket) { return; } @@ -34,7 +35,8 @@ export function patchAdapter(app /* : TemplatedApp */) { Adapter.prototype.del = function (id, room) { del.call(this, id, room); - const socket: Socket = this.nsp.sockets.get(id); + const socket: Socket = + this.nsp.sockets.get(id) || this.nsp._preConnectSockets.get(id); if (socket && socket.conn.transport.name === "websocket") { // @ts-ignore const sessionId = socket.conn.id; diff --git a/packages/socket.io/test/middleware.ts b/packages/socket.io/test/middleware.ts index 5211c05a64..4e7a584963 100644 --- a/packages/socket.io/test/middleware.ts +++ b/packages/socket.io/test/middleware.ts @@ -197,6 +197,11 @@ describe("middleware", () => { io.use((socket, next) => { expect(socket.connected).to.be(false); expect(socket.disconnected).to.be(true); + + expect(io.of("/").sockets.size).to.eql(0); + // @ts-expect-error + expect(io.of("/")._preConnectSockets.size).to.eql(1); + next(); }); @@ -204,6 +209,10 @@ describe("middleware", () => { expect(socket.connected).to.be(true); expect(socket.disconnected).to.be(false); + expect(io.of("/").sockets.size).to.eql(1); + // @ts-expect-error + expect(io.of("/")._preConnectSockets.size).to.eql(0); + success(done, io, clientSocket); }); }); diff --git a/packages/socket.io/test/uws.ts b/packages/socket.io/test/uws.ts index 52e5c32d75..aeb815688c 100644 --- a/packages/socket.io/test/uws.ts +++ b/packages/socket.io/test/uws.ts @@ -53,10 +53,10 @@ describe("socket.io with uWebSocket.js-based engine", () => { }); const partialDone = createPartialDone(done, 4); - client.on("connect", partialDone); + client.once("connect", partialDone); clientWSOnly.once("connect", partialDone); - clientPollingOnly.on("connect", partialDone); - clientCustomNamespace.on("connect", partialDone); + clientPollingOnly.once("connect", partialDone); + clientCustomNamespace.once("connect", partialDone); }); afterEach(() => { @@ -176,6 +176,31 @@ describe("socket.io with uWebSocket.js-based engine", () => { io.except("room2").emit("hello"); }); + it("should work when joining a room in a middleware", (done) => { + io.use((socket, next) => { + socket.join("test"); + next(); + }); + + client.disconnect().connect(); + clientPollingOnly.disconnect().connect(); + clientWSOnly.disconnect().connect(); + clientCustomNamespace.disconnect().connect(); + + const partialDone = createPartialDone(done, 3); + + client.on("hello", partialDone); + clientWSOnly.on("hello", partialDone); + clientPollingOnly.on("hello", partialDone); + clientCustomNamespace.on("hello", shouldNotHappen(done)); + + io.on("connection", () => { + if (io.of("/").sockets.size === 3) { + io.to("test").emit("hello"); + } + }); + }); + it("should work even after leaving room", (done) => { const partialDone = createPartialDone(done, 2);