Improve middleware definition and cleanup code

This commit is contained in:
Alice Gaudon 2020-09-25 22:03:22 +02:00
parent 0d6f7c0d90
commit b736f5f6cb
29 changed files with 382 additions and 235 deletions

View File

@ -16,14 +16,15 @@ import SecurityError from "./SecurityError";
import * as path from "path";
import CacheProvider from "./CacheProvider";
import RedisComponent from "./components/RedisComponent";
import Extendable from "./Extendable";
import TemplateError = lib.TemplateError;
export default abstract class Application {
export default abstract class Application implements Extendable<ApplicationComponent> {
private readonly version: string;
private readonly ignoreCommandLine: boolean;
private readonly controllers: Controller[] = [];
private readonly webSocketListeners: { [p: string]: WebSocketListener<any> } = {};
private readonly components: ApplicationComponent<any>[] = [];
private readonly webSocketListeners: { [p: string]: WebSocketListener<Application> } = {};
private readonly components: ApplicationComponent[] = [];
private cacheProvider?: CacheProvider;
private ready: boolean = false;
@ -37,8 +38,9 @@ export default abstract class Application {
protected abstract async init(): Promise<void>;
protected use(thing: Controller | WebSocketListener<this> | ApplicationComponent<any>) {
protected use(thing: Controller | WebSocketListener<this> | ApplicationComponent) {
if (thing instanceof Controller) {
thing.setApp(this);
this.controllers.push(thing);
} else if (thing instanceof WebSocketListener) {
const path = thing.path();
@ -151,13 +153,22 @@ export default abstract class Application {
// Start components
for (const component of this.components) {
await component.start(app);
await component.start?.(app);
}
// Components routes
for (const component of this.components) {
await component.init(initRouter);
await component.handle(handleRouter);
if (component.init) {
component.setCurrentRouter(initRouter);
await component.init(initRouter);
}
if (component.handle) {
component.setCurrentRouter(handleRouter);
await component.handle(handleRouter);
}
component.setCurrentRouter(null);
}
// Routes
@ -203,7 +214,7 @@ export default abstract class Application {
// Check security fields
for (const component of this.components) {
await component.checkSecuritySettings();
await component.checkSecuritySettings?.();
}
}
@ -211,7 +222,7 @@ export default abstract class Application {
Logger.info('Stopping application...');
for (const component of this.components) {
await component.stop();
await component.stop?.();
}
Logger.info(`${this.constructor.name} v${this.version} - bye`);
@ -219,7 +230,7 @@ export default abstract class Application {
private routes(initRouter: Router, handleRouter: Router) {
for (const controller of this.controllers) {
if (controller.hasGlobalHandlers()) {
if (controller.hasGlobalMiddlewares()) {
controller.setupGlobalHandlers(handleRouter);
Logger.info(`Registered global middlewares for controller ${controller.constructor.name}`);
@ -247,7 +258,7 @@ export default abstract class Application {
return this.version;
}
public getWebSocketListeners(): { [p: string]: WebSocketListener<any> } {
public getWebSocketListeners(): { [p: string]: WebSocketListener<Application> } {
return this.webSocketListeners;
}
@ -255,7 +266,14 @@ export default abstract class Application {
return this.cacheProvider;
}
public getComponent<T extends ApplicationComponent<any>>(type: Type<T>): T | undefined {
return <T>this.components.find(component => component.constructor === type);
public as<C extends ApplicationComponent>(type: Type<C>): C {
const component = this.components.find(component => component.constructor === type);
if (!component) throw new Error(`This app doesn't have a ${type.name} component.`);
return component as C;
}
}
public asOptional<C extends ApplicationComponent>(type: Type<C>): C | null {
const component = this.components.find(component => component.constructor === type);
return component ? component as C : null;
}
}

View File

@ -1,42 +1,24 @@
import {Express, Router} from "express";
import Logger from "./Logger";
import {sleep} from "./Utils";
import {sleep, Type} from "./Utils";
import Application from "./Application";
import config from "config";
import SecurityError from "./SecurityError";
import Middleware from "./Middleware";
export default abstract class ApplicationComponent<T> {
private val?: T;
protected app?: Application;
export default abstract class ApplicationComponent {
private currentRouter?: Router;
private app?: Application;
public async checkSecuritySettings(): Promise<void> {
}
public async checkSecuritySettings?(): Promise<void>;
public async start(app: Express): Promise<void> {
}
public async start?(expressApp: Express): Promise<void>;
public async init(router: Router): Promise<void> {
}
public async init?(router: Router): Promise<void>;
public async handle(router: Router): Promise<void> {
}
public async handle?(router: Router): Promise<void>;
public async stop(): Promise<void> {
}
protected export(val: T) {
this.val = val;
}
public import(): T {
if (!this.val) throw 'Cannot import if nothing was exported.';
return this.val;
}
public setApp(app: Application) {
this.app = app;
}
public async stop?(): Promise<void>;
protected async prepare(name: string, prepare: () => Promise<void>): Promise<void> {
let err;
@ -71,4 +53,34 @@ export default abstract class ApplicationComponent<T> {
throw new SecurityError(`${field} field not configured.`);
}
}
protected use<M extends Middleware>(middleware: Type<M>): void {
if (!this.currentRouter) throw new Error('Cannot call this method outside init() and handle().');
const instance = new middleware(this.getApp());
this.currentRouter.use(async (req, res, next) => {
try {
await instance.getRequestHandler()(req, res, next);
} catch (e) {
next(e);
}
});
}
protected getCurrentRouter(): Router | null {
return this.currentRouter || null;
}
public setCurrentRouter(router: Router | null): void {
this.currentRouter = router || undefined;
}
protected getApp(): Application {
if (!this.app) throw new Error('app field not initialized.');
return this.app;
}
public setApp(app: Application) {
this.app = app;
}
}

View File

@ -2,10 +2,13 @@ import express, {IRouter, RequestHandler, Router} from "express";
import {PathParams} from "express-serve-static-core";
import config from "config";
import Logger from "./Logger";
import Validator, {FileError, ValidationBag} from "./db/Validator";
import Validator, {ValidationBag} from "./db/Validator";
import FileUploadMiddleware from "./FileUploadMiddleware";
import * as querystring from "querystring";
import {ParsedUrlQueryInput} from "querystring";
import Middleware from "./Middleware";
import {Type} from "./Utils";
import Application from "./Application";
export default abstract class Controller {
private static readonly routes: { [p: string]: string } = {};
@ -39,18 +42,19 @@ export default abstract class Controller {
private readonly router: Router = express.Router();
private readonly fileUploadFormRouter: Router = express.Router();
private app?: Application;
public getGlobalHandlers(): RequestHandler[] {
public getGlobalMiddlewares(): Middleware[] {
return [];
}
public hasGlobalHandlers(): boolean {
return this.getGlobalHandlers().length > 0;
public hasGlobalMiddlewares(): boolean {
return this.getGlobalMiddlewares().length > 0;
}
public setupGlobalHandlers(router: Router): void {
for (const globalHandler of this.getGlobalHandlers()) {
router.use(this.wrap(globalHandler));
for (const middleware of this.getGlobalMiddlewares()) {
router.use(this.wrap(middleware.getRequestHandler()));
}
}
@ -75,19 +79,19 @@ export default abstract class Controller {
this.router.use(handler);
}
protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) {
protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('get', path, handler, routeName, ...middlewares);
}
protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) {
protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('post', path, handler, routeName, ...middlewares);
}
protected put(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) {
protected put(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('put', path, handler, routeName, ...middlewares);
}
protected delete(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) {
protected delete(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type<Middleware>)[]) {
this.handle('delete', path, handler, routeName, ...middlewares);
}
@ -96,14 +100,15 @@ export default abstract class Controller {
path: PathParams,
handler: RequestHandler,
routeName?: string,
...middlewares: (RequestHandler | FileUploadMiddleware)[]
...middlewares: (Type<Middleware>)[]
): void {
this.registerRoutes(path, handler, routeName);
for (const middleware of middlewares) {
if (middleware instanceof FileUploadMiddleware) {
this.fileUploadFormRouter[action](path, this.wrap(FILE_UPLOAD_MIDDLEWARE(middleware)));
const instance = new middleware(this.getApp());
if (instance instanceof FileUploadMiddleware) {
this.fileUploadFormRouter[action](path, this.wrap(instance.getRequestHandler()));
} else {
this.router[action](path, this.wrap(middleware));
this.router[action](path, this.wrap(instance.getRequestHandler()));
}
}
this.router[action](path, this.wrap(handler));
@ -164,33 +169,15 @@ export default abstract class Controller {
if (bag.hasMessages()) throw bag;
}
protected getApp(): Application {
if (!this.app) throw new Error('Application not initialized.');
return this.app;
}
public setApp(app: Application) {
this.app = app;
}
}
export type RouteParams = { [p: string]: string } | string[] | string | number;
const FILE_UPLOAD_MIDDLEWARE: (fileUploadMiddleware: FileUploadMiddleware) => RequestHandler = (fileUploadMiddleware: FileUploadMiddleware) => {
return async (req, res, next) => {
const form = fileUploadMiddleware.formFactory();
try {
await new Promise<any>((resolve, reject) => {
form.parse(req, (err, fields, files) => {
if (err) {
reject(err);
return;
}
req.body = fields;
req.files = files;
resolve();
});
});
} catch (e) {
const bag = new ValidationBag();
const fileError = new FileError(e);
fileError.thingName = fileUploadMiddleware.defaultField;
bag.addMessage(fileError);
next(bag);
return;
}
next();
};
};

7
src/Extendable.ts Normal file
View File

@ -0,0 +1,7 @@
import {Type} from "./Utils";
export default interface Extendable<ComponentClass> {
as<C extends ComponentClass>(type: Type<C>): C;
asOptional<C extends ComponentClass>(type: Type<C>): C | null;
}

View File

@ -1,11 +1,35 @@
import {IncomingForm} from "formidable";
import Middleware from "./Middleware";
import {NextFunction, Request, Response} from "express";
import {FileError, ValidationBag} from "./db/Validator";
export default class FileUploadMiddleware {
public readonly formFactory: () => IncomingForm;
public readonly defaultField: string;
export default abstract class FileUploadMiddleware extends Middleware {
protected abstract makeForm(): IncomingForm;
public constructor(formFactory: () => IncomingForm, defaultField: string) {
this.formFactory = formFactory;
this.defaultField = defaultField;
protected abstract getDefaultField(): string;
public async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const form = this.makeForm();
try {
await new Promise<any>((resolve, reject) => {
form.parse(req, (err, fields, files) => {
if (err) {
reject(err);
return;
}
req.body = fields;
req.files = files;
resolve();
});
});
} catch (e) {
const bag = new ValidationBag();
const fileError = new FileError(e);
fileError.thingName = this.getDefaultField();
bag.addMessage(fileError);
next(bag);
return;
}
next();
}
}

28
src/Middleware.ts Normal file
View File

@ -0,0 +1,28 @@
import {RequestHandler} from "express";
import {NextFunction, Request, Response} from "express-serve-static-core";
import Application from "./Application";
export default abstract class Middleware {
public constructor(
protected readonly app: Application,
) {
}
protected abstract async handle(req: Request, res: Response, next: NextFunction): Promise<void>;
public getRequestHandler(): RequestHandler {
return async (req, res, next): Promise<void> => {
try {
if (req.middlewares.find(m => m.constructor === this.constructor)) {
next();
} else {
req.middlewares.push(this);
return await this.handle(req, res, next);
}
} catch (e) {
next(e);
}
};
}
}

View File

@ -1,75 +1,111 @@
import ApplicationComponent from "../ApplicationComponent";
import {NextFunction, Request, Response, Router} from "express";
import {NextFunction, Request, Response} from "express";
import AuthGuard from "./AuthGuard";
import Controller from "../Controller";
import {ForbiddenHttpError} from "../HttpError";
import Middleware from "../Middleware";
import User from "./models/User";
import AuthProof from "./AuthProof";
export default class AuthComponent<T extends AuthGuard<any>> extends ApplicationComponent<void> {
private readonly authGuard: T;
export default class AuthComponent extends ApplicationComponent {
private readonly authGuard: AuthGuard<AuthProof<User>>;
public constructor(authGuard: T) {
public constructor(authGuard: AuthGuard<AuthProof<User>>) {
super();
this.authGuard = authGuard;
}
public async init(router: Router): Promise<void> {
router.use(async (req, res, next) => {
req.authGuard = this.authGuard;
req.models.user = res.locals.user = await (await req.authGuard.getProof(req))?.getResource();
next();
});
public async init(): Promise<void> {
this.use(AuthMiddleware);
}
public getAuthGuard(): T {
public getAuthGuard(): AuthGuard<AuthProof<User>> {
return this.authGuard;
}
}
export const REQUIRE_REQUEST_AUTH_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => {
let proof = await req.authGuard.isAuthenticatedViaRequest(req);
if (!proof) {
req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url,
}));
return;
}
export class AuthMiddleware extends Middleware {
private authGuard?: AuthGuard<AuthProof<User>>;
private user: User | null = null;
next();
};
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
this.authGuard = this.app.as(AuthComponent).getAuthGuard();
const proof = await this.authGuard.isAuthenticated(req.session!);
if (proof) {
this.user = await proof.getResource();
res.locals.user = this.user;
}
export const REQUIRE_AUTH_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => {
// Via request
let proof = await req.authGuard.isAuthenticatedViaRequest(req);
if (proof) {
next();
return;
}
// Via session
proof = await req.authGuard.isAuthenticated(req.session!);
if (!proof) {
req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url,
}));
return;
public getUser(): User | null {
return this.user;
}
next();
};
export const REQUIRE_GUEST_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => {
if (await req.authGuard.isAuthenticated(req.session!)) {
res.redirectBack();
return;
public getAuthGuard(): AuthGuard<AuthProof<User>> {
if (!this.authGuard) throw new Error('AuthGuard was not initialized.');
return this.authGuard;
}
}
next();
};
export const REQUIRE_ADMIN_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise<void> => {
if (!req.models.user || !req.models.user.is_admin) {
throw new ForbiddenHttpError('secret tool', req.url);
export class RequireRequestAuthMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const proof = await req.as(AuthMiddleware).getAuthGuard().isAuthenticatedViaRequest(req);
if (!proof) {
req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url,
}));
return;
}
next();
}
}
next();
};
export class RequireAuthMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const authGuard = req.as(AuthMiddleware).getAuthGuard();
// Via request
if (await authGuard.isAuthenticatedViaRequest(req)) {
next();
return;
}
// Via session
if (!await authGuard.isAuthenticated(req.session!)) {
req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url,
}));
return;
}
next();
}
}
export class RequireGuestMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
if (await req.as(AuthMiddleware).getAuthGuard().isAuthenticated(req.session!)) {
res.redirectBack();
return;
}
next();
}
}
export class RequireAdminMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const user = req.as(AuthMiddleware).getUser();
if (!user || !user.is_admin) {
throw new ForbiddenHttpError('secret tool', req.url);
}
next();
}
}

View File

@ -1,6 +1,6 @@
import Controller from "../Controller";
import {NextFunction, Request, Response} from "express";
import {REQUIRE_AUTH_MIDDLEWARE, REQUIRE_GUEST_MIDDLEWARE} from "./AuthComponent";
import {AuthMiddleware, RequireAuthMiddleware, RequireGuestMiddleware} from "./AuthComponent";
export default abstract class AuthController extends Controller {
public getRoutesPrefix(): string {
@ -8,10 +8,10 @@ export default abstract class AuthController extends Controller {
}
public routes() {
this.get('/', this.getAuth, 'auth', REQUIRE_GUEST_MIDDLEWARE);
this.post('/', this.postAuth, 'auth', REQUIRE_GUEST_MIDDLEWARE);
this.get('/', this.getAuth, 'auth', RequireGuestMiddleware);
this.post('/', this.postAuth, 'auth', RequireGuestMiddleware);
this.get('/check', this.getCheckAuth, 'check_auth');
this.post('/logout', this.postLogout, 'logout', REQUIRE_AUTH_MIDDLEWARE);
this.post('/logout', this.postLogout, 'logout', RequireAuthMiddleware);
}
protected async getAuth(req: Request, res: Response, next: NextFunction): Promise<void> {
@ -26,7 +26,7 @@ export default abstract class AuthController extends Controller {
protected abstract async getCheckAuth(req: Request, res: Response, next: NextFunction): Promise<void>;
protected async postLogout(req: Request, res: Response, next: NextFunction): Promise<void> {
const proof = await req.authGuard.getProof(req);
const proof = await req.as(AuthMiddleware).getAuthGuard().getProof(req);
await proof?.revoke();
req.flash('success', 'Successfully logged out.');
res.redirect(req.query.redirect_uri?.toString() || '/');

View File

@ -53,7 +53,7 @@ export default abstract class AuthGuard<P extends AuthProof<User>> {
proof: P,
onLogin?: (user: User) => Promise<void>,
onRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]>
): Promise<void> {
): Promise<User> {
if (!await proof.isValid()) throw new InvalidAuthProofError();
if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError();
@ -95,6 +95,8 @@ export default abstract class AuthGuard<P extends AuthProof<User>> {
// Login
session.is_authenticated = true;
if (onLogin) await onLogin(user);
return user;
}
}

View File

@ -9,17 +9,19 @@ import {AuthError, PendingApprovalAuthError, RegisterCallback} from "../AuthGuar
import geoip from "geoip-lite";
import AuthController from "../AuthController";
import RedirectBackComponent from "../../components/RedirectBackComponent";
import {AuthMiddleware} from "../AuthComponent";
import User from "../models/User";
export default abstract class MagicLinkAuthController extends AuthController {
public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise<void> {
public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise<User | null> {
if (magicLink.getSessionID() !== req.sessionID!) throw new BadOwnerMagicLink();
if (!await magicLink.isAuthorized()) throw new UnauthorizedMagicLink();
if (!await magicLink.isValid()) throw new InvalidMagicLink();
// Auth
try {
await req.authGuard.authenticateOrRegister(req.session!, magicLink, undefined, async (connection, user) => {
return await req.as(AuthMiddleware).getAuthGuard().authenticateOrRegister(req.session!, magicLink, undefined, async (connection, user) => {
const callbacks: RegisterCallback[] = [];
const userEmail = UserEmail.create({
@ -46,7 +48,7 @@ export default abstract class MagicLinkAuthController extends AuthController {
res.redirect('/');
}
});
return;
return null;
} else {
throw e;
}
@ -139,20 +141,21 @@ export default abstract class MagicLinkAuthController extends AuthController {
return;
}
await MagicLinkAuthController.checkAndAuth(req, res, magicLink);
const user = await MagicLinkAuthController.checkAndAuth(req, res, magicLink);
// Auth success
const username = req.models.user?.name;
res.format({
json: () => {
res.json({'status': 'success', 'message': `Welcome, ${username}!`});
},
default: () => {
req.flash('success', `Authentication success. Welcome, ${username}!`);
res.redirect('/');
},
});
return;
if (user) {
// Auth success
const username = user.name;
res.format({
json: () => {
res.json({'status': 'success', 'message': `Welcome, ${username}!`});
},
default: () => {
req.flash('success', `Authentication success. Welcome, ${username}!`);
res.redirect('/');
},
});
}
}
}

View File

@ -5,7 +5,7 @@ import ApplicationComponent from "../ApplicationComponent";
import {ForbiddenHttpError} from "../HttpError";
import Logger from "../Logger";
export default class AutoUpdateComponent extends ApplicationComponent<void> {
export default class AutoUpdateComponent extends ApplicationComponent {
public async checkSecuritySettings(): Promise<void> {
this.checkSecurityConfigField('gitlab_webhook_token');
}
@ -39,7 +39,7 @@ export default class AutoUpdateComponent extends ApplicationComponent<void> {
await this.runCommand(`yarn dist`);
// Stop app
await this.app!.stop();
await this.getApp().stop();
Logger.info('Success!');
} catch (e) {

View File

@ -2,8 +2,9 @@ import ApplicationComponent from "../ApplicationComponent";
import {Request, Router} from "express";
import crypto from "crypto";
import {BadRequestError} from "../HttpError";
import {AuthMiddleware} from "../auth/AuthComponent";
export default class CsrfProtectionComponent extends ApplicationComponent<void> {
export default class CsrfProtectionComponent extends ApplicationComponent {
private static readonly excluders: ((req: Request) => boolean)[] = [];
public static getCSRFToken(session: Express.Session): string {
@ -33,18 +34,17 @@ export default class CsrfProtectionComponent extends ApplicationComponent<void>
if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) {
try {
if (!(await req.authGuard.isAuthenticatedViaRequest(req))) {
if (!await req.as(AuthMiddleware).getAuthGuard().isAuthenticatedViaRequest(req)) {
if (req.session.csrf === undefined) {
throw new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`);
return next(new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`));
} else if (req.body.csrf === undefined) {
throw new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`);
return next(new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`));
} else if (req.session.csrf !== req.body.csrf) {
throw new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`);
return next(new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`));
}
}
} catch (e) {
next(e);
return;
return next(e);
}
}
next();

View File

@ -3,8 +3,10 @@ import express, {Express, Router} from "express";
import Logger from "../Logger";
import {Server} from "http";
import compression from "compression";
import Middleware from "../Middleware";
import {Type} from "../Utils";
export default class ExpressAppComponent extends ApplicationComponent<void> {
export default class ExpressAppComponent extends ApplicationComponent {
private readonly addr: string;
private readonly port: number;
private server?: Server;
@ -39,8 +41,12 @@ export default class ExpressAppComponent extends ApplicationComponent<void> {
router.use(compression());
router.use((req, res, next) => {
req.models = {};
req.modelCollections = {};
req.middlewares = [];
req.as = <M extends Middleware>(type: Type<M>) => {
const middleware = req.middlewares.find(m => m.constructor === type);
if (!middleware) throw new Error('Middleware ' + type.name + ' not present in this request.');
return middleware as M;
};
next();
});
}

View File

@ -1,7 +1,7 @@
import ApplicationComponent from "../ApplicationComponent";
import {Router} from "express";
export default class FormHelperComponent extends ApplicationComponent<void> {
export default class FormHelperComponent extends ApplicationComponent {
public async init(router: Router): Promise<void> {
router.use((req, res, next) => {
if (!req.session) {

View File

@ -3,7 +3,7 @@ import onFinished from "on-finished";
import Logger from "../Logger";
import {Request, Response, Router} from "express";
export default class LogRequestsComponent extends ApplicationComponent<void> {
export default class LogRequestsComponent extends ApplicationComponent {
private static fullRequests: boolean = false;
public static logFullHttpRequests() {

View File

@ -4,8 +4,7 @@ import Mail from "../Mail";
import config from "config";
import SecurityError from "../SecurityError";
export default class MailComponent extends ApplicationComponent<void> {
export default class MailComponent extends ApplicationComponent {
public async checkSecuritySettings(): Promise<void> {
if (!config.get<boolean>('mail.secure')) {

View File

@ -4,7 +4,7 @@ import {ServiceUnavailableHttpError} from "../HttpError";
import Application from "../Application";
import config from "config";
export default class MaintenanceComponent extends ApplicationComponent<void> {
export default class MaintenanceComponent extends ApplicationComponent {
private readonly application: Application;
private readonly canServe: () => boolean;

View File

@ -2,7 +2,7 @@ import ApplicationComponent from "../ApplicationComponent";
import {Express} from "express";
import MysqlConnectionManager from "../db/MysqlConnectionManager";
export default class MysqlComponent extends ApplicationComponent<void> {
export default class MysqlComponent extends ApplicationComponent {
public async start(app: Express): Promise<void> {
await this.prepare('Mysql connection', () => MysqlConnectionManager.prepare());
}
@ -12,7 +12,7 @@ export default class MysqlComponent extends ApplicationComponent<void> {
}
public canServe(): boolean {
return MysqlConnectionManager.pool !== undefined;
return MysqlConnectionManager.isReady();
}
}

View File

@ -1,6 +1,6 @@
import nunjucks, {Environment} from "nunjucks";
import config from "config";
import {Express, Router} from "express";
import {Express, NextFunction, Request, Response, Router} from "express";
import ApplicationComponent from "../ApplicationComponent";
import Controller from "../Controller";
import {ServerError} from "../HttpError";
@ -8,8 +8,11 @@ import * as querystring from "querystring";
import {ParsedUrlQueryInput} from "querystring";
import * as util from "util";
import * as path from "path";
import * as fs from "fs";
import Logger from "../Logger";
import Middleware from "../Middleware";
export default class NunjucksComponent extends ApplicationComponent<void> {
export default class NunjucksComponent extends ApplicationComponent {
private readonly viewsPath: string[];
private env?: Environment;
@ -20,13 +23,14 @@ export default class NunjucksComponent extends ApplicationComponent<void> {
public async start(app: Express): Promise<void> {
let coreVersion = 'unknown';
const file = fs.existsSync(path.join(__dirname, '../../package.json')) ?
path.join(__dirname, '../../package.json') :
path.join(__dirname, '../package.json');
try {
coreVersion = require('../../package.json').version;
coreVersion = JSON.parse(fs.readFileSync(file).toString()).version;
} catch (e) {
try {
coreVersion = require('../package.json').version;
} catch (e) {
}
Logger.warn('Couldn\'t determine coreVersion.', e);
}
this.env = new nunjucks.Environment([
@ -43,7 +47,7 @@ export default class NunjucksComponent extends ApplicationComponent<void> {
if (path === null) throw new ServerError(`Route ${route} not found.`);
return path;
})
.addGlobal('app_version', this.app!.getVersion())
.addGlobal('app_version', this.getApp().getVersion())
.addGlobal('core_version', coreVersion)
.addGlobal('querystring', querystring)
.addGlobal('app', config.get('app'))
@ -59,18 +63,29 @@ export default class NunjucksComponent extends ApplicationComponent<void> {
}
public async init(router: Router): Promise<void> {
router.use((req, res, next) => {
req.env = this.env!;
res.locals.url = req.url;
res.locals.params = req.params;
res.locals.query = req.query;
res.locals.body = req.body;
next();
});
this.use(NunjucksMiddleware);
}
public getEnv(): Environment | undefined {
return this.env;
}
}
}
export class NunjucksMiddleware extends Middleware {
private env?: Environment;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
this.env = this.app.as(NunjucksComponent).getEnv();
res.locals.url = req.url;
res.locals.params = req.params;
res.locals.query = req.query;
res.locals.body = req.body;
next();
}
public getEnvironment(): Environment {
if (!this.env) throw new Error('Environment not initialized.');
return this.env;
}
}

View File

@ -4,7 +4,7 @@ import {ServerError} from "../HttpError";
import onFinished from "on-finished";
import Logger from "../Logger";
export default class RedirectBackComponent extends ApplicationComponent<void> {
export default class RedirectBackComponent extends ApplicationComponent {
public static getPreviousURL(req: Request, defaultUrl?: string): string | undefined {
return req.session?.previousUrl || defaultUrl;
}

View File

@ -9,7 +9,7 @@ import CacheProvider from "../CacheProvider";
const RedisStore = connect_redis(session);
export default class RedisComponent extends ApplicationComponent<void> implements CacheProvider {
export default class RedisComponent extends ApplicationComponent implements CacheProvider {
private redisClient?: RedisClient;
private store?: Store;

View File

@ -3,7 +3,7 @@ import express, {Router} from "express";
import {PathParams} from "express-serve-static-core";
import * as path from "path";
export default class ServeStaticDirectoryComponent extends ApplicationComponent<void> {
export default class ServeStaticDirectoryComponent extends ApplicationComponent {
private readonly root: string;
private readonly path?: PathParams;

View File

@ -6,7 +6,7 @@ import flash from "connect-flash";
import {Router} from "express";
import SecurityError from "../SecurityError";
export default class SessionComponent extends ApplicationComponent<void> {
export default class SessionComponent extends ApplicationComponent {
private readonly storeComponent: RedisComponent;
public constructor(storeComponent: RedisComponent) {

View File

@ -11,7 +11,7 @@ import RedisComponent from "./RedisComponent";
import WebSocketListener from "../WebSocketListener";
import NunjucksComponent from "./NunjucksComponent";
export default class WebSocketServerComponent extends ApplicationComponent<void> {
export default class WebSocketServerComponent extends ApplicationComponent {
private wss?: WebSocket.Server;
constructor(

View File

@ -7,8 +7,9 @@ import ModelFactory from "./ModelFactory";
import ModelRelation from "./ModelRelation";
import ModelQuery, {ModelQueryResult, SelectFields} from "./ModelQuery";
import {Request} from "express";
import Extendable from "../Extendable";
export default abstract class Model {
export default abstract class Model implements Extendable<ModelComponent<Model>> {
public static get table(): string {
return this.name
.replace(/(?:^|\.?)([A-Z])/g, (x, y) => '_' + y.toLowerCase())
@ -44,14 +45,14 @@ export default abstract class Model {
return ModelFactory.get(this).paginate(request, perPage, query);
}
protected readonly _factory: ModelFactory<any>;
private readonly _components: ModelComponent<any>[] = [];
private readonly _validators: { [key: string]: Validator<any> } = {};
protected readonly _factory: ModelFactory<Model>;
private readonly _components: ModelComponent<this>[] = [];
private readonly _validators: { [key: string]: Validator<any> | undefined } = {};
private _exists: boolean;
[key: string]: any;
public constructor(factory: ModelFactory<any>, isNew: boolean) {
public constructor(factory: ModelFactory<Model>, isNew: boolean) {
if (!factory || !(factory instanceof ModelFactory)) throw new Error('Cannot instantiate model directly.');
this._factory = factory;
this.init();
@ -71,15 +72,26 @@ export default abstract class Model {
this._components.push(modelComponent);
}
public as<T extends ModelComponent<any>>(type: Type<T>): T {
public as<C extends ModelComponent<Model>>(type: Type<C>): C {
for (const component of this._components) {
if (component instanceof type) {
return <any>this;
return this as unknown as C;
}
}
throw new Error(`Component ${type.name} was not initialized for this ${this.constructor.name}.`);
}
public asOptional<C extends ModelComponent<Model>>(type: Type<C>): C | null {
for (const component of this._components) {
if (component instanceof type) {
return this as unknown as C;
}
}
return null;
}
public updateWithData(data: any) {
for (const property of this._properties) {
if (data[property] !== undefined) {
@ -92,19 +104,16 @@ export default abstract class Model {
* Override this to automatically fill obvious missing data i.e. from relation or default value that are fetched
* asynchronously.
*/
protected async autoFill(): Promise<void> {
}
protected async autoFill?(): Promise<void>;
protected async beforeSave(connection: Connection): Promise<void> {
}
protected async beforeSave?(connection: Connection): Promise<void>;
protected async afterSave(): Promise<void> {
}
protected async afterSave?(): Promise<void>;
public async save(connection?: Connection, postHook?: (callback: () => Promise<void>) => void): Promise<void> {
if (connection && !postHook) throw new Error('If connection is provided, postHook must be provided too.');
await this.autoFill();
await this.autoFill?.();
await this.validate(false, connection);
const needs_full_update = connection ?
@ -120,7 +129,7 @@ export default abstract class Model {
this.updateWithData(result.results[0]);
}
await this.afterSave();
await this.afterSave?.();
};
if (connection) {
@ -132,7 +141,7 @@ export default abstract class Model {
private async saveTransaction(connection: Connection): Promise<boolean> {
// Before save
await this.beforeSave(connection);
await this.beforeSave?.(connection);
if (!this.exists() && this.hasOwnProperty('created_at')) {
this.created_at = new Date();
}

View File

@ -13,7 +13,7 @@ export default abstract class ModelComponent<T extends Model> {
}
public applyToModel(): void {
this.init();
this.init?.();
for (const property of this._properties) {
if (!property.startsWith('_')) {
@ -32,7 +32,7 @@ export default abstract class ModelComponent<T extends Model> {
}
}
protected abstract init(): void;
protected init?(): void;
protected setValidation<V>(propertyName: keyof this): Validator<V> {
const validator = new Validator<V>();

View File

@ -21,6 +21,10 @@ export default class MysqlConnectionManager {
private static migrationsRegistered: boolean = false;
private static readonly migrations: Migration[] = [];
public static isReady(): boolean {
return this.databaseReady && this.currentPool !== undefined;
}
public static registerMigrations(migrations: Type<Migration>[]) {
if (!this.migrationsRegistered) {
this.migrationsRegistered = true;

View File

@ -1,6 +1,5 @@
import config from "config";
import Controller from "../Controller";
import {REQUIRE_ADMIN_MIDDLEWARE, REQUIRE_AUTH_MIDDLEWARE} from "../auth/AuthComponent";
import User from "../auth/models/User";
import {Request, Response} from "express";
import {BadRequestError, NotFoundHttpError} from "../HttpError";
@ -8,6 +7,7 @@ import Mail from "../Mail";
import {ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE} from "../Mails";
import UserEmail from "../auth/models/UserEmail";
import UserApprovedComponent from "../auth/models/UserApprovedComponent";
import {RequireAdminMiddleware, RequireAuthMiddleware} from "../auth/AuthComponent";
export default class BackendController extends Controller {
private static readonly menu: BackendMenuElement[] = [];
@ -37,11 +37,11 @@ export default class BackendController extends Controller {
}
public routes(): void {
this.get('/', this.getIndex, 'backend', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE);
this.get('/', this.getIndex, 'backend', RequireAuthMiddleware, RequireAdminMiddleware);
if (User.isApprovalMode()) {
this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE);
this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE);
this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE);
this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval', RequireAuthMiddleware, RequireAdminMiddleware);
this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account', RequireAuthMiddleware, RequireAdminMiddleware);
this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account', RequireAuthMiddleware, RequireAdminMiddleware);
}
}

View File

@ -1,21 +1,18 @@
import {Environment} from "nunjucks";
import Model from "../db/Model";
import AuthGuard from "../auth/AuthGuard";
import {Files} from "formidable";
import User from "../auth/models/User";
import {Type} from "../Utils";
import Middleware from "../Middleware";
declare global {
namespace Express {
export interface Request {
env: Environment;
models: {
user?: User | null,
[p: string]: Model | null | undefined,
};
modelCollections: { [p: string]: Model[] | null };
authGuard: AuthGuard<any>;
files: Files;
middlewares: Middleware[];
as<M extends Middleware>(type: Type<M>): M;
flash(): { [key: string]: string[] };
flash(message: string): any;