diff --git a/src/SessionWebSocketListener.ts b/src/SessionWebSocketListener.ts new file mode 100644 index 0000000..a4241a3 --- /dev/null +++ b/src/SessionWebSocketListener.ts @@ -0,0 +1,51 @@ +import config from "config"; +import cookie from "cookie"; +import cookieParser from "cookie-parser"; +import {Request} from "express"; +import {Session} from "express-session"; +import {IncomingMessage} from "http"; +import {WebSocket} from "ws"; + +import Application from "./Application.js"; +import RedisComponent from "./components/RedisComponent.js"; +import {logger} from "./Logger.js"; +import WebSocketListener from "./WebSocketListener.js"; + +export default abstract class SessionWebSocketListener extends WebSocketListener { + + public async handle(socket: WebSocket, request: IncomingMessage): Promise { + socket.once('message', (data, isBinary) => { + if (isBinary) return socket.close(1003); + + const cookies = cookie.parse(data.toString()); + const sid = cookieParser.signedCookie(cookies['connect.sid'], config.get('session.secret')); + + if (!sid) { + socket.close(1002, 'Could not decrypt provided session cookie.'); + return; + } + + const store = this.getApp().as(RedisComponent).getStore(); + store.get(sid, (err, session) => { + if (err || !session) { + logger.error(err, 'Error while initializing session in websocket for sid ' + sid); + socket.close(1011); + return; + } + + session.id = sid; + + store.createSession(request, session); + this.handleSessionSocket(socket, request, session as Session).catch(err => { + logger.error(err, 'Error in websocket listener.'); + }); + }); + }); + } + + protected abstract handleSessionSocket( + socket: WebSocket, + request: IncomingMessage, + session: Session, + ): Promise; +} diff --git a/src/WebSocketListener.ts b/src/WebSocketListener.ts index d934774..62316b9 100644 --- a/src/WebSocketListener.ts +++ b/src/WebSocketListener.ts @@ -1,4 +1,3 @@ -import {Session} from "express-session"; import {IncomingMessage} from "http"; import WebSocket from "ws"; @@ -20,6 +19,5 @@ export default abstract class WebSocketListener { public abstract handle( socket: WebSocket, request: IncomingMessage, - session: Session | null, ): Promise; } diff --git a/src/assets/ts/WebsocketClient.ts b/src/assets/ts/WebsocketClient.ts index 5017d64..db01210 100644 --- a/src/assets/ts/WebsocketClient.ts +++ b/src/assets/ts/WebsocketClient.ts @@ -17,6 +17,7 @@ export default class WebsocketClient { const websocket = new WebSocket(this.websocketUrl); websocket.onopen = () => { console.debug('Websocket connected'); + websocket.send(document.cookie); }; websocket.onmessage = (e) => { this.listener(websocket, e); diff --git a/src/auth/magic_link/MagicLinkWebSocketListener.ts b/src/auth/magic_link/MagicLinkWebSocketListener.ts index 5eced34..6ff3dd1 100644 --- a/src/auth/magic_link/MagicLinkWebSocketListener.ts +++ b/src/auth/magic_link/MagicLinkWebSocketListener.ts @@ -3,10 +3,10 @@ import {IncomingMessage} from "http"; import WebSocket from "ws"; import Application from "../../Application.js"; -import WebSocketListener from "../../WebSocketListener.js"; +import SessionWebSocketListener from "../../SessionWebSocketListener.js"; import MagicLink from "../models/MagicLink.js"; -export default class MagicLinkWebSocketListener extends WebSocketListener { +export default class MagicLinkWebSocketListener extends SessionWebSocketListener { private readonly connections: { [p: string]: (() => void)[] | undefined } = {}; public refreshMagicLink(sessionId: string): void { @@ -16,13 +16,7 @@ export default class MagicLinkWebSocketListener extends W } } - public async handle(socket: WebSocket, request: IncomingMessage, session: Session | null): Promise { - // Drop if requested without session - if (!session) { - socket.close(1002, 'Session is required for this request.'); - return; - } - + public async handleSessionSocket(socket: WebSocket, request: IncomingMessage, session: Session): Promise { // Refuse any incoming data socket.on('message', () => { socket.close(1003); @@ -37,19 +31,22 @@ export default class MagicLinkWebSocketListener extends W // Refresh if immediately applicable if (!magicLink || !await magicLink.isValid() || await magicLink.isAuthorized()) { socket.send('refresh'); - socket.close(1000); + const reason = magicLink ? + 'Magic link state changed.' : + 'Magic link not found for session ' + session.id; + socket.close(1000, reason); return; } const validityTimeout = setTimeout(() => { socket.send('refresh'); - socket.close(1000); + socket.close(1000, 'Timed out'); }, magicLink.getExpirationDate().getTime() - new Date().getTime()); const f = () => { clearTimeout(validityTimeout); socket.send('refresh'); - socket.close(1000); + socket.close(1000, 'Closed by server'); }; socket.on('close', () => { diff --git a/src/components/SessionComponent.ts b/src/components/SessionComponent.ts index 08774df..4f8a8b5 100644 --- a/src/components/SessionComponent.ts +++ b/src/components/SessionComponent.ts @@ -29,8 +29,9 @@ export default class SessionComponent extends ApplicationComponent { store: this.storeComponent.getStore(), resave: false, cookie: { - httpOnly: true, + httpOnly: false, secure: config.get('session.cookie.secure'), + sameSite: 'strict', }, rolling: true, })); diff --git a/src/components/WebSocketServerComponent.ts b/src/components/WebSocketServerComponent.ts index d4ec4d4..ffb4bd6 100644 --- a/src/components/WebSocketServerComponent.ts +++ b/src/components/WebSocketServerComponent.ts @@ -1,8 +1,5 @@ import config from "config"; -import cookie from "cookie"; -import cookieParser from "cookie-parser"; -import {Express, Request, Router} from "express"; -import {Session} from "express-session"; +import {Express, Router} from "express"; import {WebSocketServer} from "ws"; import Application from "../Application.js"; @@ -45,37 +42,11 @@ export default class WebSocketServerComponent extends ApplicationComponent { if (!listener) { socket.close(1002, `Path not found ${request.url}`); return; - } else if (!request.headers.cookie) { - listener.handle(socket, request, null).catch(err => { - logger.error(err, 'Error in websocket listener.'); - }); - return; } logger.debug(`Websocket on ${request.url}`); - - const cookies = cookie.parse(request.headers.cookie); - const sid = cookieParser.signedCookie(cookies['connect.sid'], config.get('session.secret')); - - if (!sid) { - socket.close(1002); - return; - } - - const store = app.as(RedisComponent).getStore(); - store.get(sid, (err, session) => { - if (err || !session) { - logger.error(err, 'Error while initializing session in websocket.'); - socket.close(1011); - return; - } - - session.id = sid; - - store.createSession(request, session); - listener.handle(socket, request, session as Session).catch(err => { - logger.error(err, 'Error in websocket listener.'); - }); + listener.handle(socket, request).catch(err => { + logger.error(err, 'Error in websocket listener.'); }); }); } diff --git a/test/Authentication.test.ts b/test/Authentication.test.ts index 2c5bbfc..48bd864 100644 --- a/test/Authentication.test.ts +++ b/test/Authentication.test.ts @@ -1185,12 +1185,12 @@ describe('Session persistence', () => { await followMagicLinkFromMail(agent, cookies); - expect(cookies[0]).toMatch(/^connect\.sid=.+; Path=\/; HttpOnly$/); + expect(cookies[0]).toMatch(/^connect\.sid=.+; Path=\/; SameSite=Strict$/); res = await agent.get('/csrf') .set('Cookie', cookies) .expect(200); - expect(res.get('Set-Cookie')[0]).toMatch(/^connect\.sid=.+; Path=\/; Expires=.+; HttpOnly$/); + expect(res.get('Set-Cookie')[0]).toMatch(/^connect\.sid=.+; Path=\/; Expires=.+; SameSite=Strict$/); // Logout await agent.post('/auth/logout') @@ -1217,7 +1217,7 @@ describe('Session persistence', () => { const res = await agent.get('/csrf') .set('Cookie', cookies) .expect(200); - expect(res.get('Set-Cookie')[0]).toMatch(/^connect\.sid=.+; Path=\/; Expires=.+; HttpOnly$/); + expect(res.get('Set-Cookie')[0]).toMatch(/^connect\.sid=.+; Path=\/; Expires=.+; SameSite=Strict$/); // Logout await agent.post('/auth/logout') @@ -1244,7 +1244,7 @@ describe('Session persistence', () => { const res = await agent.get('/csrf') .set('Cookie', cookies) .expect(200); - expect(res.get('Set-Cookie')[0]).toMatch(/^connect\.sid=.+; Path=\/; HttpOnly$/); + expect(res.get('Set-Cookie')[0]).toMatch(/^connect\.sid=.+; Path=\/; SameSite=Strict$/); // Logout await agent.post('/auth/logout')