Improve middleware definition and cleanup code

This commit is contained in:
Alice Gaudon 2020-09-25 22:03:22 +02:00
parent 0d6f7c0d90
commit b736f5f6cb
29 changed files with 382 additions and 235 deletions

View File

@ -16,14 +16,15 @@ import SecurityError from "./SecurityError";
import * as path from "path"; import * as path from "path";
import CacheProvider from "./CacheProvider"; import CacheProvider from "./CacheProvider";
import RedisComponent from "./components/RedisComponent"; import RedisComponent from "./components/RedisComponent";
import Extendable from "./Extendable";
import TemplateError = lib.TemplateError; import TemplateError = lib.TemplateError;
export default abstract class Application { export default abstract class Application implements Extendable<ApplicationComponent> {
private readonly version: string; private readonly version: string;
private readonly ignoreCommandLine: boolean; private readonly ignoreCommandLine: boolean;
private readonly controllers: Controller[] = []; private readonly controllers: Controller[] = [];
private readonly webSocketListeners: { [p: string]: WebSocketListener<any> } = {}; private readonly webSocketListeners: { [p: string]: WebSocketListener<Application> } = {};
private readonly components: ApplicationComponent<any>[] = []; private readonly components: ApplicationComponent[] = [];
private cacheProvider?: CacheProvider; private cacheProvider?: CacheProvider;
private ready: boolean = false; private ready: boolean = false;
@ -37,8 +38,9 @@ export default abstract class Application {
protected abstract async init(): Promise<void>; protected abstract async init(): Promise<void>;
protected use(thing: Controller | WebSocketListener<this> | ApplicationComponent<any>) { protected use(thing: Controller | WebSocketListener<this> | ApplicationComponent) {
if (thing instanceof Controller) { if (thing instanceof Controller) {
thing.setApp(this);
this.controllers.push(thing); this.controllers.push(thing);
} else if (thing instanceof WebSocketListener) { } else if (thing instanceof WebSocketListener) {
const path = thing.path(); const path = thing.path();
@ -151,15 +153,24 @@ export default abstract class Application {
// Start components // Start components
for (const component of this.components) { for (const component of this.components) {
await component.start(app); await component.start?.(app);
} }
// Components routes // Components routes
for (const component of this.components) { for (const component of this.components) {
if (component.init) {
component.setCurrentRouter(initRouter);
await component.init(initRouter); await component.init(initRouter);
}
if (component.handle) {
component.setCurrentRouter(handleRouter);
await component.handle(handleRouter); await component.handle(handleRouter);
} }
component.setCurrentRouter(null);
}
// Routes // Routes
this.routes(initRouter, handleRouter); this.routes(initRouter, handleRouter);
@ -203,7 +214,7 @@ export default abstract class Application {
// Check security fields // Check security fields
for (const component of this.components) { for (const component of this.components) {
await component.checkSecuritySettings(); await component.checkSecuritySettings?.();
} }
} }
@ -211,7 +222,7 @@ export default abstract class Application {
Logger.info('Stopping application...'); Logger.info('Stopping application...');
for (const component of this.components) { for (const component of this.components) {
await component.stop(); await component.stop?.();
} }
Logger.info(`${this.constructor.name} v${this.version} - bye`); Logger.info(`${this.constructor.name} v${this.version} - bye`);
@ -219,7 +230,7 @@ export default abstract class Application {
private routes(initRouter: Router, handleRouter: Router) { private routes(initRouter: Router, handleRouter: Router) {
for (const controller of this.controllers) { for (const controller of this.controllers) {
if (controller.hasGlobalHandlers()) { if (controller.hasGlobalMiddlewares()) {
controller.setupGlobalHandlers(handleRouter); controller.setupGlobalHandlers(handleRouter);
Logger.info(`Registered global middlewares for controller ${controller.constructor.name}`); Logger.info(`Registered global middlewares for controller ${controller.constructor.name}`);
@ -247,7 +258,7 @@ export default abstract class Application {
return this.version; return this.version;
} }
public getWebSocketListeners(): { [p: string]: WebSocketListener<any> } { public getWebSocketListeners(): { [p: string]: WebSocketListener<Application> } {
return this.webSocketListeners; return this.webSocketListeners;
} }
@ -255,7 +266,14 @@ export default abstract class Application {
return this.cacheProvider; return this.cacheProvider;
} }
public getComponent<T extends ApplicationComponent<any>>(type: Type<T>): T | undefined { public as<C extends ApplicationComponent>(type: Type<C>): C {
return <T>this.components.find(component => component.constructor === type); const component = this.components.find(component => component.constructor === type);
if (!component) throw new Error(`This app doesn't have a ${type.name} component.`);
return component as C;
}
public asOptional<C extends ApplicationComponent>(type: Type<C>): C | null {
const component = this.components.find(component => component.constructor === type);
return component ? component as C : null;
} }
} }

View File

@ -1,42 +1,24 @@
import {Express, Router} from "express"; import {Express, Router} from "express";
import Logger from "./Logger"; import Logger from "./Logger";
import {sleep} from "./Utils"; import {sleep, Type} from "./Utils";
import Application from "./Application"; import Application from "./Application";
import config from "config"; import config from "config";
import SecurityError from "./SecurityError"; import SecurityError from "./SecurityError";
import Middleware from "./Middleware";
export default abstract class ApplicationComponent<T> { export default abstract class ApplicationComponent {
private val?: T; private currentRouter?: Router;
protected app?: Application; private app?: Application;
public async checkSecuritySettings(): Promise<void> { public async checkSecuritySettings?(): Promise<void>;
}
public async start(app: Express): Promise<void> { public async start?(expressApp: Express): Promise<void>;
}
public async init(router: Router): Promise<void> { public async init?(router: Router): Promise<void>;
}
public async handle(router: Router): Promise<void> { public async handle?(router: Router): Promise<void>;
}
public async stop(): Promise<void> { public async stop?(): Promise<void>;
}
protected export(val: T) {
this.val = val;
}
public import(): T {
if (!this.val) throw 'Cannot import if nothing was exported.';
return this.val;
}
public setApp(app: Application) {
this.app = app;
}
protected async prepare(name: string, prepare: () => Promise<void>): Promise<void> { protected async prepare(name: string, prepare: () => Promise<void>): Promise<void> {
let err; let err;
@ -71,4 +53,34 @@ export default abstract class ApplicationComponent<T> {
throw new SecurityError(`${field} field not configured.`); throw new SecurityError(`${field} field not configured.`);
} }
} }
protected use<M extends Middleware>(middleware: Type<M>): void {
if (!this.currentRouter) throw new Error('Cannot call this method outside init() and handle().');
const instance = new middleware(this.getApp());
this.currentRouter.use(async (req, res, next) => {
try {
await instance.getRequestHandler()(req, res, next);
} catch (e) {
next(e);
}
});
}
protected getCurrentRouter(): Router | null {
return this.currentRouter || null;
}
public setCurrentRouter(router: Router | null): void {
this.currentRouter = router || undefined;
}
protected getApp(): Application {
if (!this.app) throw new Error('app field not initialized.');
return this.app;
}
public setApp(app: Application) {
this.app = app;
}
} }

View File

@ -2,10 +2,13 @@ import express, {IRouter, RequestHandler, Router} from "express";
import {PathParams} from "express-serve-static-core"; import {PathParams} from "express-serve-static-core";
import config from "config"; import config from "config";
import Logger from "./Logger"; import Logger from "./Logger";
import Validator, {FileError, ValidationBag} from "./db/Validator"; import Validator, {ValidationBag} from "./db/Validator";
import FileUploadMiddleware from "./FileUploadMiddleware"; import FileUploadMiddleware from "./FileUploadMiddleware";
import * as querystring from "querystring"; import * as querystring from "querystring";
import {ParsedUrlQueryInput} from "querystring"; import {ParsedUrlQueryInput} from "querystring";
import Middleware from "./Middleware";
import {Type} from "./Utils";
import Application from "./Application";
export default abstract class Controller { export default abstract class Controller {
private static readonly routes: { [p: string]: string } = {}; private static readonly routes: { [p: string]: string } = {};
@ -39,18 +42,19 @@ export default abstract class Controller {
private readonly router: Router = express.Router(); private readonly router: Router = express.Router();
private readonly fileUploadFormRouter: Router = express.Router(); private readonly fileUploadFormRouter: Router = express.Router();
private app?: Application;
public getGlobalHandlers(): RequestHandler[] { public getGlobalMiddlewares(): Middleware[] {
return []; return [];
} }
public hasGlobalHandlers(): boolean { public hasGlobalMiddlewares(): boolean {
return this.getGlobalHandlers().length > 0; return this.getGlobalMiddlewares().length > 0;
} }
public setupGlobalHandlers(router: Router): void { public setupGlobalHandlers(router: Router): void {
for (const globalHandler of this.getGlobalHandlers()) { for (const middleware of this.getGlobalMiddlewares()) {
router.use(this.wrap(globalHandler)); router.use(this.wrap(middleware.getRequestHandler()));
} }
} }
@ -75,19 +79,19 @@ export default abstract class Controller {
this.router.use(handler); this.router.use(handler);
} }
protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('get', path, handler, routeName, ...middlewares); this.handle('get', path, handler, routeName, ...middlewares);
} }
protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('post', path, handler, routeName, ...middlewares); this.handle('post', path, handler, routeName, ...middlewares);
} }
protected put(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { protected put(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('put', path, handler, routeName, ...middlewares); this.handle('put', path, handler, routeName, ...middlewares);
} }
protected delete(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { protected delete(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('delete', path, handler, routeName, ...middlewares); this.handle('delete', path, handler, routeName, ...middlewares);
} }
@ -96,14 +100,15 @@ export default abstract class Controller {
path: PathParams, path: PathParams,
handler: RequestHandler, handler: RequestHandler,
routeName?: string, routeName?: string,
...middlewares: (RequestHandler | FileUploadMiddleware)[] ...middlewares: (Type<Middleware>)[]
): void { ): void {
this.registerRoutes(path, handler, routeName); this.registerRoutes(path, handler, routeName);
for (const middleware of middlewares) { for (const middleware of middlewares) {
if (middleware instanceof FileUploadMiddleware) { const instance = new middleware(this.getApp());
this.fileUploadFormRouter[action](path, this.wrap(FILE_UPLOAD_MIDDLEWARE(middleware))); if (instance instanceof FileUploadMiddleware) {
this.fileUploadFormRouter[action](path, this.wrap(instance.getRequestHandler()));
} else { } else {
this.router[action](path, this.wrap(middleware)); this.router[action](path, this.wrap(instance.getRequestHandler()));
} }
} }
this.router[action](path, this.wrap(handler)); this.router[action](path, this.wrap(handler));
@ -164,33 +169,15 @@ export default abstract class Controller {
if (bag.hasMessages()) throw bag; if (bag.hasMessages()) throw bag;
} }
protected getApp(): Application {
if (!this.app) throw new Error('Application not initialized.');
return this.app;
}
public setApp(app: Application) {
this.app = app;
}
} }
export type RouteParams = { [p: string]: string } | string[] | string | number; export type RouteParams = { [p: string]: string } | string[] | string | number;
const FILE_UPLOAD_MIDDLEWARE: (fileUploadMiddleware: FileUploadMiddleware) => RequestHandler = (fileUploadMiddleware: FileUploadMiddleware) => {
return async (req, res, next) => {
const form = fileUploadMiddleware.formFactory();
try {
await new Promise<any>((resolve, reject) => {
form.parse(req, (err, fields, files) => {
if (err) {
reject(err);
return;
}
req.body = fields;
req.files = files;
resolve();
});
});
} catch (e) {
const bag = new ValidationBag();
const fileError = new FileError(e);
fileError.thingName = fileUploadMiddleware.defaultField;
bag.addMessage(fileError);
next(bag);
return;
}
next();
};
};

7
src/Extendable.ts Normal file
View File

@ -0,0 +1,7 @@
import {Type} from "./Utils";
export default interface Extendable<ComponentClass> {
as<C extends ComponentClass>(type: Type<C>): C;
asOptional<C extends ComponentClass>(type: Type<C>): C | null;
}

View File

@ -1,11 +1,35 @@
import {IncomingForm} from "formidable"; import {IncomingForm} from "formidable";
import Middleware from "./Middleware";
import {NextFunction, Request, Response} from "express";
import {FileError, ValidationBag} from "./db/Validator";
export default class FileUploadMiddleware { export default abstract class FileUploadMiddleware extends Middleware {
public readonly formFactory: () => IncomingForm; protected abstract makeForm(): IncomingForm;
public readonly defaultField: string;
public constructor(formFactory: () => IncomingForm, defaultField: string) { protected abstract getDefaultField(): string;
this.formFactory = formFactory;
this.defaultField = defaultField; public async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const form = this.makeForm();
try {
await new Promise<any>((resolve, reject) => {
form.parse(req, (err, fields, files) => {
if (err) {
reject(err);
return;
}
req.body = fields;
req.files = files;
resolve();
});
});
} catch (e) {
const bag = new ValidationBag();
const fileError = new FileError(e);
fileError.thingName = this.getDefaultField();
bag.addMessage(fileError);
next(bag);
return;
}
next();
} }
} }

28
src/Middleware.ts Normal file
View File

@ -0,0 +1,28 @@
import {RequestHandler} from "express";
import {NextFunction, Request, Response} from "express-serve-static-core";
import Application from "./Application";
export default abstract class Middleware {
public constructor(
protected readonly app: Application,
) {
}
protected abstract async handle(req: Request, res: Response, next: NextFunction): Promise<void>;
public getRequestHandler(): RequestHandler {
return async (req, res, next): Promise<void> => {
try {
if (req.middlewares.find(m => m.constructor === this.constructor)) {
next();
} else {
req.middlewares.push(this);
return await this.handle(req, res, next);
}
} catch (e) {
next(e);
}
};
}
}

View File

@ -1,32 +1,58 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {NextFunction, Request, Response, Router} from "express"; import {NextFunction, Request, Response} from "express";
import AuthGuard from "./AuthGuard"; import AuthGuard from "./AuthGuard";
import Controller from "../Controller"; import Controller from "../Controller";
import {ForbiddenHttpError} from "../HttpError"; import {ForbiddenHttpError} from "../HttpError";
import Middleware from "../Middleware";
import User from "./models/User";
import AuthProof from "./AuthProof";
export default class AuthComponent<T extends AuthGuard<any>> extends ApplicationComponent<void> { export default class AuthComponent extends ApplicationComponent {
private readonly authGuard: T; private readonly authGuard: AuthGuard<AuthProof<User>>;
public constructor(authGuard: T) { public constructor(authGuard: AuthGuard<AuthProof<User>>) {
super(); super();
this.authGuard = authGuard; this.authGuard = authGuard;
} }
public async init(router: Router): Promise<void> { public async init(): Promise<void> {
router.use(async (req, res, next) => { this.use(AuthMiddleware);
req.authGuard = this.authGuard;
req.models.user = res.locals.user = await (await req.authGuard.getProof(req))?.getResource();
next();
});
} }
public getAuthGuard(): T { public getAuthGuard(): AuthGuard<AuthProof<User>> {
return this.authGuard; return this.authGuard;
} }
} }
export const REQUIRE_REQUEST_AUTH_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => { export class AuthMiddleware extends Middleware {
let proof = await req.authGuard.isAuthenticatedViaRequest(req); private authGuard?: AuthGuard<AuthProof<User>>;
private user: User | null = null;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
this.authGuard = this.app.as(AuthComponent).getAuthGuard();
const proof = await this.authGuard.isAuthenticated(req.session!);
if (proof) {
this.user = await proof.getResource();
res.locals.user = this.user;
}
next();
}
public getUser(): User | null {
return this.user;
}
public getAuthGuard(): AuthGuard<AuthProof<User>> {
if (!this.authGuard) throw new Error('AuthGuard was not initialized.');
return this.authGuard;
}
}
export class RequireRequestAuthMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const proof = await req.as(AuthMiddleware).getAuthGuard().isAuthenticatedViaRequest(req);
if (!proof) { if (!proof) {
req.flash('error', `You must be logged in to access ${req.url}.`); req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, { res.redirect(Controller.route('auth', undefined, {
@ -36,19 +62,21 @@ export const REQUIRE_REQUEST_AUTH_MIDDLEWARE = async (req: Request, res: Respons
} }
next(); next();
}; }
}
export const REQUIRE_AUTH_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => { export class RequireAuthMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const authGuard = req.as(AuthMiddleware).getAuthGuard();
// Via request // Via request
let proof = await req.authGuard.isAuthenticatedViaRequest(req);
if (proof) { if (await authGuard.isAuthenticatedViaRequest(req)) {
next(); next();
return; return;
} }
// Via session // Via session
proof = await req.authGuard.isAuthenticated(req.session!); if (!await authGuard.isAuthenticated(req.session!)) {
if (!proof) {
req.flash('error', `You must be logged in to access ${req.url}.`); req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, { res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url, redirect_uri: req.url,
@ -57,19 +85,27 @@ export const REQUIRE_AUTH_MIDDLEWARE = async (req: Request, res: Response, next:
} }
next(); next();
}; }
export const REQUIRE_GUEST_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => { }
if (await req.authGuard.isAuthenticated(req.session!)) {
export class RequireGuestMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
if (await req.as(AuthMiddleware).getAuthGuard().isAuthenticated(req.session!)) {
res.redirectBack(); res.redirectBack();
return; return;
} }
next(); next();
}; }
export const REQUIRE_ADMIN_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => { }
if (!req.models.user || !req.models.user.is_admin) {
export class RequireAdminMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const user = req.as(AuthMiddleware).getUser();
if (!user || !user.is_admin) {
throw new ForbiddenHttpError('secret tool', req.url); throw new ForbiddenHttpError('secret tool', req.url);
} }
next(); next();
}; }
}

View File

@ -1,6 +1,6 @@
import Controller from "../Controller"; import Controller from "../Controller";
import {NextFunction, Request, Response} from "express"; import {NextFunction, Request, Response} from "express";
import {REQUIRE_AUTH_MIDDLEWARE, REQUIRE_GUEST_MIDDLEWARE} from "./AuthComponent"; import {AuthMiddleware, RequireAuthMiddleware, RequireGuestMiddleware} from "./AuthComponent";
export default abstract class AuthController extends Controller { export default abstract class AuthController extends Controller {
public getRoutesPrefix(): string { public getRoutesPrefix(): string {
@ -8,10 +8,10 @@ export default abstract class AuthController extends Controller {
} }
public routes() { public routes() {
this.get('/', this.getAuth, 'auth', REQUIRE_GUEST_MIDDLEWARE); this.get('/', this.getAuth, 'auth', RequireGuestMiddleware);
this.post('/', this.postAuth, 'auth', REQUIRE_GUEST_MIDDLEWARE); this.post('/', this.postAuth, 'auth', RequireGuestMiddleware);
this.get('/check', this.getCheckAuth, 'check_auth'); this.get('/check', this.getCheckAuth, 'check_auth');
this.post('/logout', this.postLogout, 'logout', REQUIRE_AUTH_MIDDLEWARE); this.post('/logout', this.postLogout, 'logout', RequireAuthMiddleware);
} }
protected async getAuth(req: Request, res: Response, next: NextFunction): Promise<void> { protected async getAuth(req: Request, res: Response, next: NextFunction): Promise<void> {
@ -26,7 +26,7 @@ export default abstract class AuthController extends Controller {
protected abstract async getCheckAuth(req: Request, res: Response, next: NextFunction): Promise<void>; protected abstract async getCheckAuth(req: Request, res: Response, next: NextFunction): Promise<void>;
protected async postLogout(req: Request, res: Response, next: NextFunction): Promise<void> { protected async postLogout(req: Request, res: Response, next: NextFunction): Promise<void> {
const proof = await req.authGuard.getProof(req); const proof = await req.as(AuthMiddleware).getAuthGuard().getProof(req);
await proof?.revoke(); await proof?.revoke();
req.flash('success', 'Successfully logged out.'); req.flash('success', 'Successfully logged out.');
res.redirect(req.query.redirect_uri?.toString() || '/'); res.redirect(req.query.redirect_uri?.toString() || '/');

View File

@ -53,7 +53,7 @@ export default abstract class AuthGuard<P extends AuthProof<User>> {
proof: P, proof: P,
onLogin?: (user: User) => Promise<void>, onLogin?: (user: User) => Promise<void>,
onRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]> onRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]>
): Promise<void> { ): Promise<User> {
if (!await proof.isValid()) throw new InvalidAuthProofError(); if (!await proof.isValid()) throw new InvalidAuthProofError();
if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError(); if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError();
@ -95,6 +95,8 @@ export default abstract class AuthGuard<P extends AuthProof<User>> {
// Login // Login
session.is_authenticated = true; session.is_authenticated = true;
if (onLogin) await onLogin(user); if (onLogin) await onLogin(user);
return user;
} }
} }

View File

@ -9,17 +9,19 @@ import {AuthError, PendingApprovalAuthError, RegisterCallback} from "../AuthGuar
import geoip from "geoip-lite"; import geoip from "geoip-lite";
import AuthController from "../AuthController"; import AuthController from "../AuthController";
import RedirectBackComponent from "../../components/RedirectBackComponent"; import RedirectBackComponent from "../../components/RedirectBackComponent";
import {AuthMiddleware} from "../AuthComponent";
import User from "../models/User";
export default abstract class MagicLinkAuthController extends AuthController { export default abstract class MagicLinkAuthController extends AuthController {
public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise<void> { public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise<User | null> {
if (magicLink.getSessionID() !== req.sessionID!) throw new BadOwnerMagicLink(); if (magicLink.getSessionID() !== req.sessionID!) throw new BadOwnerMagicLink();
if (!await magicLink.isAuthorized()) throw new UnauthorizedMagicLink(); if (!await magicLink.isAuthorized()) throw new UnauthorizedMagicLink();
if (!await magicLink.isValid()) throw new InvalidMagicLink(); if (!await magicLink.isValid()) throw new InvalidMagicLink();
// Auth // Auth
try { try {
await req.authGuard.authenticateOrRegister(req.session!, magicLink, undefined, async (connection, user) => { return await req.as(AuthMiddleware).getAuthGuard().authenticateOrRegister(req.session!, magicLink, undefined, async (connection, user) => {
const callbacks: RegisterCallback[] = []; const callbacks: RegisterCallback[] = [];
const userEmail = UserEmail.create({ const userEmail = UserEmail.create({
@ -46,7 +48,7 @@ export default abstract class MagicLinkAuthController extends AuthController {
res.redirect('/'); res.redirect('/');
} }
}); });
return; return null;
} else { } else {
throw e; throw e;
} }
@ -139,10 +141,11 @@ export default abstract class MagicLinkAuthController extends AuthController {
return; return;
} }
await MagicLinkAuthController.checkAndAuth(req, res, magicLink); const user = await MagicLinkAuthController.checkAndAuth(req, res, magicLink);
if (user) {
// Auth success // Auth success
const username = req.models.user?.name; const username = user.name;
res.format({ res.format({
json: () => { json: () => {
res.json({'status': 'success', 'message': `Welcome, ${username}!`}); res.json({'status': 'success', 'message': `Welcome, ${username}!`});
@ -152,7 +155,7 @@ export default abstract class MagicLinkAuthController extends AuthController {
res.redirect('/'); res.redirect('/');
}, },
}); });
return; }
} }
} }

View File

@ -5,7 +5,7 @@ import ApplicationComponent from "../ApplicationComponent";
import {ForbiddenHttpError} from "../HttpError"; import {ForbiddenHttpError} from "../HttpError";
import Logger from "../Logger"; import Logger from "../Logger";
export default class AutoUpdateComponent extends ApplicationComponent<void> { export default class AutoUpdateComponent extends ApplicationComponent {
public async checkSecuritySettings(): Promise<void> { public async checkSecuritySettings(): Promise<void> {
this.checkSecurityConfigField('gitlab_webhook_token'); this.checkSecurityConfigField('gitlab_webhook_token');
} }
@ -39,7 +39,7 @@ export default class AutoUpdateComponent extends ApplicationComponent<void> {
await this.runCommand(`yarn dist`); await this.runCommand(`yarn dist`);
// Stop app // Stop app
await this.app!.stop(); await this.getApp().stop();
Logger.info('Success!'); Logger.info('Success!');
} catch (e) { } catch (e) {

View File

@ -2,8 +2,9 @@ import ApplicationComponent from "../ApplicationComponent";
import {Request, Router} from "express"; import {Request, Router} from "express";
import crypto from "crypto"; import crypto from "crypto";
import {BadRequestError} from "../HttpError"; import {BadRequestError} from "../HttpError";
import {AuthMiddleware} from "../auth/AuthComponent";
export default class CsrfProtectionComponent extends ApplicationComponent<void> { export default class CsrfProtectionComponent extends ApplicationComponent {
private static readonly excluders: ((req: Request) => boolean)[] = []; private static readonly excluders: ((req: Request) => boolean)[] = [];
public static getCSRFToken(session: Express.Session): string { public static getCSRFToken(session: Express.Session): string {
@ -33,18 +34,17 @@ export default class CsrfProtectionComponent extends ApplicationComponent<void>
if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) { if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) {
try { try {
if (!(await req.authGuard.isAuthenticatedViaRequest(req))) { if (!await req.as(AuthMiddleware).getAuthGuard().isAuthenticatedViaRequest(req)) {
if (req.session.csrf === undefined) { if (req.session.csrf === undefined) {
throw new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`); return next(new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`));
} else if (req.body.csrf === undefined) { } else if (req.body.csrf === undefined) {
throw new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`); return next(new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`));
} else if (req.session.csrf !== req.body.csrf) { } else if (req.session.csrf !== req.body.csrf) {
throw new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`); return next(new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`));
} }
} }
} catch (e) { } catch (e) {
next(e); return next(e);
return;
} }
} }
next(); next();

View File

@ -3,8 +3,10 @@ import express, {Express, Router} from "express";
import Logger from "../Logger"; import Logger from "../Logger";
import {Server} from "http"; import {Server} from "http";
import compression from "compression"; import compression from "compression";
import Middleware from "../Middleware";
import {Type} from "../Utils";
export default class ExpressAppComponent extends ApplicationComponent<void> { export default class ExpressAppComponent extends ApplicationComponent {
private readonly addr: string; private readonly addr: string;
private readonly port: number; private readonly port: number;
private server?: Server; private server?: Server;
@ -39,8 +41,12 @@ export default class ExpressAppComponent extends ApplicationComponent<void> {
router.use(compression()); router.use(compression());
router.use((req, res, next) => { router.use((req, res, next) => {
req.models = {}; req.middlewares = [];
req.modelCollections = {}; req.as = <M extends Middleware>(type: Type<M>) => {
const middleware = req.middlewares.find(m => m.constructor === type);
if (!middleware) throw new Error('Middleware ' + type.name + ' not present in this request.');
return middleware as M;
};
next(); next();
}); });
} }

View File

@ -1,7 +1,7 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Router} from "express"; import {Router} from "express";
export default class FormHelperComponent extends ApplicationComponent<void> { export default class FormHelperComponent extends ApplicationComponent {
public async init(router: Router): Promise<void> { public async init(router: Router): Promise<void> {
router.use((req, res, next) => { router.use((req, res, next) => {
if (!req.session) { if (!req.session) {

View File

@ -3,7 +3,7 @@ import onFinished from "on-finished";
import Logger from "../Logger"; import Logger from "../Logger";
import {Request, Response, Router} from "express"; import {Request, Response, Router} from "express";
export default class LogRequestsComponent extends ApplicationComponent<void> { export default class LogRequestsComponent extends ApplicationComponent {
private static fullRequests: boolean = false; private static fullRequests: boolean = false;
public static logFullHttpRequests() { public static logFullHttpRequests() {

View File

@ -4,8 +4,7 @@ import Mail from "../Mail";
import config from "config"; import config from "config";
import SecurityError from "../SecurityError"; import SecurityError from "../SecurityError";
export default class MailComponent extends ApplicationComponent<void> { export default class MailComponent extends ApplicationComponent {
public async checkSecuritySettings(): Promise<void> { public async checkSecuritySettings(): Promise<void> {
if (!config.get<boolean>('mail.secure')) { if (!config.get<boolean>('mail.secure')) {

View File

@ -4,7 +4,7 @@ import {ServiceUnavailableHttpError} from "../HttpError";
import Application from "../Application"; import Application from "../Application";
import config from "config"; import config from "config";
export default class MaintenanceComponent extends ApplicationComponent<void> { export default class MaintenanceComponent extends ApplicationComponent {
private readonly application: Application; private readonly application: Application;
private readonly canServe: () => boolean; private readonly canServe: () => boolean;

View File

@ -2,7 +2,7 @@ import ApplicationComponent from "../ApplicationComponent";
import {Express} from "express"; import {Express} from "express";
import MysqlConnectionManager from "../db/MysqlConnectionManager"; import MysqlConnectionManager from "../db/MysqlConnectionManager";
export default class MysqlComponent extends ApplicationComponent<void> { export default class MysqlComponent extends ApplicationComponent {
public async start(app: Express): Promise<void> { public async start(app: Express): Promise<void> {
await this.prepare('Mysql connection', () => MysqlConnectionManager.prepare()); await this.prepare('Mysql connection', () => MysqlConnectionManager.prepare());
} }
@ -12,7 +12,7 @@ export default class MysqlComponent extends ApplicationComponent<void> {
} }
public canServe(): boolean { public canServe(): boolean {
return MysqlConnectionManager.pool !== undefined; return MysqlConnectionManager.isReady();
} }
} }

View File

@ -1,6 +1,6 @@
import nunjucks, {Environment} from "nunjucks"; import nunjucks, {Environment} from "nunjucks";
import config from "config"; import config from "config";
import {Express, Router} from "express"; import {Express, NextFunction, Request, Response, Router} from "express";
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import Controller from "../Controller"; import Controller from "../Controller";
import {ServerError} from "../HttpError"; import {ServerError} from "../HttpError";
@ -8,8 +8,11 @@ import * as querystring from "querystring";
import {ParsedUrlQueryInput} from "querystring"; import {ParsedUrlQueryInput} from "querystring";
import * as util from "util"; import * as util from "util";
import * as path from "path"; import * as path from "path";
import * as fs from "fs";
import Logger from "../Logger";
import Middleware from "../Middleware";
export default class NunjucksComponent extends ApplicationComponent<void> { export default class NunjucksComponent extends ApplicationComponent {
private readonly viewsPath: string[]; private readonly viewsPath: string[];
private env?: Environment; private env?: Environment;
@ -20,13 +23,14 @@ export default class NunjucksComponent extends ApplicationComponent<void> {
public async start(app: Express): Promise<void> { public async start(app: Express): Promise<void> {
let coreVersion = 'unknown'; let coreVersion = 'unknown';
const file = fs.existsSync(path.join(__dirname, '../../package.json')) ?
path.join(__dirname, '../../package.json') :
path.join(__dirname, '../package.json');
try { try {
coreVersion = require('../../package.json').version; coreVersion = JSON.parse(fs.readFileSync(file).toString()).version;
} catch (e) { } catch (e) {
try { Logger.warn('Couldn\'t determine coreVersion.', e);
coreVersion = require('../package.json').version;
} catch (e) {
}
} }
this.env = new nunjucks.Environment([ this.env = new nunjucks.Environment([
@ -43,7 +47,7 @@ export default class NunjucksComponent extends ApplicationComponent<void> {
if (path === null) throw new ServerError(`Route ${route} not found.`); if (path === null) throw new ServerError(`Route ${route} not found.`);
return path; return path;
}) })
.addGlobal('app_version', this.app!.getVersion()) .addGlobal('app_version', this.getApp().getVersion())
.addGlobal('core_version', coreVersion) .addGlobal('core_version', coreVersion)
.addGlobal('querystring', querystring) .addGlobal('querystring', querystring)
.addGlobal('app', config.get('app')) .addGlobal('app', config.get('app'))
@ -59,18 +63,29 @@ export default class NunjucksComponent extends ApplicationComponent<void> {
} }
public async init(router: Router): Promise<void> { public async init(router: Router): Promise<void> {
router.use((req, res, next) => { this.use(NunjucksMiddleware);
req.env = this.env!;
res.locals.url = req.url;
res.locals.params = req.params;
res.locals.query = req.query;
res.locals.body = req.body;
next();
});
} }
public getEnv(): Environment | undefined { public getEnv(): Environment | undefined {
return this.env; return this.env;
} }
} }
export class NunjucksMiddleware extends Middleware {
private env?: Environment;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
this.env = this.app.as(NunjucksComponent).getEnv();
res.locals.url = req.url;
res.locals.params = req.params;
res.locals.query = req.query;
res.locals.body = req.body;
next();
}
public getEnvironment(): Environment {
if (!this.env) throw new Error('Environment not initialized.');
return this.env;
}
}

View File

@ -4,7 +4,7 @@ import {ServerError} from "../HttpError";
import onFinished from "on-finished"; import onFinished from "on-finished";
import Logger from "../Logger"; import Logger from "../Logger";
export default class RedirectBackComponent extends ApplicationComponent<void> { export default class RedirectBackComponent extends ApplicationComponent {
public static getPreviousURL(req: Request, defaultUrl?: string): string | undefined { public static getPreviousURL(req: Request, defaultUrl?: string): string | undefined {
return req.session?.previousUrl || defaultUrl; return req.session?.previousUrl || defaultUrl;
} }

View File

@ -9,7 +9,7 @@ import CacheProvider from "../CacheProvider";
const RedisStore = connect_redis(session); const RedisStore = connect_redis(session);
export default class RedisComponent extends ApplicationComponent<void> implements CacheProvider { export default class RedisComponent extends ApplicationComponent implements CacheProvider {
private redisClient?: RedisClient; private redisClient?: RedisClient;
private store?: Store; private store?: Store;

View File

@ -3,7 +3,7 @@ import express, {Router} from "express";
import {PathParams} from "express-serve-static-core"; import {PathParams} from "express-serve-static-core";
import * as path from "path"; import * as path from "path";
export default class ServeStaticDirectoryComponent extends ApplicationComponent<void> { export default class ServeStaticDirectoryComponent extends ApplicationComponent {
private readonly root: string; private readonly root: string;
private readonly path?: PathParams; private readonly path?: PathParams;

View File

@ -6,7 +6,7 @@ import flash from "connect-flash";
import {Router} from "express"; import {Router} from "express";
import SecurityError from "../SecurityError"; import SecurityError from "../SecurityError";
export default class SessionComponent extends ApplicationComponent<void> { export default class SessionComponent extends ApplicationComponent {
private readonly storeComponent: RedisComponent; private readonly storeComponent: RedisComponent;
public constructor(storeComponent: RedisComponent) { public constructor(storeComponent: RedisComponent) {

View File

@ -11,7 +11,7 @@ import RedisComponent from "./RedisComponent";
import WebSocketListener from "../WebSocketListener"; import WebSocketListener from "../WebSocketListener";
import NunjucksComponent from "./NunjucksComponent"; import NunjucksComponent from "./NunjucksComponent";
export default class WebSocketServerComponent extends ApplicationComponent<void> { export default class WebSocketServerComponent extends ApplicationComponent {
private wss?: WebSocket.Server; private wss?: WebSocket.Server;
constructor( constructor(

View File

@ -7,8 +7,9 @@ import ModelFactory from "./ModelFactory";
import ModelRelation from "./ModelRelation"; import ModelRelation from "./ModelRelation";
import ModelQuery, {ModelQueryResult, SelectFields} from "./ModelQuery"; import ModelQuery, {ModelQueryResult, SelectFields} from "./ModelQuery";
import {Request} from "express"; import {Request} from "express";
import Extendable from "../Extendable";
export default abstract class Model { export default abstract class Model implements Extendable<ModelComponent<Model>> {
public static get table(): string { public static get table(): string {
return this.name return this.name
.replace(/(?:^|\.?)([A-Z])/g, (x, y) => '_' + y.toLowerCase()) .replace(/(?:^|\.?)([A-Z])/g, (x, y) => '_' + y.toLowerCase())
@ -44,14 +45,14 @@ export default abstract class Model {
return ModelFactory.get(this).paginate(request, perPage, query); return ModelFactory.get(this).paginate(request, perPage, query);
} }
protected readonly _factory: ModelFactory<any>; protected readonly _factory: ModelFactory<Model>;
private readonly _components: ModelComponent<any>[] = []; private readonly _components: ModelComponent<this>[] = [];
private readonly _validators: { [key: string]: Validator<any> } = {}; private readonly _validators: { [key: string]: Validator<any> | undefined } = {};
private _exists: boolean; private _exists: boolean;
[key: string]: any; [key: string]: any;
public constructor(factory: ModelFactory<any>, isNew: boolean) { public constructor(factory: ModelFactory<Model>, isNew: boolean) {
if (!factory || !(factory instanceof ModelFactory)) throw new Error('Cannot instantiate model directly.'); if (!factory || !(factory instanceof ModelFactory)) throw new Error('Cannot instantiate model directly.');
this._factory = factory; this._factory = factory;
this.init(); this.init();
@ -71,15 +72,26 @@ export default abstract class Model {
this._components.push(modelComponent); this._components.push(modelComponent);
} }
public as<T extends ModelComponent<any>>(type: Type<T>): T { public as<C extends ModelComponent<Model>>(type: Type<C>): C {
for (const component of this._components) { for (const component of this._components) {
if (component instanceof type) { if (component instanceof type) {
return <any>this; return this as unknown as C;
} }
} }
throw new Error(`Component ${type.name} was not initialized for this ${this.constructor.name}.`); throw new Error(`Component ${type.name} was not initialized for this ${this.constructor.name}.`);
} }
public asOptional<C extends ModelComponent<Model>>(type: Type<C>): C | null {
for (const component of this._components) {
if (component instanceof type) {
return this as unknown as C;
}
}
return null;
}
public updateWithData(data: any) { public updateWithData(data: any) {
for (const property of this._properties) { for (const property of this._properties) {
if (data[property] !== undefined) { if (data[property] !== undefined) {
@ -92,19 +104,16 @@ export default abstract class Model {
* Override this to automatically fill obvious missing data i.e. from relation or default value that are fetched * Override this to automatically fill obvious missing data i.e. from relation or default value that are fetched
* asynchronously. * asynchronously.
*/ */
protected async autoFill(): Promise<void> { protected async autoFill?(): Promise<void>;
}
protected async beforeSave(connection: Connection): Promise<void> { protected async beforeSave?(connection: Connection): Promise<void>;
}
protected async afterSave(): Promise<void> { protected async afterSave?(): Promise<void>;
}
public async save(connection?: Connection, postHook?: (callback: () => Promise<void>) => void): Promise<void> { public async save(connection?: Connection, postHook?: (callback: () => Promise<void>) => void): Promise<void> {
if (connection && !postHook) throw new Error('If connection is provided, postHook must be provided too.'); if (connection && !postHook) throw new Error('If connection is provided, postHook must be provided too.');
await this.autoFill(); await this.autoFill?.();
await this.validate(false, connection); await this.validate(false, connection);
const needs_full_update = connection ? const needs_full_update = connection ?
@ -120,7 +129,7 @@ export default abstract class Model {
this.updateWithData(result.results[0]); this.updateWithData(result.results[0]);
} }
await this.afterSave(); await this.afterSave?.();
}; };
if (connection) { if (connection) {
@ -132,7 +141,7 @@ export default abstract class Model {
private async saveTransaction(connection: Connection): Promise<boolean> { private async saveTransaction(connection: Connection): Promise<boolean> {
// Before save // Before save
await this.beforeSave(connection); await this.beforeSave?.(connection);
if (!this.exists() && this.hasOwnProperty('created_at')) { if (!this.exists() && this.hasOwnProperty('created_at')) {
this.created_at = new Date(); this.created_at = new Date();
} }

View File

@ -13,7 +13,7 @@ export default abstract class ModelComponent<T extends Model> {
} }
public applyToModel(): void { public applyToModel(): void {
this.init(); this.init?.();
for (const property of this._properties) { for (const property of this._properties) {
if (!property.startsWith('_')) { if (!property.startsWith('_')) {
@ -32,7 +32,7 @@ export default abstract class ModelComponent<T extends Model> {
} }
} }
protected abstract init(): void; protected init?(): void;
protected setValidation<V>(propertyName: keyof this): Validator<V> { protected setValidation<V>(propertyName: keyof this): Validator<V> {
const validator = new Validator<V>(); const validator = new Validator<V>();

View File

@ -21,6 +21,10 @@ export default class MysqlConnectionManager {
private static migrationsRegistered: boolean = false; private static migrationsRegistered: boolean = false;
private static readonly migrations: Migration[] = []; private static readonly migrations: Migration[] = [];
public static isReady(): boolean {
return this.databaseReady && this.currentPool !== undefined;
}
public static registerMigrations(migrations: Type<Migration>[]) { public static registerMigrations(migrations: Type<Migration>[]) {
if (!this.migrationsRegistered) { if (!this.migrationsRegistered) {
this.migrationsRegistered = true; this.migrationsRegistered = true;

View File

@ -1,6 +1,5 @@
import config from "config"; import config from "config";
import Controller from "../Controller"; import Controller from "../Controller";
import {REQUIRE_ADMIN_MIDDLEWARE, REQUIRE_AUTH_MIDDLEWARE} from "../auth/AuthComponent";
import User from "../auth/models/User"; import User from "../auth/models/User";
import {Request, Response} from "express"; import {Request, Response} from "express";
import {BadRequestError, NotFoundHttpError} from "../HttpError"; import {BadRequestError, NotFoundHttpError} from "../HttpError";
@ -8,6 +7,7 @@ import Mail from "../Mail";
import {ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE} from "../Mails"; import {ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE} from "../Mails";
import UserEmail from "../auth/models/UserEmail"; import UserEmail from "../auth/models/UserEmail";
import UserApprovedComponent from "../auth/models/UserApprovedComponent"; import UserApprovedComponent from "../auth/models/UserApprovedComponent";
import {RequireAdminMiddleware, RequireAuthMiddleware} from "../auth/AuthComponent";
export default class BackendController extends Controller { export default class BackendController extends Controller {
private static readonly menu: BackendMenuElement[] = []; private static readonly menu: BackendMenuElement[] = [];
@ -37,11 +37,11 @@ export default class BackendController extends Controller {
} }
public routes(): void { public routes(): void {
this.get('/', this.getIndex, 'backend', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); this.get('/', this.getIndex, 'backend', RequireAuthMiddleware, RequireAdminMiddleware);
if (User.isApprovalMode()) { if (User.isApprovalMode()) {
this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval', RequireAuthMiddleware, RequireAdminMiddleware);
this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account', RequireAuthMiddleware, RequireAdminMiddleware);
this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account', RequireAuthMiddleware, RequireAdminMiddleware);
} }
} }

View File

@ -1,21 +1,18 @@
import {Environment} from "nunjucks";
import Model from "../db/Model";
import AuthGuard from "../auth/AuthGuard";
import {Files} from "formidable"; import {Files} from "formidable";
import User from "../auth/models/User"; import {Type} from "../Utils";
import Middleware from "../Middleware";
declare global { declare global {
namespace Express { namespace Express {
export interface Request { export interface Request {
env: Environment;
models: {
user?: User | null,
[p: string]: Model | null | undefined,
};
modelCollections: { [p: string]: Model[] | null };
authGuard: AuthGuard<any>;
files: Files; files: Files;
middlewares: Middleware[];
as<M extends Middleware>(type: Type<M>): M;
flash(): { [key: string]: string[] }; flash(): { [key: string]: string[] };
flash(message: string): any; flash(message: string): any;