diff --git a/src/Application.ts b/src/Application.ts index 594dcda..d823e8c 100644 --- a/src/Application.ts +++ b/src/Application.ts @@ -16,14 +16,15 @@ import SecurityError from "./SecurityError"; import * as path from "path"; import CacheProvider from "./CacheProvider"; import RedisComponent from "./components/RedisComponent"; +import Extendable from "./Extendable"; import TemplateError = lib.TemplateError; -export default abstract class Application { +export default abstract class Application implements Extendable { private readonly version: string; private readonly ignoreCommandLine: boolean; private readonly controllers: Controller[] = []; - private readonly webSocketListeners: { [p: string]: WebSocketListener } = {}; - private readonly components: ApplicationComponent[] = []; + private readonly webSocketListeners: { [p: string]: WebSocketListener } = {}; + private readonly components: ApplicationComponent[] = []; private cacheProvider?: CacheProvider; private ready: boolean = false; @@ -37,8 +38,9 @@ export default abstract class Application { protected abstract async init(): Promise; - protected use(thing: Controller | WebSocketListener | ApplicationComponent) { + protected use(thing: Controller | WebSocketListener | ApplicationComponent) { if (thing instanceof Controller) { + thing.setApp(this); this.controllers.push(thing); } else if (thing instanceof WebSocketListener) { const path = thing.path(); @@ -151,13 +153,22 @@ export default abstract class Application { // Start components for (const component of this.components) { - await component.start(app); + await component.start?.(app); } // Components routes for (const component of this.components) { - await component.init(initRouter); - await component.handle(handleRouter); + if (component.init) { + component.setCurrentRouter(initRouter); + await component.init(initRouter); + } + + if (component.handle) { + component.setCurrentRouter(handleRouter); + await component.handle(handleRouter); + } + + component.setCurrentRouter(null); } // Routes @@ -203,7 +214,7 @@ export default abstract class Application { // Check security fields for (const component of this.components) { - await component.checkSecuritySettings(); + await component.checkSecuritySettings?.(); } } @@ -211,7 +222,7 @@ export default abstract class Application { Logger.info('Stopping application...'); for (const component of this.components) { - await component.stop(); + await component.stop?.(); } Logger.info(`${this.constructor.name} v${this.version} - bye`); @@ -219,7 +230,7 @@ export default abstract class Application { private routes(initRouter: Router, handleRouter: Router) { for (const controller of this.controllers) { - if (controller.hasGlobalHandlers()) { + if (controller.hasGlobalMiddlewares()) { controller.setupGlobalHandlers(handleRouter); Logger.info(`Registered global middlewares for controller ${controller.constructor.name}`); @@ -247,7 +258,7 @@ export default abstract class Application { return this.version; } - public getWebSocketListeners(): { [p: string]: WebSocketListener } { + public getWebSocketListeners(): { [p: string]: WebSocketListener } { return this.webSocketListeners; } @@ -255,7 +266,14 @@ export default abstract class Application { return this.cacheProvider; } - public getComponent>(type: Type): T | undefined { - return this.components.find(component => component.constructor === type); + public as(type: Type): C { + const component = this.components.find(component => component.constructor === type); + if (!component) throw new Error(`This app doesn't have a ${type.name} component.`); + return component as C; } -} \ No newline at end of file + + public asOptional(type: Type): C | null { + const component = this.components.find(component => component.constructor === type); + return component ? component as C : null; + } +} diff --git a/src/ApplicationComponent.ts b/src/ApplicationComponent.ts index 4feaa59..dd46b47 100644 --- a/src/ApplicationComponent.ts +++ b/src/ApplicationComponent.ts @@ -1,42 +1,24 @@ import {Express, Router} from "express"; import Logger from "./Logger"; -import {sleep} from "./Utils"; +import {sleep, Type} from "./Utils"; import Application from "./Application"; import config from "config"; import SecurityError from "./SecurityError"; +import Middleware from "./Middleware"; -export default abstract class ApplicationComponent { - private val?: T; - protected app?: Application; +export default abstract class ApplicationComponent { + private currentRouter?: Router; + private app?: Application; - public async checkSecuritySettings(): Promise { - } + public async checkSecuritySettings?(): Promise; - public async start(app: Express): Promise { - } + public async start?(expressApp: Express): Promise; - public async init(router: Router): Promise { - } + public async init?(router: Router): Promise; - public async handle(router: Router): Promise { - } + public async handle?(router: Router): Promise; - public async stop(): Promise { - - } - - protected export(val: T) { - this.val = val; - } - - public import(): T { - if (!this.val) throw 'Cannot import if nothing was exported.'; - return this.val; - } - - public setApp(app: Application) { - this.app = app; - } + public async stop?(): Promise; protected async prepare(name: string, prepare: () => Promise): Promise { let err; @@ -71,4 +53,34 @@ export default abstract class ApplicationComponent { throw new SecurityError(`${field} field not configured.`); } } + + protected use(middleware: Type): void { + if (!this.currentRouter) throw new Error('Cannot call this method outside init() and handle().'); + + const instance = new middleware(this.getApp()); + this.currentRouter.use(async (req, res, next) => { + try { + await instance.getRequestHandler()(req, res, next); + } catch (e) { + next(e); + } + }); + } + + protected getCurrentRouter(): Router | null { + return this.currentRouter || null; + } + + public setCurrentRouter(router: Router | null): void { + this.currentRouter = router || undefined; + } + + protected getApp(): Application { + if (!this.app) throw new Error('app field not initialized.'); + return this.app; + } + + public setApp(app: Application) { + this.app = app; + } } diff --git a/src/Controller.ts b/src/Controller.ts index 791bcdc..5e33e13 100644 --- a/src/Controller.ts +++ b/src/Controller.ts @@ -2,10 +2,13 @@ import express, {IRouter, RequestHandler, Router} from "express"; import {PathParams} from "express-serve-static-core"; import config from "config"; import Logger from "./Logger"; -import Validator, {FileError, ValidationBag} from "./db/Validator"; +import Validator, {ValidationBag} from "./db/Validator"; import FileUploadMiddleware from "./FileUploadMiddleware"; import * as querystring from "querystring"; import {ParsedUrlQueryInput} from "querystring"; +import Middleware from "./Middleware"; +import {Type} from "./Utils"; +import Application from "./Application"; export default abstract class Controller { private static readonly routes: { [p: string]: string } = {}; @@ -39,18 +42,19 @@ export default abstract class Controller { private readonly router: Router = express.Router(); private readonly fileUploadFormRouter: Router = express.Router(); + private app?: Application; - public getGlobalHandlers(): RequestHandler[] { + public getGlobalMiddlewares(): Middleware[] { return []; } - public hasGlobalHandlers(): boolean { - return this.getGlobalHandlers().length > 0; + public hasGlobalMiddlewares(): boolean { + return this.getGlobalMiddlewares().length > 0; } public setupGlobalHandlers(router: Router): void { - for (const globalHandler of this.getGlobalHandlers()) { - router.use(this.wrap(globalHandler)); + for (const middleware of this.getGlobalMiddlewares()) { + router.use(this.wrap(middleware.getRequestHandler())); } } @@ -75,19 +79,19 @@ export default abstract class Controller { this.router.use(handler); } - protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { + protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type)[]) { this.handle('get', path, handler, routeName, ...middlewares); } - protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { + protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type)[]) { this.handle('post', path, handler, routeName, ...middlewares); } - protected put(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { + protected put(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type)[]) { this.handle('put', path, handler, routeName, ...middlewares); } - protected delete(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (RequestHandler | FileUploadMiddleware)[]) { + protected delete(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: (Type)[]) { this.handle('delete', path, handler, routeName, ...middlewares); } @@ -96,14 +100,15 @@ export default abstract class Controller { path: PathParams, handler: RequestHandler, routeName?: string, - ...middlewares: (RequestHandler | FileUploadMiddleware)[] + ...middlewares: (Type)[] ): void { this.registerRoutes(path, handler, routeName); for (const middleware of middlewares) { - if (middleware instanceof FileUploadMiddleware) { - this.fileUploadFormRouter[action](path, this.wrap(FILE_UPLOAD_MIDDLEWARE(middleware))); + const instance = new middleware(this.getApp()); + if (instance instanceof FileUploadMiddleware) { + this.fileUploadFormRouter[action](path, this.wrap(instance.getRequestHandler())); } else { - this.router[action](path, this.wrap(middleware)); + this.router[action](path, this.wrap(instance.getRequestHandler())); } } this.router[action](path, this.wrap(handler)); @@ -164,33 +169,15 @@ export default abstract class Controller { if (bag.hasMessages()) throw bag; } + + protected getApp(): Application { + if (!this.app) throw new Error('Application not initialized.'); + return this.app; + } + + public setApp(app: Application) { + this.app = app; + } } export type RouteParams = { [p: string]: string } | string[] | string | number; - -const FILE_UPLOAD_MIDDLEWARE: (fileUploadMiddleware: FileUploadMiddleware) => RequestHandler = (fileUploadMiddleware: FileUploadMiddleware) => { - return async (req, res, next) => { - const form = fileUploadMiddleware.formFactory(); - try { - await new Promise((resolve, reject) => { - form.parse(req, (err, fields, files) => { - if (err) { - reject(err); - return; - } - req.body = fields; - req.files = files; - resolve(); - }); - }); - } catch (e) { - const bag = new ValidationBag(); - const fileError = new FileError(e); - fileError.thingName = fileUploadMiddleware.defaultField; - bag.addMessage(fileError); - next(bag); - return; - } - next(); - }; -}; diff --git a/src/Extendable.ts b/src/Extendable.ts new file mode 100644 index 0000000..25bce64 --- /dev/null +++ b/src/Extendable.ts @@ -0,0 +1,7 @@ +import {Type} from "./Utils"; + +export default interface Extendable { + as(type: Type): C; + + asOptional(type: Type): C | null; +} \ No newline at end of file diff --git a/src/FileUploadMiddleware.ts b/src/FileUploadMiddleware.ts index 499fb25..d474b1c 100644 --- a/src/FileUploadMiddleware.ts +++ b/src/FileUploadMiddleware.ts @@ -1,11 +1,35 @@ import {IncomingForm} from "formidable"; +import Middleware from "./Middleware"; +import {NextFunction, Request, Response} from "express"; +import {FileError, ValidationBag} from "./db/Validator"; -export default class FileUploadMiddleware { - public readonly formFactory: () => IncomingForm; - public readonly defaultField: string; +export default abstract class FileUploadMiddleware extends Middleware { + protected abstract makeForm(): IncomingForm; - public constructor(formFactory: () => IncomingForm, defaultField: string) { - this.formFactory = formFactory; - this.defaultField = defaultField; + protected abstract getDefaultField(): string; + + public async handle(req: Request, res: Response, next: NextFunction): Promise { + const form = this.makeForm(); + try { + await new Promise((resolve, reject) => { + form.parse(req, (err, fields, files) => { + if (err) { + reject(err); + return; + } + req.body = fields; + req.files = files; + resolve(); + }); + }); + } catch (e) { + const bag = new ValidationBag(); + const fileError = new FileError(e); + fileError.thingName = this.getDefaultField(); + bag.addMessage(fileError); + next(bag); + return; + } + next(); } } \ No newline at end of file diff --git a/src/Middleware.ts b/src/Middleware.ts new file mode 100644 index 0000000..344dbac --- /dev/null +++ b/src/Middleware.ts @@ -0,0 +1,28 @@ +import {RequestHandler} from "express"; +import {NextFunction, Request, Response} from "express-serve-static-core"; +import Application from "./Application"; + +export default abstract class Middleware { + public constructor( + protected readonly app: Application, + ) { + } + + + protected abstract async handle(req: Request, res: Response, next: NextFunction): Promise; + + public getRequestHandler(): RequestHandler { + return async (req, res, next): Promise => { + try { + if (req.middlewares.find(m => m.constructor === this.constructor)) { + next(); + } else { + req.middlewares.push(this); + return await this.handle(req, res, next); + } + } catch (e) { + next(e); + } + }; + } +} diff --git a/src/auth/AuthComponent.ts b/src/auth/AuthComponent.ts index ce92b7a..50f63c9 100644 --- a/src/auth/AuthComponent.ts +++ b/src/auth/AuthComponent.ts @@ -1,75 +1,111 @@ import ApplicationComponent from "../ApplicationComponent"; -import {NextFunction, Request, Response, Router} from "express"; +import {NextFunction, Request, Response} from "express"; import AuthGuard from "./AuthGuard"; import Controller from "../Controller"; import {ForbiddenHttpError} from "../HttpError"; +import Middleware from "../Middleware"; +import User from "./models/User"; +import AuthProof from "./AuthProof"; -export default class AuthComponent> extends ApplicationComponent { - private readonly authGuard: T; +export default class AuthComponent extends ApplicationComponent { + private readonly authGuard: AuthGuard>; - public constructor(authGuard: T) { + public constructor(authGuard: AuthGuard>) { super(); this.authGuard = authGuard; } - public async init(router: Router): Promise { - router.use(async (req, res, next) => { - req.authGuard = this.authGuard; - req.models.user = res.locals.user = await (await req.authGuard.getProof(req))?.getResource(); - next(); - }); + public async init(): Promise { + this.use(AuthMiddleware); } - public getAuthGuard(): T { + public getAuthGuard(): AuthGuard> { return this.authGuard; } } -export const REQUIRE_REQUEST_AUTH_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise => { - let proof = await req.authGuard.isAuthenticatedViaRequest(req); - if (!proof) { - req.flash('error', `You must be logged in to access ${req.url}.`); - res.redirect(Controller.route('auth', undefined, { - redirect_uri: req.url, - })); - return; - } +export class AuthMiddleware extends Middleware { + private authGuard?: AuthGuard>; + private user: User | null = null; - next(); -}; + protected async handle(req: Request, res: Response, next: NextFunction): Promise { + this.authGuard = this.app.as(AuthComponent).getAuthGuard(); + + const proof = await this.authGuard.isAuthenticated(req.session!); + if (proof) { + this.user = await proof.getResource(); + res.locals.user = this.user; + } -export const REQUIRE_AUTH_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise => { - // Via request - let proof = await req.authGuard.isAuthenticatedViaRequest(req); - if (proof) { next(); - return; } - // Via session - proof = await req.authGuard.isAuthenticated(req.session!); - if (!proof) { - req.flash('error', `You must be logged in to access ${req.url}.`); - res.redirect(Controller.route('auth', undefined, { - redirect_uri: req.url, - })); - return; + public getUser(): User | null { + return this.user; } - next(); -}; -export const REQUIRE_GUEST_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise => { - if (await req.authGuard.isAuthenticated(req.session!)) { - res.redirectBack(); - return; + public getAuthGuard(): AuthGuard> { + if (!this.authGuard) throw new Error('AuthGuard was not initialized.'); + return this.authGuard; } +} - next(); -}; -export const REQUIRE_ADMIN_MIDDLEWARE = async (req: Request, res: Response, next: NextFunction): Promise => { - if (!req.models.user || !req.models.user.is_admin) { - throw new ForbiddenHttpError('secret tool', req.url); +export class RequireRequestAuthMiddleware extends Middleware { + protected async handle(req: Request, res: Response, next: NextFunction): Promise { + const proof = await req.as(AuthMiddleware).getAuthGuard().isAuthenticatedViaRequest(req); + if (!proof) { + req.flash('error', `You must be logged in to access ${req.url}.`); + res.redirect(Controller.route('auth', undefined, { + redirect_uri: req.url, + })); + return; + } + + next(); } +} - next(); -}; \ No newline at end of file +export class RequireAuthMiddleware extends Middleware { + protected async handle(req: Request, res: Response, next: NextFunction): Promise { + const authGuard = req.as(AuthMiddleware).getAuthGuard(); + // Via request + + if (await authGuard.isAuthenticatedViaRequest(req)) { + next(); + return; + } + + // Via session + if (!await authGuard.isAuthenticated(req.session!)) { + req.flash('error', `You must be logged in to access ${req.url}.`); + res.redirect(Controller.route('auth', undefined, { + redirect_uri: req.url, + })); + return; + } + + next(); + } +} + +export class RequireGuestMiddleware extends Middleware { + protected async handle(req: Request, res: Response, next: NextFunction): Promise { + if (await req.as(AuthMiddleware).getAuthGuard().isAuthenticated(req.session!)) { + res.redirectBack(); + return; + } + + next(); + } +} + +export class RequireAdminMiddleware extends Middleware { + protected async handle(req: Request, res: Response, next: NextFunction): Promise { + const user = req.as(AuthMiddleware).getUser(); + if (!user || !user.is_admin) { + throw new ForbiddenHttpError('secret tool', req.url); + } + + next(); + } +} diff --git a/src/auth/AuthController.ts b/src/auth/AuthController.ts index 5d09617..89db190 100644 --- a/src/auth/AuthController.ts +++ b/src/auth/AuthController.ts @@ -1,6 +1,6 @@ import Controller from "../Controller"; import {NextFunction, Request, Response} from "express"; -import {REQUIRE_AUTH_MIDDLEWARE, REQUIRE_GUEST_MIDDLEWARE} from "./AuthComponent"; +import {AuthMiddleware, RequireAuthMiddleware, RequireGuestMiddleware} from "./AuthComponent"; export default abstract class AuthController extends Controller { public getRoutesPrefix(): string { @@ -8,10 +8,10 @@ export default abstract class AuthController extends Controller { } public routes() { - this.get('/', this.getAuth, 'auth', REQUIRE_GUEST_MIDDLEWARE); - this.post('/', this.postAuth, 'auth', REQUIRE_GUEST_MIDDLEWARE); + this.get('/', this.getAuth, 'auth', RequireGuestMiddleware); + this.post('/', this.postAuth, 'auth', RequireGuestMiddleware); this.get('/check', this.getCheckAuth, 'check_auth'); - this.post('/logout', this.postLogout, 'logout', REQUIRE_AUTH_MIDDLEWARE); + this.post('/logout', this.postLogout, 'logout', RequireAuthMiddleware); } protected async getAuth(req: Request, res: Response, next: NextFunction): Promise { @@ -26,7 +26,7 @@ export default abstract class AuthController extends Controller { protected abstract async getCheckAuth(req: Request, res: Response, next: NextFunction): Promise; protected async postLogout(req: Request, res: Response, next: NextFunction): Promise { - const proof = await req.authGuard.getProof(req); + const proof = await req.as(AuthMiddleware).getAuthGuard().getProof(req); await proof?.revoke(); req.flash('success', 'Successfully logged out.'); res.redirect(req.query.redirect_uri?.toString() || '/'); diff --git a/src/auth/AuthGuard.ts b/src/auth/AuthGuard.ts index e43dd57..6f04542 100644 --- a/src/auth/AuthGuard.ts +++ b/src/auth/AuthGuard.ts @@ -53,7 +53,7 @@ export default abstract class AuthGuard

> { proof: P, onLogin?: (user: User) => Promise, onRegister?: (connection: Connection, user: User) => Promise - ): Promise { + ): Promise { if (!await proof.isValid()) throw new InvalidAuthProofError(); if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError(); @@ -95,6 +95,8 @@ export default abstract class AuthGuard

> { // Login session.is_authenticated = true; if (onLogin) await onLogin(user); + + return user; } } diff --git a/src/auth/magic_link/MagicLinkAuthController.ts b/src/auth/magic_link/MagicLinkAuthController.ts index 7602514..9493cf7 100644 --- a/src/auth/magic_link/MagicLinkAuthController.ts +++ b/src/auth/magic_link/MagicLinkAuthController.ts @@ -9,17 +9,19 @@ import {AuthError, PendingApprovalAuthError, RegisterCallback} from "../AuthGuar import geoip from "geoip-lite"; import AuthController from "../AuthController"; import RedirectBackComponent from "../../components/RedirectBackComponent"; +import {AuthMiddleware} from "../AuthComponent"; +import User from "../models/User"; export default abstract class MagicLinkAuthController extends AuthController { - public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise { + public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise { if (magicLink.getSessionID() !== req.sessionID!) throw new BadOwnerMagicLink(); if (!await magicLink.isAuthorized()) throw new UnauthorizedMagicLink(); if (!await magicLink.isValid()) throw new InvalidMagicLink(); // Auth try { - await req.authGuard.authenticateOrRegister(req.session!, magicLink, undefined, async (connection, user) => { + return await req.as(AuthMiddleware).getAuthGuard().authenticateOrRegister(req.session!, magicLink, undefined, async (connection, user) => { const callbacks: RegisterCallback[] = []; const userEmail = UserEmail.create({ @@ -46,7 +48,7 @@ export default abstract class MagicLinkAuthController extends AuthController { res.redirect('/'); } }); - return; + return null; } else { throw e; } @@ -139,20 +141,21 @@ export default abstract class MagicLinkAuthController extends AuthController { return; } - await MagicLinkAuthController.checkAndAuth(req, res, magicLink); + const user = await MagicLinkAuthController.checkAndAuth(req, res, magicLink); - // Auth success - const username = req.models.user?.name; - res.format({ - json: () => { - res.json({'status': 'success', 'message': `Welcome, ${username}!`}); - }, - default: () => { - req.flash('success', `Authentication success. Welcome, ${username}!`); - res.redirect('/'); - }, - }); - return; + if (user) { + // Auth success + const username = user.name; + res.format({ + json: () => { + res.json({'status': 'success', 'message': `Welcome, ${username}!`}); + }, + default: () => { + req.flash('success', `Authentication success. Welcome, ${username}!`); + res.redirect('/'); + }, + }); + } } } diff --git a/src/components/AutoUpdateComponent.ts b/src/components/AutoUpdateComponent.ts index 0f82972..10fa90c 100644 --- a/src/components/AutoUpdateComponent.ts +++ b/src/components/AutoUpdateComponent.ts @@ -5,7 +5,7 @@ import ApplicationComponent from "../ApplicationComponent"; import {ForbiddenHttpError} from "../HttpError"; import Logger from "../Logger"; -export default class AutoUpdateComponent extends ApplicationComponent { +export default class AutoUpdateComponent extends ApplicationComponent { public async checkSecuritySettings(): Promise { this.checkSecurityConfigField('gitlab_webhook_token'); } @@ -39,7 +39,7 @@ export default class AutoUpdateComponent extends ApplicationComponent { await this.runCommand(`yarn dist`); // Stop app - await this.app!.stop(); + await this.getApp().stop(); Logger.info('Success!'); } catch (e) { diff --git a/src/components/CsrfProtectionComponent.ts b/src/components/CsrfProtectionComponent.ts index bc6584d..467a959 100644 --- a/src/components/CsrfProtectionComponent.ts +++ b/src/components/CsrfProtectionComponent.ts @@ -2,8 +2,9 @@ import ApplicationComponent from "../ApplicationComponent"; import {Request, Router} from "express"; import crypto from "crypto"; import {BadRequestError} from "../HttpError"; +import {AuthMiddleware} from "../auth/AuthComponent"; -export default class CsrfProtectionComponent extends ApplicationComponent { +export default class CsrfProtectionComponent extends ApplicationComponent { private static readonly excluders: ((req: Request) => boolean)[] = []; public static getCSRFToken(session: Express.Session): string { @@ -33,18 +34,17 @@ export default class CsrfProtectionComponent extends ApplicationComponent if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) { try { - if (!(await req.authGuard.isAuthenticatedViaRequest(req))) { + if (!await req.as(AuthMiddleware).getAuthGuard().isAuthenticatedViaRequest(req)) { if (req.session.csrf === undefined) { - throw new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`); + return next(new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`)); } else if (req.body.csrf === undefined) { - throw new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`); + return next(new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`)); } else if (req.session.csrf !== req.body.csrf) { - throw new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`); + return next(new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`)); } } } catch (e) { - next(e); - return; + return next(e); } } next(); diff --git a/src/components/ExpressAppComponent.ts b/src/components/ExpressAppComponent.ts index 9544f4c..58a3d59 100644 --- a/src/components/ExpressAppComponent.ts +++ b/src/components/ExpressAppComponent.ts @@ -3,8 +3,10 @@ import express, {Express, Router} from "express"; import Logger from "../Logger"; import {Server} from "http"; import compression from "compression"; +import Middleware from "../Middleware"; +import {Type} from "../Utils"; -export default class ExpressAppComponent extends ApplicationComponent { +export default class ExpressAppComponent extends ApplicationComponent { private readonly addr: string; private readonly port: number; private server?: Server; @@ -39,8 +41,12 @@ export default class ExpressAppComponent extends ApplicationComponent { router.use(compression()); router.use((req, res, next) => { - req.models = {}; - req.modelCollections = {}; + req.middlewares = []; + req.as = (type: Type) => { + const middleware = req.middlewares.find(m => m.constructor === type); + if (!middleware) throw new Error('Middleware ' + type.name + ' not present in this request.'); + return middleware as M; + }; next(); }); } diff --git a/src/components/FormHelperComponent.ts b/src/components/FormHelperComponent.ts index ab9f7f4..0905aaa 100644 --- a/src/components/FormHelperComponent.ts +++ b/src/components/FormHelperComponent.ts @@ -1,7 +1,7 @@ import ApplicationComponent from "../ApplicationComponent"; import {Router} from "express"; -export default class FormHelperComponent extends ApplicationComponent { +export default class FormHelperComponent extends ApplicationComponent { public async init(router: Router): Promise { router.use((req, res, next) => { if (!req.session) { diff --git a/src/components/LogRequestsComponent.ts b/src/components/LogRequestsComponent.ts index daf1a3e..c492a48 100644 --- a/src/components/LogRequestsComponent.ts +++ b/src/components/LogRequestsComponent.ts @@ -3,7 +3,7 @@ import onFinished from "on-finished"; import Logger from "../Logger"; import {Request, Response, Router} from "express"; -export default class LogRequestsComponent extends ApplicationComponent { +export default class LogRequestsComponent extends ApplicationComponent { private static fullRequests: boolean = false; public static logFullHttpRequests() { diff --git a/src/components/MailComponent.ts b/src/components/MailComponent.ts index 18d8998..5e903bb 100644 --- a/src/components/MailComponent.ts +++ b/src/components/MailComponent.ts @@ -4,8 +4,7 @@ import Mail from "../Mail"; import config from "config"; import SecurityError from "../SecurityError"; -export default class MailComponent extends ApplicationComponent { - +export default class MailComponent extends ApplicationComponent { public async checkSecuritySettings(): Promise { if (!config.get('mail.secure')) { diff --git a/src/components/MaintenanceComponent.ts b/src/components/MaintenanceComponent.ts index 08442b6..e6ad250 100644 --- a/src/components/MaintenanceComponent.ts +++ b/src/components/MaintenanceComponent.ts @@ -4,7 +4,7 @@ import {ServiceUnavailableHttpError} from "../HttpError"; import Application from "../Application"; import config from "config"; -export default class MaintenanceComponent extends ApplicationComponent { +export default class MaintenanceComponent extends ApplicationComponent { private readonly application: Application; private readonly canServe: () => boolean; diff --git a/src/components/MysqlComponent.ts b/src/components/MysqlComponent.ts index a1ab499..b61c6ba 100644 --- a/src/components/MysqlComponent.ts +++ b/src/components/MysqlComponent.ts @@ -2,7 +2,7 @@ import ApplicationComponent from "../ApplicationComponent"; import {Express} from "express"; import MysqlConnectionManager from "../db/MysqlConnectionManager"; -export default class MysqlComponent extends ApplicationComponent { +export default class MysqlComponent extends ApplicationComponent { public async start(app: Express): Promise { await this.prepare('Mysql connection', () => MysqlConnectionManager.prepare()); } @@ -12,7 +12,7 @@ export default class MysqlComponent extends ApplicationComponent { } public canServe(): boolean { - return MysqlConnectionManager.pool !== undefined; + return MysqlConnectionManager.isReady(); } } \ No newline at end of file diff --git a/src/components/NunjucksComponent.ts b/src/components/NunjucksComponent.ts index 724484c..31fd54c 100644 --- a/src/components/NunjucksComponent.ts +++ b/src/components/NunjucksComponent.ts @@ -1,6 +1,6 @@ import nunjucks, {Environment} from "nunjucks"; import config from "config"; -import {Express, Router} from "express"; +import {Express, NextFunction, Request, Response, Router} from "express"; import ApplicationComponent from "../ApplicationComponent"; import Controller from "../Controller"; import {ServerError} from "../HttpError"; @@ -8,8 +8,11 @@ import * as querystring from "querystring"; import {ParsedUrlQueryInput} from "querystring"; import * as util from "util"; import * as path from "path"; +import * as fs from "fs"; +import Logger from "../Logger"; +import Middleware from "../Middleware"; -export default class NunjucksComponent extends ApplicationComponent { +export default class NunjucksComponent extends ApplicationComponent { private readonly viewsPath: string[]; private env?: Environment; @@ -20,13 +23,14 @@ export default class NunjucksComponent extends ApplicationComponent { public async start(app: Express): Promise { let coreVersion = 'unknown'; + const file = fs.existsSync(path.join(__dirname, '../../package.json')) ? + path.join(__dirname, '../../package.json') : + path.join(__dirname, '../package.json'); + try { - coreVersion = require('../../package.json').version; + coreVersion = JSON.parse(fs.readFileSync(file).toString()).version; } catch (e) { - try { - coreVersion = require('../package.json').version; - } catch (e) { - } + Logger.warn('Couldn\'t determine coreVersion.', e); } this.env = new nunjucks.Environment([ @@ -43,7 +47,7 @@ export default class NunjucksComponent extends ApplicationComponent { if (path === null) throw new ServerError(`Route ${route} not found.`); return path; }) - .addGlobal('app_version', this.app!.getVersion()) + .addGlobal('app_version', this.getApp().getVersion()) .addGlobal('core_version', coreVersion) .addGlobal('querystring', querystring) .addGlobal('app', config.get('app')) @@ -59,18 +63,29 @@ export default class NunjucksComponent extends ApplicationComponent { } public async init(router: Router): Promise { - router.use((req, res, next) => { - req.env = this.env!; - res.locals.url = req.url; - res.locals.params = req.params; - res.locals.query = req.query; - res.locals.body = req.body; - - next(); - }); + this.use(NunjucksMiddleware); } public getEnv(): Environment | undefined { return this.env; } -} \ No newline at end of file +} + +export class NunjucksMiddleware extends Middleware { + private env?: Environment; + + protected async handle(req: Request, res: Response, next: NextFunction): Promise { + this.env = this.app.as(NunjucksComponent).getEnv(); + res.locals.url = req.url; + res.locals.params = req.params; + res.locals.query = req.query; + res.locals.body = req.body; + + next(); + } + + public getEnvironment(): Environment { + if (!this.env) throw new Error('Environment not initialized.'); + return this.env; + } +} diff --git a/src/components/RedirectBackComponent.ts b/src/components/RedirectBackComponent.ts index a10fdbe..75be932 100644 --- a/src/components/RedirectBackComponent.ts +++ b/src/components/RedirectBackComponent.ts @@ -4,7 +4,7 @@ import {ServerError} from "../HttpError"; import onFinished from "on-finished"; import Logger from "../Logger"; -export default class RedirectBackComponent extends ApplicationComponent { +export default class RedirectBackComponent extends ApplicationComponent { public static getPreviousURL(req: Request, defaultUrl?: string): string | undefined { return req.session?.previousUrl || defaultUrl; } diff --git a/src/components/RedisComponent.ts b/src/components/RedisComponent.ts index c8594e0..30e12fb 100644 --- a/src/components/RedisComponent.ts +++ b/src/components/RedisComponent.ts @@ -9,7 +9,7 @@ import CacheProvider from "../CacheProvider"; const RedisStore = connect_redis(session); -export default class RedisComponent extends ApplicationComponent implements CacheProvider { +export default class RedisComponent extends ApplicationComponent implements CacheProvider { private redisClient?: RedisClient; private store?: Store; diff --git a/src/components/ServeStaticDirectoryComponent.ts b/src/components/ServeStaticDirectoryComponent.ts index f05685b..c2042e3 100644 --- a/src/components/ServeStaticDirectoryComponent.ts +++ b/src/components/ServeStaticDirectoryComponent.ts @@ -3,7 +3,7 @@ import express, {Router} from "express"; import {PathParams} from "express-serve-static-core"; import * as path from "path"; -export default class ServeStaticDirectoryComponent extends ApplicationComponent { +export default class ServeStaticDirectoryComponent extends ApplicationComponent { private readonly root: string; private readonly path?: PathParams; diff --git a/src/components/SessionComponent.ts b/src/components/SessionComponent.ts index 2409d8c..964bf35 100644 --- a/src/components/SessionComponent.ts +++ b/src/components/SessionComponent.ts @@ -6,7 +6,7 @@ import flash from "connect-flash"; import {Router} from "express"; import SecurityError from "../SecurityError"; -export default class SessionComponent extends ApplicationComponent { +export default class SessionComponent extends ApplicationComponent { private readonly storeComponent: RedisComponent; public constructor(storeComponent: RedisComponent) { diff --git a/src/components/WebSocketServerComponent.ts b/src/components/WebSocketServerComponent.ts index 25eba6d..ed5dbca 100644 --- a/src/components/WebSocketServerComponent.ts +++ b/src/components/WebSocketServerComponent.ts @@ -11,7 +11,7 @@ import RedisComponent from "./RedisComponent"; import WebSocketListener from "../WebSocketListener"; import NunjucksComponent from "./NunjucksComponent"; -export default class WebSocketServerComponent extends ApplicationComponent { +export default class WebSocketServerComponent extends ApplicationComponent { private wss?: WebSocket.Server; constructor( diff --git a/src/db/Model.ts b/src/db/Model.ts index f8c4150..29ec758 100644 --- a/src/db/Model.ts +++ b/src/db/Model.ts @@ -7,8 +7,9 @@ import ModelFactory from "./ModelFactory"; import ModelRelation from "./ModelRelation"; import ModelQuery, {ModelQueryResult, SelectFields} from "./ModelQuery"; import {Request} from "express"; +import Extendable from "../Extendable"; -export default abstract class Model { +export default abstract class Model implements Extendable> { public static get table(): string { return this.name .replace(/(?:^|\.?)([A-Z])/g, (x, y) => '_' + y.toLowerCase()) @@ -44,14 +45,14 @@ export default abstract class Model { return ModelFactory.get(this).paginate(request, perPage, query); } - protected readonly _factory: ModelFactory; - private readonly _components: ModelComponent[] = []; - private readonly _validators: { [key: string]: Validator } = {}; + protected readonly _factory: ModelFactory; + private readonly _components: ModelComponent[] = []; + private readonly _validators: { [key: string]: Validator | undefined } = {}; private _exists: boolean; [key: string]: any; - public constructor(factory: ModelFactory, isNew: boolean) { + public constructor(factory: ModelFactory, isNew: boolean) { if (!factory || !(factory instanceof ModelFactory)) throw new Error('Cannot instantiate model directly.'); this._factory = factory; this.init(); @@ -71,15 +72,26 @@ export default abstract class Model { this._components.push(modelComponent); } - public as>(type: Type): T { + public as>(type: Type): C { for (const component of this._components) { if (component instanceof type) { - return this; + return this as unknown as C; } } + throw new Error(`Component ${type.name} was not initialized for this ${this.constructor.name}.`); } + public asOptional>(type: Type): C | null { + for (const component of this._components) { + if (component instanceof type) { + return this as unknown as C; + } + } + + return null; + } + public updateWithData(data: any) { for (const property of this._properties) { if (data[property] !== undefined) { @@ -92,19 +104,16 @@ export default abstract class Model { * Override this to automatically fill obvious missing data i.e. from relation or default value that are fetched * asynchronously. */ - protected async autoFill(): Promise { - } + protected async autoFill?(): Promise; - protected async beforeSave(connection: Connection): Promise { - } + protected async beforeSave?(connection: Connection): Promise; - protected async afterSave(): Promise { - } + protected async afterSave?(): Promise; public async save(connection?: Connection, postHook?: (callback: () => Promise) => void): Promise { if (connection && !postHook) throw new Error('If connection is provided, postHook must be provided too.'); - await this.autoFill(); + await this.autoFill?.(); await this.validate(false, connection); const needs_full_update = connection ? @@ -120,7 +129,7 @@ export default abstract class Model { this.updateWithData(result.results[0]); } - await this.afterSave(); + await this.afterSave?.(); }; if (connection) { @@ -132,7 +141,7 @@ export default abstract class Model { private async saveTransaction(connection: Connection): Promise { // Before save - await this.beforeSave(connection); + await this.beforeSave?.(connection); if (!this.exists() && this.hasOwnProperty('created_at')) { this.created_at = new Date(); } diff --git a/src/db/ModelComponent.ts b/src/db/ModelComponent.ts index 967e465..6d37763 100644 --- a/src/db/ModelComponent.ts +++ b/src/db/ModelComponent.ts @@ -13,7 +13,7 @@ export default abstract class ModelComponent { } public applyToModel(): void { - this.init(); + this.init?.(); for (const property of this._properties) { if (!property.startsWith('_')) { @@ -32,7 +32,7 @@ export default abstract class ModelComponent { } } - protected abstract init(): void; + protected init?(): void; protected setValidation(propertyName: keyof this): Validator { const validator = new Validator(); diff --git a/src/db/MysqlConnectionManager.ts b/src/db/MysqlConnectionManager.ts index dfb64a0..5c6e53a 100644 --- a/src/db/MysqlConnectionManager.ts +++ b/src/db/MysqlConnectionManager.ts @@ -21,6 +21,10 @@ export default class MysqlConnectionManager { private static migrationsRegistered: boolean = false; private static readonly migrations: Migration[] = []; + public static isReady(): boolean { + return this.databaseReady && this.currentPool !== undefined; + } + public static registerMigrations(migrations: Type[]) { if (!this.migrationsRegistered) { this.migrationsRegistered = true; diff --git a/src/helpers/BackendController.ts b/src/helpers/BackendController.ts index 3c55dd7..da754e0 100644 --- a/src/helpers/BackendController.ts +++ b/src/helpers/BackendController.ts @@ -1,6 +1,5 @@ import config from "config"; import Controller from "../Controller"; -import {REQUIRE_ADMIN_MIDDLEWARE, REQUIRE_AUTH_MIDDLEWARE} from "../auth/AuthComponent"; import User from "../auth/models/User"; import {Request, Response} from "express"; import {BadRequestError, NotFoundHttpError} from "../HttpError"; @@ -8,6 +7,7 @@ import Mail from "../Mail"; import {ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE} from "../Mails"; import UserEmail from "../auth/models/UserEmail"; import UserApprovedComponent from "../auth/models/UserApprovedComponent"; +import {RequireAdminMiddleware, RequireAuthMiddleware} from "../auth/AuthComponent"; export default class BackendController extends Controller { private static readonly menu: BackendMenuElement[] = []; @@ -37,11 +37,11 @@ export default class BackendController extends Controller { } public routes(): void { - this.get('/', this.getIndex, 'backend', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); + this.get('/', this.getIndex, 'backend', RequireAuthMiddleware, RequireAdminMiddleware); if (User.isApprovalMode()) { - this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); - this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); - this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account', REQUIRE_AUTH_MIDDLEWARE, REQUIRE_ADMIN_MIDDLEWARE); + this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval', RequireAuthMiddleware, RequireAdminMiddleware); + this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account', RequireAuthMiddleware, RequireAdminMiddleware); + this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account', RequireAuthMiddleware, RequireAdminMiddleware); } } diff --git a/src/types/Express.d.ts b/src/types/Express.d.ts index 006e76c..6764ea9 100644 --- a/src/types/Express.d.ts +++ b/src/types/Express.d.ts @@ -1,21 +1,18 @@ -import {Environment} from "nunjucks"; -import Model from "../db/Model"; -import AuthGuard from "../auth/AuthGuard"; import {Files} from "formidable"; -import User from "../auth/models/User"; +import {Type} from "../Utils"; +import Middleware from "../Middleware"; declare global { namespace Express { export interface Request { - env: Environment; - models: { - user?: User | null, - [p: string]: Model | null | undefined, - }; - modelCollections: { [p: string]: Model[] | null }; - authGuard: AuthGuard; files: Files; + + middlewares: Middleware[]; + + as(type: Type): M; + + flash(): { [key: string]: string[] }; flash(message: string): any;