diff --git a/src/Application.ts b/src/Application.ts index a32cadc..02db79c 100644 --- a/src/Application.ts +++ b/src/Application.ts @@ -121,7 +121,7 @@ export default abstract class Application implements Extendable { req.flash('validation', bag.getMessages()); - res.redirectBack(); + res.redirect(req.getPreviousUrl() || Controller.route('home')); }, }); return; diff --git a/src/TestApp.ts b/src/TestApp.ts index 9909491..de24b16 100644 --- a/src/TestApp.ts +++ b/src/TestApp.ts @@ -9,7 +9,6 @@ import MailComponent from "../src/components/MailComponent"; import SessionComponent from "../src/components/SessionComponent"; import AuthComponent from "../src/auth/AuthComponent"; import FormHelperComponent from "../src/components/FormHelperComponent"; -import RedirectBackComponent from "../src/components/RedirectBackComponent"; import ServeStaticDirectoryComponent from "../src/components/ServeStaticDirectoryComponent"; import {Express} from "express"; import MagicLinkAuthMethod from "../src/auth/magic_link/MagicLinkAuthMethod"; @@ -31,6 +30,7 @@ import AccountController from "./auth/AccountController"; import MakeMagicLinksSessionNotUniqueMigration from "./auth/magic_link/MakeMagicLinksSessionNotUniqueMigration"; import AddUsedToMagicLinksMigration from "./auth/magic_link/AddUsedToMagicLinksMigration"; import packageJson = require('../package.json'); +import PreviousUrlComponent from "./components/PreviousUrlComponent"; export const MIGRATIONS = [ CreateMigrationsTable, @@ -72,7 +72,7 @@ export default class TestApp extends Application { // Dynamic views and routes this.use(new NunjucksComponent(['test/views', 'views'])); - this.use(new RedirectBackComponent()); + this.use(new PreviousUrlComponent()); // Services this.use(new MysqlComponent()); diff --git a/src/auth/AccountController.ts b/src/auth/AccountController.ts index 0893ea0..a68d0e4 100644 --- a/src/auth/AccountController.ts +++ b/src/auth/AccountController.ts @@ -58,7 +58,7 @@ export default class AccountController extends Controller { const passwordComponent = user.as(UserPasswordComponent); if (passwordComponent.hasPassword() && !await passwordComponent.verifyPassword(req.body.current_password)) { req.flash('error', 'Invalid current password.'); - res.redirectBack(Controller.route('account')); + res.redirect(Controller.route('account')); return; } @@ -66,7 +66,7 @@ export default class AccountController extends Controller { await user.save(); req.flash('success', 'Password changed successfully.'); - res.redirectBack(Controller.route('account')); + res.redirect(Controller.route('account')); } @@ -120,7 +120,7 @@ export default class AccountController extends Controller { await user.save(); req.flash('success', 'This email was successfully set as your main address.'); - res.redirectBack(); + res.redirect(Controller.route('account')); } protected async postRemoveEmail(req: Request, res: Response): Promise { @@ -140,6 +140,6 @@ export default class AccountController extends Controller { await userEmail.delete(); req.flash('success', 'This email was successfully removed from your account.'); - res.redirectBack(); + res.redirect(Controller.route('account')); } } diff --git a/src/auth/AuthComponent.ts b/src/auth/AuthComponent.ts index dc1141e..552959b 100644 --- a/src/auth/AuthComponent.ts +++ b/src/auth/AuthComponent.ts @@ -116,7 +116,7 @@ export class RequireGuestMiddleware extends Middleware { protected async handle(req: Request, res: Response, next: NextFunction): Promise { const proofs = await req.as(AuthMiddleware).getAuthGuard().getProofsForSession(req.getSession()); if (proofs.length > 0) { - res.redirectBack(); + res.redirect(Controller.route('home')); return; } diff --git a/src/auth/AuthController.ts b/src/auth/AuthController.ts index 113a646..81fd255 100644 --- a/src/auth/AuthController.ts +++ b/src/auth/AuthController.ts @@ -117,7 +117,7 @@ export default class AuthController extends Controller { } req.flash('success', 'Successfully logged out.'); - res.redirect(req.query.redirect_uri?.toString() || '/'); + res.redirect(req.getIntendedUrl() || '/'); } protected async redirectToRegistration(req: Request, res: Response, identifier: string): Promise { diff --git a/src/auth/magic_link/MagicLinkAuthMethod.ts b/src/auth/magic_link/MagicLinkAuthMethod.ts index d79748f..9048594 100644 --- a/src/auth/magic_link/MagicLinkAuthMethod.ts +++ b/src/auth/magic_link/MagicLinkAuthMethod.ts @@ -7,7 +7,6 @@ import {WhereTest} from "../../db/ModelQuery"; import Controller from "../../Controller"; import geoip from "geoip-lite"; import MagicLinkController from "./MagicLinkController"; -import RedirectBackComponent from "../../components/RedirectBackComponent"; import Application from "../../Application"; import {MailTemplate} from "../../mail/Mail"; import AuthMagicLinkActionType from "./AuthMagicLinkActionType"; @@ -55,7 +54,7 @@ export default class MagicLinkAuthMethod implements AuthMethod { if (pendingLink) { if (await pendingLink.isValid()) { res.redirect(Controller.route('magic_link_lobby', undefined, { - redirect_uri: req.query.redirect_uri?.toString() || pendingLink.original_url || undefined, + redirect_uri: req.getIntendedUrl() || pendingLink.original_url || undefined, })); return true; } else { @@ -105,7 +104,7 @@ export default class MagicLinkAuthMethod implements AuthMethod { req.getSession().id, actionType, Controller.route('auth', undefined, { - redirect_uri: req.query.redirect_uri?.toString() || undefined, + redirect_uri: req.getIntendedUrl() || undefined, }), email, this.magicLinkMailTemplate, @@ -120,7 +119,7 @@ export default class MagicLinkAuthMethod implements AuthMethod { ); res.redirect(Controller.route('magic_link_lobby', undefined, { - redirect_uri: req.query.redirect_uri?.toString() || RedirectBackComponent.getPreviousURL(req), + redirect_uri: req.getIntendedUrl(), })); } } diff --git a/src/auth/magic_link/MagicLinkController.ts b/src/auth/magic_link/MagicLinkController.ts index eaa2a07..3df778c 100644 --- a/src/auth/magic_link/MagicLinkController.ts +++ b/src/auth/magic_link/MagicLinkController.ts @@ -192,7 +192,7 @@ export default class MagicLinkController extends Controll if (!res.headersSent && user) { // Auth success req.flash('success', `Authentication success. Welcome, ${user.name}!`); - res.redirect(req.query.redirect_uri?.toString() || Controller.route('home')); + res.redirect(req.getIntendedUrl() || Controller.route('home')); } break; } diff --git a/src/auth/password/PasswordAuthMethod.ts b/src/auth/password/PasswordAuthMethod.ts index 0c46434..27bd85e 100644 --- a/src/auth/password/PasswordAuthMethod.ts +++ b/src/auth/password/PasswordAuthMethod.ts @@ -72,7 +72,7 @@ export default class PasswordAuthMethod implements AuthMethod if (e instanceof PendingApprovalAuthError) { req.flash('error', 'Your account is still being reviewed.'); - res.redirectBack(); + res.redirect(Controller.route('auth')); return; } else { const err = new InvalidFormatValidationError('Invalid password.'); @@ -85,7 +85,7 @@ export default class PasswordAuthMethod implements AuthMethod } req.flash('success', `Welcome, ${user.name}.`); - res.redirect(Controller.route('home')); + res.redirect(req.getIntendedUrl() || Controller.route('home')); } public async attemptRegister(req: Request, res: Response, identifier: string): Promise { @@ -123,7 +123,7 @@ export default class PasswordAuthMethod implements AuthMethod } catch (e) { if (e instanceof PendingApprovalAuthError) { req.flash('info', `Your account was successfully created and is pending review from an administrator.`); - res.redirect(Controller.route('home')); + res.redirect(Controller.route('auth')); return; } else { throw e; @@ -133,7 +133,7 @@ export default class PasswordAuthMethod implements AuthMethod const user = await passwordAuthProof.getResource(); req.flash('success', `Your account was successfully created! Welcome, ${user?.as(UserNameComponent).name}.`); - res.redirect(Controller.route('home')); + res.redirect(req.getIntendedUrl() || Controller.route('home')); } } diff --git a/src/components/PreviousUrlComponent.ts b/src/components/PreviousUrlComponent.ts new file mode 100644 index 0000000..9e18ee7 --- /dev/null +++ b/src/components/PreviousUrlComponent.ts @@ -0,0 +1,54 @@ +import ApplicationComponent from "../ApplicationComponent"; +import {Router} from "express"; +import onFinished from "on-finished"; +import {logger} from "../Logger"; +import SessionComponent from "./SessionComponent"; + +export default class PreviousUrlComponent extends ApplicationComponent { + + public async handle(router: Router): Promise { + router.use((req, res, next) => { + req.getPreviousUrl = () => { + let url = req.header('referer'); + if (url) { + if (url.indexOf('://') >= 0) url = '/' + url.split('/').slice(3).join('/'); + if (url !== req.originalUrl) return url; + } + + if (this.getApp().asOptional(SessionComponent)) { + const session = req.getSessionOptional(); + url = session?.previousUrl; + if (url && url !== req.originalUrl) return url; + } + + return null; + }; + res.locals.getPreviousUrl = req.getPreviousUrl; + + req.getIntendedUrl = () => { + return req.query.redirect_uri?.toString() || null; + }; + + if (this.getApp().asOptional(SessionComponent)) { + const session = req.getSessionOptional(); + if (session && req.method === 'GET') { + onFinished(res, (err) => { + if (err) return; + + const contentType = res.getHeader('content-type'); + if (res.statusCode === 200 && + contentType && typeof contentType !== 'number' && contentType.indexOf('text/html') >= 0) { + session.previousUrl = req.originalUrl; + + session.save((err) => { + if (err) logger.error(err, 'Error while saving session'); + else logger.debug('Prev url set to', session.previousUrl); + }); + } + }); + } + } + next(); + }); + } +} diff --git a/src/components/RedirectBackComponent.ts b/src/components/RedirectBackComponent.ts deleted file mode 100644 index 5e0e25b..0000000 --- a/src/components/RedirectBackComponent.ts +++ /dev/null @@ -1,45 +0,0 @@ -import ApplicationComponent from "../ApplicationComponent"; -import {Request, Router} from "express"; -import {ServerError} from "../HttpError"; -import onFinished from "on-finished"; -import {logger} from "../Logger"; - -export default class RedirectBackComponent extends ApplicationComponent { - public static getPreviousURL(req: Request, defaultUrl?: string): string | undefined { - return req.getSessionOptional()?.previousUrl || defaultUrl; - } - - public async handle(router: Router): Promise { - router.use((req, res, next) => { - res.redirectBack = (defaultUrl?: string): void => { - const previousUrl = RedirectBackComponent.getPreviousURL(req, defaultUrl); - if (!previousUrl) throw new ServerError(`Couldn't redirect you back.`); - res.redirect(previousUrl); - }; - - res.locals.getPreviousURL = (defaultUrl?: string) => { - return RedirectBackComponent.getPreviousURL(req, defaultUrl); - }; - - onFinished(res, (err) => { - const session = req.getSessionOptional(); - if (session) { - const contentType = res.getHeader('content-type'); - if (!err && res.statusCode === 200 && ( - contentType && typeof contentType !== 'number' && contentType.indexOf('text/html') >= 0 - )) { - session.previousUrl = req.originalUrl; - logger.debug('Prev url set to', session.previousUrl); - session.save((err) => { - if (err) { - logger.error(err, 'Error while saving session'); - } - }); - } - } - }); - - next(); - }); - } -} diff --git a/src/helpers/BackendController.ts b/src/helpers/BackendController.ts index e095d10..66fc1c0 100644 --- a/src/helpers/BackendController.ts +++ b/src/helpers/BackendController.ts @@ -82,7 +82,7 @@ export default class BackendController extends Controller { } req.flash('success', `Account successfully approved.`); - res.redirectBack(Controller.route('accounts-approval')); + res.redirect(Controller.route('accounts-approval')); } protected async postRejectAccount(req: Request, res: Response): Promise { @@ -97,7 +97,7 @@ export default class BackendController extends Controller { } req.flash('success', `Account successfully deleted.`); - res.redirectBack(Controller.route('accounts-approval')); + res.redirect(Controller.route('accounts-approval')); } protected async accountRequest(req: Request): Promise<{ diff --git a/src/types/Express.d.ts b/src/types/Express.d.ts index 33fba07..5b14407 100644 --- a/src/types/Express.d.ts +++ b/src/types/Express.d.ts @@ -26,10 +26,11 @@ declare global { flash(message: string): unknown[]; flash(event: string, message: unknown): void; - } - export interface Response { - redirectBack(defaultUrl?: string): void; + + getPreviousUrl(): string | null; + + getIntendedUrl(): string | null; } } } diff --git a/test/Authentication.test.ts b/test/Authentication.test.ts index bc165a6..ad4e907 100644 --- a/test/Authentication.test.ts +++ b/test/Authentication.test.ts @@ -10,6 +10,7 @@ import {popEmail} from "./_mail_server"; import AuthComponent from "../src/auth/AuthComponent"; import {followMagicLinkFromMail, testLogout} from "./_authentication_common"; import UserEmail from "../src/auth/models/UserEmail"; +import * as querystring from "querystring"; let app: TestApp; useApp(async (addr, port) => { @@ -87,7 +88,7 @@ describe('Register with username and password (password)', () => { terms: 'on', }) .expect(302) - .expect('Location', '/csrf'); + .expect('Location', '/'); const user2 = await User.select() .where('name', 'entrapta2') @@ -158,7 +159,7 @@ describe('Register with email (magic_link)', () => { const cookies = res.get('Set-Cookie'); const csrf = res.text; - await agent.post('/auth/register') + await agent.post('/auth/register?' + querystring.stringify({redirect_uri: '/redirect-uri'})) .set('Cookie', cookies) .send({ csrf: csrf, @@ -167,7 +168,7 @@ describe('Register with email (magic_link)', () => { name: 'glimmer', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri=%2Fredirect-uri'); await followMagicLinkFromMail(agent, cookies); @@ -221,7 +222,7 @@ describe('Register with email (magic_link)', () => { name: 'angella', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -263,7 +264,7 @@ describe('Register with email (magic_link)', () => { name: 'bow', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -314,7 +315,7 @@ describe('Authenticate with username and password (password)', () => { expect(res.body.messages?.password?.name).toStrictEqual('InvalidFormatValidationError'); // Authenticate - await agent.post('/auth/login') + await agent.post('/auth/login?' + querystring.stringify({redirect_uri: '/redirect-uri'})) .set('Cookie', cookies) .send({ csrf: csrf, @@ -323,7 +324,7 @@ describe('Authenticate with username and password (password)', () => { auth_method: 'password', }) .expect(302) - .expect('Location', '/'); + .expect('Location', '/redirect-uri'); await testLogout(agent, cookies, csrf); }); @@ -458,7 +459,7 @@ describe('Authenticate with email (magic_link)', () => { await agent.get('/is-auth').set('Cookie', cookies).expect(401); // Authenticate - await agent.post('/auth/login') + await agent.post('/auth/login?' + querystring.stringify({redirect_uri: '/redirect-uri'})) .set('Cookie', cookies) .send({ csrf: csrf, @@ -466,7 +467,7 @@ describe('Authenticate with email (magic_link)', () => { auth_method: 'magic_link', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri=%2Fredirect-uri'); await followMagicLinkFromMail(agent, cookies); @@ -489,7 +490,7 @@ describe('Authenticate with email (magic_link)', () => { identifier: 'angella@example.org', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -553,7 +554,7 @@ describe('Authenticate with email and password (password)', () => { name: 'double-trouble', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -664,7 +665,7 @@ describe('Change password', () => { name: 'aang', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); }); @@ -699,7 +700,7 @@ describe('Change password', () => { 'new_password_confirmation': 'a_very_strong_password', }) .expect(302) - .expect('Location', '/csrf'); // TODO: because of buggy RedirectBackComponent, change to /account once fixed. + .expect('Location', '/account/'); const user = await User.select() .where('name', 'aang') @@ -741,7 +742,7 @@ describe('Change password', () => { 'new_password_confirmation': 'a_very_strong_password_but_different', }) .expect(302) - .expect('Location', '/csrf'); // TODO: because of buggy RedirectBackComponent, change to /account once fixed. + .expect('Location', '/account/'); const user = await User.select() .where('name', 'aang') @@ -839,7 +840,7 @@ describe('Manage email addresses', () => { name: 'katara', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -933,7 +934,7 @@ describe('Manage email addresses', () => { id: beforeSecondaryEmail?.id, }) .expect(302) - .expect('Location', '/csrf'); // TODO: because of buggy RedirectBackComponent, change to /account once fixed. + .expect('Location', '/account/'); await testMainSecondaryState('katara3@example.org', 'katara@example.org'); }); @@ -981,7 +982,7 @@ describe('Manage email addresses', () => { id: (await UserEmail.select().where('email', 'katara2@example.org').first())?.id, }) .expect(302) - .expect('Location', '/csrf'); // TODO: because of buggy RedirectBackComponent, change to /account once fixed. + .expect('Location', '/account/'); expect(await UserEmail.select().where('email', 'katara2@example.org').count()).toBe(0); }); @@ -1046,7 +1047,7 @@ describe('Session persistence', () => { name: 'zuko', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -1075,7 +1076,7 @@ describe('Session persistence', () => { persist_session: 'on', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -1102,7 +1103,7 @@ describe('Session persistence', () => { persist_session: undefined, }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); diff --git a/test/AuthenticationNoUsername.test.ts b/test/AuthenticationNoUsername.test.ts index 31ef118..8930d7d 100644 --- a/test/AuthenticationNoUsername.test.ts +++ b/test/AuthenticationNoUsername.test.ts @@ -81,7 +81,7 @@ describe('Register with email (magic_link)', () => { identifier: 'glimmer@example.org', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); @@ -116,7 +116,7 @@ describe('Register with email (magic_link)', () => { name: 'bow', }) .expect(302) - .expect('Location', '/magic/lobby?redirect_uri=%2Fcsrf'); + .expect('Location', '/magic/lobby?redirect_uri='); await followMagicLinkFromMail(agent, cookies); diff --git a/views/auth/auth.njk b/views/auth/auth.njk index 4cb30eb..545135e 100644 --- a/views/auth/auth.njk +++ b/views/auth/auth.njk @@ -8,8 +8,11 @@ {% block body %}