Merge branch 'develop'

This commit is contained in:
Alice Gaudon 2021-01-25 10:56:36 +01:00
commit 1f9fcbec58
128 changed files with 12246 additions and 6316 deletions

110
.eslintrc.json Normal file
View File

@ -0,0 +1,110 @@
{
"root": true,
"parser": "@typescript-eslint/parser",
"plugins": [
"@typescript-eslint"
],
"parserOptions": {
"project": [
"./tsconfig.json",
"./tsconfig.test.json"
]
},
"extends": [
"eslint:recommended",
"plugin:@typescript-eslint/recommended"
],
"rules": {
"indent": [
"error",
4,
{
"SwitchCase": 1
}
],
"no-trailing-spaces": "error",
"max-len": [
"error",
{
"code": 120,
"ignoreTemplateLiterals": true,
"ignoreRegExpLiterals": true
}
],
"semi": "off",
"@typescript-eslint/semi": [
"error"
],
"no-extra-semi": "error",
"eol-last": "error",
"comma-dangle": "off",
"@typescript-eslint/comma-dangle": [
"error",
{
"arrays": "always-multiline",
"objects": "always-multiline",
"imports": "always-multiline",
"exports": "always-multiline",
"functions": "always-multiline",
"enums": "always-multiline",
"generics": "always-multiline",
"tuples": "always-multiline"
}
],
"no-extra-parens": "off",
"@typescript-eslint/no-extra-parens": [
"error"
],
"no-nested-ternary": "error",
"@typescript-eslint/no-inferrable-types": "off",
"@typescript-eslint/explicit-module-boundary-types": "error",
"@typescript-eslint/no-unnecessary-condition": "error",
"@typescript-eslint/no-unused-vars": [
"error",
{
"argsIgnorePattern": "^_"
}
],
"@typescript-eslint/no-non-null-assertion": "error",
"no-useless-return": "error",
"no-useless-constructor": "off",
"@typescript-eslint/no-useless-constructor": [
"error"
],
"no-return-await": "off",
"@typescript-eslint/return-await": [
"error",
"always"
],
"@typescript-eslint/explicit-member-accessibility": [
"error",
{
"accessibility": "explicit"
}
],
"@typescript-eslint/no-floating-promises": "error"
},
"ignorePatterns": [
"jest.config.js",
"dist/**/*",
"config/**/*"
],
"overrides": [
{
"files": [
"test/**/*"
],
"rules": {
"max-len": [
"error",
{
"code": 120,
"ignoreTemplateLiterals": true,
"ignoreRegExpLiterals": true,
"ignoreStrings": true
}
]
}
}
]
}

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea .idea
node_modules node_modules
dist dist
yarn-error.log

8
LICENSE Normal file
View File

@ -0,0 +1,8 @@
Copyright 2020 Alice Gaudon
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1 +1,33 @@
# WMS-Core # Structure Web Application Framework
A NodeJS TypeScript web application framework (duh).
## /!\ Still in development! There are not near enough tests /!\
Use at your own risk. Also please feel free to contribute with issues, bug reports and pull requests.
## Features
### Application building
- Model, View, Controller
- Uses express
- Custom Middleware classes that enable advanced modularity
- Modular models (you can add components with some definition automation)
- Simple database migrations (raw sql queries for now)
- Nunjucks for the view template engine
- Mail template system using Nunjucks + MJML
- Beautiful logging thanks to `tslog`
### Databases
- MySQL (persistent data)
- Redis (cache, session)
- (more to come)
### Common systems
- Advanced modular multi-factor authentication system
- CSRF protection
- WebSocket server with Controller-style endpoint listeners
- WIP: automatic updates

54
config/default.json5 Normal file
View File

@ -0,0 +1,54 @@
{
app: {
name: 'Example App',
contact_email: 'contact@example.net',
display_email_warning: true,
},
log: {
level: "DEBUG",
verbose: true,
db_level: "ERROR",
},
public_url: "http://localhost:4899",
public_websocket_url: "ws://localhost:4899",
listen_addr: '127.0.0.1',
port: 4899,
gitlab_webhook_token: 'default',
mysql: {
connectionLimit: 10,
host: "localhost",
user: "root",
password: "",
database: "swaf",
create_database_automatically: false,
},
redis: {
host: "127.0.0.1",
port: 6379,
prefix: 'swaf',
},
session: {
secret: 'default',
cookie: {
secure: false,
maxAge: 31557600000, // 1 year
},
},
mail: {
host: "127.0.0.1",
port: "1025",
secure: false,
username: "",
password: "",
allow_invalid_tls: true,
from: 'contact@example.net',
from_name: 'Example App',
},
view: {
cache: false,
},
magic_link: {
validity_period: 20,
},
approval_mode: false,
}

View File

@ -1,36 +0,0 @@
export default {
log_level: "DEV",
db_log_level: "ERROR",
public_url: "http://localhost:4899",
public_websocket_url: "ws://localhost:4899",
port: 4899,
mysql: {
connectionLimit: 10,
host: "localhost",
user: "root",
password: "",
database: "wms2",
create_database_automatically: false
},
redis: {
host: "127.0.0.1",
port: 6379
},
session: {
secret: "very_secret_not_known",
cookie: {
secure: false
}
},
mail: {
host: "127.0.0.1",
port: "1025",
secure: false,
username: "",
password: "",
allow_invalid_tls: true
},
view: {
cache: false
}
};

21
config/production.json5 Normal file
View File

@ -0,0 +1,21 @@
{
log: {
level: "DEV",
verbose: false,
db_level: "ERROR",
},
public_url: "https://swaf.example",
public_websocket_url: "wss://swaf.example",
session: {
cookie: {
secure: true,
},
},
mail: {
secure: true,
allow_invalid_tls: false,
},
magic_link: {
validity_period: 900,
},
}

View File

@ -1,15 +0,0 @@
export default {
log_level: "DEBUG",
db_log_level: "ERROR",
public_url: "https://watch-my.stream",
public_websocket_url: "wss://watch-my.stream",
session: {
cookie: {
secure: true
}
},
mail: {
secure: true,
allow_invalid_tls: false
}
};

14
config/test.json5 Normal file
View File

@ -0,0 +1,14 @@
{
mysql: {
host: "localhost",
user: "root",
password: "",
database: "swaf_test",
create_database_automatically: true,
},
session: {
cookie: {
maxAge: 1000, // 1s
},
},
}

View File

@ -1,9 +0,0 @@
export default {
mysql: {
host: "localhost",
user: "root",
password: "",
database: "wms2_test",
create_database_automatically: true
}
};

View File

@ -1,4 +1,9 @@
module.exports = { module.exports = {
globals: {
'ts-jest': {
tsconfig: 'tsconfig.test.json',
}
},
transform: { transform: {
"^.+\\.ts$": "ts-jest" "^.+\\.ts$": "ts-jest"
}, },
@ -10,5 +15,5 @@ module.exports = {
testMatch: [ testMatch: [
'**/test/**/*.test.ts' '**/test/**/*.test.ts'
], ],
testEnvironment: 'node' testEnvironment: 'node',
}; };

View File

@ -1,58 +1,81 @@
{ {
"name": "wms-core", "name": "swaf",
"version": "0.2.7", "version": "0.23.0",
"description": "Node web framework", "description": "Structure Web Application Framework.",
"repository": "git@gitlab.com:ArisuOngaku/wms-core.git", "repository": "https://eternae.ink/arisu/swaf",
"author": "Alice Gaudon <alice@gaudon.pro>", "author": "Alice Gaudon <alice@gaudon.pro>",
"license": "MIT", "license": "MIT",
"readme": "README.md",
"publishConfig": { "publishConfig": {
"registry": "http://127.0.0.1:4873", "registry": "https://registry.npmjs.com",
"access": "restricted" "access": "public"
}, },
"main": "dist/index.js", "main": "dist/src/main.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
"scripts": { "scripts": {
"test": "jest --verbose --runInBand", "test": "jest --verbose --runInBand",
"build": "(test ! -d dist || rm -r dist) && tsc && cp package.json dist/ && mkdir dist/types && cp src/types/* dist/types/ && mv dist/src/* dist/ && rm -r dist/src", "clean": "(test ! -d dist || rm -r dist)",
"publish_to_local": "yarn test && yarn build && cd dist && yarn publish" "compile": "yarn clean && tsc",
"dev": "concurrently -k -n \"Typescript,Node,Webpack,Maildev\" -p \"[{name}]\" -c \"blue,green,red,yellow\" \"tsc --watch\" \"nodemon\" \"maildev\"",
"build": "yarn compile && cp -r package.json yarn.lock README.md config/ views/ dist/ && mkdir dist/types && cp src/types/* dist/types/",
"lint": "eslint . --ext .js,.jsx,.ts,.tsx",
"release": "yarn lint && yarn test && yarn build && cd dist && yarn publish"
}, },
"devDependencies": { "devDependencies": {
"@types/config": "^0.0.36", "@types/compression": "^1.7.0",
"@types/connect-flash": "^0.0.35", "@types/config": "^0.0.38",
"@types/connect-redis": "^0.0.13", "@types/connect-flash": "^0.0.36",
"@types/cookie": "^0.3.3", "@types/cookie": "^0.4.0",
"@types/cookie-parser": "^1.4.2", "@types/cookie-parser": "^1.4.2",
"@types/jest": "^25.2.1",
"@types/mjml": "^4.0.4",
"@types/on-finished": "^2.3.1",
"@types/uuid": "^7.0.2",
"jest": "^25.4.0",
"ts-jest": "^25.4.0",
"typescript": "^3.8.3"
},
"dependencies": {
"@types/express": "^4.17.6", "@types/express": "^4.17.6",
"@types/express-session": "^1.17.0", "@types/express-session": "^1.17.0",
"@types/formidable": "^1.0.31",
"@types/geoip-lite": "^1.1.31",
"@types/jest": "^26.0.4",
"@types/mjml": "^4.0.4",
"@types/mysql": "^2.15.10", "@types/mysql": "^2.15.10",
"@types/nanoid": "^2.1.0",
"@types/node-fetch": "^2.5.7",
"@types/nodemailer": "^6.4.0", "@types/nodemailer": "^6.4.0",
"@types/nunjucks": "^3.1.3", "@types/nunjucks": "^3.1.3",
"@types/on-finished": "^2.3.1",
"@types/redis": "^2.8.18", "@types/redis": "^2.8.18",
"@types/supertest": "^2.0.10",
"@types/uuid": "^8.0.0",
"@types/ws": "^7.2.4", "@types/ws": "^7.2.4",
"@typescript-eslint/eslint-plugin": "^4.2.0",
"@typescript-eslint/parser": "^4.2.0",
"concurrently": "^5.3.0",
"eslint": "^7.9.0",
"jest": "^26.1.0",
"maildev": "^1.1.0",
"node-fetch": "^2.6.0",
"nodemon": "^2.0.6",
"supertest": "^6.0.0",
"ts-jest": "^26.1.1",
"typescript": "^4.0.2"
},
"dependencies": {
"argon2": "^0.27.0",
"compression": "^1.7.4",
"config": "^3.3.1", "config": "^3.3.1",
"connect-flash": "^0.1.1", "connect-flash": "^0.1.1",
"connect-redis": "^4.0.4",
"cookie": "^0.4.1", "cookie": "^0.4.1",
"cookie-parser": "^1.4.5", "cookie-parser": "^1.4.5",
"express": "^4.17.1", "express": "^4.17.1",
"express-session": "^1.17.1", "express-session": "^1.17.1",
"formidable": "^1.2.2",
"geoip-lite": "^1.4.2",
"mjml": "^4.6.2", "mjml": "^4.6.2",
"mysql": "^2.18.1", "mysql": "^2.18.1",
"nanoid": "^3.1.20",
"nodemailer": "^6.4.6", "nodemailer": "^6.4.6",
"nunjucks": "^3.2.1", "nunjucks": "^3.2.1",
"on-finished": "^2.3.0", "on-finished": "^2.3.0",
"redis": "^3.0.2", "redis": "^3.0.2",
"ts-node": "^8.9.0", "ts-node": "^9.0.0",
"uuid": "^7.0.3", "tslog": "^3.0.1",
"uuid": "^8.0.0",
"ws": "^7.2.3" "ws": "^7.2.3"
} }
} }

View File

@ -1,45 +1,64 @@
import express, {NextFunction, Request, Response, Router} from 'express'; import express, {NextFunction, Request, Response, Router} from 'express';
import {BadRequestError, HttpError, NotFoundHttpError, ServerError, ServiceUnavailableHttpError} from "./HttpError"; import {BadRequestError, HttpError, NotFoundHttpError, ServerError, ServiceUnavailableHttpError} from "./HttpError";
import {lib} from "nunjucks"; import {lib} from "nunjucks";
import Logger from "./Logger";
import WebSocketListener from "./WebSocketListener"; import WebSocketListener from "./WebSocketListener";
import ApplicationComponent from "./ApplicationComponent"; import ApplicationComponent from "./ApplicationComponent";
import Controller from "./Controller"; import Controller from "./Controller";
import MysqlConnectionManager from "./db/MysqlConnectionManager"; import MysqlConnectionManager from "./db/MysqlConnectionManager";
import Migration from "./db/Migration"; import Migration, {MigrationType} from "./db/Migration";
import TemplateError = lib.TemplateError;
import {Type} from "./Utils"; import {Type} from "./Utils";
import LogRequestsComponent from "./components/LogRequestsComponent";
import {ValidationBag, ValidationError} from "./db/Validator";
import config from "config";
import * as fs from "fs";
import SecurityError from "./SecurityError";
import * as path from "path";
import CacheProvider from "./CacheProvider";
import RedisComponent from "./components/RedisComponent";
import Extendable from "./Extendable";
import {logger, loggingContextMiddleware} from "./Logger";
import TemplateError = lib.TemplateError;
export default abstract class Application { export default abstract class Application implements Extendable<ApplicationComponent | WebSocketListener<Application>> {
private readonly version: string; private readonly version: string;
private readonly ignoreCommandLine: boolean;
private readonly controllers: Controller[] = []; private readonly controllers: Controller[] = [];
private readonly webSocketListeners: { [p: string]: WebSocketListener } = {}; private readonly webSocketListeners: { [p: string]: WebSocketListener<Application> } = {};
private readonly components: ApplicationComponent<any>[] = []; private readonly components: ApplicationComponent[] = [];
private cacheProvider?: CacheProvider;
private ready: boolean = false; private ready: boolean = false;
protected constructor(version: string) { protected constructor(version: string, ignoreCommandLine: boolean = false) {
this.version = version; this.version = version;
this.ignoreCommandLine = ignoreCommandLine;
} }
protected abstract getMigrations(): Type<Migration>[]; protected abstract getMigrations(): MigrationType<Migration>[];
protected abstract async init(): Promise<void>; protected abstract init(): Promise<void>;
protected use(thing: Controller | WebSocketListener | ApplicationComponent<any>) { protected use(thing: Controller | WebSocketListener<this> | ApplicationComponent): void {
if (thing instanceof Controller) { if (thing instanceof Controller) {
thing.setApp(this);
this.controllers.push(thing); this.controllers.push(thing);
} else if (thing instanceof WebSocketListener) { } else if (thing instanceof WebSocketListener) {
const path = thing.path(); const path = thing.path();
this.webSocketListeners[path] = thing; this.webSocketListeners[path] = thing;
Logger.info(`Added websocket listener on ${path}`); thing.init(this);
logger.info(`Added websocket listener on ${path}`);
} else { } else {
thing.setApp(this);
this.components.push(thing); this.components.push(thing);
if (thing instanceof RedisComponent) {
this.cacheProvider = thing;
}
} }
} }
public async start(): Promise<void> { public async start(): Promise<void> {
Logger.info(`${this.constructor.name} v${this.version} - hi`); logger.info(`${config.get('app.name')} v${this.version} - hi`);
process.once('SIGINT', () => { process.once('SIGINT', () => {
this.stop().catch(console.error); this.stop().catch(console.error);
}); });
@ -47,30 +66,70 @@ export default abstract class Application {
// Register migrations // Register migrations
MysqlConnectionManager.registerMigrations(this.getMigrations()); MysqlConnectionManager.registerMigrations(this.getMigrations());
// Process command line
if (!this.ignoreCommandLine && await this.processCommandLine()) {
await this.stop();
return;
}
// Register all components and alike // Register all components and alike
await this.init(); await this.init();
// Security
if (process.env.NODE_ENV === 'production') {
await this.checkSecuritySettings();
}
// Init express // Init express
const app = express(); const app = express();
const router = express.Router({});
app.use(router);
// Error handler // Logging context
app.use((err: any, req: Request, res: Response, next: NextFunction) => { app.use(loggingContextMiddleware);
if (res.headersSent) {
return next(err); // Routers
const initRouter = express.Router();
const handleRouter = express.Router();
app.use(initRouter);
app.use(handleRouter);
// Error handlers
app.use((err: unknown, req: Request, res: Response, next: NextFunction) => {
if (res.headersSent) return next(err);
// Transform single validation errors into a validation bag for convenience
if (err instanceof ValidationError) {
const bag = new ValidationBag();
bag.addMessage(err);
err = bag;
} }
let errorID: string; if (err instanceof ValidationBag) {
const bag = err;
let logStr = `${req.method} ${req.originalUrl} - `; res.format({
if (err instanceof BadRequestError || err instanceof ServiceUnavailableHttpError) { json: () => {
logStr += `${err.errorCode} ${err.name}`; res.status(400);
errorID = Logger.silentError(err, logStr); res.json({
} else { status: 'error',
errorID = Logger.error(err, logStr + `500 Internal Error`, err); code: 400,
message: 'Invalid form data',
messages: bag.getMessages(),
});
},
text: () => {
res.status(400);
res.send('Error: ' + bag.getMessages());
},
html: () => {
req.flash('validation', bag.getMessages());
res.redirect(req.getPreviousUrl() || Controller.route('home'));
},
});
return;
} }
const errorId = LogRequestsComponent.logRequest(req, res, err, '500 Internal Error',
err instanceof BadRequestError || err instanceof ServiceUnavailableHttpError);
let httpError: HttpError; let httpError: HttpError;
if (err instanceof HttpError) { if (err instanceof HttpError) {
@ -78,7 +137,7 @@ export default abstract class Application {
} else if (err instanceof TemplateError && err.cause instanceof HttpError) { } else if (err instanceof TemplateError && err.cause instanceof HttpError) {
httpError = err.cause; httpError = err.cause;
} else { } else {
httpError = new ServerError('Internal server error.', err); httpError = new ServerError('Internal server error.', err instanceof Error ? err : undefined);
} }
res.status(httpError.errorCode); res.status(httpError.errorCode);
@ -88,7 +147,7 @@ export default abstract class Application {
error_code: httpError.errorCode, error_code: httpError.errorCode,
error_message: httpError.message, error_message: httpError.message,
error_instructions: httpError.instructions, error_instructions: httpError.instructions,
error_id: errorID, error_id: errorId,
}); });
}, },
json: () => { json: () => {
@ -97,63 +156,140 @@ export default abstract class Application {
code: httpError.errorCode, code: httpError.errorCode,
message: httpError.message, message: httpError.message,
instructions: httpError.instructions, instructions: httpError.instructions,
error_id: errorID, error_id: errorId,
}); });
}, },
default: () => { default: () => {
res.type('txt').send(`${httpError.errorCode} - ${httpError.message}\n\n${httpError.instructions}\n\nError ID: ${errorID}`); res.type('txt').send(`${httpError.errorCode} - ${httpError.message}\n\n${httpError.instructions}\n\nError ID: ${errorId}`);
} },
}); });
}); });
// Start all components // Start components
for (const component of this.components) { for (const component of this.components) {
await component.start(app, router); await component.start?.(app);
}
// Components routes
for (const component of this.components) {
if (component.init) {
component.setCurrentRouter(initRouter);
await component.init(initRouter);
}
if (component.handle) {
component.setCurrentRouter(handleRouter);
await component.handle(handleRouter);
}
component.setCurrentRouter(null);
} }
// Routes // Routes
this.routes(router); this.routes(initRouter, handleRouter);
this.ready = true; this.ready = true;
} }
async stop(): Promise<void> { protected async processCommandLine(): Promise<boolean> {
Logger.info('Stopping application...'); const args = process.argv;
for (let i = 2; i < args.length; i++) {
for (const component of this.components) { switch (args[i]) {
await component.stop(); case '--verbose':
logger.setSettings({minLevel: "trace"});
break;
case '--full-http-requests':
LogRequestsComponent.logFullHttpRequests();
break;
case 'migration':
await MysqlConnectionManager.migrationCommand(args.slice(i + 1));
return true;
default:
logger.warn('Unrecognized argument', args[i]);
return true;
}
} }
return false;
Logger.info(`${this.constructor.name} v${this.version} - bye`);
} }
private routes(rootRouter: Router) { private async checkSecuritySettings(): Promise<void> {
for (const controller of this.controllers) { // Check config file permissions
if (controller.hasGlobalHandlers()) { const configDir = 'config';
controller.setupGlobalHandlers(rootRouter); for (const file of fs.readdirSync(configDir)) {
const fullPath = path.resolve(configDir, file);
const stats = fs.lstatSync(fullPath);
if (stats.uid !== process.getuid())
throw new SecurityError(`${fullPath} is not owned by this process (${process.getuid()}).`);
Logger.info(`Registered global middlewares for controller ${controller.constructor.name}`); const mode = (stats.mode & parseInt('777', 8)).toString(8);
if (mode !== '400')
throw new SecurityError(`${fullPath} is ${mode}; should be 400.`);
}
// Check security fields
for (const component of this.components) {
await component.checkSecuritySettings?.();
}
}
public async stop(): Promise<void> {
logger.info('Stopping application...');
for (const component of this.components) {
await component.stop?.();
}
logger.info(`${this.constructor.name} v${this.version} - bye`);
}
private routes(initRouter: Router, handleRouter: Router) {
for (const controller of this.controllers) {
if (controller.hasGlobalMiddlewares()) {
controller.setupGlobalHandlers(handleRouter);
logger.info(`Registered global middlewares for controller ${controller.constructor.name}`);
} }
} }
for (const controller of this.controllers) { for (const controller of this.controllers) {
const router = express.Router(); const {mainRouter, fileUploadFormRouter} = controller.setupRoutes();
controller.setupRoutes(router); initRouter.use(controller.getRoutesPrefix(), fileUploadFormRouter);
rootRouter.use(controller.getRoutesPrefix(), router); handleRouter.use(controller.getRoutesPrefix(), mainRouter);
Logger.info(`> Registered routes for controller ${controller.constructor.name}`); logger.info(`> Registered routes for controller ${controller.constructor.name} at ${controller.getRoutesPrefix()}`);
} }
rootRouter.use((req: Request) => { handleRouter.use((req: Request) => {
throw new NotFoundHttpError('page', req.originalUrl); throw new NotFoundHttpError('page', req.originalUrl);
}); });
} }
public getWebSocketListeners(): { [p: string]: WebSocketListener } {
return this.webSocketListeners;
}
public isReady(): boolean { public isReady(): boolean {
return this.ready; return this.ready;
} }
public getVersion(): string {
return this.version;
}
public getWebSocketListeners(): { [p: string]: WebSocketListener<Application> } {
return this.webSocketListeners;
}
public getCache(): CacheProvider | null {
return this.cacheProvider || null;
}
public as<C extends ApplicationComponent | WebSocketListener<Application>>(type: Type<C>): C {
const module = this.components.find(component => component.constructor === type) ||
Object.values(this.webSocketListeners).find(listener => listener.constructor === type);
if (!module) throw new Error(`This app doesn't have a ${type.name} component.`);
return module as C;
}
public asOptional<C extends ApplicationComponent | WebSocketListener<Application>>(type: Type<C>): C | null {
const module = this.components.find(component => component.constructor === type) ||
Object.values(this.webSocketListeners).find(listener => listener.constructor === type);
return module ? module as C : null;
}
} }

View File

@ -1,22 +1,24 @@
import {Express, Router} from "express"; import {Express, Router} from "express";
import Logger from "./Logger"; import {logger} from "./Logger";
import {sleep} from "./Utils"; import {sleep} from "./Utils";
import Application from "./Application";
import config from "config";
import SecurityError from "./SecurityError";
import Middleware, {MiddlewareType} from "./Middleware";
export default abstract class ApplicationComponent<T> { export default abstract class ApplicationComponent {
private val?: T; private currentRouter?: Router;
private app?: Application;
public abstract async start(app: Express, router: Router): Promise<void>; public async checkSecuritySettings?(): Promise<void>;
public abstract async stop(): Promise<void>; public async start?(expressApp: Express): Promise<void>;
protected export(val: T) { public async init?(router: Router): Promise<void>;
this.val = val;
}
public import(): T { public async handle?(router: Router): Promise<void>;
if (!this.val) throw 'Cannot import if nothing was exported.';
return this.val; public async stop?(): Promise<void>;
}
protected async prepare(name: string, prepare: () => Promise<void>): Promise<void> { protected async prepare(name: string, prepare: () => Promise<void>): Promise<void> {
let err; let err;
@ -26,23 +28,55 @@ export default abstract class ApplicationComponent<T> {
err = null; err = null;
} catch (e) { } catch (e) {
err = e; err = e;
Logger.error(err, `${name} failed to prepare; retrying in 5s...`) logger.error(err, `${name} failed to prepare; retrying in 5s...`);
await sleep(5000); await sleep(5000);
} }
} while (err); } while (err);
Logger.info(`${name} ready!`); logger.info(`${name} ready!`);
} }
protected async close(thingName: string, thing: any, fn: Function): Promise<void> { protected async close(thingName: string, fn: (callback: (err?: Error | null) => void) => void): Promise<void> {
try { try {
await new Promise((resolve, reject) => fn.call(thing, (err: any) => { await new Promise<void>((resolve, reject) => fn((err?: Error | null) => {
if (err) reject(err); if (err) reject(err);
else resolve(); else resolve();
})); }));
Logger.info(`${thingName} closed.`); logger.info(`${thingName} closed.`);
} catch (e) { } catch (e) {
Logger.error(e, `An error occurred while closing the ${thingName}.`); logger.error(e, `An error occurred while closing the ${thingName}.`);
} }
} }
protected checkSecurityConfigField(field: string): void {
if (!config.has(field) || config.get<string>(field) === 'default') {
throw new SecurityError(`${field} field not configured.`);
}
}
protected use<M extends Middleware>(middleware: MiddlewareType<M>): void {
if (!this.currentRouter) throw new Error('Cannot call this method outside init() and handle().');
const instance = new middleware(this.getApp());
this.currentRouter.use(async (req, res, next) => {
try {
await instance.getRequestHandler()(req, res, next);
} catch (e) {
next(e);
}
});
}
public setCurrentRouter(router: Router | null): void {
this.currentRouter = router || undefined;
}
protected getApp(): Application {
if (!this.app) throw new Error('app field not initialized.');
return this.app;
}
public setApp(app: Application): void {
this.app = app;
}
} }

14
src/CacheProvider.ts Normal file
View File

@ -0,0 +1,14 @@
export default interface CacheProvider {
get<T extends string | undefined>(key: string, defaultValue?: T): Promise<T>;
has(key: string): Promise<boolean>;
forget(key: string): Promise<void>;
/**
* @param key
* @param value
* @param ttl in ms
*/
remember(key: string, value: string, ttl: number): Promise<void>;
}

View File

@ -1,50 +1,61 @@
import {RequestHandler, Router} from "express"; import express, {IRouter, RequestHandler, Router} from "express";
import {PathParams} from "express-serve-static-core"; import {PathParams} from "express-serve-static-core";
import config from "config"; import config from "config";
import Logger from "./Logger"; import {logger} from "./Logger";
import FileUploadMiddleware from "./FileUploadMiddleware";
import * as querystring from "querystring";
import {ParsedUrlQueryInput} from "querystring";
import Middleware, {MiddlewareType} from "./Middleware";
import Application from "./Application";
export default abstract class Controller { export default abstract class Controller {
private static readonly routes: { [p: string]: string } = {}; private static readonly routes: { [p: string]: string | undefined } = {};
public static route(route: string, params: RouteParams = [], absolute: boolean = false): string { public static route(
route: string,
params: RouteParams = [],
query: ParsedUrlQueryInput = {},
absolute: boolean = false,
): string {
let path = this.routes[route]; let path = this.routes[route];
if (path === undefined) throw new Error(`Unknown route for name ${route}.`); if (path === undefined) throw new Error(`Unknown route for name ${route}.`);
if (typeof params === 'string' || typeof params === 'number') { if (typeof params === 'string' || typeof params === 'number') {
path = path.replace(/:[a-zA-Z_-]+\??/, '' + params); path = path.replace(/:[a-zA-Z_-]+\??/g, '' + params);
} else if (Array.isArray(params)) { } else if (Array.isArray(params)) {
let i = 0; let i = 0;
for (const match of path.matchAll(/:[a-zA-Z_-]+\??/)) { for (const match of path.matchAll(/:[a-zA-Z_-]+(\(.*\))?\??/g)) {
if (match.length > 0) { if (match.length > 0) {
path = path.replace(match[0], typeof params[i] !== 'undefined' ? params[i] : ''); path = path.replace(match[0], typeof params[i] !== 'undefined' ? params[i] : '');
} }
i++; i++;
} }
path = path.replace(/\/+/, '/'); path = path.replace(/\/+/g, '/');
} else { } else {
for (const key in params) { for (const key of Object.keys(params)) {
if (params.hasOwnProperty(key)) { path = path.replace(new RegExp(`:${key}\\??`), params[key]);
path = path.replace(new RegExp(`:${key}\\??`), params[key]);
}
} }
} }
return `${absolute ? config.get<string>('public_url') : ''}${path}`; const queryStr = querystring.stringify(query);
return `${absolute ? config.get<string>('public_url') : ''}${path}` + (queryStr.length > 0 ? '?' + queryStr : '');
} }
private router?: Router; private readonly router: Router = express.Router();
private readonly fileUploadFormRouter: Router = express.Router();
private app?: Application;
public getGlobalHandlers(): RequestHandler[] { public getGlobalMiddlewares(): Middleware[] {
return []; return [];
} }
public hasGlobalHandlers(): boolean { public hasGlobalMiddlewares(): boolean {
return this.getGlobalHandlers().length > 0; return this.getGlobalMiddlewares().length > 0;
} }
public setupGlobalHandlers(router: Router): void { public setupGlobalHandlers(router: Router): void {
for (const globalHandler of this.getGlobalHandlers()) { for (const middleware of this.getGlobalMiddlewares()) {
router.use(this.wrap(globalHandler)); router.use(this.wrap(middleware.getRequestHandler()));
} }
} }
@ -54,36 +65,93 @@ export default abstract class Controller {
public abstract routes(): void; public abstract routes(): void;
public setupRoutes(router: Router): void { public setupRoutes(): { mainRouter: Router, fileUploadFormRouter: Router } {
this.router = router;
this.routes(); this.routes();
return {
mainRouter: this.router,
fileUploadFormRouter: this.fileUploadFormRouter,
};
} }
protected use(handler: RequestHandler) { protected use(handler: RequestHandler): void {
this.router?.use(this.wrap(handler)); this.router.use(this.wrap(handler));
logger.info('Installed anonymous middleware on ' + this.getRoutesPrefix());
} }
protected get(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: RequestHandler[]) { protected useMiddleware(...middlewares: MiddlewareType<Middleware>[]): void {
for (const middleware of middlewares) {
const instance = new middleware(this.getApp());
if (instance instanceof FileUploadMiddleware) {
this.fileUploadFormRouter.use(this.wrap(instance.getRequestHandler()));
} else {
this.router.use(this.wrap(instance.getRequestHandler()));
}
logger.info('Installed ' + middleware.name + ' on ' + this.getRoutesPrefix());
}
}
protected get(
path: PathParams,
handler: RequestHandler,
routeName?: string,
...middlewares: (MiddlewareType<Middleware>)[]
): void {
this.handle('get', path, handler, routeName, ...middlewares);
}
protected post(
path: PathParams,
handler: RequestHandler,
routeName?: string,
...middlewares: (MiddlewareType<Middleware>)[]
): void {
this.handle('post', path, handler, routeName, ...middlewares);
}
protected put(
path: PathParams,
handler: RequestHandler,
routeName?: string,
...middlewares: (MiddlewareType<Middleware>)[]
): void {
this.handle('put', path, handler, routeName, ...middlewares);
}
protected delete(
path: PathParams,
handler: RequestHandler,
routeName?: string,
...middlewares: (MiddlewareType<Middleware>)[]
): void {
this.handle('delete', path, handler, routeName, ...middlewares);
}
private handle(
action: Exclude<keyof IRouter, 'stack' | 'param' | 'route' | 'use'>,
path: PathParams,
handler: RequestHandler,
routeName?: string,
...middlewares: (MiddlewareType<Middleware>)[]
): void {
this.registerRoutes(path, handler, routeName); this.registerRoutes(path, handler, routeName);
for (const middleware of middlewares) { for (const middleware of middlewares) {
this.router?.get(path, this.wrap(middleware)); const instance = new middleware(this.getApp());
if (instance instanceof FileUploadMiddleware) {
this.fileUploadFormRouter[action](path, this.wrap(instance.getRequestHandler()));
} else {
this.router[action](path, this.wrap(instance.getRequestHandler()));
}
} }
this.router?.get(path, this.wrap(handler)); this.router[action](path, this.wrap(handler));
}
protected post(path: PathParams, handler: RequestHandler, routeName?: string, ...middlewares: RequestHandler[]) {
this.registerRoutes(path, handler, routeName);
for (const middleware of middlewares) {
this.router?.post(path, this.wrap(middleware));
}
this.router?.post(path, this.wrap(handler));
} }
private wrap(handler: RequestHandler): RequestHandler { private wrap(handler: RequestHandler): RequestHandler {
return (req, res, next) => { return async (req, res, next) => {
const promise = handler.call(this, req, res, next); try {
if (promise instanceof Promise) { await handler.call(this, req, res, next);
promise.catch(err => next(err)); } catch (e) {
next(e);
} }
}; };
} }
@ -108,13 +176,22 @@ export default abstract class Controller {
if (!Controller.routes[routeName]) { if (!Controller.routes[routeName]) {
if (typeof routePath === 'string') { if (typeof routePath === 'string') {
Logger.info(`Route ${routeName} has path ${routePath}`); logger.info(`Route ${routeName} has path ${routePath}`);
Controller.routes[routeName] = routePath; Controller.routes[routeName] = routePath;
} else { } else {
Logger.warn(`Cannot assign path to route ${routeName}.`); logger.warn(`Cannot assign path to route ${routeName}.`);
} }
} }
} }
protected getApp(): Application {
if (!this.app) throw new Error('Application not initialized.');
return this.app;
}
public setApp(app: Application): void {
this.app = app;
}
} }
export type RouteParams = { [p: string]: string } | string[] | string | number; export type RouteParams = { [p: string]: string } | string[] | string | number;

7
src/Extendable.ts Normal file
View File

@ -0,0 +1,7 @@
import {Type} from "./Utils";
export default interface Extendable<ComponentClass> {
as<C extends ComponentClass>(type: Type<C>): C;
asOptional<C extends ComponentClass>(type: Type<C>): C | null;
}

View File

@ -0,0 +1,35 @@
import {IncomingForm} from "formidable";
import Middleware from "./Middleware";
import {NextFunction, Request, Response} from "express";
import {FileError, ValidationBag} from "./db/Validator";
export default abstract class FileUploadMiddleware extends Middleware {
protected abstract makeForm(): IncomingForm;
protected abstract getDefaultField(): string;
public async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const form = this.makeForm();
try {
await new Promise<void>((resolve, reject) => {
form.parse(req, (err, fields, files) => {
if (err) {
reject(err);
return;
}
req.body = fields;
req.files = files;
resolve();
});
});
} catch (e) {
const bag = new ValidationBag();
const fileError = new FileError(e);
fileError.thingName = this.getDefaultField();
bag.addMessage(fileError);
next(bag);
return;
}
next();
}
}

View File

@ -3,12 +3,12 @@ import {WrappingError} from "./Utils";
export abstract class HttpError extends WrappingError { export abstract class HttpError extends WrappingError {
public readonly instructions: string; public readonly instructions: string;
constructor(message: string, instructions: string, cause?: Error) { protected constructor(message: string, instructions: string, cause?: Error) {
super(message, cause); super(message, cause);
this.instructions = instructions; this.instructions = instructions;
} }
get name(): string { public get name(): string {
return this.constructor.name; return this.constructor.name;
} }
@ -18,62 +18,83 @@ export abstract class HttpError extends WrappingError {
export class BadRequestError extends HttpError { export class BadRequestError extends HttpError {
public readonly url: string; public readonly url: string;
constructor(message: string, instructions: string, url: string, cause?: Error) { public constructor(message: string, instructions: string, url: string, cause?: Error) {
super(message, instructions, cause); super(message, instructions, cause);
this.url = url; this.url = url;
} }
get errorCode(): number { public get errorCode(): number {
return 400; return 400;
} }
} }
export class UnauthorizedHttpError extends BadRequestError {
public constructor(message: string, url: string, cause?: Error) {
super(message, '', url, cause);
}
public get errorCode(): number {
return 401;
}
}
export class ForbiddenHttpError extends BadRequestError { export class ForbiddenHttpError extends BadRequestError {
constructor(thing: string, url: string, cause?: Error) { public constructor(thing: string, url: string, cause?: Error) {
super( super(
`You don't have access to this ${thing}.`, `You don't have access to this ${thing}.`,
`${url} doesn't belong to *you*.`, `${url} doesn't belong to *you*.`,
url, url,
cause cause,
); );
} }
get errorCode(): number { public get errorCode(): number {
return 403; return 403;
} }
} }
export class NotFoundHttpError extends BadRequestError { export class NotFoundHttpError extends BadRequestError {
constructor(thing: string, url: string, cause?: Error) { public constructor(thing: string, url: string, cause?: Error) {
super( super(
`${thing.charAt(0).toUpperCase()}${thing.substr(1)} not found.`, `${thing.charAt(0).toUpperCase()}${thing.substr(1)} not found.`,
`${url} doesn't exist or was deleted.`, `${url} doesn't exist or was deleted.`,
url, url,
cause cause,
); );
} }
get errorCode(): number { public get errorCode(): number {
return 404; return 404;
} }
} }
export class TooManyRequestsHttpError extends BadRequestError {
public constructor(retryIn: number, jailName: string, cause?: Error) {
super(
`You're making too many requests!`,
`We need some rest. Please retry in ${Math.floor(retryIn / 1000)} seconds.`,
jailName,
cause,
);
}
public get errorCode(): number {
return 429;
}
}
export class ServerError extends HttpError { export class ServerError extends HttpError {
constructor(message: string, cause?: Error) { public constructor(message: string, cause?: Error) {
super(message, `Maybe you should contact us; see instructions below.`, cause); super(message, `Maybe you should contact us; see instructions below.`, cause);
} }
get errorCode(): number { public get errorCode(): number {
return 500; return 500;
} }
} }
export class ServiceUnavailableHttpError extends ServerError { export class ServiceUnavailableHttpError extends ServerError {
constructor(message: string, cause?: Error) { public get errorCode(): number {
super(message, cause);
}
get errorCode(): number {
return 503; return 503;
} }
} }

View File

@ -1,121 +1,40 @@
import config from "config"; import {Logger as TsLogger} from "tslog";
import {v4 as uuid} from "uuid"; import {AsyncLocalStorage} from "async_hooks";
import Log from "./models/Log"; import {RequestHandler} from "express";
import {nanoid} from "nanoid";
const LOG_LEVEL: LogLevelKeys = <LogLevelKeys>config.get<string>('log_level'); const requestIdStorage: AsyncLocalStorage<string> = new AsyncLocalStorage();
const DB_LOG_LEVEL: LogLevelKeys = <LogLevelKeys>config.get<string>('db_log_level');
export default class Logger { export const logger = new TsLogger({
public static silentError(error: Error, ...message: any[]): string { requestId: (): string => {
return this.log('ERROR', message, error, true) || ''; return requestIdStorage.getStore() as string;
} },
delimiter: '\t',
maskValuesOfKeys: [
'Authorization',
'password',
'password_confirmation',
'secret',
],
displayFunctionName: false,
displayFilePath: 'hidden',
});
public static error(error: Error, ...message: any[]): string { export const loggingContextMiddleware: RequestHandler = (req, res, next) => {
return this.log('ERROR', message, error) || ''; requestIdStorage.run(nanoid(8), () => {
} next();
});
};
public static warn(...message: any[]) { export const preventContextCorruptionMiddleware = (delegate: RequestHandler): RequestHandler => (
this.log('WARN', message); req,
} res,
next,
) => {
const data = requestIdStorage.getStore() as string;
public static info(...message: any[]) { delegate(req, res, (err?: Error | 'router') => {
this.log('INFO', message); requestIdStorage.enterWith(data);
} next(err);
});
public static debug(...message: any[]) { };
this.log('DEBUG', message);
}
public static dev(...message: any[]) {
this.log('DEV', message);
}
private static log(level: LogLevelKeys, message: any[], error?: Error, silent: boolean = false): string | null {
const levelIndex = LogLevel[level];
if (levelIndex <= LogLevel[LOG_LEVEL]) {
if (error) {
if (levelIndex > LogLevel.ERROR) this.warn(`Wrong log level ${level} with attached error.`);
} else {
if (levelIndex <= LogLevel.ERROR) this.warn(`No error attached with log level ${level}.`);
}
const computedMsg = message.map(v => {
if (typeof v === 'string') {
return v;
} else {
return JSON.stringify(v, (key: string, value: any) => {
if (value instanceof Object) {
if (value.type === 'Buffer') {
return `Buffer<${Buffer.from(value.data).toString('hex')}>`;
} else if (value !== v) {
return `[object Object]`;
}
}
if (typeof value === 'string' && value.length > 96) {
return value.substr(0, 96) + '...';
}
return value;
}, 4);
}
}).join(' ');
const log = new Log({});
log.setLevel(level);
log.message = computedMsg;
log.setError(error);
let logID = Buffer.alloc(16);
uuid({}, logID);
log.setLogID(logID);
let output = `[${level}] `;
let pad = output.length;
if (levelIndex <= LogLevel[DB_LOG_LEVEL]) output += `${log.getLogID()} - `;
output += computedMsg.replace(/\n/g, '\n' + ' '.repeat(pad));
switch (level) {
case "ERROR":
if (silent || !error) {
console.error(output);
} else {
console.error(output, error);
}
break;
case "WARN":
console.warn(output);
break;
case "INFO":
console.info(output);
break;
case "DEBUG":
case "DEV":
console.debug(output);
break;
}
if (levelIndex <= LogLevel[DB_LOG_LEVEL]) {
log.save().catch(err => {
if (!silent && err.message.indexOf('ECONNREFUSED') < 0) {
console.error({save_err: err, error});
}
});
}
return log.getLogID();
}
return null;
}
private constructor() {
}
}
export enum LogLevel {
ERROR,
WARN,
INFO,
DEBUG,
DEV,
}
export type LogLevelKeys = keyof typeof LogLevel;

24
src/Mails.ts Normal file
View File

@ -0,0 +1,24 @@
import config from "config";
import {MailTemplate} from "./mail/Mail";
export const MAGIC_LINK_MAIL = new MailTemplate(
'magic_link',
data => data.type === 'register' ?
'Registration' :
'Login magic link',
);
export const ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE: MailTemplate = new MailTemplate(
'account_review_notice',
data => `Your account was ${data.approved ? 'approved' : 'rejected'}.`,
);
export const PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE: MailTemplate = new MailTemplate(
'pending_account_review',
() => `A new account is pending review on ${config.get<string>('app.name')}`,
);
export const ADD_EMAIL_MAIL_TEMPLATE: MailTemplate = new MailTemplate(
'add_email',
(data) => `Add ${data.email} address to your ${config.get<string>('app.name')} account.`,
);

33
src/Middleware.ts Normal file
View File

@ -0,0 +1,33 @@
import {RequestHandler} from "express";
import {NextFunction, Request, Response} from "express-serve-static-core";
import Application from "./Application";
import {Type} from "./Utils";
export default abstract class Middleware {
public constructor(
protected readonly app: Application,
) {
}
protected abstract handle(req: Request, res: Response, next: NextFunction): Promise<void>;
public getRequestHandler(): RequestHandler {
return async (req, res, next): Promise<void> => {
try {
if (req.middlewares.find(m => m.constructor === this.constructor)) {
next();
} else {
req.middlewares.push(this);
return await this.handle(req, res, next);
}
} catch (e) {
next(e);
}
};
}
}
export interface MiddlewareType<M extends Middleware> extends Type<M> {
new(app: Application): M;
}

View File

@ -6,7 +6,7 @@ export default class Pagination<T extends Model> {
public readonly perPage: number; public readonly perPage: number;
public readonly totalCount: number; public readonly totalCount: number;
constructor(models: T[], page: number, perPage: number, totalCount: number) { public constructor(models: T[], page: number, perPage: number, totalCount: number) {
this.models = models; this.models = models;
this.page = page; this.page = page;
this.perPage = perPage; this.perPage = perPage;

8
src/SecurityError.ts Normal file
View File

@ -0,0 +1,8 @@
export default class SecurityError implements Error {
public readonly name: string = 'SecurityError';
public readonly message: string;
public constructor(message: string) {
this.message = message;
}
}

122
src/TestApp.ts Normal file
View File

@ -0,0 +1,122 @@
import Application from "../src/Application";
import Migration, {MigrationType} from "../src/db/Migration";
import ExpressAppComponent from "../src/components/ExpressAppComponent";
import RedisComponent from "../src/components/RedisComponent";
import MysqlComponent from "../src/components/MysqlComponent";
import NunjucksComponent from "../src/components/NunjucksComponent";
import LogRequestsComponent from "../src/components/LogRequestsComponent";
import MailComponent from "../src/components/MailComponent";
import SessionComponent from "../src/components/SessionComponent";
import AuthComponent from "../src/auth/AuthComponent";
import FormHelperComponent from "../src/components/FormHelperComponent";
import ServeStaticDirectoryComponent from "../src/components/ServeStaticDirectoryComponent";
import {Express} from "express";
import MagicLinkAuthMethod from "../src/auth/magic_link/MagicLinkAuthMethod";
import PasswordAuthMethod from "../src/auth/password/PasswordAuthMethod";
import {MAGIC_LINK_MAIL} from "./Mails";
import CreateMigrationsTable from "./migrations/CreateMigrationsTable";
import CreateUsersAndUserEmailsTableMigration from "./auth/migrations/CreateUsersAndUserEmailsTableMigration";
import CreateMagicLinksTableMigration from "./auth/magic_link/CreateMagicLinksTableMigration";
import AuthController from "./auth/AuthController";
import MagicLinkWebSocketListener from "./auth/magic_link/MagicLinkWebSocketListener";
import MagicLinkController from "./auth/magic_link/MagicLinkController";
import AddPasswordToUsersMigration from "./auth/password/AddPasswordToUsersMigration";
import AddNameToUsersMigration from "./auth/migrations/AddNameToUsersMigration";
import CsrfProtectionComponent from "./components/CsrfProtectionComponent";
import MailController from "./mail/MailController";
import WebSocketServerComponent from "./components/WebSocketServerComponent";
import Controller from "./Controller";
import AccountController from "./auth/AccountController";
import MakeMagicLinksSessionNotUniqueMigration from "./auth/magic_link/MakeMagicLinksSessionNotUniqueMigration";
import AddUsedToMagicLinksMigration from "./auth/magic_link/AddUsedToMagicLinksMigration";
import packageJson = require('../package.json');
import PreviousUrlComponent from "./components/PreviousUrlComponent";
export const MIGRATIONS = [
CreateMigrationsTable,
CreateUsersAndUserEmailsTableMigration,
AddPasswordToUsersMigration,
CreateMagicLinksTableMigration,
AddNameToUsersMigration,
MakeMagicLinksSessionNotUniqueMigration,
AddUsedToMagicLinksMigration,
];
export default class TestApp extends Application {
private readonly addr: string;
private readonly port: number;
public constructor(addr: string, port: number) {
super(packageJson.version, true);
this.addr = addr;
this.port = port;
}
protected getMigrations(): MigrationType<Migration>[] {
return MIGRATIONS;
}
protected async init(): Promise<void> {
this.registerComponents();
this.registerWebSocketListeners();
this.registerControllers();
}
protected registerComponents(): void {
// Base
this.use(new ExpressAppComponent(this.addr, this.port));
this.use(new LogRequestsComponent());
// Static files
this.use(new ServeStaticDirectoryComponent('public'));
// Dynamic views and routes
this.use(new NunjucksComponent(['test/views', 'views']));
this.use(new PreviousUrlComponent());
// Services
this.use(new MysqlComponent());
this.use(new MailComponent());
// Session
this.use(new RedisComponent());
this.use(new SessionComponent(this.as(RedisComponent)));
// Utils
this.use(new FormHelperComponent());
// Middlewares
this.use(new CsrfProtectionComponent());
// Auth
this.use(new AuthComponent(this, new MagicLinkAuthMethod(this, MAGIC_LINK_MAIL), new PasswordAuthMethod(this)));
// WebSocket server
this.use(new WebSocketServerComponent(this, this.as(ExpressAppComponent), this.as(RedisComponent)));
}
protected registerWebSocketListeners(): void {
this.use(new MagicLinkWebSocketListener());
}
protected registerControllers(): void {
this.use(new MailController());
this.use(new AuthController());
this.use(new AccountController());
this.use(new MagicLinkController(this.as<MagicLinkWebSocketListener<this>>(MagicLinkWebSocketListener)));
// Special home controller
this.use(new class extends Controller {
public routes(): void {
this.get('/', (req, res) => {
res.render('home');
}, 'home');
}
}());
}
public getExpressApp(): Express {
return this.as(ExpressAppComponent).getExpressApp();
}
}

93
src/Throttler.ts Normal file
View File

@ -0,0 +1,93 @@
import {TooManyRequestsHttpError} from "./HttpError";
import {logger} from "./Logger";
export default class Throttler {
private static readonly throttles: Record<string, Throttle | undefined> = {};
/**
* Throttle function; will throw a TooManyRequestsHttpError when the threshold is reached.
*
* This throttle is adaptive: it will slowly decrease (linear) until it reaches 0 after {@param resetPeriod} ms.
* Threshold will hold for {@param holdPeriod} ms.
*
* @param action a unique action name (can be used multiple times, but it'll account for a single action).
* @param max how many times this action can be triggered per id.
* @param resetPeriod after how much time in ms the throttle will reach 0.
* @param id an identifier of who triggered the action.
* @param holdPeriod time in ms after each call before the threshold begins to decrease.
* @param jailPeriod time in ms for which the throttle will throw when it is triggered.
*/
public static throttle(
action: string,
max: number,
resetPeriod: number,
id: string,
holdPeriod: number = 100,
jailPeriod: number = 30 * 1000,
): void {
let throttle = this.throttles[action];
if (!throttle)
throttle = this.throttles[action] = new Throttle(action, max, resetPeriod, holdPeriod, jailPeriod);
throttle.trigger(id);
}
private constructor() {
// Disable constructor
}
}
class Throttle {
private readonly jailName: string;
private readonly max: number;
private readonly resetPeriod: number;
private readonly holdPeriod: number;
private readonly jailPeriod: number;
private readonly triggers: Record<string, {
count: number,
lastTrigger?: number,
jailed?: number;
} | undefined> = {};
public constructor(jailName: string, max: number, resetPeriod: number, holdPeriod: number, jailPeriod: number) {
this.jailName = jailName;
this.max = max;
this.resetPeriod = resetPeriod;
this.holdPeriod = holdPeriod;
this.jailPeriod = jailPeriod;
}
public trigger(id: string) {
let trigger = this.triggers[id];
if (!trigger) trigger = this.triggers[id] = {count: 0};
const currentDate = new Date().getTime();
if (trigger.jailed && currentDate - trigger.jailed < this.jailPeriod)
return this.throw(trigger.jailed + this.jailPeriod - currentDate);
if (trigger.lastTrigger) {
let timeDiff = currentDate - trigger.lastTrigger;
if (timeDiff > this.holdPeriod) {
timeDiff -= this.holdPeriod;
trigger.count = Math.floor(Math.min(trigger.count, (this.max + 1) * (1 - timeDiff / this.resetPeriod)));
}
}
trigger.count++;
trigger.lastTrigger = currentDate;
if (trigger.count > this.max) {
trigger.jailed = currentDate;
const unjailedIn = trigger.jailed + this.jailPeriod - currentDate;
logger.info(`Jail ${this.jailName} triggered by ${id} and will be unjailed in ${unjailedIn}ms.`);
return this.throw(unjailedIn);
}
}
protected throw(unjailedIn: number) {
throw new TooManyRequestsHttpError(unjailedIn, this.jailName);
}
}

View File

@ -1,5 +1,3 @@
import * as crypto from "crypto";
export async function sleep(ms: number): Promise<void> { export async function sleep(ms: number): Promise<void> {
return await new Promise(resolve => { return await new Promise(resolve => {
setTimeout(() => resolve(), ms); setTimeout(() => resolve(), ms);
@ -18,22 +16,31 @@ export abstract class WrappingError extends Error {
} }
} }
get name(): string { public get name(): string {
return this.constructor.name; return this.constructor.name;
} }
} }
export function cryptoRandomDictionary(size: number, dictionary: string): string { export type Type<T> = { new(...args: never[]): T };
const randomBytes = crypto.randomBytes(size);
const output = new Array(size);
for (let i = 0; i < size; i++) { export function bufferToUuid(buffer: Buffer): string {
output[i] = dictionary[Math.floor((randomBytes[i] / 255) * dictionary.length)]; const chars = buffer.toString('hex');
let out = '';
let i = 0;
for (const l of [8, 4, 4, 4, 12]) {
if (i > 0) out += '-';
out += chars.substr(i, l);
i += l;
} }
return out;
return output.join('');
} }
export interface Type<T> extends Function { export function getMethods<T extends { [p: string]: unknown }>(obj: T): string[] {
new(...args: any[]): T const properties = new Set<string>();
let currentObj: T | unknown = obj;
do {
Object.getOwnPropertyNames(currentObj).map(item => properties.add(item));
currentObj = Object.getPrototypeOf(currentObj);
} while (currentObj);
return [...properties.keys()].filter(item => typeof obj[item] === 'function');
} }

View File

@ -1,8 +1,24 @@
import WebSocket from "ws"; import WebSocket from "ws";
import {IncomingMessage} from "http"; import {IncomingMessage} from "http";
import Application from "./Application";
import {Session} from "express-session";
export default abstract class WebSocketListener<T extends Application> {
private app!: T;
public init(app: T): void {
this.app = app;
}
protected getApp(): T {
return this.app;
}
export default abstract class WebSocketListener {
public abstract path(): string; public abstract path(): string;
public abstract async handle(socket: WebSocket, request: IncomingMessage, session: Express.SessionData): Promise<void>; public abstract handle(
socket: WebSocket,
request: IncomingMessage,
session: Session | null,
): Promise<void>;
} }

View File

@ -0,0 +1,145 @@
import Controller from "../Controller";
import {RequireAuthMiddleware} from "./AuthComponent";
import {Request, Response} from "express";
import {BadRequestError, ForbiddenHttpError, NotFoundHttpError} from "../HttpError";
import config from "config";
import Validator, {EMAIL_REGEX, InvalidFormatValidationError} from "../db/Validator";
import UserPasswordComponent from "./password/UserPasswordComponent";
import User from "./models/User";
import ModelFactory from "../db/ModelFactory";
import UserEmail from "./models/UserEmail";
import MagicLinkController from "./magic_link/MagicLinkController";
import {MailTemplate} from "../mail/Mail";
import {ADD_EMAIL_MAIL_TEMPLATE} from "../Mails";
import AuthMagicLinkActionType from "./magic_link/AuthMagicLinkActionType";
export default class AccountController extends Controller {
private readonly addEmailMailTemplate: MailTemplate;
public constructor(addEmailMailTemplate: MailTemplate = ADD_EMAIL_MAIL_TEMPLATE) {
super();
this.addEmailMailTemplate = addEmailMailTemplate;
}
public getRoutesPrefix(): string {
return '/account';
}
public routes(): void {
this.get('/', this.getAccount, 'account', RequireAuthMiddleware);
if (ModelFactory.get(User).hasComponent(UserPasswordComponent)) {
this.post('/change-password', this.postChangePassword, 'change-password', RequireAuthMiddleware);
}
this.post('/add-email', this.addEmail, 'add-email', RequireAuthMiddleware);
this.post('/set-main-email', this.postSetMainEmail, 'set-main-email', RequireAuthMiddleware);
this.post('/remove-email', this.postRemoveEmail, 'remove-email', RequireAuthMiddleware);
}
protected async getAccount(req: Request, res: Response): Promise<void> {
const user = req.as(RequireAuthMiddleware).getUser();
res.render('auth/account', {
main_email: await user.mainEmail.get(),
emails: await user.emails.get(),
display_email_warning: config.get('app.display_email_warning'),
has_password: user.asOptional(UserPasswordComponent)?.hasPassword(),
});
}
protected async postChangePassword(req: Request, res: Response): Promise<void> {
const validationMap = {
'new_password': new Validator().defined(),
'new_password_confirmation': new Validator().sameAs('new_password', req.body.new_password),
};
await Validator.validate(validationMap, req.body);
const user = req.as(RequireAuthMiddleware).getUser();
const passwordComponent = user.as(UserPasswordComponent);
if (passwordComponent.hasPassword() && !await passwordComponent.verifyPassword(req.body.current_password)) {
req.flash('error', 'Invalid current password.');
res.redirect(Controller.route('account'));
return;
}
await passwordComponent.setPassword(req.body.new_password, 'new_password');
await user.save();
req.flash('success', 'Password changed successfully.');
res.redirect(Controller.route('account'));
}
protected async addEmail(req: Request, res: Response): Promise<void> {
await Validator.validate({
email: new Validator().defined().regexp(EMAIL_REGEX),
}, req.body);
const email = req.body.email;
// Existing email
if (await UserEmail.select().where('email', email).first()) {
const error = new InvalidFormatValidationError('You already have this email.');
error.thingName = 'email';
throw error;
}
await MagicLinkController.sendMagicLink(
this.getApp(),
req.getSession().id,
AuthMagicLinkActionType.ADD_EMAIL,
Controller.route('account'),
email,
this.addEmailMailTemplate,
{
email: email,
},
);
res.redirect(Controller.route('magic_link_lobby', undefined, {
redirect_uri: Controller.route('account'),
}));
}
protected async postSetMainEmail(req: Request, res: Response): Promise<void> {
if (!req.body.id)
throw new BadRequestError('Missing id field', 'Check form parameters.', req.url);
const user = req.as(RequireAuthMiddleware).getUser();
const userEmail = await UserEmail.getById(req.body.id);
if (!userEmail)
throw new NotFoundHttpError('email', req.url);
if (userEmail.user_id !== user.id)
throw new ForbiddenHttpError('email', req.url);
if (userEmail.id === user.main_email_id)
throw new BadRequestError('This address is already your main address',
'Try refreshing the account page.', req.url);
user.main_email_id = userEmail.id;
await user.save();
req.flash('success', 'This email was successfully set as your main address.');
res.redirect(Controller.route('account'));
}
protected async postRemoveEmail(req: Request, res: Response): Promise<void> {
if (!req.body.id)
throw new BadRequestError('Missing id field', 'Check form parameters.', req.url);
const user = req.as(RequireAuthMiddleware).getUser();
const userEmail = await UserEmail.getById(req.body.id);
if (!userEmail)
throw new NotFoundHttpError('email', req.url);
if (userEmail.user_id !== user.id)
throw new ForbiddenHttpError('email', req.url);
if (userEmail.id === user.main_email_id)
throw new BadRequestError('Cannot remove main email address', 'Try refreshing the account page.', req.url);
await userEmail.delete();
req.flash('success', 'This email was successfully removed from your account.');
res.redirect(Controller.route('account'));
}
}

136
src/auth/AuthComponent.ts Normal file
View File

@ -0,0 +1,136 @@
import ApplicationComponent from "../ApplicationComponent";
import {NextFunction, Request, Response} from "express";
import AuthGuard from "./AuthGuard";
import Controller from "../Controller";
import {ForbiddenHttpError} from "../HttpError";
import Middleware from "../Middleware";
import User from "./models/User";
import Application from "../Application";
import AuthMethod from "./AuthMethod";
import AuthProof from "./AuthProof";
export default class AuthComponent extends ApplicationComponent {
private readonly authGuard: AuthGuard;
public constructor(app: Application, ...authMethods: AuthMethod<AuthProof<User>>[]) {
super();
this.authGuard = new AuthGuard(app, ...authMethods);
}
public async init(): Promise<void> {
this.use(AuthMiddleware);
}
public getAuthGuard(): AuthGuard {
return this.authGuard;
}
}
export class AuthMiddleware extends Middleware {
private authGuard?: AuthGuard;
private user: User | null = null;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
this.authGuard = this.app.as(AuthComponent).getAuthGuard();
const proofs = await this.authGuard.getProofsForSession(req.getSession());
if (proofs.length > 0) {
this.user = await proofs[0].getResource();
res.locals.user = this.user;
}
next();
}
public getUser(): User | null {
return this.user;
}
public getAuthGuard(): AuthGuard {
if (!this.authGuard) throw new Error('AuthGuard was not initialized.');
return this.authGuard;
}
}
export class RequireRequestAuthMiddleware extends Middleware {
private user?: User;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const proofs = await req.as(AuthMiddleware).getAuthGuard().getProofsForRequest(req);
const user = await proofs[0]?.getResource();
if (user) {
this.user = user;
next();
return;
}
req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url,
}));
}
public getUser(): User {
if (!this.user) throw new Error('user not initialized.');
return this.user;
}
}
export class RequireAuthMiddleware extends Middleware {
private user?: User;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const authGuard = req.as(AuthMiddleware).getAuthGuard();
// Via request
let proofs = await authGuard.getProofsForRequest(req);
let user = await proofs[0]?.getResource();
if (user) {
this.user = user;
next();
return;
}
// Via session
proofs = await authGuard.getProofsForSession(req.getSession());
user = await proofs[0]?.getResource();
if (user) {
this.user = user;
next();
return;
}
req.flash('error', `You must be logged in to access ${req.url}.`);
res.redirect(Controller.route('auth', undefined, {
redirect_uri: req.url,
}));
}
public getUser(): User {
if (!this.user) throw new Error('user not initialized.');
return this.user;
}
}
export class RequireGuestMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const proofs = await req.as(AuthMiddleware).getAuthGuard().getProofsForSession(req.getSession());
if (proofs.length > 0) {
res.redirect(Controller.route('home'));
return;
}
next();
}
}
export class RequireAdminMiddleware extends Middleware {
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
const user = req.as(AuthMiddleware).getUser();
if (!user || !user.is_admin) {
throw new ForbiddenHttpError('secret tool', req.url);
}
next();
}
}

129
src/auth/AuthController.ts Normal file
View File

@ -0,0 +1,129 @@
import Controller from "../Controller";
import {NextFunction, Request, Response} from "express";
import AuthComponent, {AuthMiddleware, RequireAuthMiddleware, RequireGuestMiddleware} from "./AuthComponent";
import {BadRequestError} from "../HttpError";
import ModelFactory from "../db/ModelFactory";
import User from "./models/User";
import UserPasswordComponent from "./password/UserPasswordComponent";
import UserNameComponent from "./models/UserNameComponent";
import {UnknownRelationValidationError} from "../db/Validator";
import AuthMethod from "./AuthMethod";
import AuthProof from "./AuthProof";
export default class AuthController extends Controller {
public getRoutesPrefix(): string {
return '/auth';
}
public routes(): void {
this.post('/logout', this.postLogout, 'logout', RequireAuthMiddleware);
this.use(async (req, res, next) => {
const authGuard = this.getApp().as(AuthComponent).getAuthGuard();
if (await authGuard.interruptAuth(req, res)) return;
next();
});
this.get('/', this.getAuth, 'auth', RequireGuestMiddleware);
this.post('/login', this.postLogin, 'login', RequireGuestMiddleware);
this.post('/register', this.postRegister, 'register', RequireGuestMiddleware);
}
protected async getAuth(req: Request, res: Response, _next: NextFunction): Promise<void> {
const authGuard = this.getApp().as(AuthComponent).getAuthGuard();
const userModelFactory = ModelFactory.get(User);
const hasUsername = userModelFactory.hasComponent(UserNameComponent);
res.render('auth/auth', {
auth_methods: authGuard.getAuthMethodNames(),
has_username: hasUsername,
register_with_password: hasUsername && userModelFactory.hasComponent(UserPasswordComponent),
});
}
protected async postLogin(req: Request, res: Response): Promise<void> {
return await this.handleAuth(req, res, false);
}
protected async postRegister(req: Request, res: Response): Promise<void> {
return await this.handleAuth(req, res, true);
}
protected async handleAuth(req: Request, res: Response, isRegistration: boolean): Promise<void> {
if (isRegistration && !req.body.auth_method) {
throw new BadRequestError('Cannot register without specifying desired auth_method.',
'Please specify auth_method.', req.url);
}
const authGuard = this.getApp().as(AuthComponent).getAuthGuard();
const identifier = req.body.identifier;
if (!identifier) throw new BadRequestError('Identifier not specified.', 'Please try again.', req.originalUrl);
// Get requested auth method
if (req.body.auth_method) {
const method = await authGuard.getAuthMethodByName(req.body.auth_method);
if (!method) {
throw new BadRequestError('Invalid auth method: ' + req.body.auth_method,
'Available methods are: ' + authGuard.getAuthMethodNames(), req.url);
}
// Register
if (isRegistration) return await method.attemptRegister(req, res, identifier);
const user = await method.findUserByIdentifier(identifier);
// Redirect to registration if user not found
if (!user) return await this.redirectToRegistration(req, res, identifier);
// Login
return await method.attemptLogin(req, res, user);
}
const methods = await authGuard.getAuthMethodsByIdentifier(identifier);
// Redirect to registration if user not found
if (methods.length === 0) return await this.redirectToRegistration(req, res, identifier);
// Choose best matching method
let user: User | null = null;
let method: AuthMethod<AuthProof<User>> | null = null;
let weight = -1;
for (const entry of methods) {
const methodWeight = entry.method.getWeightForRequest(req);
if (methodWeight > weight) {
user = entry.user;
method = entry.method;
weight = methodWeight;
}
}
if (!method || !user) ({method, user} = methods[0]); // Default to first method
// Login
return await method.attemptLogin(req, res, user);
}
protected async postLogout(req: Request, res: Response, _next: NextFunction): Promise<void> {
const userId = typeof req.body.user_id === 'string' ? parseInt(req.body.user_id) : null;
const proofs = await req.as(AuthMiddleware).getAuthGuard().getProofs(req);
for (const proof of proofs) {
if (userId === null || (await proof.getResource())?.id === userId) {
await proof.revoke();
}
}
req.flash('success', 'Successfully logged out.');
res.redirect(req.getIntendedUrl() || '/');
}
protected async redirectToRegistration(req: Request, res: Response, identifier: string): Promise<void> {
const error = new UnknownRelationValidationError(User.table, 'identifier');
error.thingName = 'identifier';
error.value = identifier;
throw error;
}
}

187
src/auth/AuthGuard.ts Normal file
View File

@ -0,0 +1,187 @@
import AuthProof from "./AuthProof";
import MysqlConnectionManager from "../db/MysqlConnectionManager";
import User from "./models/User";
import {Connection} from "mysql";
import {Request, Response} from "express";
import {PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE} from "../Mails";
import Mail from "../mail/Mail";
import Controller from "../Controller";
import config from "config";
import Application from "../Application";
import NunjucksComponent from "../components/NunjucksComponent";
import AuthMethod from "./AuthMethod";
import {Session, SessionData} from "express-session";
export default class AuthGuard {
private readonly authMethods: AuthMethod<AuthProof<User>>[];
public constructor(
private readonly app: Application,
...authMethods: AuthMethod<AuthProof<User>>[]
) {
this.authMethods = authMethods;
}
public async interruptAuth(req: Request, res: Response): Promise<boolean> {
for (const method of this.authMethods) {
if (method.interruptAuth && await method.interruptAuth(req, res)) return true;
}
return false;
}
public getAuthMethodByName(authMethodName: string): AuthMethod<AuthProof<User>> | null {
return this.authMethods.find(m => m.getName() === authMethodName) || null;
}
public getAuthMethodNames(): string[] {
return this.authMethods.map(m => m.getName());
}
public getRegistrationMethod(): AuthMethod<AuthProof<User>> {
return this.authMethods[0];
}
public async getAuthMethodsByIdentifier(
identifier: string,
): Promise<{ user: User, method: AuthMethod<AuthProof<User>> }[]> {
const methods = [];
for (const method of this.authMethods) {
const user = await method.findUserByIdentifier(identifier);
if (user) methods.push({user, method});
}
return methods;
}
public async getProofs(req: Request): Promise<AuthProof<User>[]> {
const proofs = [];
if (req.getSessionOptional()) {
proofs.push(...await this.getProofsForSession(req.session));
}
proofs.push(...await this.getProofsForRequest(req));
return proofs;
}
public async getProofsForSession(session: Session & Partial<SessionData>): Promise<AuthProof<User>[]> {
if (!session.isAuthenticated) return [];
const proofs = [];
for (const method of this.authMethods) {
if (method.getProofsForSession) {
const methodProofs = await method.getProofsForSession(session);
for (const proof of methodProofs) {
if (!await proof.isValid() || !await proof.isAuthorized()) {
await proof.revoke();
} else {
proofs.push(proof);
}
}
}
}
if (proofs.length === 0) {
session.isAuthenticated = false;
session.persistent = false;
}
return proofs;
}
public async getProofsForRequest(req: Request): Promise<AuthProof<User>[]> {
const proofs = [];
for (const method of this.authMethods) {
if (method.getProofsForRequest) {
const methodProofs = await method.getProofsForRequest(req);
for (const proof of methodProofs) {
if (!await proof.isValid() || !await proof.isAuthorized()) {
await proof.revoke();
} else {
proofs.push(proof);
}
}
}
}
return proofs;
}
public async authenticateOrRegister(
session: Session & Partial<SessionData>,
proof: AuthProof<User>,
persistSession: boolean,
onLogin?: (user: User) => Promise<void>,
beforeRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]>,
afterRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]>,
): Promise<User> {
if (!await proof.isValid()) throw new InvalidAuthProofError();
if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError();
let user = await proof.getResource();
// Register if user doesn't exist
if (!user) {
const callbacks: RegisterCallback[] = [];
user = await MysqlConnectionManager.wrapTransaction(async connection => {
const user = User.create({});
if (beforeRegister) {
(await beforeRegister(connection, user)).forEach(c => callbacks.push(c));
}
await user.save(connection, c => callbacks.push(c));
if (afterRegister) {
(await afterRegister(connection, user)).forEach(c => callbacks.push(c));
}
return user;
});
for (const callback of callbacks) {
await callback();
}
if (!user.isApproved()) {
await new Mail(this.app.as(NunjucksComponent).getEnvironment(), PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE, {
username: (await user.mainEmail.get())?.getOrFail('email'),
link: config.get<string>('public_url') + Controller.route('accounts-approval'),
}).send(config.get<string>('app.contact_email'));
}
}
// Don't login if user isn't approved
if (!user.isApproved()) {
throw new PendingApprovalAuthError();
}
// Login
session.isAuthenticated = true;
session.persistent = persistSession;
if (onLogin) await onLogin(user);
return user;
}
}
export class AuthError extends Error {
}
export class AuthProofError extends AuthError {
}
export class InvalidAuthProofError extends AuthProofError {
public constructor() {
super('Invalid auth proof.');
}
}
export class UnauthorizedAuthProofError extends AuthProofError {
public constructor() {
super('Unauthorized auth proof.');
}
}
export class PendingApprovalAuthError extends AuthError {
public constructor() {
super(`User is not approved.`);
}
}
export type RegisterCallback = () => Promise<void>;

35
src/auth/AuthMethod.ts Normal file
View File

@ -0,0 +1,35 @@
import User from "./models/User";
import AuthProof from "./AuthProof";
import {Request, Response} from "express";
import {Session} from "express-session";
export default interface AuthMethod<P extends AuthProof<User>> {
/**
* @return A unique name.
*/
getName(): string;
/**
* Used for automatic auth method detection. Won't affect forced auth method.
*
* @return {@code 0} if the request is not conform to this auth method, otherwise the exact count of matching
* fields.
*/
getWeightForRequest(req: Request): number;
findUserByIdentifier(identifier: string): Promise<User | null>;
getProofsForSession?(session: Session): Promise<P[]>;
getProofsForRequest?(req: Request): Promise<P[]>;
/**
* @return {@code true} if interrupted, {@code false} otherwise.
*/
interruptAuth?(req: Request, res: Response): Promise<boolean>;
attemptLogin(req: Request, res: Response, user: User): Promise<void>;
attemptRegister(req: Request, res: Response, identifier: string): Promise<void>;
}

41
src/auth/AuthProof.ts Normal file
View File

@ -0,0 +1,41 @@
/**
* This class is most commonly used for authentication. It can be more generically used to represent a verification
* state of whether a given resource is owned by a session.
*
* Any auth system should consider this auth proof valid if and only if both {@code isValid()} and
* {@code isAuthorized()} both return {@code true}.
*
* @type <R> The resource type this AuthProof authorizes.
*/
export default interface AuthProof<R> {
/**
* Is this auth proof valid in time (and context)?
*
* For example, it can return true for an initial short validity time period then false, and increase that time
* period if {@code isAuthorized()} returns true.
*/
isValid(): Promise<boolean>;
/**
* Was this proof authorized?
*
* Return true once the session is proven to own the associated resource.
*/
isAuthorized(): Promise<boolean>;
/**
* Retrieve the resource this auth proof is supposed to authorize.
* If this resource doesn't exist yet, return {@code null}.
*/
getResource(): Promise<R | null>;
/**
* Manually revokes this authentication proof. Once this method is called, all of the following must be true:
* - {@code isAuthorized} returns {@code false}
* - There is no way to re-authorize this proof (i.e. {@code isAuthorized} can never return {@code true} again)
*
* Additionally, this method should delete any stored data that could lead to restoration of this AuthProof
* instance.
*/
revoke(): Promise<void>;
}

View File

@ -0,0 +1,12 @@
import Migration from "../../db/Migration";
export default class AddUsedToMagicLinksMigration extends Migration {
public async install(): Promise<void> {
await this.query(`ALTER TABLE magic_links
ADD COLUMN used BOOLEAN NOT NULL`);
}
public async rollback(): Promise<void> {
await this.query('ALTER TABLE magic_links DROP COLUMN IF EXISTS used');
}
}

View File

@ -0,0 +1,5 @@
export default {
LOGIN: 'login',
REGISTER: 'register',
ADD_EMAIL: 'add_email',
};

View File

@ -0,0 +1,29 @@
import Migration from "../../db/Migration";
import ModelFactory from "../../db/ModelFactory";
import MagicLink from "../models/MagicLink";
export default class CreateMagicLinksTableMigration extends Migration {
public async install(): Promise<void> {
await this.query(`CREATE TABLE magic_links
(
id INT NOT NULL AUTO_INCREMENT,
session_id CHAR(32) UNIQUE NOT NULL,
email VARCHAR(254) NOT NULL,
token CHAR(96) NOT NULL,
action_type VARCHAR(64) NOT NULL,
original_url VARCHAR(1745) NOT NULL,
generated_at DATETIME NOT NULL,
authorized BOOLEAN NOT NULL,
PRIMARY KEY (id)
)`);
}
public async rollback(): Promise<void> {
await this.query('DROP TABLE magic_links');
}
public registerModels(): void {
ModelFactory.register(MagicLink);
}
}

View File

@ -0,0 +1,125 @@
import AuthMethod from "../AuthMethod";
import {Request, Response} from "express";
import User from "../models/User";
import UserEmail from "../models/UserEmail";
import MagicLink from "../models/MagicLink";
import {WhereTest} from "../../db/ModelQuery";
import Controller from "../../Controller";
import geoip from "geoip-lite";
import MagicLinkController from "./MagicLinkController";
import Application from "../../Application";
import {MailTemplate} from "../../mail/Mail";
import AuthMagicLinkActionType from "./AuthMagicLinkActionType";
import Validator, {EMAIL_REGEX} from "../../db/Validator";
import ModelFactory from "../../db/ModelFactory";
import UserNameComponent from "../models/UserNameComponent";
import {Session} from "express-session";
export default class MagicLinkAuthMethod implements AuthMethod<MagicLink> {
public constructor(
protected readonly app: Application,
protected readonly magicLinkMailTemplate: MailTemplate,
) {
}
public getName(): string {
return 'magic_link';
}
public getWeightForRequest(req: Request): number {
return !req.body.identifier || !EMAIL_REGEX.test(req.body.identifier) ?
0 :
1;
}
public async findUserByIdentifier(identifier: string): Promise<User | null> {
return (await UserEmail.select()
.with('user.mainEmail')
.where('email', identifier)
.first())?.user.getOrFail() || null;
}
public async getProofsForSession(session: Session): Promise<MagicLink[]> {
return await MagicLink.select()
.where('session_id', session.id)
.where('action_type', [AuthMagicLinkActionType.LOGIN, AuthMagicLinkActionType.REGISTER], WhereTest.IN)
.get();
}
public async interruptAuth(req: Request, res: Response): Promise<boolean> {
const pendingLink = await MagicLink.select()
.where('session_id', req.getSession().id)
.first();
if (pendingLink) {
if (await pendingLink.isValid()) {
res.redirect(Controller.route('magic_link_lobby', undefined, {
redirect_uri: req.getIntendedUrl() || pendingLink.original_url || undefined,
}));
return true;
} else {
await pendingLink.delete();
}
}
return false;
}
public async attemptLogin(req: Request, res: Response, user: User): Promise<void> {
const userEmail = user.mainEmail.getOrFail();
if (!userEmail) throw new Error('No main email for user ' + user.id);
await this.auth(req, res, false, userEmail.getOrFail('email'));
}
public async attemptRegister(req: Request, res: Response, identifier: string): Promise<void> {
const userEmail = UserEmail.create({
email: identifier,
main: true,
});
await userEmail.validate(true);
await this.auth(req, res, true, identifier);
}
private async auth(req: Request, res: Response, isRegistration: boolean, email: string): Promise<void> {
const geo = geoip.lookup(req.ip);
const actionType = isRegistration ? AuthMagicLinkActionType.REGISTER : AuthMagicLinkActionType.LOGIN;
if (isRegistration) {
const usernameValidator = new Validator();
if (ModelFactory.get(User).hasComponent(UserNameComponent)) usernameValidator.defined();
await Validator.validate({
email: new Validator().defined().unique(UserEmail, 'email'),
name: usernameValidator,
}, {
email: email,
name: req.body.name,
});
}
req.getSession().wantsSessionPersistence = !!req.body.persist_session || isRegistration;
await MagicLinkController.sendMagicLink(
this.app,
req.getSession().id,
actionType,
Controller.route('auth', undefined, {
redirect_uri: req.getIntendedUrl() || undefined,
}),
email,
this.magicLinkMailTemplate,
{
type: actionType,
ip: req.ip,
geo: geo ? `${geo.city}, ${geo.country}` : 'Unknown location',
},
{
username: req.body.name,
},
);
res.redirect(Controller.route('magic_link_lobby', undefined, {
redirect_uri: req.getIntendedUrl(),
}));
}
}

View File

@ -0,0 +1,258 @@
import Controller from "../../Controller";
import {Request, Response} from "express";
import MagicLinkWebSocketListener from "./MagicLinkWebSocketListener";
import {BadRequestError, NotFoundHttpError} from "../../HttpError";
import Throttler from "../../Throttler";
import Mail, {MailTemplate} from "../../mail/Mail";
import MagicLink from "../models/MagicLink";
import config from "config";
import Application from "../../Application";
import {ParsedUrlQueryInput} from "querystring";
import NunjucksComponent from "../../components/NunjucksComponent";
import User from "../models/User";
import AuthComponent, {AuthMiddleware} from "../AuthComponent";
import {AuthError, PendingApprovalAuthError, RegisterCallback} from "../AuthGuard";
import UserEmail from "../models/UserEmail";
import AuthMagicLinkActionType from "./AuthMagicLinkActionType";
import {QueryVariable} from "../../db/MysqlConnectionManager";
import UserNameComponent from "../models/UserNameComponent";
import MagicLinkUserNameComponent from "../models/MagicLinkUserNameComponent";
import {logger} from "../../Logger";
export default class MagicLinkController<A extends Application> extends Controller {
public static async sendMagicLink(
app: Application,
sessionId: string,
actionType: string,
original_url: string,
email: string,
mailTemplate: MailTemplate,
data: ParsedUrlQueryInput,
magicLinkData: Record<string, QueryVariable> = {},
): Promise<void> {
Throttler.throttle('magic_link', process.env.NODE_ENV === 'test' ? 10 : 2, MagicLink.validityPeriod(),
sessionId, 0, 0);
Throttler.throttle('magic_link', 1, MagicLink.validityPeriod(),
email, 0, 0);
const link = MagicLink.create(Object.assign(magicLinkData, {
session_id: sessionId,
action_type: actionType,
original_url: original_url,
}));
const token = await link.generateToken(email);
await link.save();
// Send email
await new Mail(app.as(NunjucksComponent).getEnvironment(), mailTemplate, Object.assign(data, {
link: `${config.get<string>('public_url')}${Controller.route('magic_link', undefined, {
id: link.id,
token: token,
})}`,
})).send(email);
}
public static async checkAndAuth(req: Request, res: Response, magicLink: MagicLink): Promise<User | null> {
const session = req.getSession();
if (magicLink.getOrFail('session_id') !== session.id) throw new BadOwnerMagicLink();
if (!await magicLink.isAuthorized()) throw new UnauthorizedMagicLink();
if (!await magicLink.isValid()) throw new InvalidMagicLink();
// Auth
try {
return await req.as(AuthMiddleware).getAuthGuard().authenticateOrRegister(
session, magicLink, !!session.wantsSessionPersistence, undefined, async (connection, user) => {
const userNameComponent = user.asOptional(UserNameComponent);
if (userNameComponent) userNameComponent.name = magicLink.as(MagicLinkUserNameComponent).username;
return [];
}, async (connection, user) => {
const callbacks: RegisterCallback[] = [];
const userEmail = UserEmail.create({
user_id: user.id,
email: magicLink.getOrFail('email'),
});
await userEmail.save(connection, c => callbacks.push(c));
user.main_email_id = userEmail.id;
await user.save(connection, c => callbacks.push(c));
return callbacks;
});
} catch (e) {
if (e instanceof PendingApprovalAuthError) {
res.format({
json: () => {
res.json({
'status': 'warning',
'message': `Your account is pending review. You'll receive an email once you're approved.`,
});
},
html: () => {
req.flash('warning', `Your account is pending review. You'll receive an email once you're approved.`);
res.redirect('/');
},
});
return null;
} else {
throw e;
}
}
}
protected readonly magicLinkWebsocketPath: string;
public constructor(magicLinkWebsocketListener: MagicLinkWebSocketListener<A>) {
super();
this.magicLinkWebsocketPath = magicLinkWebsocketListener.path();
}
public getRoutesPrefix(): string {
return '/magic';
}
public routes(): void {
this.get('/lobby', this.getLobby, 'magic_link_lobby');
this.get('/link', this.getMagicLink, 'magic_link');
}
protected async getLobby(req: Request, res: Response): Promise<void> {
const link = await MagicLink.select()
.where('session_id', req.getSession().id)
.sortBy('authorized')
.where('used', 0)
.first();
if (!link) {
throw new NotFoundHttpError('magic link', req.url);
}
if (!await link.isValid()) {
req.flash('error', 'This magic link has expired. Please try again.');
res.redirect(link.getOrFail('original_url'));
return;
}
if (await link.isAuthorized()) {
link.use();
await link.save();
await this.performAction(link, req, res);
return;
}
res.render('magic_link_lobby', {
email: link.getOrFail('email'),
type: link.getOrFail('action_type'),
validUntil: link.getExpirationDate().getTime(),
websocketUrl: config.get<string>('public_websocket_url') + this.magicLinkWebsocketPath,
});
}
protected async getMagicLink(req: Request, res: Response): Promise<void> {
const id = parseInt(<string>req.query.id);
const token = <string>req.query.token;
if (!id || !token)
throw new BadRequestError('Need parameters id, token.', 'Please try again.', req.originalUrl);
let success = true;
let err;
const magicLink = await MagicLink.getById<MagicLink>(id);
if (!magicLink) {
res.status(404);
err = `Couldn't find this magic link. Perhaps it has already expired.`;
success = false;
} else if (!await magicLink.isAuthorized()) {
err = await magicLink.verifyToken(token);
if (err === null) {
// Validation success, authenticate the user
magicLink.authorize();
await magicLink.save();
this.getApp().as<MagicLinkWebSocketListener<A>>(MagicLinkWebSocketListener)
.refreshMagicLink(magicLink.getOrFail('session_id'));
}
}
res.render('magic_link', {
magicLink: magicLink,
err: err,
success: success && err === null,
});
}
protected async performAction(magicLink: MagicLink, req: Request, res: Response): Promise<void> {
switch (magicLink.getOrFail('action_type')) {
case AuthMagicLinkActionType.LOGIN:
case AuthMagicLinkActionType.REGISTER: {
await MagicLinkController.checkAndAuth(req, res, magicLink);
const authGuard = this.getApp().as(AuthComponent).getAuthGuard();
const proofs = await authGuard.getProofsForSession(req.getSession());
const user = await proofs[0]?.getResource();
if (!res.headersSent && user) {
// Auth success
req.flash('success', `Authentication success. Welcome, ${user.name}!`);
res.redirect(req.getIntendedUrl() || Controller.route('home'));
}
break;
}
case AuthMagicLinkActionType.ADD_EMAIL: {
const session = req.getSessionOptional();
if (!session || magicLink.session_id !== session.id) throw new BadOwnerMagicLink();
await magicLink.delete();
const authGuard = this.getApp().as(AuthComponent).getAuthGuard();
const proofs = await authGuard.getProofsForSession(session);
const user = await proofs[0]?.getResource();
if (!user) return;
const email = await magicLink.getOrFail('email');
if (await UserEmail.select().with('user').where('email', email).first()) {
req.flash('error', 'An account already exists with this email address.' +
' Please first remove it there before adding it here.');
res.redirect(Controller.route('account'));
return;
}
const userEmail = UserEmail.create({
user_id: user.id,
email: email,
main: false,
});
await userEmail.save();
if (!user.main_email_id) {
user.main_email_id = userEmail.id;
await user.save();
}
req.flash('success', `Email address ${userEmail.email} successfully added.`);
res.redirect(Controller.route('account'));
break;
}
default:
logger.warn('Unknown magic link action type ' + magicLink.action_type);
break;
}
}
}
export class BadOwnerMagicLink extends AuthError {
public constructor() {
super(`This magic link doesn't belong to this session.`);
}
}
export class UnauthorizedMagicLink extends AuthError {
public constructor() {
super(`This magic link is unauthorized.`);
}
}
export class InvalidMagicLink extends AuthError {
public constructor() {
super(`This magic link is invalid.`);
}
}

View File

@ -0,0 +1,70 @@
import WebSocket from "ws";
import {IncomingMessage} from "http";
import WebSocketListener from "../../WebSocketListener";
import MagicLink from "../models/MagicLink";
import Application from "../../Application";
import {Session} from "express-session";
export default class MagicLinkWebSocketListener<A extends Application> extends WebSocketListener<A> {
private readonly connections: { [p: string]: (() => void)[] | undefined } = {};
public refreshMagicLink(sessionId: string): void {
const fs = this.connections[sessionId];
if (fs) {
fs.forEach(f => f());
}
}
public async handle(socket: WebSocket, request: IncomingMessage, session: Session | null): Promise<void> {
// Drop if requested without session
if (!session) {
socket.close(1002, 'Session is required for this request.');
return;
}
// Refuse any incoming data
socket.on('message', () => {
socket.close(1003);
});
// Get magic link
const magicLink = await MagicLink.select()
.where('session_id', session.id)
.sortBy('authorized')
.first();
// Refresh if immediately applicable
if (!magicLink || !await magicLink.isValid() || await magicLink.isAuthorized()) {
socket.send('refresh');
socket.close(1000);
return;
}
const validityTimeout = setTimeout(() => {
socket.send('refresh');
socket.close(1000);
}, magicLink.getExpirationDate().getTime() - new Date().getTime());
const f = () => {
clearTimeout(validityTimeout);
socket.send('refresh');
socket.close(1000);
};
socket.on('close', () => {
const connections = this.connections[session.id];
if (connections) {
this.connections[session.id] = connections.filter(f => f !== f);
if (connections.length === 0) delete this.connections[session.id];
}
});
let connections = this.connections[session.id];
if (!connections) connections = this.connections[session.id] = [];
connections.push(f);
}
public path(): string {
return '/magic-link';
}
}

View File

@ -0,0 +1,12 @@
import Migration from "../../db/Migration";
export default class MakeMagicLinksSessionNotUniqueMigration extends Migration {
public async install(): Promise<void> {
await this.query(`ALTER TABLE magic_links
DROP INDEX IF EXISTS session_id`);
}
public async rollback(): Promise<void> {
await this.query('ALTER TABLE magic_links ADD CONSTRAINT UNIQUE (session_id)');
}
}

View File

@ -0,0 +1,18 @@
import Migration from "../../db/Migration";
import ModelFactory from "../../db/ModelFactory";
import User from "../models/User";
import UserApprovedComponent from "../models/UserApprovedComponent";
export default class AddApprovedFieldToUsersTableMigration extends Migration {
public async install(): Promise<void> {
await this.query('ALTER TABLE users ADD COLUMN approved BOOLEAN NOT NULL DEFAULT 0');
}
public async rollback(): Promise<void> {
await this.query('ALTER TABLE users DROP COLUMN approved');
}
public registerModels(): void {
ModelFactory.get(User).addComponent(UserApprovedComponent);
}
}

View File

@ -0,0 +1,25 @@
import Migration from "../../db/Migration";
import ModelFactory from "../../db/ModelFactory";
import User from "../models/User";
import UserNameComponent from "../models/UserNameComponent";
import MagicLink from "../models/MagicLink";
import MagicLinkUserNameComponent from "../models/MagicLinkUserNameComponent";
export default class AddNameToUsersMigration extends Migration {
public async install(): Promise<void> {
await this.query(`ALTER TABLE users
ADD COLUMN name VARCHAR(64) UNIQUE NOT NULL`);
await this.query(`ALTER TABLE magic_links
ADD COLUMN username VARCHAR(64) DEFAULT NULL`);
}
public async rollback(): Promise<void> {
await this.query('ALTER TABLE users DROP COLUMN name');
await this.query('ALTER TABLE magic_links DROP COLUMN username');
}
public registerModels(): void {
ModelFactory.get(User).addComponent(UserNameComponent);
ModelFactory.get(MagicLink).addComponent(MagicLinkUserNameComponent);
}
}

View File

@ -0,0 +1,39 @@
import Migration from "../../db/Migration";
import ModelFactory from "../../db/ModelFactory";
import User from "../models/User";
import UserEmail from "../models/UserEmail";
export default class CreateUsersAndUserEmailsTableMigration extends Migration {
public async install(): Promise<void> {
await this.query(`CREATE TABLE users
(
id INT NOT NULL AUTO_INCREMENT,
main_email_id INT,
is_admin BOOLEAN NOT NULL DEFAULT false,
created_at DATETIME NOT NULL DEFAULT NOW(),
updated_at DATETIME NOT NULL DEFAULT NOW(),
PRIMARY KEY (id)
)`);
await this.query(`CREATE TABLE user_emails
(
id INT NOT NULL AUTO_INCREMENT,
user_id INT NOT NULL,
email VARCHAR(254) UNIQUE NOT NULL,
created_at DATETIME NOT NULL DEFAULT NOW(),
PRIMARY KEY (id),
FOREIGN KEY user_fk (user_id) REFERENCES users (id) ON DELETE CASCADE
)`);
await this.query(`ALTER TABLE users
ADD FOREIGN KEY main_user_email_fk (main_email_id) REFERENCES user_emails (id) ON DELETE SET NULL`);
}
public async rollback(): Promise<void> {
await this.query('DROP TABLE user_emails');
await this.query('DROP TABLE users');
}
public registerModels(): void {
ModelFactory.register(User);
ModelFactory.register(UserEmail);
}
}

View File

@ -0,0 +1,14 @@
import Migration from "../../db/Migration";
/**
* @deprecated - TODO may be remove at next major version >= 0.24, replace with DummyMigration.
*/
export default class DropNameFromUsers extends Migration {
public async install(): Promise<void> {
await this.query('ALTER TABLE users DROP COLUMN name');
}
public async rollback(): Promise<void> {
await this.query('ALTER TABLE users ADD COLUMN name VARCHAR(64)');
}
}

View File

@ -0,0 +1,27 @@
import Migration from "../../db/Migration";
/**
* @deprecated - TODO may be remove at next major version >= 0.24, replace with DummyMigration.
*/
export default class FixUserMainEmailRelation extends Migration {
public async install(): Promise<void> {
await this.query(`ALTER TABLE users
ADD COLUMN main_email_id INT,
ADD FOREIGN KEY main_user_email_fk (main_email_id) REFERENCES user_emails (id)`);
await this.query(`UPDATE users u LEFT JOIN user_emails ue ON u.id = ue.user_id
SET u.main_email_id=ue.id
WHERE ue.main = true`);
await this.query(`ALTER TABLE user_emails
DROP COLUMN main`);
}
public async rollback(): Promise<void> {
await this.query(`ALTER TABLE user_emails
ADD COLUMN main BOOLEAN DEFAULT false`);
await this.query(`UPDATE user_emails ue LEFT JOIN users u ON ue.id = u.main_email_id
SET ue.main = true`);
await this.query(`ALTER TABLE users
DROP FOREIGN KEY main_user_email_fk,
DROP COLUMN main_email_id`);
}
}

View File

@ -0,0 +1,106 @@
import crypto from "crypto";
import config from "config";
import Model from "../../db/Model";
import AuthProof from "../AuthProof";
import User from "./User";
import argon2 from "argon2";
import UserEmail from "./UserEmail";
import {EMAIL_REGEX} from "../../db/Validator";
export default class MagicLink extends Model implements AuthProof<User> {
public static validityPeriod(): number {
return config.get<number>('magic_link.validity_period') * 1000;
}
public readonly id?: number = undefined;
public readonly session_id?: string = undefined;
private email?: string = undefined;
private token?: string = undefined;
public readonly action_type?: string = undefined;
public readonly original_url?: string = undefined;
private generated_at?: Date = undefined;
private authorized: boolean = false;
private used: boolean = false;
protected init(): void {
this.setValidation('session_id').defined().length(32);
this.setValidation('email').defined().regexp(EMAIL_REGEX);
this.setValidation('token').defined().length(96);
this.setValidation('action_type').defined().maxLength(64);
this.setValidation('original_url').defined().maxLength(1745);
this.setValidation('authorized').defined();
this.setValidation('used').defined();
}
public async getResource(): Promise<User | null> {
const email = await UserEmail.select()
.with('user')
.where('email', await this.getOrFail('email'))
.first();
return email ? await email.user.get() : null;
}
public async revoke(): Promise<void> {
await this.delete();
}
public async isValid(): Promise<boolean> {
return await this.isAuthorized() ||
new Date().getTime() < this.getExpirationDate().getTime();
}
public async isAuthorized(): Promise<boolean> {
return this.authorized;
}
public authorize(): void {
this.authorized = true;
}
public isUsed(): boolean {
return this.used;
}
public use(): void {
this.used = true;
}
public async generateToken(email: string): Promise<string> {
const rawToken = crypto.randomBytes(48).toString('base64'); // Raw token length = 64
this.email = email;
this.generated_at = new Date();
this.token = await argon2.hash(rawToken, {
timeCost: 10,
memoryCost: 4096,
parallelism: 4,
hashLength: 32,
});
return rawToken;
}
/**
* @returns {@code null} if the token is valid, an error {@code string} otherwise.
*/
public async verifyToken(tokenGuess: string): Promise<string | null> {
// There is no token
if (this.token === undefined || this.generated_at === undefined)
return 'This token was not generated.';
// Token has expired
if (new Date().getTime() - this.generated_at.getTime() > MagicLink.validityPeriod())
return 'This token has expired.';
// Token is invalid
if (!await argon2.verify(this.token, tokenGuess))
return 'This token is invalid.';
return null;
}
public getExpirationDate(): Date {
if (!this.generated_at) return new Date();
return new Date(this.generated_at.getTime() + MagicLink.validityPeriod());
}
}

View File

@ -0,0 +1,12 @@
import ModelComponent from "../../db/ModelComponent";
import MagicLink from "./MagicLink";
import {USERNAME_REGEXP} from "./UserNameComponent";
import User from "./User";
export default class MagicLinkUserNameComponent extends ModelComponent<MagicLink> {
public readonly username?: string = undefined;
protected init(): void {
this.setValidation('username').acceptUndefined().between(3, 64).regexp(USERNAME_REGEXP).unique(User, 'name');
}
}

53
src/auth/models/User.ts Normal file
View File

@ -0,0 +1,53 @@
import Model from "../../db/Model";
import MysqlConnectionManager from "../../db/MysqlConnectionManager";
import AddApprovedFieldToUsersTableMigration from "../migrations/AddApprovedFieldToUsersTableMigration";
import config from "config";
import {ManyModelRelation} from "../../db/ModelRelation";
import UserEmail from "./UserEmail";
import UserApprovedComponent from "./UserApprovedComponent";
import UserNameComponent from "./UserNameComponent";
export default class User extends Model {
public static isApprovalMode(): boolean {
return config.get<boolean>('approval_mode') &&
MysqlConnectionManager.hasMigration(AddApprovedFieldToUsersTableMigration);
}
public readonly id?: number = undefined;
public main_email_id?: number = undefined;
public is_admin: boolean = false;
public created_at?: Date = undefined;
public updated_at?: Date = undefined;
public readonly emails = new ManyModelRelation(this, UserEmail, {
localKey: 'id',
foreignKey: 'user_id',
});
public readonly mainEmail = this.emails.cloneReduceToOne().constraint(q => q.where('id', this.main_email_id));
protected init(): void {
this.setValidation('name').acceptUndefined().between(3, 64);
this.setValidation('main_email_id').acceptUndefined().exists(UserEmail, 'id');
if (User.isApprovalMode()) {
this.setValidation('approved').defined();
}
this.setValidation('is_admin').defined();
}
public isApproved(): boolean {
return !User.isApprovalMode() || this.as(UserApprovedComponent).approved;
}
protected getPersonalInfoFields(): { name: string, value: string }[] {
const fields: { name: string, value: string }[] = [];
const nameComponent = this.asOptional(UserNameComponent);
if (nameComponent && nameComponent.name) {
fields.push({
name: 'Name',
value: nameComponent.name,
});
}
return fields;
}
}

View File

@ -0,0 +1,6 @@
import ModelComponent from "../../db/ModelComponent";
import User from "./User";
export default class UserApprovedComponent extends ModelComponent<User> {
public approved: boolean = false;
}

View File

@ -0,0 +1,22 @@
import User from "./User";
import Model from "../../db/Model";
import {OneModelRelation} from "../../db/ModelRelation";
import {EMAIL_REGEX} from "../../db/Validator";
export default class UserEmail extends Model {
public readonly id?: number = undefined;
public user_id?: number = undefined;
public readonly email?: string = undefined;
public created_at?: Date = undefined;
public readonly user = new OneModelRelation(this, User, {
localKey: 'user_id',
foreignKey: 'id',
});
protected init(): void {
this.setValidation('user_id').acceptUndefined().exists(User, 'id');
this.setValidation('email').defined().regexp(EMAIL_REGEX).unique(this);
this.setValidation('main').defined();
}
}

View File

@ -0,0 +1,12 @@
import ModelComponent from "../../db/ModelComponent";
import User from "../models/User";
export const USERNAME_REGEXP = /^[0-9a-z_-]+$/;
export default class UserNameComponent extends ModelComponent<User> {
public name?: string = undefined;
public init(): void {
this.setValidation('name').defined().between(3, 64).regexp(USERNAME_REGEXP).unique(this._model);
}
}

View File

@ -0,0 +1,20 @@
import Migration from "../../db/Migration";
import ModelFactory from "../../db/ModelFactory";
import User from "../models/User";
import UserPasswordComponent from "./UserPasswordComponent";
export default class AddPasswordToUsersMigration extends Migration {
public async install(): Promise<void> {
await this.query(`ALTER TABLE users
ADD COLUMN password VARCHAR(128) DEFAULT NULL`);
}
public async rollback(): Promise<void> {
await this.query(`ALTER TABLE users
DROP COLUMN password`);
}
public registerModels(): void {
ModelFactory.get(User).addComponent(UserPasswordComponent);
}
}

View File

@ -0,0 +1,139 @@
import AuthMethod from "../AuthMethod";
import PasswordAuthProof from "./PasswordAuthProof";
import User from "../models/User";
import {Request, Response} from "express";
import UserEmail from "../models/UserEmail";
import AuthComponent from "../AuthComponent";
import Application from "../../Application";
import Throttler from "../../Throttler";
import {AuthError, PendingApprovalAuthError, RegisterCallback} from "../AuthGuard";
import Validator, {InvalidFormatValidationError} from "../../db/Validator";
import Controller from "../../Controller";
import UserPasswordComponent from "./UserPasswordComponent";
import UserNameComponent, {USERNAME_REGEXP} from "../models/UserNameComponent";
import ModelFactory from "../../db/ModelFactory";
import {ServerError} from "../../HttpError";
import {Session} from "express-session";
export default class PasswordAuthMethod implements AuthMethod<PasswordAuthProof> {
public constructor(
protected readonly app: Application,
) {
}
public getName(): string {
return 'password';
}
public getWeightForRequest(req: Request): number {
return !req.body.identifier || !req.body.password || req.body.password.length === 0 ?
0 :
2;
}
public async findUserByIdentifier(identifier: string): Promise<User | null> {
const query = UserEmail.select()
.with('user')
.where('email', identifier);
const user = (await query
.first())?.user.getOrFail();
if (user) return user;
if (ModelFactory.get(User).hasComponent(UserNameComponent)) {
return await User.select().where('name', identifier).first();
}
return null;
}
public async getProofsForSession(session: Session): Promise<PasswordAuthProof[]> {
const proof = PasswordAuthProof.getProofForSession(session);
return proof ? [proof] : [];
}
public async attemptLogin(req: Request, res: Response, user: User): Promise<void> {
const passwordAuthProof = PasswordAuthProof.createProofForLogin(req.getSession());
passwordAuthProof.setResource(user);
await passwordAuthProof.authorize(req.body.password);
try {
await this.app.as(AuthComponent).getAuthGuard().authenticateOrRegister(
req.getSession(),
passwordAuthProof,
!!req.body.persist_session,
);
} catch (e) {
if (e instanceof AuthError) {
Throttler.throttle('login_failed_attempts_user', 3, 3 * 60 * 1000, // 3min
<string>user.getOrFail('name'), 1000, 60 * 1000); // 1min
Throttler.throttle('login_failed_attempts_ip', 50, 60 * 1000, // 1min
req.ip, 1000, 3600 * 1000); // 1h
if (e instanceof PendingApprovalAuthError) {
req.flash('error', 'Your account is still being reviewed.');
res.redirect(Controller.route('auth'));
return;
} else {
const err = new InvalidFormatValidationError('Invalid password.');
err.thingName = 'password';
throw err;
}
} else {
throw e;
}
}
req.flash('success', `Welcome, ${user.name}.`);
res.redirect(req.getIntendedUrl() || Controller.route('home'));
}
public async attemptRegister(req: Request, res: Response, identifier: string): Promise<void> {
if (!ModelFactory.get(User).hasComponent(UserNameComponent))
throw new ServerError('Cannot register with password without UserNameComponent.');
Throttler.throttle('register_password', 10, 30000, req.ip);
req.body.identifier = identifier;
await Validator.validate({
identifier: new Validator().defined().between(3, 64).regexp(USERNAME_REGEXP).unique(User, 'name'),
password: new Validator().defined().minLength(UserPasswordComponent.PASSWORD_MIN_LENGTH),
password_confirmation: new Validator().defined().sameAs('password', req.body.password),
terms: new Validator().defined(),
}, req.body);
const passwordAuthProof = PasswordAuthProof.createAuthorizedProofForRegistration(req.getSession());
try {
await this.app.as(AuthComponent).getAuthGuard().authenticateOrRegister(req.getSession(), passwordAuthProof,
true, undefined, async (connection, user) => {
const callbacks: RegisterCallback[] = [];
// Password
await user.as(UserPasswordComponent).setPassword(req.body.password);
// Username
user.as(UserNameComponent).name = req.body.identifier;
return callbacks;
}, async (connection, user) => {
passwordAuthProof.setResource(user);
return [];
});
} catch (e) {
if (e instanceof PendingApprovalAuthError) {
req.flash('info', `Your account was successfully created and is pending review from an administrator.`);
res.redirect(Controller.route('auth'));
return;
} else {
throw e;
}
}
const user = await passwordAuthProof.getResource();
req.flash('success', `Your account was successfully created! Welcome, ${user?.as(UserNameComponent).name}.`);
res.redirect(req.getIntendedUrl() || Controller.route('home'));
}
}

View File

@ -0,0 +1,89 @@
import AuthProof from "../AuthProof";
import User from "../models/User";
import UserPasswordComponent from "./UserPasswordComponent";
import {Session, SessionData} from "express-session";
export default class PasswordAuthProof implements AuthProof<User> {
public static getProofForSession(session: Session & Partial<SessionData>): PasswordAuthProof | null {
return session.authPasswordProof ? new PasswordAuthProof(session) : null;
}
public static createAuthorizedProofForRegistration(session: Session): PasswordAuthProof {
const proofForSession = new PasswordAuthProof(session);
proofForSession.authorized = true;
proofForSession.forRegistration = true;
proofForSession.save();
return proofForSession;
}
public static createProofForLogin(session: Session & Partial<SessionData>): PasswordAuthProof {
return new PasswordAuthProof(session);
}
private readonly session: Session & Partial<SessionData>;
private authorized: boolean;
private forRegistration: boolean = false;
private userId: number | null;
private userPassword: UserPasswordComponent | null = null;
private constructor(session: Session & Partial<SessionData>) {
this.session = session;
this.authorized = session.authPasswordProof?.authorized || false;
this.forRegistration = session.authPasswordProof?.forRegistration || false;
this.userId = session.authPasswordProof?.userId || null;
}
public async getResource(): Promise<User | null> {
if (typeof this.userId !== 'number') return null;
return await User.getById(this.userId);
}
public setResource(user: User): void {
this.userId = user.getOrFail('id');
this.save();
}
public async isAuthorized(): Promise<boolean> {
return this.authorized;
}
public async isValid(): Promise<boolean> {
return (this.forRegistration || Boolean(await this.getResource())) &&
await this.isAuthorized();
}
public async revoke(): Promise<void> {
this.session.authPasswordProof = undefined;
}
private async getUserPassword(): Promise<UserPasswordComponent | null> {
if (!this.userPassword) {
this.userPassword = (await User.getById(this.userId))?.as(UserPasswordComponent) || null;
}
return this.userPassword;
}
public async authorize(passwordGuess: string): Promise<boolean> {
const password = await this.getUserPassword();
if (!password || !await password.verifyPassword(passwordGuess)) return false;
this.authorized = true;
this.save();
return true;
}
private save() {
this.session.authPasswordProof = {
authorized: this.authorized,
forRegistration: this.forRegistration,
userId: this.userId,
};
}
}
export type PasswordAuthProofSessionData = {
authorized: boolean,
forRegistration: boolean,
userId: number | null,
};

View File

@ -0,0 +1,39 @@
import argon2, {argon2id} from "argon2";
import ModelComponent from "../../db/ModelComponent";
import User from "../models/User";
import Validator from "../../db/Validator";
export default class UserPasswordComponent extends ModelComponent<User> {
public static readonly PASSWORD_MIN_LENGTH = 12;
private password?: string = undefined;
public init(): void {
this.setValidation('password').acceptUndefined().maxLength(128);
}
public async setPassword(rawPassword: string, fieldName: string = 'password'): Promise<void> {
await new Validator<string>()
.defined()
.minLength(UserPasswordComponent.PASSWORD_MIN_LENGTH)
.maxLength(512)
.execute(fieldName, rawPassword, true);
this.password = await argon2.hash(rawPassword, {
timeCost: 10,
memoryCost: 65536,
parallelism: 4,
type: argon2id,
hashLength: 32,
});
}
public async verifyPassword(passwordGuess: string): Promise<boolean> {
if (!this.password || !passwordGuess) return false;
return await argon2.verify(this.password, passwordGuess);
}
public hasPassword(): boolean {
return typeof this.password === 'string';
}
}

View File

@ -0,0 +1,55 @@
import {Router} from "express";
import config from "config";
import * as child_process from "child_process";
import ApplicationComponent from "../ApplicationComponent";
import {ForbiddenHttpError} from "../HttpError";
import {logger} from "../Logger";
export default class AutoUpdateComponent extends ApplicationComponent {
private static async runCommand(command: string): Promise<void> {
logger.info(`> ${command}`);
logger.info(child_process.execSync(command).toString());
}
public async checkSecuritySettings(): Promise<void> {
this.checkSecurityConfigField('gitlab_webhook_token');
}
public async init(router: Router): Promise<void> {
router.post('/update/push.json', (req, res) => {
const token = req.header('X-Gitlab-Token');
if (!token || token !== config.get<string>('gitlab_webhook_token'))
throw new ForbiddenHttpError('Invalid token', req.url);
this.update(req.body).catch(logger.error);
res.json({
'status': 'ok',
});
});
}
private async update(params: { [p: string]: unknown }) {
logger.info('Update params:', params);
try {
logger.info('Starting auto update...');
// Fetch
await AutoUpdateComponent.runCommand(`git pull`);
// Install new dependencies
await AutoUpdateComponent.runCommand(`yarn install --production=false`);
// Process assets
await AutoUpdateComponent.runCommand(`yarn dist`);
// Stop app
await this.getApp().stop();
logger.info('Success!');
} catch (e) {
logger.error(e, 'An error occurred while running the auto update.');
}
}
}

View File

@ -1,54 +1,70 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, Router} from "express"; import {Request, Router} from "express";
import crypto from "crypto"; import crypto from "crypto";
import {BadRequestError} from "../HttpError"; import {BadRequestError} from "../HttpError";
import {AuthMiddleware} from "../auth/AuthComponent";
import {Session, SessionData} from "express-session";
export default class CsrfProtectionComponent extends ApplicationComponent<void> { export default class CsrfProtectionComponent extends ApplicationComponent {
public async start(app: Express, router: Router): Promise<void> { private static readonly excluders: ((req: Request) => boolean)[] = [];
router.use((req, res, next) => {
if (!req.session) { public static getCsrfToken(session: Session & Partial<SessionData>): string {
throw new Error('Session is unavailable.'); if (typeof session.csrf !== 'string') {
session.csrf = crypto.randomBytes(64).toString('base64');
}
return session.csrf;
}
public static addExcluder(excluder: (req: Request) => boolean): void {
this.excluders.push(excluder);
}
public async handle(router: Router): Promise<void> {
router.use(async (req, res, next) => {
for (const excluder of CsrfProtectionComponent.excluders) {
if (excluder(req)) return next();
} }
res.locals.getCSRFToken = () => { const session = req.getSession();
if (typeof req.session!.csrf !== 'string') { res.locals.getCsrfToken = () => {
req.session!.csrf = crypto.randomBytes(64).toString('base64'); return CsrfProtectionComponent.getCsrfToken(session);
}
return req.session!.csrf;
}; };
if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) { if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) {
if (req.session.csrf === undefined) { try {
throw new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`); if ((await req.as(AuthMiddleware).getAuthGuard().getProofsForRequest(req)).length === 0) {
} else if (req.body.csrf === undefined) { if (session.csrf === undefined) {
throw new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`); return next(new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`));
} else if (req.session.csrf !== req.body.csrf) { } else if (req.body.csrf === undefined) {
throw new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`); return next(new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`));
} else if (session.csrf !== req.body.csrf) {
return next(new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`));
}
}
} catch (e) {
return next(e);
} }
} }
next(); next();
}); });
} }
public async stop(): Promise<void> {
}
} }
class InvalidCsrfTokenError extends BadRequestError { class InvalidCsrfTokenError extends BadRequestError {
constructor(url: string, details: string, cause?: Error) { public constructor(url: string, details: string, cause?: Error) {
super( super(
`Invalid CSRF token`, `Invalid CSRF token`,
`${details} We can't process this request. Please try again.`, `${details} We can't process this request. Please try again.`,
url, url,
cause cause,
); );
} }
get name(): string { public get name(): string {
return 'Invalid CSRF Token'; return 'Invalid CSRF Token';
} }
get errorCode(): number { public get errorCode(): number {
return 401; return 401;
} }
} }

View File

@ -1,35 +1,60 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import express, {Express, Router} from "express"; import express, {Express, Router} from "express";
import Logger from "../Logger"; import {logger, preventContextCorruptionMiddleware} from "../Logger";
import {Server} from "http"; import {Server} from "http";
import compression from "compression";
import Middleware from "../Middleware";
import {Type} from "../Utils";
export default class ExpressAppComponent extends ApplicationComponent<void> { export default class ExpressAppComponent extends ApplicationComponent {
private readonly addr: string;
private readonly port: number; private readonly port: number;
private server?: Server; private server?: Server;
private expressApp?: Express;
constructor(port: number) { public constructor(addr: string, port: number) {
super(); super();
this.addr = addr;
this.port = port; this.port = port;
} }
public async start(app: Express, router: Router): Promise<void> { public async start(app: Express): Promise<void> {
this.server = app.listen(this.port, 'localhost', () => { this.server = app.listen(this.port, this.addr, () => {
Logger.info(`Web server running on localhost:${this.port}.`); logger.info(`Web server running on ${this.addr}:${this.port}.`);
}); });
router.use(express.json()); // Proxy
router.use(express.urlencoded()); app.set('trust proxy', 'loopback');
this.expressApp = app;
}
public async init(router: Router): Promise<void> {
router.use(preventContextCorruptionMiddleware(express.json({
type: req => req.headers['content-type']?.match(/^application\/(.+\+)?json$/),
})));
router.use(preventContextCorruptionMiddleware(express.urlencoded({
extended: true,
})));
// gzip
router.use(compression());
router.use((req, res, next) => { router.use((req, res, next) => {
req.models = {}; req.middlewares = [];
req.modelCollections = {}; req.as = <M extends Middleware>(type: Type<M>): M => {
const middleware = req.middlewares.find(m => m.constructor === type);
if (!middleware) throw new Error('Middleware ' + type.name + ' not present in this request.');
return middleware as M;
};
next(); next();
}); });
} }
public async stop(): Promise<void> { public async stop(): Promise<void> {
if (this.server) { const server = this.server;
await this.close('Webserver', this.server, this.server.close); if (server) {
await this.close('Webserver', callback => server.close(callback));
} }
} }
@ -37,4 +62,9 @@ export default class ExpressAppComponent extends ApplicationComponent<void> {
if (!this.server) throw 'Server was not initialized.'; if (!this.server) throw 'Server was not initialized.';
return this.server; return this.server;
} }
public getExpressApp(): Express {
if (!this.expressApp) throw new Error('Express app not initialized.');
return this.expressApp;
}
} }

View File

@ -1,16 +1,10 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, Router} from "express"; import {Router} from "express";
export default class FormHelperComponent extends ApplicationComponent<void> { export default class FormHelperComponent extends ApplicationComponent {
public async start(app: Express, router: Router): Promise<void> { public async init(router: Router): Promise<void> {
router.use((req, res, next) => { router.use((req, res, next) => {
if (!req.session) { let _validation: unknown | null;
throw new Error('Session is unavailable.');
}
res.locals.query = req.query;
let _validation: any = null;
res.locals.validation = () => { res.locals.validation = () => {
if (!_validation) { if (!_validation) {
const v = req.flash('validation'); const v = req.flash('validation');
@ -18,9 +12,9 @@ export default class FormHelperComponent extends ApplicationComponent<void> {
} }
return _validation; return _validation;
} };
let _previousFormData: any = null; let _previousFormData: unknown | null = null;
res.locals.previousFormData = () => { res.locals.previousFormData = () => {
if (!_previousFormData) { if (!_previousFormData) {
const v = req.flash('previousFormData'); const v = req.flash('previousFormData');
@ -41,8 +35,4 @@ export default class FormHelperComponent extends ApplicationComponent<void> {
next(); next();
}); });
} }
public async stop(): Promise<void> {
}
} }

View File

@ -1,21 +1,87 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import onFinished from "on-finished"; import onFinished from "on-finished";
import Logger from "../Logger"; import {logger} from "../Logger";
import {Express, Router} from "express"; import {Request, Response, Router} from "express";
import {HttpError} from "../HttpError";
export default class LogRequestsComponent extends ApplicationComponent<void> { export default class LogRequestsComponent extends ApplicationComponent {
public async start(app: Express, router: Router): Promise<void> { private static fullRequests: boolean = false;
public static logFullHttpRequests(): void {
this.fullRequests = true;
logger.info('Http requests will be logged with more details.');
}
public static logRequest(
req: Request,
res: Response,
err?: unknown,
additionalStr: string = '',
silent: boolean = false,
): string | undefined {
if (LogRequestsComponent.fullRequests) {
const requestObj = JSON.stringify({
ip: req.ip,
host: req.hostname,
method: req.method,
url: req.originalUrl,
headers: req.headers,
query: req.query,
params: req.params,
body: req.body,
files: req.files,
cookies: req.cookies,
sessionId: req.sessionID,
result: {
code: res.statusCode,
},
}, null, 4);
if (err) {
if (err instanceof Error) {
return logger.error(err, requestObj, err).requestId;
} else {
return logger.error(new Error(String(err)), requestObj).requestId;
}
} else {
logger.info(requestObj);
}
} else {
let codeDescription = '';
if (res.statusCode === 301) {
codeDescription = 'Permanent redirect to ' + res.getHeader('location');
} else if (res.statusCode === 302) {
codeDescription = 'Temporary redirect to ' + res.getHeader('location');
}
let logStr = `${req.ip} < ${req.method} ${req.originalUrl} - ${res.statusCode} ${codeDescription}`;
if (err) {
if (err instanceof Error) {
if (silent) {
if (err instanceof HttpError) logStr += ` ${err.errorCode}`;
logStr += ` ${err.name}`;
return logger.info(err.name, logStr).requestId;
} else {
return logger.error(err, logStr, additionalStr, err).requestId;
}
} else {
return logger.error(new Error(String(err)), logStr).requestId;
}
} else {
logger.info(logStr);
}
}
return '';
}
public async init(router: Router): Promise<void> {
router.use((req, res, next) => { router.use((req, res, next) => {
onFinished(res, (err) => { onFinished(res, (err) => {
if (!err) { if (!err) {
Logger.info(`${req.method} ${req.originalUrl} - ${res.statusCode}`); LogRequestsComponent.logRequest(req, res);
} }
}); });
next(); next();
}); });
} }
public async stop(): Promise<void> {
}
} }

View File

@ -1,9 +1,21 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, Router} from "express"; import {Express} from "express";
import Mail from "../Mail"; import Mail from "../mail/Mail";
import config from "config";
import SecurityError from "../SecurityError";
export default class MailComponent extends ApplicationComponent<void> { export default class MailComponent extends ApplicationComponent {
public async start(app: Express, router: Router): Promise<void> {
public async checkSecuritySettings(): Promise<void> {
if (!config.get<boolean>('mail.secure')) {
throw new SecurityError('Cannot set mail.secure (starttls) to false');
}
if (config.get<boolean>('mail.allow_invalid_tls')) {
throw new SecurityError('Cannot set mail.allow_invalid_tls (ignore tls failure) to true');
}
}
public async start(_app: Express): Promise<void> {
await this.prepare('Mail connection', () => Mail.prepare()); await this.prepare('Mail connection', () => Mail.prepare());
} }

View File

@ -1,19 +1,20 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, NextFunction, Request, Response, Router} from "express"; import {NextFunction, Request, Response, Router} from "express";
import {ServiceUnavailableHttpError} from "../HttpError"; import {ServiceUnavailableHttpError} from "../HttpError";
import Application from "../Application"; import Application from "../Application";
import config from "config";
export default class MaintenanceComponent extends ApplicationComponent<void> { export default class MaintenanceComponent extends ApplicationComponent {
private readonly application: Application; private readonly application: Application;
private readonly canServe: () => boolean; private readonly canServe: () => boolean;
constructor(application: Application, canServe: () => boolean) { public constructor(application: Application, canServe: () => boolean) {
super(); super();
this.application = application; this.application = application;
this.canServe = canServe; this.canServe = canServe;
} }
public async start(app: Express, router: Router): Promise<void> { public async handle(router: Router): Promise<void> {
router.use((req: Request, res: Response, next: NextFunction) => { router.use((req: Request, res: Response, next: NextFunction) => {
if (res.headersSent) { if (res.headersSent) {
return next(); return next();
@ -22,19 +23,15 @@ export default class MaintenanceComponent extends ApplicationComponent<void> {
if (!this.application.isReady()) { if (!this.application.isReady()) {
res.header({'Retry-After': 60}); res.header({'Retry-After': 60});
res.locals.refresh_after = 5; res.locals.refresh_after = 5;
throw new ServiceUnavailableHttpError('Watch My Stream is readying up. Please wait a few seconds...'); throw new ServiceUnavailableHttpError(`${config.get('app.name')} is readying up. Please wait a few seconds...`);
} }
if (!this.canServe()) { if (!this.canServe()) {
res.locals.refresh_after = 30; res.locals.refresh_after = 30;
throw new ServiceUnavailableHttpError('Watch My Stream is unavailable due to failure of dependent services.'); throw new ServiceUnavailableHttpError(`${config.get('app.name')} is unavailable due to failure of dependent services.`);
} }
next(); next();
}); });
} }
public async stop(): Promise<void> {
}
} }

View File

@ -1,9 +1,9 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, Router} from "express"; import {Express} from "express";
import MysqlConnectionManager from "../db/MysqlConnectionManager"; import MysqlConnectionManager from "../db/MysqlConnectionManager";
export default class MysqlComponent extends ApplicationComponent<void> { export default class MysqlComponent extends ApplicationComponent {
public async start(app: Express, router: Router): Promise<void> { public async start(_app: Express): Promise<void> {
await this.prepare('Mysql connection', () => MysqlConnectionManager.prepare()); await this.prepare('Mysql connection', () => MysqlConnectionManager.prepare());
} }
@ -12,7 +12,7 @@ export default class MysqlComponent extends ApplicationComponent<void> {
} }
public canServe(): boolean { public canServe(): boolean {
return MysqlConnectionManager.pool !== undefined; return MysqlConnectionManager.isReady();
} }
} }

View File

@ -1,38 +1,95 @@
import nunjucks from "nunjucks"; import nunjucks, {Environment} from "nunjucks";
import config from "config"; import config from "config";
import {Express, Router} from "express"; import {Express, NextFunction, Request, Response, Router} from "express";
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import Controller from "../Controller"; import Controller, {RouteParams} from "../Controller";
import {ServerError} from "../HttpError"; import * as querystring from "querystring";
import {ParsedUrlQueryInput} from "querystring";
import * as util from "util";
import * as path from "path";
import * as fs from "fs";
import {logger} from "../Logger";
import Middleware from "../Middleware";
export default class NunjucksComponent extends ApplicationComponent<void> { export default class NunjucksComponent extends ApplicationComponent {
public async start(app: Express, router: Router): Promise<void> { private readonly viewsPath: string[];
const env = nunjucks.configure('views', { private environment?: Environment;
public constructor(viewsPath: string[] = ['views']) {
super();
this.viewsPath = viewsPath;
}
public async start(app: Express): Promise<void> {
let coreVersion = 'unknown';
const file = fs.existsSync(path.join(__dirname, '../../package.json')) ?
path.join(__dirname, '../../package.json') :
path.join(__dirname, '../package.json');
try {
coreVersion = JSON.parse(fs.readFileSync(file).toString()).version;
} catch (e) {
logger.warn('Couldn\'t determine coreVersion.', e);
}
const opts = {
autoescape: true, autoescape: true,
express: app,
noCache: !config.get('view.cache'), noCache: !config.get('view.cache'),
throwOnUndefined: true, throwOnUndefined: true,
}) };
.addGlobal('route', (route: string, params: { [p: string]: string } | [] = [], absolute: boolean = false) => { this.environment = new nunjucks.Environment([
const path = Controller.route(route, params, absolute); ...this.viewsPath.map(path => new nunjucks.FileSystemLoader(path, opts)),
if (path === null) throw new ServerError(`Route ${route} not found.`); new nunjucks.FileSystemLoader(path.join(__dirname, '../../../views'), opts),
return path; new nunjucks.FileSystemLoader(path.join(__dirname, '../views'), opts),
], opts)
.addGlobal('route', (
route: string,
params: RouteParams = [],
query: ParsedUrlQueryInput = {},
absolute: boolean = false,
): string => {
return Controller.route(route, params, query, absolute);
})
.addGlobal('app_version', this.getApp().getVersion())
.addGlobal('core_version', coreVersion)
.addGlobal('querystring', querystring)
.addGlobal('app', config.get('app'))
.addFilter('dump', (val) => {
return util.inspect(val);
}) })
.addGlobal('app_version', require('../package.json').version)
.addFilter('hex', (v: number) => { .addFilter('hex', (v: number) => {
return v.toString(16); return v.toString(16);
}); });
this.environment.express(app);
app.set('view engine', 'njk'); app.set('view engine', 'njk');
router.use((req, res, next) => {
req.env = env;
res.locals.url = req.url;
res.locals.params = () => req.params;
next();
});
} }
public async stop(): Promise<void> { public async init(_router: Router): Promise<void> {
this.use(NunjucksMiddleware);
} }
public getEnvironment(): Environment {
if (!this.environment) throw new Error('Environment not initialized.');
return this.environment;
}
}
export class NunjucksMiddleware extends Middleware {
private env?: Environment;
protected async handle(req: Request, res: Response, next: NextFunction): Promise<void> {
this.env = this.app.as(NunjucksComponent).getEnvironment();
res.locals.url = req.url;
res.locals.params = req.params;
res.locals.query = req.query;
res.locals.body = req.body;
next();
}
public getEnvironment(): Environment {
if (!this.env) throw new Error('Environment not initialized.');
return this.env;
}
} }

View File

@ -0,0 +1,54 @@
import ApplicationComponent from "../ApplicationComponent";
import {Router} from "express";
import onFinished from "on-finished";
import {logger} from "../Logger";
import SessionComponent from "./SessionComponent";
export default class PreviousUrlComponent extends ApplicationComponent {
public async handle(router: Router): Promise<void> {
router.use((req, res, next) => {
req.getPreviousUrl = () => {
let url = req.header('referer');
if (url) {
if (url.indexOf('://') >= 0) url = '/' + url.split('/').slice(3).join('/');
if (url !== req.originalUrl) return url;
}
if (this.getApp().asOptional(SessionComponent)) {
const session = req.getSessionOptional();
url = session?.previousUrl;
if (url && url !== req.originalUrl) return url;
}
return null;
};
res.locals.getPreviousUrl = req.getPreviousUrl;
req.getIntendedUrl = () => {
return req.query.redirect_uri?.toString() || null;
};
if (this.getApp().asOptional(SessionComponent)) {
const session = req.getSessionOptional();
if (session && req.method === 'GET') {
onFinished(res, (err) => {
if (err) return;
const contentType = res.getHeader('content-type');
if (res.statusCode === 200 &&
contentType && typeof contentType !== 'number' && contentType.indexOf('text/html') >= 0) {
session.previousUrl = req.originalUrl;
session.save((err) => {
if (err) logger.error(err, 'Error while saving session');
else logger.debug('Prev url set to', session.previousUrl);
});
}
});
}
}
next();
});
}
}

View File

@ -1,43 +0,0 @@
import ApplicationComponent from "../ApplicationComponent";
import {Express, Router} from "express";
import onFinished from "on-finished";
import Logger from "../Logger";
import {ServerError} from "../HttpError";
export default class RedirectBackComponent extends ApplicationComponent<void> {
public async start(app: Express, router: Router): Promise<void> {
router.use((req, res, next) => {
if (!req.session) {
throw new Error('Session is unavailable.');
}
onFinished(res, (err) => {
if (!err && res.statusCode === 200) {
req.session!.previousUrl = req.originalUrl;
Logger.debug('Prev url set to', req.session!.previousUrl);
req.session!.save((err) => {
if (err) {
Logger.error(err, 'Error while saving session');
}
});
}
});
res.redirectBack = (defaultUrl?: string) => {
if (req.session && typeof req.session.previousUrl === 'string') {
res.redirect(req.session.previousUrl);
} else if (typeof defaultUrl === 'string') {
res.redirect(defaultUrl);
} else {
throw new ServerError('There is no previous url and no default redirection url was provided.');
}
};
next();
});
}
public async stop(): Promise<void> {
}
}

View File

@ -1,31 +1,30 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, Router} from "express"; import {Express} from "express";
import redis, {RedisClient} from "redis"; import redis, {RedisClient} from "redis";
import config from "config"; import config from "config";
import Logger from "../Logger"; import {logger} from "../Logger";
import session, {Store} from "express-session"; import session, {Store} from "express-session";
import connect_redis from "connect-redis"; import CacheProvider from "../CacheProvider";
const RedisStore = connect_redis(session); export default class RedisComponent extends ApplicationComponent implements CacheProvider {
export default class RedisComponent extends ApplicationComponent<void> {
private redisClient?: RedisClient; private redisClient?: RedisClient;
private store?: Store; private store?: Store;
public async start(app: Express, router: Router): Promise<void> { public async start(_app: Express): Promise<void> {
this.redisClient = redis.createClient(config.get('redis.port'), config.get('redis.host'), {}); this.redisClient = redis.createClient(config.get('redis.port'), config.get('redis.host'), {
this.redisClient.on('error', (err: any) => { password: config.has('redis.password') ? config.get<string>('redis.password') : undefined,
Logger.error(err, 'An error occurred with redis.');
}); });
this.store = new RedisStore({ this.redisClient.on('error', (err: Error) => {
client: this.redisClient, logger.error(err, 'An error occurred with redis.');
prefix: 'wms-sess:',
}); });
this.store = new RedisStore(this);
} }
public async stop(): Promise<void> { public async stop(): Promise<void> {
if (this.redisClient) { const redisClient = this.redisClient;
await this.close('Redis connection', this.redisClient, this.redisClient.quit); if (redisClient) {
await this.close('Redis connection', callback => redisClient.quit(callback));
} }
} }
@ -37,4 +36,112 @@ export default class RedisComponent extends ApplicationComponent<void> {
public canServe(): boolean { public canServe(): boolean {
return this.redisClient !== undefined && this.redisClient.connected; return this.redisClient !== undefined && this.redisClient.connected;
} }
public async get<T extends string | undefined>(key: string, defaultValue?: T): Promise<T> {
return await new Promise<T>((resolve, reject) => {
if (!this.redisClient) {
reject(`Redis client was not initialized.`);
return;
}
this.redisClient.get(key, (err, val) => {
if (err) {
reject(err);
return;
}
resolve((val || defaultValue || undefined) as T);
});
});
}
public async has(key: string): Promise<boolean> {
return await this.get(key) !== undefined;
}
public async forget(key: string): Promise<void> {
return await new Promise<void>((resolve, reject) => {
if (!this.redisClient) {
reject(`Redis client was not initialized.`);
return;
}
this.redisClient.del(key, (err) => {
if (err) {
reject(err);
return;
}
resolve();
});
});
}
public async remember(key: string, value: string, ttl: number): Promise<void> {
return await new Promise<void>((resolve, reject) => {
if (!this.redisClient) {
reject(`Redis client was not initialized.`);
return;
}
this.redisClient.psetex(key, ttl, value, (err) => {
if (err) return reject(err);
resolve();
});
});
}
public async persist(key: string, ttl: number): Promise<void> {
return await new Promise<void>((resolve, reject) => {
if (!this.redisClient) {
reject(`Redis client was not initialized.`);
return;
}
this.redisClient.pexpire(key, ttl, (err) => {
if (err) return reject(err);
resolve();
});
});
}
}
class RedisStore extends Store {
private readonly redisComponent: RedisComponent;
public constructor(redisComponent: RedisComponent) {
super();
this.redisComponent = redisComponent;
}
public get(sid: string, callback: (err?: Error, session?: (session.SessionData | null)) => void): void {
this.redisComponent.get(`-session:${sid}`)
.then(value => {
if (value) {
this.redisComponent.persist(`-session:${sid}`, config.get<number>('session.cookie.maxAge'))
.then(() => {
callback(undefined, JSON.parse(value));
})
.catch(callback);
} else {
callback(undefined, null);
}
})
.catch(callback);
}
public set(sid: string, session: session.SessionData, callback?: (err?: Error) => void): void {
this.redisComponent.remember(`-session:${sid}`, JSON.stringify(session), config.get<number>('session.cookie.maxAge'))
.then(() => {
if (callback) callback();
})
.catch(callback);
}
public destroy(sid: string, callback?: (err?: Error) => void): void {
this.redisComponent.forget(`-session:${sid}`)
.then(() => {
if (callback) callback();
})
.catch(callback);
}
} }

View File

@ -1,26 +1,24 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import express, {Express, Router} from "express"; import express, {Router} from "express";
import {PathParams} from "express-serve-static-core"; import {PathParams} from "express-serve-static-core";
import * as path from "path";
export default class ServeStaticDirectoryComponent extends ApplicationComponent<void> { export default class ServeStaticDirectoryComponent extends ApplicationComponent {
private readonly root: string; private readonly root: string;
private readonly path?: PathParams; private readonly path?: PathParams;
constructor(root: string, routePath?: PathParams) { public constructor(root: string, routePath?: PathParams) {
super(); super();
this.root = root; this.root = path.join(__dirname, '../../../', root);
this.path = routePath; this.path = routePath;
} }
public async start(app: Express, router: Router): Promise<void> { public async init(router: Router): Promise<void> {
if (typeof this.path !== 'undefined') { if (this.path) {
router.use(this.path, express.static(this.root, {maxAge: 1000 * 3600 * 72})); router.use(this.path, express.static(this.root, {maxAge: 1000 * 3600 * 72}));
} else { } else {
router.use(express.static(this.root, {maxAge: 1000 * 3600 * 72})); router.use(express.static(this.root, {maxAge: 1000 * 3600 * 72}));
} }
} }
public async stop(): Promise<void> {
}
} }

View File

@ -3,23 +3,30 @@ import session from "express-session";
import config from "config"; import config from "config";
import RedisComponent from "./RedisComponent"; import RedisComponent from "./RedisComponent";
import flash from "connect-flash"; import flash from "connect-flash";
import {Express, Router} from "express"; import {Router} from "express";
import SecurityError from "../SecurityError";
export default class SessionComponent extends ApplicationComponent<void> { export default class SessionComponent extends ApplicationComponent {
private readonly storeComponent: RedisComponent; private readonly storeComponent: RedisComponent;
public constructor(storeComponent: RedisComponent) { public constructor(storeComponent: RedisComponent) {
super(); super();
this.storeComponent = storeComponent; this.storeComponent = storeComponent;
} }
public async start(app: Express, router: Router): Promise<void> { public async checkSecuritySettings(): Promise<void> {
this.checkSecurityConfigField('session.secret');
if (!config.get<boolean>('session.cookie.secure')) {
throw new SecurityError('Cannot set cookie secure field to false.');
}
}
public async init(router: Router): Promise<void> {
router.use(session({ router.use(session({
saveUninitialized: true, saveUninitialized: true,
secret: config.get('session.secret'), secret: config.get('session.secret'),
store: this.storeComponent.getStore(), store: this.storeComponent.getStore(),
resave: true, resave: false,
cookie: { cookie: {
httpOnly: true, httpOnly: true,
secure: config.get('session.cookie.secure'), secure: config.get('session.cookie.secure'),
@ -30,28 +37,61 @@ export default class SessionComponent extends ApplicationComponent<void> {
router.use(flash()); router.use(flash());
router.use((req, res, next) => { router.use((req, res, next) => {
if (!req.session) { // Request session getters
throw new Error('Session is unavailable.'); req.getSessionOptional = () => {
return req.session;
};
req.getSession = () => {
const session = req.getSessionOptional();
if (!session) throw new Error('Session not initialized.');
return session;
};
// Session persistence
const session = req.getSession();
if (session.persistent) {
session.cookie.maxAge = config.get('session.cookie.maxAge');
} else {
session.cookie.maxAge = session.cookie.expires = undefined;
} }
res.locals.session = req.session; // Views session local
res.locals.session = session;
let _flash: any = null; // Views flash function
res.locals.flash = () => { const _flash: FlashStorage = {};
if (!_flash) { res.locals.flash = (key?: string): FlashMessages | unknown[] => {
_flash = { if (key !== undefined) {
if (_flash[key] === undefined) _flash[key] = req.flash(key);
return _flash[key] || [];
}
if (_flash._messages === undefined) {
_flash._messages = {
info: req.flash('info'), info: req.flash('info'),
success: req.flash('success'), success: req.flash('success'),
warning: req.flash('warning'), warning: req.flash('warning'),
error: req.flash('error'), error: req.flash('error'),
}; };
} }
return _flash; return _flash._messages;
}; };
next(); next();
}); });
} }
public async stop(): Promise<void> {
}
} }
export type FlashMessages = {
[k: string]: unknown[] | undefined
};
export type DefaultFlashMessages = FlashMessages & {
info?: unknown[] | undefined;
success?: unknown[] | undefined;
warning?: unknown[] | undefined;
error?: unknown[] | undefined;
};
type FlashStorage = FlashMessages & {
_messages?: DefaultFlashMessages,
};

View File

@ -1,7 +1,7 @@
import ApplicationComponent from "../ApplicationComponent"; import ApplicationComponent from "../ApplicationComponent";
import {Express, Request, Router} from "express"; import {Express, Request} from "express";
import WebSocket, {Server as WebSocketServer} from "ws"; import WebSocket, {Server as WebSocketServer} from "ws";
import Logger from "../Logger"; import {logger} from "../Logger";
import cookie from "cookie"; import cookie from "cookie";
import cookieParser from "cookie-parser"; import cookieParser from "cookie-parser";
import config from "config"; import config from "config";
@ -9,29 +9,28 @@ import ExpressAppComponent from "./ExpressAppComponent";
import Application from "../Application"; import Application from "../Application";
import RedisComponent from "./RedisComponent"; import RedisComponent from "./RedisComponent";
import WebSocketListener from "../WebSocketListener"; import WebSocketListener from "../WebSocketListener";
import NunjucksComponent from "./NunjucksComponent";
export default class WebSocketServerComponent extends ApplicationComponent<void> { export default class WebSocketServerComponent extends ApplicationComponent {
private readonly application: Application;
private readonly expressAppComponent: ExpressAppComponent;
private readonly storeComponent: RedisComponent;
private wss?: WebSocket.Server; private wss?: WebSocket.Server;
constructor(application: Application, expressAppComponent: ExpressAppComponent, storeComponent: RedisComponent) { public constructor(
private readonly application: Application,
private readonly expressAppComponent: ExpressAppComponent,
private readonly storeComponent: RedisComponent,
private readonly nunjucksComponent?: NunjucksComponent,
) {
super(); super();
this.expressAppComponent = expressAppComponent;
this.application = application;
this.storeComponent = storeComponent;
} }
public async start(app: Express, router: Router): Promise<void> { public async start(_app: Express): Promise<void> {
const listeners: { [p: string]: WebSocketListener } = this.application.getWebSocketListeners(); const listeners: { [p: string]: WebSocketListener<Application> } = this.application.getWebSocketListeners();
this.wss = new WebSocketServer({ this.wss = new WebSocketServer({
server: this.expressAppComponent.getServer(), server: this.expressAppComponent.getServer(),
}, () => { }, () => {
Logger.info(`Websocket server started over webserver.`); logger.info(`Websocket server started over webserver.`);
}).on('error', (err) => { }).on('error', (err) => {
Logger.error(err, 'An error occurred in the websocket server.'); logger.error(err, 'An error occurred in the websocket server.');
}).on('connection', (socket, request) => { }).on('connection', (socket, request) => {
const listener = request.url ? listeners[request.url] : null; const listener = request.url ? listeners[request.url] : null;
@ -39,11 +38,13 @@ export default class WebSocketServerComponent extends ApplicationComponent<void>
socket.close(1002, `Path not found ${request.url}`); socket.close(1002, `Path not found ${request.url}`);
return; return;
} else if (!request.headers.cookie) { } else if (!request.headers.cookie) {
socket.close(1002, `Can't process request without cookies.`); listener.handle(socket, request, null).catch(err => {
logger.error(err, 'Error in websocket listener.');
});
return; return;
} }
Logger.debug(`Websocket on ${request.url}`); logger.debug(`Websocket on ${request.url}`);
const cookies = cookie.parse(request.headers.cookie); const cookies = cookie.parse(request.headers.cookie);
const sid = cookieParser.signedCookie(cookies['connect.sid'], config.get('session.secret')); const sid = cookieParser.signedCookie(cookies['connect.sid'], config.get('session.secret'));
@ -56,7 +57,7 @@ export default class WebSocketServerComponent extends ApplicationComponent<void>
const store = this.storeComponent.getStore(); const store = this.storeComponent.getStore();
store.get(sid, (err, session) => { store.get(sid, (err, session) => {
if (err || !session) { if (err || !session) {
Logger.error(err, 'Error while initializing session in websocket.'); logger.error(err, 'Error while initializing session in websocket.');
socket.close(1011); socket.close(1011);
return; return;
} }
@ -64,16 +65,22 @@ export default class WebSocketServerComponent extends ApplicationComponent<void>
session.id = sid; session.id = sid;
store.createSession(<Request>request, session); store.createSession(<Request>request, session);
listener.handle(socket, request, session).catch(err => { listener.handle(socket, request, (<Request>request).session).catch(err => {
Logger.error(err, 'Error in websocket listener.'); logger.error(err, 'Error in websocket listener.');
}); });
}); });
}); });
const env = this.nunjucksComponent?.getEnvironment();
if (env) {
env.addGlobal('websocketUrl', config.get('public_websocket_url'));
}
} }
public async stop(): Promise<void> { public async stop(): Promise<void> {
if (this.wss) { const wss = this.wss;
await this.close('WebSocket server', this.wss, this.wss.close); if (wss) {
await this.close('WebSocket server', callback => wss.close(callback));
} }
} }
} }

View File

@ -1,15 +1,39 @@
import {Connection} from "mysql";
import MysqlConnectionManager from "./MysqlConnectionManager";
import {Type} from "../Utils";
export default abstract class Migration { export default abstract class Migration {
public readonly version: number; public readonly version: number;
private currentConnection?: Connection;
constructor(version: number) { public constructor(version: number) {
this.version = version; this.version = version;
} }
async shouldRun(currentVersion: number): Promise<boolean> { public async shouldRun(currentVersion: number): Promise<boolean> {
return this.version > currentVersion; return this.version > currentVersion;
} }
abstract async install(): Promise<void>; public abstract install(): Promise<void>;
abstract async rollback(): Promise<void>; public abstract rollback(): Promise<void>;
public registerModels?(): void;
protected async query(queryString: string): Promise<void> {
await MysqlConnectionManager.query(queryString, undefined, this.getCurrentConnection());
}
protected getCurrentConnection(): Connection {
if (!this.currentConnection) throw new Error('No current connection set.');
return this.currentConnection;
}
public setCurrentConnection(connection: Connection | null): void {
this.currentConnection = connection || undefined;
}
}
export interface MigrationType<M extends Migration> extends Type<M> {
new(version: number): M;
} }

View File

@ -1,311 +1,245 @@
import MysqlConnectionManager, {query} from "./MysqlConnectionManager"; import MysqlConnectionManager from "./MysqlConnectionManager";
import Validator from "./Validator"; import Validator from "./Validator";
import {Connection} from "mysql"; import {Connection} from "mysql";
import Query from "./Query"; import ModelComponent from "./ModelComponent";
import {Type} from "../Utils";
import ModelFactory, {PrimaryKeyValue} from "./ModelFactory";
import ModelRelation from "./ModelRelation";
import ModelQuery, {ModelFieldData, ModelQueryResult, QueryFields} from "./ModelQuery";
import {Request} from "express"; import {Request} from "express";
import Pagination from "../Pagination"; import Extendable from "../Extendable";
export default abstract class Model { export default abstract class Model implements Extendable<ModelComponent<Model>> {
public static async getById<T extends Model>(id: number): Promise<T | null> { public static get table(): string {
const cachedModel = ModelCache.get(this.table, id); const single = this.name
if (cachedModel?.constructor === this) { .replace(/(?:^|\.?)([A-Z])/g, (x, y) => '_' + y.toLowerCase())
return <T>cachedModel; .replace(/^_/, '');
} return single + 's';
const models = await this.models<T>(this.select().where('id', id).first());
return models.length > 0 ? models[0] : null;
} }
public static async paginate<T extends Model>(request: Request, perPage: number = 20): Promise<T[]> { public static getPrimaryKeyFields(): string[] {
let page = request.params.page ? parseInt(request.params.page) : 1; return ['id'];
let query: Query = this.select().limit(perPage, (page - 1) * perPage).withTotalRowCount();
if (request.params.sortBy) {
const dir = request.params.sortDirection;
query = query.sortBy(request.params.sortBy, dir === 'ASC' || dir === 'DESC' ? dir : undefined);
} else {
query = query.sortBy('id');
}
const models = await this.models<T>(query);
// @ts-ignore
models.pagination = new Pagination(models, page, perPage, models.totalCount);
return models;
} }
protected static select(...fields: string[]): Query { public static create<M extends Model>(this: ModelType<M>, data: Pick<M, keyof M>): M {
return Query.select(this.table, ...fields); return ModelFactory.get(this).create(data, true);
} }
protected static update(data: { [key: string]: any }): Query { public static select<M extends Model>(this: ModelType<M>, ...fields: QueryFields): ModelQuery<M> {
return Query.update(this.table, data); return ModelFactory.get(this).select(...fields);
} }
protected static delete(): Query { public static update<M extends Model>(this: ModelType<M>, data: Pick<M, keyof M>): ModelQuery<M> {
return Query.delete(this.table); return ModelFactory.get(this).update(data);
} }
protected static async models<T extends Model>(query: Query): Promise<T[]> { public static delete<M extends Model>(this: ModelType<M>): ModelQuery<M> {
const results = await query.execute(); return ModelFactory.get(this).delete();
const models: T[] = []; }
const factory = this.getFactory<T>();
for (const result of results.results) { public static async getById<M extends Model>(this: ModelType<M>, ...id: PrimaryKeyValue[]): Promise<M | null> {
const cachedModel = ModelCache.get(this.table, result.id); return await ModelFactory.get(this).getById(...id);
if (cachedModel && cachedModel.constructor === this) { }
cachedModel.updateWithData(result);
models.push(<T>cachedModel); public static async paginate<M extends Model>(
} else { this: ModelType<M>,
models.push(factory(result)); request: Request,
perPage: number = 20,
query?: ModelQuery<M>,
): Promise<ModelQueryResult<M>> {
return await ModelFactory.get(this).paginate(request, perPage, query);
}
protected readonly _factory: ModelFactory<Model>;
private readonly _components: ModelComponent<this>[] = [];
private readonly _validators: { [K in keyof this]?: Validator<this[K]> | undefined } = {};
private _exists: boolean;
[key: string]: ModelFieldData;
public constructor(factory: ModelFactory<never>, isNew: boolean) {
if (!(factory instanceof ModelFactory)) throw new Error('Cannot instantiate model directly.');
this._factory = factory;
this.init?.();
this._exists = !isNew;
}
protected init?(): void;
protected setValidation<K extends keyof this>(propertyName: K): Validator<this[K]> {
const validator = new Validator<this[K]>();
this._validators[propertyName] = validator;
return validator;
}
public addComponent(modelComponent: ModelComponent<this>): void {
modelComponent.applyToModel();
this._components.push(modelComponent);
}
public as<C extends ModelComponent<Model>>(type: Type<C>): C {
for (const component of this._components) {
if (component instanceof type) {
return this as unknown as C;
} }
} }
// @ts-ignore
models.totalCount = results.foundRows; throw new Error(`Component ${type.name} was not initialized for this ${this.constructor.name}.`);
return models;
} }
public static async loadRelation<T extends Model>(models: T[], relation: string, model: Function, localField: string) { public asOptional<C extends ModelComponent<Model>>(type: Type<C>): C | null {
const loadMap: { [p: number]: (model: T) => void } = {}; for (const component of this._components) {
const ids = models.map(m => { if (component instanceof type) {
m.relations[relation] = null; return this as unknown as C;
if (m[localField]) loadMap[m[localField]] = v => m.relations[relation] = v; }
return m[localField];
}).filter(id => id);
for (const v of await (<any>model).models((<any>model).select().whereIn('id', ids))) {
loadMap[v.id!](v);
}
}
private static getFactory<T extends Model>(factory?: ModelFactory<T>): ModelFactory<T> {
if (factory === undefined) {
factory = (<any>this).FACTORY;
if (factory === undefined) factory = data => new (<any>this)(data);
}
return factory;
}
protected readonly properties: ModelProperty<any>[] = [];
private readonly relations: { [p: string]: (Model | null) } = {};
public id?: number;
[key: string]: any;
public constructor(data: any) {
this.defineProperty<number>('id', new Validator());
this.defineProperties();
this.updateWithData(data);
}
protected abstract defineProperties(): void;
protected defineProperty<T>(name: string, validator?: Validator<T> | RegExp) {
if (validator === undefined) validator = new Validator();
if (validator instanceof RegExp) {
const regexp = validator;
validator = new Validator().regexp(regexp);
} }
const prop = new ModelProperty<T>(name, validator); return null;
this.properties.push(prop);
Object.defineProperty(this, name, {
get: () => prop.value,
set: (value: T) => prop.value = value,
});
} }
private updateWithData(data: any) { public updateWithData(data: Pick<this, keyof this> | Record<string, unknown>): void {
this.id = data['id']; for (const property of this._properties) {
if (data[property] !== undefined) {
for (const prop of this.properties) { this[property] = data[property] as this[keyof this & string];
if (data[prop.name] !== undefined) {
this[prop.name] = data[prop.name];
} }
} }
} }
protected async beforeSave(exists: boolean, connection: Connection): Promise<void> { /**
} * Override this to automatically fill obvious missing data i.e. from relation or default value that are fetched
* asynchronously.
*/
protected async autoFill?(): Promise<void>;
protected async afterSave(): Promise<void> { protected async beforeSave?(connection: Connection): Promise<void>;
}
protected async afterSave?(): Promise<void>;
public async save(connection?: Connection, postHook?: (callback: () => Promise<void>) => void): Promise<void> { public async save(connection?: Connection, postHook?: (callback: () => Promise<void>) => void): Promise<void> {
if (connection && !postHook) throw new Error('If connection is provided, postHook must be provided too.');
await this.autoFill?.();
await this.validate(false, connection); await this.validate(false, connection);
const exists = await this.exists(); const needs_full_update = connection ?
let needs_full_update = false; await this.saveTransaction(connection) :
await MysqlConnectionManager.wrapTransaction(async connection => await this.saveTransaction(connection));
if (connection) {
needs_full_update = await this.saveTransaction(connection, exists, needs_full_update);
} else {
needs_full_update = await MysqlConnectionManager.wrapTransaction(async connection => this.saveTransaction(connection, exists, needs_full_update));
}
const callback = async () => { const callback = async () => {
if (needs_full_update) { if (needs_full_update) {
this.updateWithData((await (<Model><unknown>this.constructor).select().where('id', this.id!).first().execute()).results[0]); const query = this._factory.select();
for (const field of this._factory.getPrimaryKeyFields()) {
query.where(field, this[field]);
}
query.limit(1);
const result = await query.execute(connection);
this.updateWithData(result.results[0]);
} }
if (!exists) { await this.afterSave?.();
this.cache();
}
await this.afterSave();
}; };
if (connection) { if (postHook) {
postHook!(callback); postHook(callback);
} else { } else {
await callback(); await callback();
} }
} }
private async saveTransaction(connection: Connection, exists: boolean, needs_full_update: boolean): Promise<boolean> { private async saveTransaction(connection: Connection): Promise<boolean> {
// Before save // Before save
await this.beforeSave(exists, connection); await this.beforeSave?.(connection);
if (exists && this.hasOwnProperty('updated_at')) { if (!this.exists() && this.hasProperty('created_at')) {
this.created_at = new Date();
}
if (this.exists() && this.hasProperty('updated_at')) {
this.updated_at = new Date(); this.updated_at = new Date();
} }
const props = []; let needsFullUpdate = false;
const values = [];
if (exists) { const data: { [K in keyof this]?: this[K] } = {};
for (const prop of this.properties) { for (const property of this._properties) {
if (prop.value !== undefined) { const value = this[property];
props.push(prop.name + '=?');
values.push(prop.value);
} else {
needs_full_update = true;
}
}
values.push(this.id);
await query(`UPDATE ${this.table} SET ${props.join(',')} WHERE id=?`, values, connection);
} else {
const props_holders = [];
for (const prop of this.properties) {
if (prop.value !== undefined) {
props.push(prop.name);
props_holders.push('?');
values.push(prop.value);
} else {
needs_full_update = true;
}
}
const result = await query(`INSERT INTO ${this.table} (${props.join(', ')}) VALUES(${props_holders.join(', ')})`, values, connection);
this.id = result.other.insertId; if (value === undefined) needsFullUpdate = true;
else data[property] = value;
} }
return needs_full_update; if (this.exists()) {
} const query = this._factory.update(data);
for (const indexField of this._factory.getPrimaryKeyFields()) {
query.where(indexField, this[indexField]);
}
await query.execute(connection);
} else {
const query = this._factory.insert(data);
const result = await query.execute(connection);
public static get table(): string { if (this.hasProperty('id')) this.id = Number(result.other?.insertId);
return this.name this._exists = true;
.replace(/(?:^|\.?)([A-Z])/g, (x, y) => '_' + y.toLowerCase()) }
.replace(/^_/, '')
+ 's';
}
public get table(): string { return needsFullUpdate;
// @ts-ignore
return this.constructor.table;
}
public async exists(): Promise<boolean> {
if (!this.id) return false;
const result = await query(`SELECT 1 FROM ${this.table} WHERE id=? LIMIT 1`, [
this.id,
]);
return result.results.length > 0;
} }
public async delete(): Promise<void> { public async delete(): Promise<void> {
if (!(await this.exists())) throw new Error('This model instance doesn\'t exist in DB.'); if (!await this.exists()) throw new Error('This model instance doesn\'t exist in DB.');
await query(`DELETE FROM ${this.table} WHERE id=?`, [ const query = this._factory.delete();
this.id, for (const indexField of this._factory.getPrimaryKeyFields()) {
]); query.where(indexField, this[indexField]);
ModelCache.forget(this); }
this.id = undefined; await query.execute();
this._exists = false;
} }
public async validate(onlyFormat: boolean = false, connection?: Connection): Promise<void[]> { public async validate(onlyFormat: boolean = false, connection?: Connection): Promise<void[]> {
return await Promise.all(this.properties.map(prop => prop.validate(onlyFormat, connection))); return await Promise.all(this._properties.map(
prop => this._validators[prop]?.execute(prop, this[prop], onlyFormat, connection),
));
} }
private cache() { public exists(): boolean {
ModelCache.cache(this); return this._exists;
} }
protected relation<T extends Model>(name: string): T | null { public equals(model: this): boolean {
if (this.relations[name] === undefined) throw new Error('Model not loaded'); for (const field of this._factory.getPrimaryKeyFields()) {
return <T | null>this.relations[name]; if (this[field] !== model[field]) return false;
}
}
export interface ModelFactory<T extends Model> {
(data: any): T;
}
class ModelProperty<T> {
public readonly name: string;
private readonly validator: Validator<T>;
private val?: T;
constructor(name: string, validator: Validator<T>) {
this.name = name;
this.validator = validator;
}
public async validate(onlyFormat: boolean, connection?: Connection): Promise<void> {
return await this.validator.execute(this.name, this.value, onlyFormat, connection);
}
public get value(): T | undefined {
return this.val;
}
public set value(val: T | undefined) {
this.val = val;
}
}
export class ModelCache {
private static readonly caches: {
[key: string]: {
[key: number]: Model
} }
} = {}; return true;
public static cache(instance: Model) {
if (instance.id === undefined) throw new Error('Cannot cache an instance with an undefined id.');
let tableCache = this.caches[instance.table];
if (!tableCache) tableCache = this.caches[instance.table] = {};
if (!tableCache[instance.id]) tableCache[instance.id] = instance;
} }
public static forget(instance: Model) { public get table(): string {
if (instance.id === undefined) throw new Error('Cannot forget an instance with an undefined id.'); return this._factory.table;
let tableCache = this.caches[instance.table];
if (!tableCache) return;
if (tableCache[instance.id]) delete tableCache[instance.id];
} }
public static all(table: string): { private get _properties(): (keyof this & string)[] {
[key: number]: Model return Object.getOwnPropertyNames(this).filter(p => {
} | undefined { return !p.startsWith('_') &&
return this.caches[table]; typeof this[p] !== 'function' &&
!(this[p] instanceof ModelRelation);
});
} }
public static get(table: string, id: number): Model | undefined { private hasProperty(key: string | number | symbol): key is keyof this {
const tableCache = this.all(table); return typeof key === 'string' && this._properties.indexOf(key) >= 0;
if (!tableCache) return undefined; }
return tableCache[id];
public getOrFail<K extends keyof this & string>(k: K): NonNullable<this[K]> {
if (!this[k]) throw new Error(k + ' not initialized.');
return this[k] as NonNullable<this[K]>;
} }
} }
export const EMAIL_REGEX = /^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+$/; export interface ModelType<M extends Model> extends Type<M> {
table: string;
new(factory: ModelFactory<never>, isNew: boolean): M;
getPrimaryKeyFields(): (keyof M & string)[];
select<M extends Model>(this: ModelType<M>, ...fields: QueryFields): ModelQuery<M>;
}

48
src/db/ModelComponent.ts Normal file
View File

@ -0,0 +1,48 @@
import Model from "./Model";
import Validator from "./Validator";
import {getMethods} from "../Utils";
import {ModelFieldData} from "./ModelQuery";
export default abstract class ModelComponent<M extends Model> {
protected readonly _model: M;
private readonly _validators: { [K in keyof this]?: Validator<this[K]> } = {};
[key: string]: ModelFieldData;
public constructor(model: M) {
this._model = model;
}
public applyToModel(): void {
this.init?.();
const model = this._model as Model;
for (const property of this._properties) {
if (!property.startsWith('_')) {
model[property] = this[property];
model['_validators'][property] = this._validators[property] as Validator<ModelFieldData> | undefined;
}
}
for (const method of getMethods(this)) {
if (!method.startsWith('_') &&
['init', 'setValidation'].indexOf(method) < 0 &&
model[method] === undefined) {
model[method] = this[method];
}
}
}
protected init?(): void;
protected setValidation<K extends keyof this>(propertyName: K): Validator<this[K]> {
const validator = new Validator<this[K]>();
this._validators[propertyName] = validator;
return validator;
}
private get _properties(): string[] {
return Object.getOwnPropertyNames(this).filter(p => !p.startsWith('_'));
}
}

104
src/db/ModelFactory.ts Normal file
View File

@ -0,0 +1,104 @@
import ModelComponent from "./ModelComponent";
import Model, {ModelType} from "./Model";
import ModelQuery, {ModelQueryResult, QueryFields} from "./ModelQuery";
import {Request} from "express";
export default class ModelFactory<M extends Model> {
private static readonly factories: { [modelType: string]: ModelFactory<Model> | undefined } = {};
public static register<M extends Model>(modelType: ModelType<M>): void {
if (this.factories[modelType.name]) throw new Error(`Factory for type ${modelType.name} already defined.`);
this.factories[modelType.name] = new ModelFactory<M>(modelType) as unknown as ModelFactory<Model>;
}
public static get<M extends Model>(modelType: ModelType<M>): ModelFactory<M> {
const factory = this.factories[modelType.name];
if (!factory) throw new Error(`No factory registered for ${modelType.name}.`);
return factory as unknown as ModelFactory<M>;
}
public static has<M extends Model>(modelType: ModelType<M>): boolean {
return !!this.factories[modelType.name];
}
private readonly modelType: ModelType<M>;
private readonly components: ModelComponentFactory<M>[] = [];
protected constructor(modelType: ModelType<M>) {
this.modelType = modelType;
}
public addComponent(modelComponentFactory: ModelComponentFactory<M>): void {
this.components.push(modelComponentFactory);
}
public hasComponent(modelComponentFactory: ModelComponentFactory<M>): boolean {
return !!this.components.find(c => c === modelComponentFactory);
}
public create(data: Pick<M, keyof M>, isNewModel: boolean): M {
const model = new this.modelType(this as unknown as ModelFactory<never>, isNewModel);
for (const component of this.components) {
model.addComponent(new component(model));
}
model.updateWithData(data);
return model;
}
public get table(): string {
return this.modelType.table;
}
public select(...fields: QueryFields): ModelQuery<M> {
return ModelQuery.select(this, ...fields);
}
public insert(data: Pick<M, keyof M>): ModelQuery<M> {
return ModelQuery.insert(this, data);
}
public update(data: Pick<M, keyof M>): ModelQuery<M> {
return ModelQuery.update(this, data);
}
public delete(): ModelQuery<M> {
return ModelQuery.delete(this);
}
public getPrimaryKeyFields(): (keyof M & string)[] {
return this.modelType.getPrimaryKeyFields();
}
public getPrimaryKey(modelData: Pick<M, keyof M>): Pick<M, keyof M>[keyof M & string][] {
return this.getPrimaryKeyFields().map(f => modelData[f]);
}
public getPrimaryKeyString(modelData: Pick<M, keyof M>): string {
return this.getPrimaryKey(modelData).join(',');
}
public async getById(...id: PrimaryKeyValue[]): Promise<M | null> {
let query = this.select();
const primaryKeyFields = this.getPrimaryKeyFields();
for (let i = 0; i < primaryKeyFields.length; i++) {
query = query.where(primaryKeyFields[i], id[i]);
}
return await query.first();
}
public async paginate(request: Request, perPage: number = 20, query?: ModelQuery<M>): Promise<ModelQueryResult<M>> {
const page = request.params.page ? parseInt(request.params.page) : 1;
if (!query) query = this.select();
if (request.params.sortBy) {
const dir = request.params.sortDirection;
query = query.sortBy(request.params.sortBy, dir === 'ASC' || dir === 'DESC' ? dir : undefined);
} else {
query = query.sortBy('id');
}
return await query.paginate(page, perPage);
}
}
export type ModelComponentFactory<M extends Model> = new (model: M) => ModelComponent<M>;
export type PrimaryKeyValue = string | number | boolean | null | undefined;

557
src/db/ModelQuery.ts Normal file
View File

@ -0,0 +1,557 @@
import {isQueryVariable, query, QueryResult, QueryVariable} from "./MysqlConnectionManager";
import {Connection} from "mysql";
import Model from "./Model";
import Pagination from "../Pagination";
import ModelRelation, {RelationDatabaseProperties} from "./ModelRelation";
import ModelFactory from "./ModelFactory";
export default class ModelQuery<M extends Model> implements WhereFieldConsumer<M> {
public static select<M extends Model>(factory: ModelFactory<M>, ...fields: QueryFields): ModelQuery<M> {
fields = fields.map(v => v === '' ? new SelectFieldValue('none', 1, true) : v);
return new ModelQuery(QueryType.SELECT, factory, fields.length > 0 ? fields : ['*']);
}
public static insert<M extends Model>(factory: ModelFactory<M>, data: Pick<M, keyof M>): ModelQuery<M> {
const fields = [];
for (const key of Object.keys(data)) {
fields.push(new FieldValue(key, data[key], false));
}
return new ModelQuery(QueryType.INSERT, factory, fields);
}
public static update<M extends Model>(factory: ModelFactory<M>, data: Pick<M, keyof M>): ModelQuery<M> {
const fields = [];
for (const key of Object.keys(data)) {
fields.push(new FieldValue(inputToFieldOrValue(key, factory.table), data[key], false));
}
return new ModelQuery(QueryType.UPDATE, factory, fields);
}
public static delete<M extends Model>(factory: ModelFactory<M>): ModelQuery<M> {
return new ModelQuery(QueryType.DELETE, factory);
}
private readonly type: QueryType;
private readonly factory: ModelFactory<M>;
private readonly table: string;
private readonly fields: QueryFields;
private _leftJoin?: string;
private _leftJoinAlias?: string;
private _leftJoinOn: WhereFieldValue[] = [];
private _where: (WhereFieldValue | WhereFieldValueGroup)[] = [];
private _limit?: number;
private _offset?: number;
private _sortBy?: string;
private _sortDirection?: 'ASC' | 'DESC';
private readonly relations: string[] = [];
private readonly subRelations: { [relation: string]: string[] | undefined } = {};
private _pivot?: string[];
private _union?: ModelQueryUnion;
private _recursiveRelation?: RelationDatabaseProperties;
private _reverseRecursiveRelation?: boolean;
private constructor(type: QueryType, factory: ModelFactory<M>, fields?: QueryFields) {
this.type = type;
this.factory = factory;
this.table = factory.table;
this.fields = fields || [];
}
public leftJoin(table: string, alias?: string): this {
this._leftJoin = table;
this._leftJoinAlias = alias;
return this;
}
public on(
field1: string,
field2: string,
test: WhereTest = WhereTest.EQ,
operator: WhereOperator = WhereOperator.AND,
): this {
this._leftJoinOn.push(new WhereFieldValue(
inputToFieldOrValue(field1), inputToFieldOrValue(field2), true, test, operator,
));
return this;
}
public where(
field: string,
value: ModelFieldData,
test: WhereTest = WhereTest.EQ,
operator: WhereOperator = WhereOperator.AND,
): this {
this._where.push(new WhereFieldValue(field, value, false, test, operator));
return this;
}
public groupWhere(
setter: (query: WhereFieldConsumer<M>) => void,
operator: WhereOperator = WhereOperator.AND,
): this {
this._where.push(new WhereFieldValueGroup(this.collectWheres(setter), operator));
return this;
}
private collectWheres(setter: (query: WhereFieldConsumer<M>) => void): (WhereFieldValue | WhereFieldValueGroup)[] {
// eslint-disable-next-line @typescript-eslint/no-this-alias
const query = this;
const wheres: (WhereFieldValue | WhereFieldValueGroup)[] = [];
setter({
where(
field: string,
value: ModelFieldData,
test: WhereTest = WhereTest.EQ,
operator: WhereOperator = WhereOperator.AND,
) {
wheres.push(new WhereFieldValue(field, value, false, test, operator));
return this;
},
groupWhere(
setter: (query: WhereFieldConsumer<M>) => void,
operator: WhereOperator = WhereOperator.AND,
) {
wheres.push(new WhereFieldValueGroup(query.collectWheres(setter), operator));
return this;
},
});
return wheres;
}
public limit(limit: number, offset: number = 0): this {
this._limit = limit;
this._offset = offset;
return this;
}
public sortBy(field: string, direction: SortDirection = 'ASC', raw: boolean = false): this {
this._sortBy = raw ? field : inputToFieldOrValue(field);
this._sortDirection = direction;
return this;
}
/**
* @param relations The relations field names to eagerload. To load nested relations, separate fields with '.'
* (i.e.: "author.roles.permissions" loads authors, their roles, and the permissions of these roles)
*/
public with(...relations: string[]): this {
relations.forEach(relation => {
const parts = relation.split('.');
if (this.relations.indexOf(parts[0]) < 0) this.relations.push(parts[0]);
if (parts.length > 1) {
if (!this.subRelations[parts[0]]) this.subRelations[parts[0]] = [];
this.subRelations[parts[0]]?.push(parts.slice(1).join('.'));
}
});
return this;
}
public pivot(...fields: string[]): this {
this._pivot = fields;
return this;
}
public union(
query: ModelQuery<Model>,
sortBy: string, direction: SortDirection = 'ASC',
raw: boolean = false,
limit?: number,
offset?: number,
): this {
if (this.type !== QueryType.SELECT) throw new Error('Union queries are only implemented with SELECT.');
this._union = {
query: query,
sortBy: raw ? sortBy : inputToFieldOrValue(sortBy),
direction: direction,
limit: limit,
offset: offset,
};
return this;
}
public recursive(relation: RelationDatabaseProperties, reverse: boolean): this {
if (this.type !== QueryType.SELECT) throw new Error('Recursive queries are only implemented with SELECT.');
this._recursiveRelation = relation;
this._reverseRecursiveRelation = reverse;
return this;
}
public toString(final: boolean = false): string {
let query = '';
if (this._pivot) this.fields.push(...this._pivot);
// Prevent wildcard and fields from conflicting
const fields = this.fields.map(f => {
const field = f.toString();
if (field.startsWith('(')) return f; // Skip sub-queries
return inputToFieldOrValue(field, this.table);
}).join(',');
let join = '';
if (this._leftJoin) {
join = ` LEFT JOIN \`${this._leftJoin}\``
+ (this._leftJoinAlias ? ` AS \`${this._leftJoinAlias}\`` : '')
+ ` ON ${this._leftJoinOn[0]}`;
for (let i = 1; i < this._leftJoinOn.length; i++) {
join += this._leftJoinOn[i].toString(false);
}
}
let where = '';
if (this._where.length > 0) {
where = ` WHERE ${this._where[0]}`;
for (let i = 1; i < this._where.length; i++) {
where += this._where[i].toString(false);
}
}
let limit = '';
if (typeof this._limit === 'number') {
limit = ` LIMIT ${this._limit}`;
if (typeof this._offset === 'number' && this._offset !== 0) {
limit += ` OFFSET ${this._offset}`;
}
}
let orderBy = '';
if (typeof this._sortBy === 'string') {
orderBy = ` ORDER BY ${this._sortBy} ${this._sortDirection}`;
}
const table = `\`${this.table}\``;
switch (this.type) {
case QueryType.SELECT:
if (this._recursiveRelation) {
const cteFields = fields.replace(RegExp(`${table}`, 'g'), 'o');
const idKey = this._reverseRecursiveRelation ?
this._recursiveRelation.foreignKey :
this._recursiveRelation.localKey;
const sortOrder = this._reverseRecursiveRelation ? 'DESC' : 'ASC';
query = `WITH RECURSIVE cte AS (`
+ `SELECT ${fields},1 AS __depth, CONCAT(\`${idKey}\`) AS __path FROM ${table}${where}`
+ ` UNION `
+ `SELECT ${cteFields},c.__depth + 1,CONCAT(c.__path,'/',o.\`${idKey}\`) AS __path FROM ${table} AS o, cte AS c WHERE o.\`${this._recursiveRelation.foreignKey}\`=c.\`${this._recursiveRelation.localKey}\``
+ `) SELECT * FROM cte${join}${orderBy || ` ORDER BY __path ${sortOrder}`}${limit}`;
} else {
query = `SELECT ${fields} FROM ${table}${join}${where}${orderBy}${limit}`;
}
if (this._union) {
const unionOrderBy = this._union.sortBy ? ` ORDER BY ${this._union.sortBy} ${this._union.direction}` : '';
const unionLimit = typeof this._union.limit === 'number' ? ` LIMIT ${this._union.limit}` : '';
const unionOffset = typeof this._union.offset === 'number' ? ` OFFSET ${this._union.offset}` : '';
query = `(${query}) UNION ${this._union.query.toString(false)}${unionOrderBy}${unionLimit}${unionOffset}`;
}
break;
case QueryType.INSERT: {
const insertFields = this.fields.filter(f => f instanceof FieldValue)
.map(f => f as FieldValue);
const insertFieldNames = insertFields.map(f => f.fieldName).join(',');
const insertFieldValues = insertFields.map(f => f.fieldValue).join(',');
query = `INSERT INTO ${table} (${insertFieldNames}) VALUES(${insertFieldValues})`;
break;
}
case QueryType.UPDATE:
query = `UPDATE ${table} SET ${fields}${where}${orderBy}${limit}`;
break;
case QueryType.DELETE:
query = `DELETE FROM ${table}${where}${orderBy}${limit}`;
break;
}
return final ? query : `(${query})`;
}
public build(): string {
return this.toString(true);
}
public get variables(): QueryVariable[] {
const variables: QueryVariable[] = [];
this.fields.filter(v => v instanceof FieldValue)
.flatMap(v => (v as FieldValue).variables)
.forEach(v => variables.push(v));
this._where.flatMap(v => this.getVariables(v))
.forEach(v => variables.push(v));
this._union?.query.variables.forEach(v => variables.push(v));
return variables;
}
private getVariables(where: WhereFieldValue | WhereFieldValueGroup): QueryVariable[] {
return where instanceof WhereFieldValueGroup ?
where.fields.flatMap(v => this.getVariables(v)) :
where.variables;
}
public async execute(connection?: Connection): Promise<QueryResult> {
return await query(this.build(), this.variables, connection);
}
public async get(connection?: Connection): Promise<ModelQueryResult<M>> {
const queryResult = await this.execute(connection);
const models: ModelQueryResult<M> = [];
models.originalData = [];
if (this._pivot) models.pivot = [];
// Eager loading init
const relationMap: { [p: string]: ModelRelation<Model, Model, Model | Model[] | null>[] } = {};
for (const relation of this.relations) {
relationMap[relation] = [];
}
for (const result of queryResult.results) {
const modelData: Record<string, unknown> = {};
for (const field of Object.keys(result)) {
modelData[field.split('.')[1] || field] = result[field];
}
const model = this.factory.create(modelData as Pick<M, keyof M>, false);
models.push(model);
models.originalData.push(modelData);
if (this._pivot && models.pivot) {
const pivotData: Record<string, unknown> = {};
for (const field of this._pivot) {
pivotData[field] = result[field.split('.')[1]];
}
models.pivot.push(pivotData);
}
// Eager loading init map
for (const relation of this.relations) {
if (model[relation] === undefined) throw new Error(`Relation ${relation} doesn't exist on ${model.constructor.name}.`);
if (!(model[relation] instanceof ModelRelation)) throw new Error(`Field ${relation} is not a relation on ${model.constructor.name}.`);
relationMap[relation].push(model[relation] as ModelRelation<Model, Model, Model | Model[] | null>);
}
}
// Eager loading execute
for (const relationName of this.relations) {
const relations = relationMap[relationName];
if (relations.length > 0) {
const allModels = await relations[0].eagerLoad(relations, this.subRelations[relationName]);
await Promise.all(relations.map(r => r.populate(allModels)));
}
}
return models;
}
public async paginate(page: number, perPage: number, connection?: Connection): Promise<ModelQueryResult<M>> {
this.limit(perPage, (page - 1) * perPage);
const result = await this.get(connection);
result.pagination = new Pagination<M>(result, page, perPage, await this.count(true, connection));
return result;
}
public async first(): Promise<M | null> {
const models = await this.limit(1).get();
return models.length > 0 ? models[0] : null;
}
public async count(removeLimit: boolean = false, connection?: Connection): Promise<number> {
if (removeLimit) {
this._limit = undefined;
this._offset = undefined;
}
this._sortBy = undefined;
this._sortDirection = undefined;
this.fields.splice(0, this.fields.length);
this.fields.push(new SelectFieldValue('_count', 'COUNT(*)', true));
const queryResult = await this.execute(connection);
return Number(queryResult.results[0]['_count']);
}
}
function inputToFieldOrValue(input: string, addTable?: string): string {
if (input.startsWith('`') || input.startsWith('"') || input.startsWith("'")) {
return input;
}
let parts = input.split('.');
if (addTable && parts.length === 1) parts = [addTable, input]; // Add table disambiguation
return parts.map(v => v === '*' ? v : `\`${v}\``).join('.');
}
export interface ModelQueryResult<M extends Model> extends Array<M> {
originalData?: Record<string, unknown>[];
pagination?: Pagination<M>;
pivot?: Record<string, unknown>[];
}
export enum QueryType {
SELECT,
INSERT,
UPDATE,
DELETE,
}
export enum WhereOperator {
AND = 'AND',
OR = 'OR',
}
export enum WhereTest {
EQ = '=',
NE = '!=',
GT = '>',
GE = '>=',
LT = '<',
LE = '<=',
IN = ' IN ',
}
class FieldValue {
protected readonly field: string;
protected value: ModelFieldData;
protected raw: boolean;
public constructor(field: string, value: ModelFieldData, raw: boolean) {
this.field = field;
this.value = value;
this.raw = raw;
}
public toString(first: boolean = true): string {
return `${first ? '' : ','}${this.fieldName}${this.test}${this.fieldValue}`;
}
protected get test(): string {
return '=';
}
public get variables(): QueryVariable[] {
if (this.value instanceof ModelQuery) return this.value.variables;
if (this.raw || this.value === null || this.value === undefined ||
typeof this.value === 'boolean') return [];
if (Array.isArray(this.value)) return this.value.map(value => {
if (!isQueryVariable(value)) value = value.toString();
return value;
}) as QueryVariable[];
let value = this.value;
if (!isQueryVariable(value)) value = value.toString();
return [value as QueryVariable];
}
public get fieldName(): string {
return inputToFieldOrValue(this.field);
}
public get fieldValue(): ModelFieldData {
let value: string;
if (this.value instanceof ModelQuery) {
value = this.value.toString(false);
} else if (this.value === null || this.value === undefined) {
value = 'null';
} else if (typeof this.value === 'boolean') {
value = String(this.value);
} else if (this.raw) {
value = this.value.toString();
} else {
value = Array.isArray(this.value) ?
`(${'?'.repeat(this.value.length).split('').join(',')})` :
'?';
}
return value;
}
}
export class SelectFieldValue extends FieldValue {
public toString(): string {
let value: string;
if (this.value instanceof ModelQuery) {
value = this.value.toString(true);
} else if (this.value === null || this.value === undefined) {
value = 'null';
} else if (typeof this.value === 'boolean') {
value = String(this.value);
} else {
value = this.raw ?
this.value.toString() :
'?';
}
return `(${value}) AS \`${this.field}\``;
}
}
class WhereFieldValue extends FieldValue {
private readonly _test: WhereTest;
private readonly operator: WhereOperator;
public constructor(field: string, value: ModelFieldData, raw: boolean, test: WhereTest, operator: WhereOperator) {
super(field, value, raw);
this._test = test;
this.operator = operator;
}
public toString(first: boolean = true): string {
return (!first ? ` ${this.operator} ` : '') + super.toString(true);
}
protected get test(): string {
if (this.value === null || this.value === undefined) {
if (this._test === WhereTest.EQ) {
return ' IS ';
} else if (this._test === WhereTest.NE) {
return ' IS NOT ';
}
}
return this._test;
}
}
class WhereFieldValueGroup {
public readonly fields: (WhereFieldValue | WhereFieldValueGroup)[];
public readonly operator: WhereOperator;
public constructor(fields: (WhereFieldValue | WhereFieldValueGroup)[], operator: WhereOperator) {
this.fields = fields;
this.operator = operator;
}
public toString(first: boolean = true): string {
let str = `${first ? '' : ` ${this.operator} `}(`;
let firstField = true;
for (const field of this.fields) {
str += field.toString(firstField);
firstField = false;
}
str += ')';
return str;
}
}
export interface WhereFieldConsumer<M extends Model> {
where(field: string, value: ModelFieldData, test?: WhereTest, operator?: WhereOperator): this;
groupWhere(setter: (query: WhereFieldConsumer<M>) => void, operator?: WhereOperator): this;
}
export type QueryFields = (string | SelectFieldValue | FieldValue)[];
export type SortDirection = 'ASC' | 'DESC';
type ModelQueryUnion = {
query: ModelQuery<Model>,
sortBy: string,
direction: SortDirection,
limit?: number,
offset?: number,
};
export type ModelFieldData =
| QueryVariable
| ModelQuery<Model>
| { toString(): string }
| (QueryVariable | { toString(): string })[];

268
src/db/ModelRelation.ts Normal file
View File

@ -0,0 +1,268 @@
import ModelQuery, {ModelFieldData, ModelQueryResult, WhereTest} from "./ModelQuery";
import Model, {ModelType} from "./Model";
import ModelFactory from "./ModelFactory";
export default abstract class ModelRelation<S extends Model, O extends Model, R extends O | O[] | null> {
protected readonly model: S;
protected readonly foreignModelType: ModelType<O>;
protected readonly dbProperties: RelationDatabaseProperties;
protected readonly queryModifiers: QueryModifier<O>[] = [];
protected readonly filters: ModelFilter<O>[] = [];
protected cachedModels?: O[];
protected constructor(model: S, foreignModelType: ModelType<O>, dbProperties: RelationDatabaseProperties) {
this.model = model;
this.foreignModelType = foreignModelType;
this.dbProperties = dbProperties;
}
public abstract clone(): ModelRelation<S, O, R>;
public constraint(queryModifier: QueryModifier<O>): this {
this.queryModifiers.push(queryModifier);
return this;
}
public filter(modelFilter: ModelFilter<O>): this {
this.filters.push(modelFilter);
return this;
}
protected makeQuery(): ModelQuery<O> {
const query = ModelFactory.get(this.foreignModelType).select();
for (const modifier of this.queryModifiers) modifier(query);
return query;
}
public getModelId(): ModelFieldData {
return this.model[this.dbProperties.localKey];
}
protected applyRegularConstraints(query: ModelQuery<O>): void {
query.where(this.dbProperties.foreignKey, this.getModelId());
}
public async get(): Promise<R> {
if (this.cachedModels === undefined) {
const query = this.makeQuery();
this.applyRegularConstraints(query);
this.cachedModels = await query.get();
}
let models = this.cachedModels;
for (const filter of this.filters) {
const newModels = [];
for (const model of models) {
if (await filter(model)) {
newModels.push(model);
}
}
models = newModels;
}
return this.collectionToOutput(models);
}
public getOrFail(): R {
if (this.cachedModels === undefined) throw new Error('Models were not fetched');
return this.collectionToOutput(this.cachedModels);
}
protected abstract collectionToOutput(models: O[]): R;
public async eagerLoad(
relations: ModelRelation<S, O, R>[],
subRelations: string[] = [],
): Promise<ModelQueryResult<O>> {
const ids = relations.map(r => r.getModelId())
.filter(id => id !== null && id !== undefined)
.reduce((array: ModelFieldData[], val) => array.indexOf(val) >= 0 ? array : [...array, val], []);
if (ids.length === 0) return [];
const query = this.makeQuery();
query.where(this.dbProperties.foreignKey, ids, WhereTest.IN);
query.with(...subRelations);
return await query.get();
}
public async populate(models: ModelQueryResult<O>): Promise<void> {
this.cachedModels = models.filter(m => m[this.dbProperties.foreignKey] === this.getModelId())
.reduce((array: O[], val) => array.find(v => v.equals(val)) ? array : [...array, val], []);
}
public async count(): Promise<number> {
const models = await this.get();
if (Array.isArray(models)) return models.length;
else return models !== null ? 1 : 0;
}
public async has(model: O): Promise<boolean> {
const models = await this.get();
if (models instanceof Array) {
return models.find(m => m.equals(model)) !== undefined;
} else {
return models !== null && models.equals(model);
}
}
}
export class OneModelRelation<S extends Model, O extends Model> extends ModelRelation<S, O, O | null> {
public constructor(model: S, foreignModelType: ModelType<O>, dbProperties: RelationDatabaseProperties) {
super(model, foreignModelType, dbProperties);
}
public clone(): OneModelRelation<S, O> {
return new OneModelRelation(this.model, this.foreignModelType, this.dbProperties);
}
protected collectionToOutput(models: O[]): O | null {
return models[0] || null;
}
public async set(model: O): Promise<void> {
(this.model as Model)[this.dbProperties.localKey] = model[this.dbProperties.foreignKey];
}
public async clear(): Promise<void> {
(this.model as Model)[this.dbProperties.localKey] = undefined;
}
}
export class ManyModelRelation<S extends Model, O extends Model> extends ModelRelation<S, O, O[]> {
protected readonly paginatedCache: {
[perPage: number]: {
[pageNumber: number]: ModelQueryResult<O> | undefined
} | undefined
} = {};
public constructor(model: S, foreignModelType: ModelType<O>, dbProperties: RelationDatabaseProperties) {
super(model, foreignModelType, dbProperties);
}
public clone(): ManyModelRelation<S, O> {
return new ManyModelRelation<S, O>(this.model, this.foreignModelType, this.dbProperties);
}
public cloneReduceToOne(): OneModelRelation<S, O> {
return new OneModelRelation<S, O>(this.model, this.foreignModelType, this.dbProperties);
}
protected collectionToOutput(models: O[]): O[] {
return models;
}
public async paginate(page: number, perPage: number): Promise<ModelQueryResult<O>> {
let cache = this.paginatedCache[perPage];
if (!cache) cache = this.paginatedCache[perPage] = {};
let cachePage = cache[page];
if (!cachePage) {
const query = this.makeQuery();
this.applyRegularConstraints(query);
cachePage = cache[page] = await query.paginate(page, perPage);
}
return cachePage;
}
}
export class ManyThroughModelRelation<S extends Model, O extends Model> extends ManyModelRelation<S, O> {
protected readonly dbProperties: PivotRelationDatabaseProperties;
public constructor(model: S, foreignModelType: ModelType<O>, dbProperties: PivotRelationDatabaseProperties) {
super(model, foreignModelType, dbProperties);
this.dbProperties = dbProperties;
this.constraint(query => query
.leftJoin(this.dbProperties.pivotTable, 'pivot')
.on(`pivot.${this.dbProperties.foreignPivotKey}`, `${this.foreignModelType.table}.${this.dbProperties.foreignKey}`),
);
}
public clone(): ManyThroughModelRelation<S, O> {
return new ManyThroughModelRelation<S, O>(this.model, this.foreignModelType, this.dbProperties);
}
public cloneReduceToOne(): OneModelRelation<S, O> {
throw new Error('Cannot reduce many through relation to one model.');
}
protected applyRegularConstraints(query: ModelQuery<O>): void {
query.where(`pivot.${this.dbProperties.localPivotKey}`, this.getModelId());
}
public async eagerLoad(
relations: ModelRelation<S, O, O[]>[],
subRelations: string[] = [],
): Promise<ModelQueryResult<O>> {
const ids = relations.map(r => r.getModelId())
.reduce((array: ModelFieldData[], val) => array.indexOf(val) >= 0 ? array : [...array, val], []);
if (ids.length === 0) return [];
const query = this.makeQuery();
query.where(`pivot.${this.dbProperties.localPivotKey}`, ids, WhereTest.IN);
query.pivot(`pivot.${this.dbProperties.localPivotKey}`, `pivot.${this.dbProperties.foreignPivotKey}`);
query.with(...subRelations);
return await query.get();
}
public async populate(models: ModelQueryResult<O>): Promise<void> {
if (!models.pivot) throw new Error('ModelQueryResult.pivot not loaded.');
const ids = models.pivot
.filter(p => p[`pivot.${this.dbProperties.localPivotKey}`] === this.getModelId())
.map(p => p[`pivot.${this.dbProperties.foreignPivotKey}`]);
this.cachedModels = models.filter(m => ids.indexOf(m[this.dbProperties.foreignKey]) >= 0)
.reduce((array: O[], val) => array.find(v => v.equals(val)) ? array : [...array, val], []);
}
}
export class RecursiveModelRelation<M extends Model> extends ManyModelRelation<M, M> {
private readonly reverse: boolean;
public constructor(
model: M,
foreignModelType: ModelType<M>,
dbProperties: RelationDatabaseProperties,
reverse: boolean = false,
) {
super(model, foreignModelType, dbProperties);
this.constraint(query => query.recursive(this.dbProperties, reverse));
this.reverse = reverse;
}
public clone(): RecursiveModelRelation<M> {
return new RecursiveModelRelation(this.model, this.foreignModelType, this.dbProperties);
}
public async populate(models: ModelQueryResult<M>): Promise<void> {
await super.populate(models);
const cachedModels = this.cachedModels;
if (cachedModels) {
let count;
do {
count = cachedModels.length;
cachedModels.push(...models.filter(model =>
!cachedModels.find(cached => cached.equals(model)) && cachedModels.find(cached => {
return cached[this.dbProperties.localKey] === model[this.dbProperties.foreignKey];
}),
).reduce((array: M[], val) => array.find(v => v.equals(val)) ? array : [...array, val], []));
} while (count !== cachedModels.length);
if (this.reverse) cachedModels.reverse();
}
}
}
export type QueryModifier<M extends Model> = (query: ModelQuery<M>) => ModelQuery<M>;
export type ModelFilter<O extends Model> = (model: O) => boolean | Promise<boolean>;
export type RelationDatabaseProperties = {
localKey: string;
foreignKey: string;
};
export type PivotRelationDatabaseProperties = RelationDatabaseProperties & {
pivotTable: string;
localPivotKey: string;
foreignPivotKey: string;
};

View File

@ -1,17 +1,21 @@
import mysql, {Connection, FieldInfo, Pool} from 'mysql'; import mysql, {Connection, FieldInfo, MysqlError, Pool, PoolConnection} from 'mysql';
import config from 'config'; import config from 'config';
import Migration from "./Migration"; import Migration, {MigrationType} from "./Migration";
import Logger from "../Logger"; import {logger} from "../Logger";
import {Type} from "../Utils"; import {Type} from "../Utils";
export interface QueryResult { export interface QueryResult {
readonly results: any[]; readonly results: Record<string, unknown>[];
readonly fields: FieldInfo[]; readonly fields: FieldInfo[];
readonly other?: any; readonly other?: Record<string, unknown>;
foundRows?: number; foundRows?: number;
} }
export async function query(queryString: string, values?: any, connection?: Connection): Promise<QueryResult> { export async function query(
queryString: string,
values?: QueryVariable[],
connection?: Connection,
): Promise<QueryResult> {
return await MysqlConnectionManager.query(queryString, values, connection); return await MysqlConnectionManager.query(queryString, values, connection);
} }
@ -21,7 +25,11 @@ export default class MysqlConnectionManager {
private static migrationsRegistered: boolean = false; private static migrationsRegistered: boolean = false;
private static readonly migrations: Migration[] = []; private static readonly migrations: Migration[] = [];
public static registerMigrations(migrations: Type<Migration>[]) { public static isReady(): boolean {
return this.databaseReady && this.currentPool !== undefined;
}
public static registerMigrations(migrations: MigrationType<Migration>[]): void {
if (!this.migrationsRegistered) { if (!this.migrationsRegistered) {
this.migrationsRegistered = true; this.migrationsRegistered = true;
migrations.forEach(m => this.registerMigration(v => new m(v))); migrations.forEach(m => this.registerMigration(v => new m(v)));
@ -32,30 +40,36 @@ export default class MysqlConnectionManager {
this.migrations.push(migration(this.migrations.length + 1)); this.migrations.push(migration(this.migrations.length + 1));
} }
public static async prepare() { public static hasMigration(migration: Type<Migration>): boolean {
for (const m of this.migrations) {
if (m.constructor === migration) return true;
}
return false;
}
public static async prepare(runMigrations: boolean = true): Promise<void> {
if (config.get('mysql.create_database_automatically') === true) { if (config.get('mysql.create_database_automatically') === true) {
const dbName = config.get('mysql.database'); const dbName = config.get('mysql.database');
Logger.info(`Creating database ${dbName}...`); logger.info(`Creating database ${dbName}...`);
const connection = mysql.createConnection({ const connection = mysql.createConnection({
host: config.get('mysql.host'), host: config.get('mysql.host'),
user: config.get('mysql.user'), user: config.get('mysql.user'),
password: config.get('mysql.password'), password: config.get('mysql.password'),
charset: 'utf8mb4',
}); });
await new Promise((resolve, reject) => { await new Promise<void>((resolve, reject) => {
connection.query(`CREATE DATABASE IF NOT EXISTS ${dbName}`, (error) => { connection.query(`CREATE DATABASE IF NOT EXISTS ${dbName}`, (error) => {
if (error !== null) { return error !== null ?
reject(error); reject(error) :
} else {
resolve(); resolve();
}
}); });
}); });
connection.end(); connection.end();
Logger.info(`Database ${dbName} created!`); logger.info(`Database ${dbName} created!`);
} }
this.databaseReady = true; this.databaseReady = true;
await this.handleMigrations(); if (runMigrations) await this.handleMigrations();
} }
public static get pool(): Pool { public static get pool(): Pool {
@ -72,26 +86,34 @@ export default class MysqlConnectionManager {
user: config.get('mysql.user'), user: config.get('mysql.user'),
password: config.get('mysql.password'), password: config.get('mysql.password'),
database: config.get('mysql.database'), database: config.get('mysql.database'),
charset: 'utf8mb4',
}); });
} }
public static async endPool(): Promise<void> { public static async endPool(): Promise<void> {
return new Promise(resolve => { return await new Promise(resolve => {
if (this.currentPool !== undefined) { if (this.currentPool === undefined) {
this.currentPool.end(() => { return resolve();
Logger.info('Mysql connection pool ended.');
resolve();
});
this.currentPool = undefined;
} else {
resolve();
} }
this.currentPool.end(() => {
logger.info('Mysql connection pool ended.');
resolve();
});
this.currentPool = undefined;
}); });
} }
public static async query(queryString: string, values?: any, connection?: Connection): Promise<QueryResult> { public static async query(
queryString: string,
values: QueryVariable[] = [],
connection?: Connection,
): Promise<QueryResult> {
return await new Promise<QueryResult>((resolve, reject) => { return await new Promise<QueryResult>((resolve, reject) => {
Logger.dev('Mysql query:', queryString, '; values:', values); logger.debug('SQL:', logger.settings.minLevel === 'trace' || logger.settings.minLevel === 'silly' ?
mysql.format(queryString, values) :
queryString);
(connection ? connection : this.pool).query(queryString, values, (error, results, fields) => { (connection ? connection : this.pool).query(queryString, values, (error, results, fields) => {
if (error !== null) { if (error !== null) {
reject(error); reject(error);
@ -101,7 +123,7 @@ export default class MysqlConnectionManager {
resolve({ resolve({
results: Array.isArray(results) ? results : [], results: Array.isArray(results) ? results : [],
fields: fields !== undefined ? fields : [], fields: fields !== undefined ? fields : [],
other: Array.isArray(results) ? null : results other: Array.isArray(results) ? null : results,
}); });
}); });
}); });
@ -109,70 +131,139 @@ export default class MysqlConnectionManager {
public static async wrapTransaction<T>(transaction: (connection: Connection) => Promise<T>): Promise<T> { public static async wrapTransaction<T>(transaction: (connection: Connection) => Promise<T>): Promise<T> {
return await new Promise<T>((resolve, reject) => { return await new Promise<T>((resolve, reject) => {
this.pool.getConnection((err, connection) => { this.pool.getConnection((err: MysqlError | undefined, connection: PoolConnection) => {
if (err) { if (err) {
reject(err); reject(err);
return; return;
} }
connection.beginTransaction((err) => { connection.beginTransaction((err?: MysqlError) => {
if (err) { if (err) {
reject(err); reject(err);
this.pool.releaseConnection(connection); connection.release();
return; return;
} }
transaction(connection).then(val => { transaction(connection).then(val => {
connection.commit((err) => { connection.commit((err?: MysqlError) => {
if (err) { if (err) {
this.rejectAndRollback(connection, err, reject); this.rejectAndRollback(connection, err, reject);
this.pool.releaseConnection(connection); connection.release();
return; return;
} }
this.pool.releaseConnection(connection); connection.release();
resolve(val); resolve(val);
}); });
}).catch(err => { }).catch(err => {
this.rejectAndRollback(connection, err, reject); this.rejectAndRollback(connection, err, reject);
this.pool.releaseConnection(connection); connection.release();
}); });
}); });
}); });
}); });
} }
private static rejectAndRollback(connection: Connection, err: any, reject: (err: any) => void) { private static rejectAndRollback(
connection.rollback((rollbackErr) => { connection: Connection,
if (rollbackErr) { err: MysqlError | undefined,
reject(err + '\n' + rollbackErr); reject: (err: unknown) => void,
} else { ) {
connection.rollback((rollbackErr?: MysqlError) => {
return rollbackErr ?
reject(err + '\n' + rollbackErr) :
reject(err); reject(err);
}
}); });
} }
private static async handleMigrations() { public static async getCurrentMigrationVersion(): Promise<number> {
let currentVersion = 0; let currentVersion = 0;
try { try {
const result = await query('SELECT id FROM migrations ORDER BY id DESC LIMIT 1'); const result = await query('SELECT id FROM migrations ORDER BY id DESC LIMIT 1');
currentVersion = result.results[0].id; currentVersion = Number(result.results[0]?.id);
} catch (e) { } catch (e) {
if (e.code === 'ECONNREFUSED' || e.code !== 'ER_NO_SUCH_TABLE') { if (e.code === 'ECONNREFUSED' || e.code !== 'ER_NO_SUCH_TABLE') {
throw new Error('Cannot run migrations: ' + e.code); throw new Error('Cannot run migrations: ' + e.code);
} }
} }
return currentVersion;
}
private static async handleMigrations() {
const currentVersion = await this.getCurrentMigrationVersion();
for (const migration of this.migrations) { for (const migration of this.migrations) {
if (await migration.shouldRun(currentVersion)) { if (await migration.shouldRun(currentVersion)) {
Logger.info('Running migration ', migration.version, migration.constructor.name); logger.info('Running migration ', migration.version, migration.constructor.name);
await migration.install(); await MysqlConnectionManager.wrapTransaction<void>(async c => {
await query('INSERT INTO migrations VALUES(?, ?, NOW())', [ migration.setCurrentConnection(c);
migration.version, await migration.install();
migration.constructor.name, migration.setCurrentConnection(null);
]); await query('INSERT INTO migrations VALUES(?, ?, NOW())', [
migration.version,
migration.constructor.name,
]);
});
} }
} }
for (const migration of this.migrations) {
migration.registerModels?.();
}
}
/**
* @param migrationId what migration to rollback. Use with caution. default=0 is for last registered migration.
*/
public static async rollbackMigration(migrationId: number = 0): Promise<void> {
migrationId--;
const migration = this.migrations[migrationId];
logger.info('Rolling back migration ', migration.version, migration.constructor.name);
await MysqlConnectionManager.wrapTransaction<void>(async c => {
migration.setCurrentConnection(c);
await migration.rollback();
migration.setCurrentConnection(null);
await query('DELETE FROM migrations WHERE id=?', [migration.version]);
});
}
public static async migrationCommand(args: string[]): Promise<void> {
try {
logger.info('Current migration:', await this.getCurrentMigrationVersion());
for (let i = 0; i < args.length; i++) {
if (args[i] === 'rollback') {
let migrationId = 0;
if (args.length > i + 1) {
migrationId = parseInt(args[i + 1]);
}
await this.prepare(false);
await this.rollbackMigration(migrationId);
return;
}
}
} finally {
await MysqlConnectionManager.endPool();
}
} }
} }
export type QueryVariable =
| boolean
| string
| number
| Date
| Buffer
| null
| undefined;
export function isQueryVariable(value: unknown): value is QueryVariable {
return typeof value === 'boolean' ||
typeof value === "string" ||
typeof value === 'number' ||
value instanceof Date ||
value instanceof Buffer ||
value === null ||
value === undefined;
}

View File

@ -1,214 +0,0 @@
import {query, QueryResult} from "./MysqlConnectionManager";
import {Connection} from "mysql";
export default class Query {
public static select(table: string, ...fields: string[]): Query {
return new Query(QueryType.SELECT, table, fields.length > 0 ? fields : ['*']);
}
public static update(table: string, data: {
[key: string]: any
}) {
const fields = [];
for (let key in data) {
if (data.hasOwnProperty(key)) {
fields.push(new UpdateFieldValue(key, data[key]));
}
}
return new Query(QueryType.UPDATE, table, fields);
}
public static delete(table: string) {
return new Query(QueryType.DELETE, table);
}
private readonly type: QueryType;
private readonly table: string;
private readonly fields: (string | SelectFieldValue | UpdateFieldValue)[];
private _where: WhereFieldValue[] = [];
private _limit?: number;
private _offset?: number;
private _sortBy?: string;
private _sortDirection?: 'ASC' | 'DESC';
private _foundRows: boolean = false;
private constructor(type: QueryType, table: string, fields?: (string | SelectFieldValue | UpdateFieldValue)[]) {
this.type = type;
this.table = table;
this.fields = fields || [];
}
public where(field: string, value: string | Date | Query | any, operator: WhereOperator = WhereOperator.AND, test: WhereTest = WhereTest.EQUALS): Query {
this._where.push(new WhereFieldValue(field, value, operator, test));
return this;
}
public whereNot(field: string, value: string | Date | Query | any, operator: WhereOperator = WhereOperator.AND): Query {
return this.where(field, value, operator, WhereTest.DIFFERENT);
}
public orWhere(field: string, value: string | Date | Query | any): Query {
return this.where(field, value, WhereOperator.OR);
}
public whereIn(field: string, value: any[]): Query {
return this.where(field, value, WhereOperator.AND, WhereTest.IN);
}
public limit(limit: number, offset: number = 0): Query {
this._limit = limit;
this._offset = offset;
return this;
}
public first(): Query {
return this.limit(1);
}
public sortBy(field: string, direction: 'ASC' | 'DESC' = 'ASC'): Query {
this._sortBy = field;
this._sortDirection = direction;
return this;
}
public withTotalRowCount(): Query {
this._foundRows = true;
return this;
}
public toString(final: boolean = false): string {
let query = '';
let fields = this.fields?.join(',');
let where = '';
if (this._where.length > 0) {
where = `WHERE ${this._where[0]}`;
for (let i = 1; i < this._where.length; i++) {
where += this._where[i].toString(false);
}
}
let limit = '';
if (typeof this._limit === 'number') {
limit = `LIMIT ${this._limit}`;
if (typeof this._offset === 'number' && this._offset !== 0) {
limit += ` OFFSET ${this._offset}`;
}
}
let orderBy = '';
if (typeof this._sortBy === 'string') {
orderBy = `ORDER BY ${this._sortBy} ${this._sortDirection}`;
}
switch (this.type) {
case QueryType.SELECT:
query = `SELECT ${this._foundRows ? 'SQL_CALC_FOUND_ROWS' : ''} ${fields} FROM ${this.table} ${where} ${orderBy} ${limit}`;
break;
case QueryType.UPDATE:
query = `UPDATE ${this.table} SET ${fields} ${where} ${orderBy} ${limit}`;
break;
case QueryType.DELETE:
query = `DELETE FROM ${this.table} ${where} ${orderBy} ${limit}`;
break;
}
return final ? query : `(${query})`;
}
public build(): string {
return this.toString(true);
}
public get variables(): any[] {
const variables: any[] = [];
this.fields?.filter(v => v instanceof FieldValue)
.flatMap(v => (<FieldValue>v).variables)
.forEach(v => variables.push(v));
this._where.flatMap(v => v.variables)
.forEach(v => variables.push(v));
return variables;
}
public isCacheable(): boolean {
return this.type === QueryType.SELECT && this.fields.length === 1 && this.fields[0] === '*';
}
public async execute(connection?: Connection): Promise<QueryResult> {
const queryResult = await query(this.build(), this.variables, connection);
if (this._foundRows) {
const foundRows = await query('SELECT FOUND_ROWS() as r', undefined, connection);
queryResult.foundRows = foundRows.results[0].r;
}
return queryResult;
}
}
export enum QueryType {
SELECT,
UPDATE,
DELETE,
}
enum WhereOperator {
AND = 'AND',
OR = 'OR',
}
enum WhereTest {
EQUALS = '=',
DIFFERENT = '!=',
IN = ' IN ',
}
class FieldValue {
protected readonly field: string;
protected value: any;
constructor(field: string, value: any) {
this.field = field;
this.value = value;
}
public toString(first: boolean = true): string {
return `${!first ? ',' : ''}${this.field}${this.test}${this.value instanceof Query ? this.value : (Array.isArray(this.value) ? '(?)' : '?')}`;
}
protected get test(): string {
return '=';
}
public get variables(): any[] {
return this.value instanceof Query ? this.value.variables : [this.value];
}
}
class SelectFieldValue extends FieldValue {
public toString(first: boolean = true): string {
return `(${this.value instanceof Query ? this.value : '?'}) AS ${this.field}`;
}
}
class UpdateFieldValue extends FieldValue {
}
class WhereFieldValue extends FieldValue {
private readonly operator: WhereOperator;
private readonly _test: WhereTest;
constructor(field: string, value: any, operator: WhereOperator, test: WhereTest) {
super(field, value);
this.operator = operator;
this._test = test;
}
public toString(first: boolean = true): string {
return (!first ? ` ${this.operator} ` : '') + super.toString(true);
}
protected get test(): string {
return this._test;
}
}

View File

@ -1,32 +1,76 @@
import Model from "./Model"; import Model, {ModelType} from "./Model";
import Query from "./Query"; import ModelQuery, {WhereTest} from "./ModelQuery";
import {Connection} from "mysql"; import {Connection} from "mysql";
import {ServerError} from "../HttpError";
export default class Validator<T> { export const EMAIL_REGEX = /^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+$/;
private readonly steps: ValidationStep<T>[] = [];
export default class Validator<V> {
public static async validate(
validationMap: { [p: string]: Validator<unknown> },
body: { [p: string]: unknown },
): Promise<void> {
const bag = new ValidationBag();
for (const p of Object.keys(validationMap)) {
try {
await validationMap[p].execute(p, body[p], false);
} catch (e) {
if (e instanceof ValidationBag) {
bag.addBag(e);
} else throw e;
}
}
if (bag.hasMessages()) throw bag;
}
private readonly steps: ValidationStep<V>[] = [];
private readonly validationAttributes: string[] = []; private readonly validationAttributes: string[] = [];
private readonly rawValueToHuman?: (val: V) => string;
private _min?: number; private _min?: number;
private _max?: number; private _max?: number;
public constructor(rawValueToHuman?: (val: V) => string) {
this.rawValueToHuman = rawValueToHuman;
}
/** /**
* @param thingName The name of the thing to validate. * @param thingName The name of the thing to validate.
* @param value The value to verify. * @param value The value to verify.
* @param onlyFormat {@code true} to only validate format properties, {@code false} otherwise. * @param onlyFormat {@code true} to only validate format properties, {@code false} otherwise.
* @param connection A connection to use in case of wrapped transactions. * @param connection A connection to use in case of wrapped transactions.
*/ */
async execute(thingName: string, value: T | undefined, onlyFormat: boolean, connection?: Connection): Promise<void> { public async execute(
const bag = new ValidationBag(); thingName: string,
value: V | undefined,
onlyFormat: boolean,
connection?: Connection,
): Promise<void> {
const bag = new ValidationBag<V>();
for (const step of this.steps) { for (const step of this.steps) {
if (onlyFormat && !step.isFormat) continue; if (onlyFormat && !step.isFormat) continue;
const result = step.verifyStep(value, thingName, connection); let result;
if ((result === false || result instanceof Promise && (await result) === false) && step.throw) { try {
const error: ValidationError = step.throw(); result = step.verifyStep(value, thingName, connection);
if (result instanceof Promise) {
result = await result;
}
} catch (e) {
throw new ServerError(`An error occurred while validating ${thingName} with value "${value}".`, e);
}
if (result === false && step.throw) {
const error: ValidationError<V> = step.throw();
error.rawValueToHuman = this.rawValueToHuman;
error.thingName = thingName; error.thingName = thingName;
error.value = value; error.value = value;
bag.addMessage(error); bag.addMessage(error);
break;
} else if (step.interrupt !== undefined && step.interrupt(value)) { } else if (step.interrupt !== undefined && step.interrupt(value)) {
break; break;
} }
@ -37,7 +81,7 @@ export default class Validator<T> {
} }
} }
public defined(): Validator<T> { public defined(): Validator<V> {
this.validationAttributes.push('required'); this.validationAttributes.push('required');
this.addStep({ this.addStep({
@ -48,17 +92,21 @@ export default class Validator<T> {
return this; return this;
} }
public acceptUndefined(): Validator<T> { public acceptUndefined(alsoAcceptEmptyString: boolean = false): Validator<V> {
this.addStep({ this.addStep({
verifyStep: () => true, verifyStep: () => true,
throw: null, throw: null,
interrupt: val => val === undefined || val === null, interrupt: val => {
return val === undefined ||
val === null ||
alsoAcceptEmptyString && typeof val === 'string' && val.length === 0;
},
isFormat: true, isFormat: true,
}); });
return this; return this;
} }
public equals(other?: T): Validator<T> { public equals(other?: V): Validator<V> {
this.addStep({ this.addStep({
verifyStep: val => val === other, verifyStep: val => val === other,
throw: () => new BadValueValidationError(other), throw: () => new BadValueValidationError(other),
@ -67,7 +115,16 @@ export default class Validator<T> {
return this; return this;
} }
public regexp(regexp: RegExp): Validator<T> { public sameAs(otherName?: string, other?: V): Validator<V> {
this.addStep({
verifyStep: val => val === other,
throw: () => new DifferentThanError(otherName),
isFormat: true,
});
return this;
}
public regexp(regexp: RegExp): Validator<V> {
this.validationAttributes.push(`pattern="${regexp}"`); this.validationAttributes.push(`pattern="${regexp}"`);
this.addStep({ this.addStep({
verifyStep: val => regexp.test(<string><unknown>val), verifyStep: val => regexp.test(<string><unknown>val),
@ -77,24 +134,48 @@ export default class Validator<T> {
return this; return this;
} }
public length(length: number): Validator<T> { public length(length: number): Validator<V> {
this.addStep({ this.addStep({
verifyStep: val => (<any>val).length === length, verifyStep: val => isLenghtable(val) && val.length === length,
throw: () => new BadLengthValidationError(length), throw: () => new BadLengthValidationError(length),
isFormat: true, isFormat: true,
}); });
return this; return this;
} }
/**
* @param minLength included
*/
public minLength(minLength: number): Validator<V> {
this.addStep({
verifyStep: val => isLenghtable(val) && val.length >= minLength,
throw: () => new TooShortError(minLength),
isFormat: true,
});
return this;
}
/**
* @param maxLength included
*/
public maxLength(maxLength: number): Validator<V> {
this.addStep({
verifyStep: val => isLenghtable(val) && val.length <= maxLength,
throw: () => new TooLongError(maxLength),
isFormat: true,
});
return this;
}
/** /**
* @param minLength included * @param minLength included
* @param maxLength included * @param maxLength included
*/ */
public between(minLength: number, maxLength: number): Validator<T> { public between(minLength: number, maxLength: number): Validator<V> {
this.addStep({ this.addStep({
verifyStep: val => { verifyStep: val => {
const length = (<any>val).length; return isLenghtable(val) &&
return length >= minLength && length <= maxLength; val.length >= minLength && val.length <= maxLength;
}, },
throw: () => new BadLengthValidationError(minLength, maxLength), throw: () => new BadLengthValidationError(minLength, maxLength),
isFormat: true, isFormat: true,
@ -105,12 +186,12 @@ export default class Validator<T> {
/** /**
* @param min included * @param min included
*/ */
public min(min: number): Validator<T> { public min(min: number): Validator<V> {
this.validationAttributes.push(`min="${min}"`); this.validationAttributes.push(`min="${min}"`);
this._min = min; this._min = min;
this.addStep({ this.addStep({
verifyStep: val => { verifyStep: val => {
return (<any>val) >= min; return typeof val === 'number' && val >= min;
}, },
throw: () => new OutOfRangeValidationError(this._min, this._max), throw: () => new OutOfRangeValidationError(this._min, this._max),
isFormat: true, isFormat: true,
@ -121,12 +202,12 @@ export default class Validator<T> {
/** /**
* @param max included * @param max included
*/ */
public max(max: number): Validator<T> { public max(max: number): Validator<V> {
this.validationAttributes.push(`max="${max}"`); this.validationAttributes.push(`max="${max}"`);
this._max = max; this._max = max;
this.addStep({ this.addStep({
verifyStep: val => { verifyStep: val => {
return (<any>val) <= max; return typeof val === 'number' && val <= max;
}, },
throw: () => new OutOfRangeValidationError(this._min, this._max), throw: () => new OutOfRangeValidationError(this._min, this._max),
isFormat: true, isFormat: true,
@ -134,16 +215,23 @@ export default class Validator<T> {
return this; return this;
} }
public unique(model: Model, querySupplier?: () => Query): Validator<T> { public unique<M extends Model>(
model: M | ModelType<M>,
foreignKey?: string,
querySupplier?: () => ModelQuery<M>,
): Validator<V> {
this.addStep({ this.addStep({
verifyStep: async (val, thingName, c) => { verifyStep: async (val, thingName, c) => {
let query: Query; if (!foreignKey) foreignKey = thingName;
let query: ModelQuery<M>;
if (querySupplier) { if (querySupplier) {
query = querySupplier().where(thingName, val); query = querySupplier();
} else { } else {
query = (<any>model.constructor).select('1').where(thingName, val); query = (model instanceof Model ? <ModelType<M>>model.constructor : model).select('');
} }
if (typeof model.id === 'number') query = query.whereNot('id', model.id); query.where(foreignKey, val);
if (model instanceof Model && typeof model.id === 'number')
query = query.where('id', model.id, WhereTest.NE);
return (await query.execute(c)).results.length === 0; return (await query.execute(c)).results.length === 0;
}, },
throw: () => new AlreadyExistsValidationError(model.table), throw: () => new AlreadyExistsValidationError(model.table),
@ -152,16 +240,21 @@ export default class Validator<T> {
return this; return this;
} }
public exists(modelClass: Function, foreignKey?: string): Validator<T> { public exists(modelType: ModelType<Model>, foreignKey?: string): Validator<V> {
this.addStep({ this.addStep({
verifyStep: async (val, thingName, c) => (await (<any>modelClass).select('1').where(foreignKey !== undefined ? foreignKey : thingName, val).execute(c)).results.length >= 1, verifyStep: async (val, thingName, c) => {
throw: () => new UnknownRelationValidationError((<any>modelClass).table, foreignKey), const results = await modelType.select('')
.where(foreignKey !== undefined ? foreignKey : thingName, val)
.execute(c);
return results.results.length >= 1;
},
throw: () => new UnknownRelationValidationError(modelType.table, foreignKey),
isFormat: false, isFormat: false,
}); });
return this; return this;
} }
private addStep(step: ValidationStep<T>) { private addStep(step: ValidationStep<V>) {
this.steps.push(step); this.steps.push(step);
} }
@ -169,60 +262,72 @@ export default class Validator<T> {
return this.validationAttributes; return this.validationAttributes;
} }
public step(step: number): Validator<T> { public step(step: number): Validator<V> {
this.validationAttributes.push(`step="${step}"`); this.validationAttributes.push(`step="${step}"`);
return this; return this;
} }
} }
interface ValidationStep<T> { interface ValidationStep<V> {
interrupt?: (val?: T) => boolean; interrupt?: (val?: V) => boolean;
verifyStep(val: T | undefined, thingName: string, connection?: Connection): boolean | Promise<boolean>; verifyStep(val: V | undefined, thingName: string, connection?: Connection): boolean | Promise<boolean>;
throw: ((val?: T) => ValidationError) | null; throw: ((val?: V) => ValidationError<V>) | null;
readonly isFormat: boolean; readonly isFormat: boolean;
} }
export class ValidationBag extends Error { export class ValidationBag<V> extends Error {
private readonly messages: { [p: string]: any } = {}; private readonly errors: ValidationError<V>[] = [];
public addMessage(err: ValidationError) { public addMessage(err: ValidationError<V>): void {
if (!err.thingName) { if (!err.thingName) throw new Error('Null thing name');
throw new Error('Null thing name'); this.errors.push(err);
}
public addBag(otherBag: ValidationBag<V>): void {
for (const error of otherBag.errors) {
this.errors.push(error);
} }
this.messages[err.thingName] = {
name: err.name,
message: err.message,
value: err.value,
};
} }
public hasMessages(): boolean { public hasMessages(): boolean {
return Object.keys(this.messages).length > 0; return this.errors.length > 0;
} }
public getMessages(): { [p: string]: ValidationError } { public getMessages(): { [p: string]: ValidationError<V> } {
return this.messages; const messages: { [p: string]: ValidationError<V> } = {};
for (const err of this.errors) {
messages[err.thingName || 'unknown'] = {
name: err.name,
message: err.message,
value: err.value,
};
}
return messages;
}
public getErrors(): ValidationError<V>[] {
return this.errors;
} }
} }
export abstract class ValidationError extends Error { export class ValidationError<V> extends Error {
public rawValueToHuman?: (val: V) => string;
public thingName?: string; public thingName?: string;
public value?: any; public value?: V;
public get name(): string { public get name(): string {
return this.constructor.name; return this.constructor.name;
} }
} }
export class BadLengthValidationError extends ValidationError { export class BadLengthValidationError<V> extends ValidationError<V> {
private readonly expectedLength: number; private readonly expectedLength: number;
private readonly maxLength?: number; private readonly maxLength?: number;
constructor(expectedLength: number, maxLength?: number) { public constructor(expectedLength: number, maxLength?: number) {
super(); super();
this.expectedLength = expectedLength; this.expectedLength = expectedLength;
this.maxLength = maxLength; this.maxLength = maxLength;
@ -230,28 +335,73 @@ export class BadLengthValidationError extends ValidationError {
public get message(): string { public get message(): string {
return `${this.thingName} expected length: ${this.expectedLength}${this.maxLength !== undefined ? ` to ${this.maxLength}` : ''}; ` + return `${this.thingName} expected length: ${this.expectedLength}${this.maxLength !== undefined ? ` to ${this.maxLength}` : ''}; ` +
`actual length: ${this.value.length}.`; `actual length: ${isLenghtable(this.value) && this.value.length}.`;
} }
} }
export class BadValueValidationError extends ValidationError { export class TooShortError<V> extends ValidationError<V> {
private readonly expectedValue: any; private readonly minLength: number;
constructor(expectedValue: any) { public constructor(minLength: number) {
super();
this.minLength = minLength;
}
public get message(): string {
return `${this.thingName} must be at least ${this.minLength} characters.`;
}
}
export class TooLongError<V> extends ValidationError<V> {
private readonly maxLength: number;
public constructor(maxLength: number) {
super();
this.maxLength = maxLength;
}
public get message(): string {
return `${this.thingName} must be at most ${this.maxLength} characters.`;
}
}
export class BadValueValidationError<V> extends ValidationError<V> {
private readonly expectedValue: V;
public constructor(expectedValue: V) {
super(); super();
this.expectedValue = expectedValue; this.expectedValue = expectedValue;
} }
public get message(): string { public get message(): string {
return `Expected: ${this.expectedValue}; got: ${this.value}.` let expectedValue: string = String(this.expectedValue);
let actualValue: string = String(this.value);
if (this.rawValueToHuman && this.value) {
expectedValue = this.rawValueToHuman(this.expectedValue);
actualValue = this.rawValueToHuman(this.value);
}
return `Expected: ${expectedValue}; got: ${actualValue}.`;
} }
} }
export class OutOfRangeValidationError extends ValidationError { export class DifferentThanError<V> extends ValidationError<V> {
private readonly otherName?: string;
public constructor(otherName?: string) {
super();
this.otherName = otherName;
}
public get message(): string {
return `This should be the same as ${this.otherName}.`;
}
}
export class OutOfRangeValidationError<V> extends ValidationError<V> {
private readonly min?: number; private readonly min?: number;
private readonly max?: number; private readonly max?: number;
constructor(min?: number, max?: number) { public constructor(min?: number, max?: number) {
super(); super();
this.min = min; this.min = min;
this.max = max; this.max = max;
@ -263,40 +413,46 @@ export class OutOfRangeValidationError extends ValidationError {
} else if (this.max === undefined) { } else if (this.max === undefined) {
return `${this.thingName} must be at least ${this.min}`; return `${this.thingName} must be at least ${this.min}`;
} }
return `${this.thingName} must be between ${this.min} and ${this.max}.`; let min: string = String(this.min);
let max: string = String(this.max);
if (this.rawValueToHuman) {
min = this.rawValueToHuman(this.min as unknown as V);
max = this.rawValueToHuman(this.max as unknown as V);
}
return `${this.thingName} must be between ${min} and ${max}.`;
} }
} }
export class InvalidFormatValidationError extends ValidationError { export class InvalidFormatValidationError<V> extends ValidationError<V> {
public get message(): string { public get message(): string {
return `"${this.value}" is not a valid ${this.thingName}.`; return `"${this.value}" is not a valid ${this.thingName}.`;
} }
} }
export class UndefinedValueValidationError extends ValidationError { export class UndefinedValueValidationError<V> extends ValidationError<V> {
public get message(): string { public get message(): string {
return `${this.thingName} is required.`; return `${this.thingName} is required.`;
} }
} }
export class AlreadyExistsValidationError extends ValidationError { export class AlreadyExistsValidationError<V> extends ValidationError<V> {
private readonly table: string; private readonly table: string;
constructor(table: string) { public constructor(table: string) {
super(); super();
this.table = table; this.table = table;
} }
public get message(): string { public get message(): string {
return `${this.value} already exists in ${this.table}.${this.thingName}.`; return `${this.thingName} already exists in ${this.table}.`;
} }
} }
export class UnknownRelationValidationError extends ValidationError { export class UnknownRelationValidationError<V> extends ValidationError<V> {
private readonly table: string; private readonly table: string;
private readonly foreignKey?: string; private readonly foreignKey?: string;
constructor(table: string, foreignKey?: string) { public constructor(table: string, foreignKey?: string) {
super(); super();
this.table = table; this.table = table;
this.foreignKey = foreignKey; this.foreignKey = foreignKey;
@ -306,3 +462,25 @@ export class UnknownRelationValidationError extends ValidationError {
return `${this.thingName}=${this.value} relation was not found in ${this.table}${this.foreignKey !== undefined ? `.${this.foreignKey}` : ''}.`; return `${this.thingName}=${this.value} relation was not found in ${this.table}${this.foreignKey !== undefined ? `.${this.foreignKey}` : ''}.`;
} }
} }
export class FileError<V> extends ValidationError<V> {
private readonly _message: string;
public constructor(message: string) {
super();
this._message = message;
}
public get message(): string {
return `${this._message}`;
}
}
export type Lengthable = {
length: number,
};
export function isLenghtable(value: unknown): value is Lengthable {
return value !== undefined && value !== null &&
typeof (value as Lengthable).length === 'number';
}

View File

@ -0,0 +1,134 @@
import config from "config";
import Controller from "../Controller";
import User from "../auth/models/User";
import {Request, Response} from "express";
import {BadRequestError, NotFoundHttpError} from "../HttpError";
import Mail from "../mail/Mail";
import {ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE} from "../Mails";
import UserEmail from "../auth/models/UserEmail";
import UserApprovedComponent from "../auth/models/UserApprovedComponent";
import {RequireAdminMiddleware, RequireAuthMiddleware} from "../auth/AuthComponent";
import NunjucksComponent from "../components/NunjucksComponent";
export default class BackendController extends Controller {
private static readonly menu: BackendMenuElement[] = [];
public static registerMenuElement(element: BackendMenuElement): void {
this.menu.push(element);
}
public constructor() {
super();
if (User.isApprovalMode()) {
BackendController.registerMenuElement({
getLink: async () => Controller.route('accounts-approval'),
getDisplayString: async () => {
const pendingUsersCount = (await User.select()
.where('approved', false)
.get()).length;
return `Accounts approval (${pendingUsersCount})`;
},
getDisplayIcon: async () => 'user-check',
});
}
}
public getRoutesPrefix(): string {
return '/backend';
}
public routes(): void {
this.useMiddleware(RequireAuthMiddleware, RequireAdminMiddleware);
this.get('/', this.getIndex, 'backend');
if (User.isApprovalMode()) {
this.get('/accounts-approval', this.getAccountApproval, 'accounts-approval');
this.post('/accounts-approval/approve', this.postApproveAccount, 'approve-account');
this.post('/accounts-approval/reject', this.postRejectAccount, 'reject-account');
}
}
protected async getIndex(req: Request, res: Response): Promise<void> {
res.render('backend/index', {
menu: await Promise.all(BackendController.menu.map(async m => ({
link: await m.getLink(),
display_string: await m.getDisplayString(),
display_icon: await m.getDisplayIcon(),
}))),
});
}
protected async getAccountApproval(req: Request, res: Response): Promise<void> {
const accounts = await User.select()
.where('approved', 0)
.with('mainEmail')
.get();
res.render('backend/accounts_approval', {
accounts: accounts,
});
}
protected async postApproveAccount(req: Request, res: Response): Promise<void> {
const {account, email} = await this.accountRequest(req);
account.as(UserApprovedComponent).approved = true;
await account.save();
if (email && email.email) {
await new Mail(this.getApp().as(NunjucksComponent).getEnvironment(), ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE, {
approved: true,
link: config.get<string>('public_url') + Controller.route('auth'),
}).send(email.email);
}
req.flash('success', `Account successfully approved.`);
res.redirect(Controller.route('accounts-approval'));
}
protected async postRejectAccount(req: Request, res: Response): Promise<void> {
const {account, email} = await this.accountRequest(req);
await account.delete();
if (email && email.email) {
await new Mail(this.getApp().as(NunjucksComponent).getEnvironment(), ACCOUNT_REVIEW_NOTICE_MAIL_TEMPLATE, {
approved: false,
}).send(email.email);
}
req.flash('success', `Account successfully deleted.`);
res.redirect(Controller.route('accounts-approval'));
}
protected async accountRequest(req: Request): Promise<{
account: User,
email: UserEmail | null,
}> {
if (!req.body.user_id) throw new BadRequestError('Missing user_id field', 'Check your form', req.url);
const account = await User.select().where('id', req.body.user_id).with('mainEmail').first();
if (!account) throw new NotFoundHttpError('User', req.url);
const email = await account.mainEmail.get();
return {
account: account,
email: email,
};
}
}
export interface BackendMenuElement {
/**
* Returns the link of this menu element (usually using {@code Controller.route})
*/
getLink(): Promise<string>;
/**
* The string part of the link display
*/
getDisplayString(): Promise<string>;
/**
* An optional feather icon name
*/
getDisplayIcon(): Promise<string | null>;
}

View File

@ -1,19 +1,16 @@
import nodemailer, {SentMessageInfo, Transporter} from "nodemailer"; import nodemailer, {SentMessageInfo, Transporter} from "nodemailer";
import config from "config"; import config from "config";
import {Options} from "nodemailer/lib/mailer"; import {Options} from "nodemailer/lib/mailer";
import nunjucks from 'nunjucks'; import {Environment} from 'nunjucks';
import * as util from "util"; import * as util from "util";
import {WrappingError} from "./Utils"; import {WrappingError} from "../Utils";
import mjml2html from "mjml"; import mjml2html from "mjml";
import * as querystring from "querystring"; import {logger} from "../Logger";
import Logger from "./Logger"; import Controller from "../Controller";
import {ParsedUrlQueryInput} from "querystring";
export function mailRoute(template: string): string {
return `/mail/${template}`;
}
export default class Mail { export default class Mail {
private static transporter: Transporter; private static transporter?: Transporter;
private static getTransporter(): Transporter { private static getTransporter(): Transporter {
if (!this.transporter) throw new MailError('Mail system was not prepared.'); if (!this.transporter) throw new MailError('Mail system was not prepared.');
@ -24,14 +21,14 @@ export default class Mail {
const transporter = nodemailer.createTransport({ const transporter = nodemailer.createTransport({
host: config.get('mail.host'), host: config.get('mail.host'),
port: config.get('mail.port'), port: config.get('mail.port'),
secure: config.get('mail.secure'), requireTLS: config.get('mail.secure'), // STARTTLS
auth: { auth: {
user: config.get('mail.username'), user: config.get('mail.username'),
pass: config.get('mail.password'), pass: config.get('mail.password'),
}, },
tls: { tls: {
rejectUnauthorized: !config.get('mail.allow_invalid_tls') rejectUnauthorized: !config.get('mail.allow_invalid_tls'),
} },
}); });
try { try {
@ -41,16 +38,21 @@ export default class Mail {
throw new MailError('Connection to mail service unsuccessful.', e); throw new MailError('Connection to mail service unsuccessful.', e);
} }
Logger.info(`Mail ready to be distributed via ${config.get('mail.host')}:${config.get('mail.port')}`); logger.info(`Mail ready to be distributed via ${config.get('mail.host')}:${config.get('mail.port')}`);
} }
public static end() { public static end(): void {
this.transporter.close(); if (this.transporter) this.transporter.close();
} }
public static parse(template: string, data: any, textOnly: boolean): string { public static parse(
environment: Environment,
template: string,
data: { [p: string]: unknown },
textOnly: boolean,
): string {
data.text = textOnly; data.text = textOnly;
const nunjucksResult = nunjucks.render(template, data); const nunjucksResult = environment.render(template, data);
if (textOnly) return nunjucksResult; if (textOnly) return nunjucksResult;
const mjmlResult = mjml2html(nunjucksResult, {}); const mjmlResult = mjml2html(nunjucksResult, {});
@ -62,11 +64,13 @@ export default class Mail {
return mjmlResult.html; return mjmlResult.html;
} }
private readonly template: MailTemplate;
private readonly options: Options = {}; private readonly options: Options = {};
private readonly data: { [p: string]: any };
constructor(template: MailTemplate, data: { [p: string]: any } = {}) { public constructor(
private readonly environment: Environment,
private readonly template: MailTemplate,
private readonly data: ParsedUrlQueryInput = {},
) {
this.template = template; this.template = template;
this.data = data; this.data = data;
this.options.subject = this.template.getSubject(data); this.options.subject = this.template.getSubject(data);
@ -93,18 +97,26 @@ export default class Mail {
// Set options // Set options
this.options.to = destEmail; this.options.to = destEmail;
this.options.from = {
name: config.get('mail.from_name'),
address: config.get('mail.from'),
};
// Set data // Set data
this.data.mail_subject = this.options.subject; this.data.mail_subject = this.options.subject;
this.data.mail_to = this.options.to; this.data.mail_to = this.options.to;
this.data.mail_link = `${config.get<string>('public_url')}${mailRoute(this.template.template)}?${querystring.stringify(this.data)}`; this.data.mail_link = config.get<string>('public_url') +
Controller.route('mail', [this.template.template], this.data);
this.data.app = config.get('app');
// Log // Log
Logger.dev('Send mail', this.options); logger.debug('Send mail', this.options);
// Render email // Render email
this.options.html = Mail.parse('mails/' + this.template.template + '.mjml.njk', this.data, false); this.options.html = Mail.parse(this.environment, 'mails/' + this.template.template + '.mjml.njk',
this.options.text = Mail.parse('mails/' + this.template.template + '.mjml.njk', this.data, true); this.data, false);
this.options.text = Mail.parse(this.environment, 'mails/' + this.template.template + '.mjml.njk',
this.data, true);
// Send email // Send email
results.push(await Mail.getTransporter().sendMail(this.options)); results.push(await Mail.getTransporter().sendMail(this.options));
@ -116,9 +128,9 @@ export default class Mail {
export class MailTemplate { export class MailTemplate {
private readonly _template: string; private readonly _template: string;
private readonly subject: (data: any) => string; private readonly subject: (data: { [p: string]: unknown }) => string;
constructor(template: string, subject: (data: any) => string) { public constructor(template: string, subject: (data: { [p: string]: unknown }) => string) {
this._template = template; this._template = template;
this.subject = subject; this.subject = subject;
} }
@ -127,13 +139,13 @@ export class MailTemplate {
return this._template; return this._template;
} }
public getSubject(data: any): string { public getSubject(data: { [p: string]: unknown }): string {
return 'Watch My Stream - ' + this.subject(data); return `${config.get('app.name')} - ${this.subject(data)}`;
} }
} }
class MailError extends WrappingError { class MailError extends WrappingError {
constructor(message: string = 'An error occurred while sending mail.', cause?: Error) { public constructor(message: string = 'An error occurred while sending mail.', cause?: Error) {
super(message, cause); super(message, cause);
} }
} }

View File

@ -0,0 +1,16 @@
import {Request, Response} from "express";
import Controller from "../Controller";
import Mail from "./Mail";
import NunjucksComponent from "../components/NunjucksComponent";
export default class MailController extends Controller {
public routes(): void {
this.get("/mail/:template", this.getMail, 'mail');
}
protected async getMail(request: Request, response: Response): Promise<void> {
const template = request.params['template'];
response.send(Mail.parse(this.getApp().as(NunjucksComponent).getEnvironment(),
`mails/${template}.mjml.njk`, request.query, false));
}
}

20
src/main.ts Normal file
View File

@ -0,0 +1,20 @@
import {delimiter} from "path";
// Load config from specified path or default + swaf/config (default defaults)
process.env['NODE_CONFIG_DIR'] =
__dirname + '/../../node_modules/swaf/config/'
+ delimiter
+ (process.env['NODE_CONFIG_DIR'] || __dirname + '/../../config/');
import {logger} from "./Logger";
import TestApp from "./TestApp";
import config from "config";
(async () => {
logger.debug('Config path:', process.env['NODE_CONFIG_DIR']);
const app = new TestApp(config.get<string>('listen_addr'), config.get<number>('port'));
await app.start();
})().catch(err => {
logger.error(err);
});

View File

@ -1,25 +0,0 @@
import Migration from "../db/Migration";
import {query} from "../db/MysqlConnectionManager";
/**
* Must be the first migration
*/
export default class CreateLogsTable extends Migration {
async install(): Promise<void> {
await query('CREATE TABLE logs(' +
'id INT NOT NULL AUTO_INCREMENT,' +
'level TINYINT UNSIGNED NOT NULL,' +
'message TEXT NOT NULL,' +
'log_id BINARY(16),' +
'error_name VARCHAR(128),' +
'error_message VARCHAR(512),' +
'error_stack TEXT,' +
'created_at DATETIME NOT NULL DEFAULT NOW(),' +
'PRIMARY KEY (id)' +
')');
}
async rollback(): Promise<void> {
await query('DROP TABLE logs');
}
}

View File

@ -5,7 +5,7 @@ import {query} from "../db/MysqlConnectionManager";
* Must be the first migration * Must be the first migration
*/ */
export default class CreateMigrationsTable extends Migration { export default class CreateMigrationsTable extends Migration {
async shouldRun(currentVersion: number): Promise<boolean> { public async shouldRun(currentVersion: number): Promise<boolean> {
try { try {
await query('SELECT 1 FROM migrations LIMIT 1'); await query('SELECT 1 FROM migrations LIMIT 1');
} catch (e) { } catch (e) {
@ -17,16 +17,17 @@ export default class CreateMigrationsTable extends Migration {
return await super.shouldRun(currentVersion); return await super.shouldRun(currentVersion);
} }
async install(): Promise<void> { public async install(): Promise<void> {
await query('CREATE TABLE migrations(' + await this.query(`CREATE TABLE migrations
'id INT NOT NULL,' + (
'name VARCHAR(64) NOT NULL,' + id INT NOT NULL,
'migration_date DATE,' + name VARCHAR(64) NOT NULL,
'PRIMARY KEY (id)' + migration_date DATE,
')'); PRIMARY KEY (id)
)`);
} }
async rollback(): Promise<void> { public async rollback(): Promise<void> {
await query('DROP TABLE migrations'); await this.query('DROP TABLE migrations');
} }
} }

View File

@ -0,0 +1,11 @@
import Migration from "../db/Migration";
export default class DropLegacyLogsTable extends Migration {
public async install(): Promise<void> {
await this.query('DROP TABLE IF EXISTS logs');
}
public async rollback(): Promise<void> {
// Do nothing
}
}

View File

@ -0,0 +1,11 @@
import Migration from "../db/Migration";
export default class DummyMigration extends Migration {
public async install(): Promise<void> {
// Do nothing
}
public async rollback(): Promise<void> {
// Do nothing
}
}

View File

@ -1,69 +0,0 @@
import Model from "../db/Model";
import {LogLevel, LogLevelKeys} from "../Logger";
import Validator from "../db/Validator";
export default class Log extends Model {
private level?: number;
public message?: string;
private log_id?: Buffer;
private error_name?: string;
private error_message?: string;
private error_stack?: string;
private created_at?: Date;
protected defineProperties(): void {
this.defineProperty<number>('level', new Validator<number>().defined());
this.defineProperty<string>('message', new Validator<string>().defined().between(0, 65535));
this.defineProperty<Buffer>('log_id', new Validator<Buffer>().acceptUndefined().length(16));
this.defineProperty<string>('error_name', new Validator<string>().acceptUndefined().between(0, 128));
this.defineProperty<string>('error_message', new Validator<string>().acceptUndefined().between(0, 512));
this.defineProperty<string>('error_stack', new Validator<string>().acceptUndefined().between(0, 65535));
this.defineProperty<Date>('created_at', new Validator<Date>());
}
public getLevel(): LogLevelKeys {
if (typeof this.level !== 'number') return 'ERROR';
return <LogLevelKeys>LogLevel[this.level];
}
public setLevel(level: LogLevelKeys) {
this.level = LogLevel[level];
}
public getLogID(): string | null {
if (!this.log_id) return null;
const chars = this.log_id!.toString('hex');
let out = '';
let i = 0;
for (const l of [8, 4, 4, 4, 12]) {
if (i > 0) out += '-';
out += chars.substr(i, l);
i += l;
}
return out;
}
public setLogID(buffer: Buffer) {
this.log_id = buffer;
}
public getErrorName(): string {
return this.error_name || '';
}
public getErrorMessage(): string {
return this.error_message || '';
}
public getErrorStack(): string {
return this.error_stack || '';
}
public setError(error?: Error) {
if (!error) return;
this.error_name = error.name;
this.error_message = error.message;
this.error_stack = error.stack;
}
}

View File

@ -1,22 +1,54 @@
import {Environment} from "nunjucks"; import {Files} from "formidable";
import Model from "../db/Model"; import {Type} from "../Utils";
import Middleware from "../Middleware";
import {FlashMessages} from "../components/SessionComponent";
import {Session, SessionData} from "express-session";
import {PasswordAuthProofSessionData} from "../auth/password/PasswordAuthProof";
declare global { declare global {
namespace Express { namespace Express {
export interface Request { export interface Request {
env: Environment; getSession(): Session & Partial<SessionData>;
models: { [p: string]: Model | null };
modelCollections: { [p: string]: Model[] | null };
flash(): { [key: string]: string[] }; getSessionOptional(): Session & Partial<SessionData> | undefined;
flash(message: string): any;
flash(event: string, message: any): any; files: Files;
}
export interface Response {
redirectBack(defaultUrl?: string): any; middlewares: Middleware[];
as<M extends Middleware>(type: Type<M>): M;
flash(): FlashMessages;
flash(message: string): unknown[];
flash(event: string, message: unknown): void;
getPreviousUrl(): string | null;
getIntendedUrl(): string | null;
} }
} }
} }
declare module 'express-session' {
interface SessionData {
id?: string;
previousUrl?: string;
wantsSessionPersistence?: boolean;
persistent?: boolean;
isAuthenticated?: boolean;
authPasswordProof?: PasswordAuthProofSessionData;
csrf?: string;
}
}

1122
test/Authentication.test.ts Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,147 @@
import TestApp from "../src/TestApp";
import useApp from "./_app";
import Controller from "../src/Controller";
import supertest from "supertest";
import CsrfProtectionComponent from "../src/components/CsrfProtectionComponent";
import UserPasswordComponent from "../src/auth/password/UserPasswordComponent";
import {popEmail} from "./_mail_server";
import AuthComponent from "../src/auth/AuthComponent";
import Migration, {MigrationType} from "../src/db/Migration";
import AddNameToUsersMigration from "../src/auth/migrations/AddNameToUsersMigration";
import {followMagicLinkFromMail, testLogout} from "./_authentication_common";
import UserEmail from "../src/auth/models/UserEmail";
let app: TestApp;
useApp(async (addr, port) => {
return app = new class extends TestApp {
protected async init(): Promise<void> {
this.use(new class extends Controller {
public routes(): void {
this.get('/', (req, res) => {
res.render('home');
}, 'home');
this.get('/csrf', (req, res) => {
res.send(CsrfProtectionComponent.getCsrfToken(req.getSession()));
}, 'csrf');
this.get('/is-auth', async (req, res) => {
const proofs = await this.getApp().as(AuthComponent).getAuthGuard().getProofs(req);
if (proofs.length > 0) res.sendStatus(200);
else res.sendStatus(401);
}, 'is-auth');
}
}());
await super.init();
}
protected getMigrations(): MigrationType<Migration>[] {
return super.getMigrations().filter(m => m !== AddNameToUsersMigration);
}
}(addr, port);
});
let agent: supertest.SuperTest<supertest.Test>;
beforeAll(() => {
agent = supertest(app.getExpressApp());
});
describe('Register with username and password (password)', () => {
test('Must be disabled', async () => {
const res = await agent.get('/csrf').expect(200);
const cookies = res.get('Set-Cookie');
const csrf = res.text;
// Register user
await agent.post('/auth/register')
.set('Cookie', cookies)
.send({
csrf: csrf,
auth_method: 'password',
identifier: 'entrapta',
password: 'darla_is_cute',
password_confirmation: 'darla_is_cute',
terms: 'on',
})
.expect(500);
});
});
describe('Register with email (magic_link)', () => {
test('General case', async () => {
const res = await agent.get('/csrf').expect(200);
const cookies = res.get('Set-Cookie');
const csrf = res.text;
await agent.post('/auth/register')
.set('Cookie', cookies)
.send({
csrf: csrf,
auth_method: 'magic_link',
identifier: 'glimmer@example.org',
})
.expect(302)
.expect('Location', '/magic/lobby?redirect_uri=');
await followMagicLinkFromMail(agent, cookies);
await testLogout(agent, cookies, csrf);
// Verify saved user
const email = await UserEmail.select()
.with('user')
.where('email', 'glimmer@example.org')
.first();
const user = email?.user.getOrFail();
expect(user).toBeDefined();
expect(email).toBeDefined();
expect(email?.email).toStrictEqual('glimmer@example.org');
await expect(user?.as(UserPasswordComponent).verifyPassword('')).resolves.toStrictEqual(false);
});
test('Cannot register taken email', async () => {
const res = await agent.get('/csrf').expect(200);
const cookies = res.get('Set-Cookie');
const csrf = res.text;
await agent.post('/auth/register')
.set('Cookie', cookies)
.send({
csrf: csrf,
auth_method: 'magic_link',
identifier: 'bow@example.org',
name: 'bow',
})
.expect(302)
.expect('Location', '/magic/lobby?redirect_uri=');
await followMagicLinkFromMail(agent, cookies);
// Verify saved user
const userEmail = await UserEmail.select()
.with('user')
.where('email', 'bow@example.org')
.first();
const user = userEmail?.user.getOrFail();
expect(user).toBeDefined();
// Attempt register with another mail but same username
const res2 = await agent.get('/csrf').expect(200);
await agent.post('/auth/register')
.set('Cookie', res2.get('Set-Cookie'))
.send({
csrf: res2.text,
auth_method: 'magic_link',
identifier: 'bow@example.org',
name: 'bow2',
})
.expect(400);
expect(await popEmail()).toBeNull();
});
});

View File

@ -0,0 +1,92 @@
import useApp from "./_app";
import Controller from "../src/Controller";
import supertest from "supertest";
import TestApp from "../src/TestApp";
import CsrfProtectionComponent from "../src/components/CsrfProtectionComponent";
let app: TestApp;
useApp(async (addr, port) => {
return app = new class extends TestApp {
protected async init(): Promise<void> {
this.use(new class extends Controller {
public routes(): void {
this.get('/', (req, res) => {
res.send(CsrfProtectionComponent.getCsrfToken(req.getSession()));
}, 'csrf_test');
this.post('/', (req, res) => {
res.json({
status: 'ok',
});
}, 'csrf_test');
}
}());
await super.init();
}
}(addr, port);
});
describe('Test CSRF protection', () => {
let cookies: string[];
let csrf: string;
test('no csrf token should be in session at first', (done) => {
const agent = supertest(app.getExpressApp());
agent.post('/')
.expect(401)
.then(res => {
expect(res.text).toContain(`You weren't assigned any CSRF token.`);
cookies = res.get('Set-Cookie');
agent.get('/')
.set('Cookie', cookies)
.expect(200)
.then(res => {
csrf = res.text;
done();
}).catch(done.fail);
}).catch(done.fail);
});
test('sending no csrf token should fail', (done) => {
expect(cookies).toBeDefined();
const agent = supertest(app.getExpressApp());
agent.post('/')
.set('Cookie', cookies)
.expect(401)
.then((res) => {
expect(res.text).toContain(`You didn't provide any CSRF token.`);
done();
}).catch(done.fail);
});
test('sending an invalid csrf token should fail', (done) => {
expect(cookies).toBeDefined();
const agent = supertest(app.getExpressApp());
agent.post('/')
.set('Cookie', cookies)
.set('Content-Type', 'application/json')
.send({csrf: 'not_a_valid_csrf'})
.expect(401)
.then(res => {
expect(res.text).toContain(`Tokens don't match.`);
done();
}).catch(done.fail);
});
test('sending a valid csrf token should success', (done) => {
expect(cookies).toBeDefined();
const agent = supertest(app.getExpressApp());
agent.post('/')
.set('Cookie', cookies)
.set('Content-Type', 'application/json')
.send({csrf: csrf})
.expect(200)
.then(() => done())
.catch(done.fail);
});
});

View File

@ -1,40 +1,140 @@
import MysqlConnectionManager from "../src/db/MysqlConnectionManager"; import MysqlConnectionManager from "../src/db/MysqlConnectionManager";
import Model from "../src/db/Model"; import Model from "../src/db/Model";
import Validator from "../src/db/Validator"; import ModelFactory from "../src/db/ModelFactory";
import {MIGRATIONS} from "./_migrations"; import {ValidationBag} from "../src/db/Validator";
import {logger} from "../src/Logger";
import {ManyThroughModelRelation, OneModelRelation} from "../src/db/ModelRelation";
import {MIGRATIONS} from "../src/TestApp";
import config from "config";
class FakeDummyModel extends Model { class FakeDummyModel extends Model {
public name?: string; public id?: number = undefined;
public date?: Date; public name?: string = undefined;
public date_default?: Date; public date?: Date = undefined;
public date_default?: Date = undefined;
protected defineProperties(): void { protected init(): void {
this.defineProperty<string>('name', new Validator().acceptUndefined().between(3, 256)); this.setValidation('name').acceptUndefined().between(3, 256);
this.defineProperty<Date>('date', new Validator());
this.defineProperty<Date>('date_default', new Validator());
} }
} }
beforeAll(async (done) => { class Post extends Model {
MysqlConnectionManager.registerMigrations(MIGRATIONS); public id?: number = undefined;
await MysqlConnectionManager.prepare(); public author_id?: number = undefined;
done(); public content?: string = undefined;
});
afterAll(async (done) => { public readonly author = new OneModelRelation(this, Author, {
await MysqlConnectionManager.endPool(); localKey: 'author_id',
done(); foreignKey: 'id',
});
describe('Model', () => {
it('should have a proper table name', async () => {
expect(FakeDummyModel.table).toBe('fake_dummy_models');
expect(new FakeDummyModel({}).table).toBe('fake_dummy_models');
}); });
it('should insert and retrieve properly', async () => { protected init(): void {
await MysqlConnectionManager.query(`DROP TABLE IF EXISTS ${FakeDummyModel.table}`); this.setValidation('author_id').defined().exists(Author, 'id');
await MysqlConnectionManager.query(`CREATE TABLE ${FakeDummyModel.table}( }
}
class Author extends Model {
public id?: number = undefined;
public name?: string = undefined;
public readonly roles = new ManyThroughModelRelation(this, Role, {
localKey: 'id',
foreignKey: 'id',
pivotTable: 'author_role',
localPivotKey: 'author_id',
foreignPivotKey: 'role_id',
});
}
class Role extends Model {
public id?: number = undefined;
public name?: string = undefined;
public readonly permissions = new ManyThroughModelRelation(this, Permission, {
localKey: 'id',
foreignKey: 'id',
pivotTable: 'role_permission',
localPivotKey: 'role_id',
foreignPivotKey: 'permission_id',
});
}
class Permission extends Model {
public id?: number = undefined;
public name?: string = undefined;
}
class AuthorRole extends Model {
public static get table(): string {
return 'author_role';
}
public author_id?: number = undefined;
public role_id?: number = undefined;
protected init(): void {
this.setValidation('author_id').defined().exists(Author, 'id');
this.setValidation('role_id').defined().exists(Role, 'id');
}
}
class RolePermission extends Model {
public static get table(): string {
return 'role_permission';
}
public role_id?: number = undefined;
public permission_id?: number = undefined;
protected init(): void {
this.setValidation('role_id').defined().exists(Role, 'id');
this.setValidation('permission_id').defined().exists(Permission, 'id');
}
}
let fakeDummyModelModelFactory: ModelFactory<FakeDummyModel>;
let postFactory: ModelFactory<Post>;
let authorFactory: ModelFactory<Author>;
let roleFactory: ModelFactory<Role>;
let permissionFactory: ModelFactory<Permission>;
beforeAll(async () => {
await MysqlConnectionManager.prepare();
await MysqlConnectionManager.query('DROP DATABASE IF EXISTS ' + config.get<string>('mysql.database'));
await MysqlConnectionManager.endPool();
logger.setSettings({minLevel: "trace"});
MysqlConnectionManager.registerMigrations(MIGRATIONS);
ModelFactory.register(FakeDummyModel);
ModelFactory.register(Post);
ModelFactory.register(Author);
ModelFactory.register(Role);
ModelFactory.register(Permission);
ModelFactory.register(AuthorRole);
ModelFactory.register(RolePermission);
await MysqlConnectionManager.prepare();
// Create FakeDummyModel table
fakeDummyModelModelFactory = ModelFactory.get(FakeDummyModel);
postFactory = ModelFactory.get(Post);
authorFactory = ModelFactory.get(Author);
roleFactory = ModelFactory.get(Role);
permissionFactory = ModelFactory.get(Permission);
await MysqlConnectionManager.query(`DROP TABLE IF EXISTS author_role`);
await MysqlConnectionManager.query(`DROP TABLE IF EXISTS role_permission`);
for (const factory of [
fakeDummyModelModelFactory,
postFactory,
authorFactory,
roleFactory,
permissionFactory,
]) {
await MysqlConnectionManager.query(`DROP TABLE IF EXISTS ${factory.table}`);
}
await MysqlConnectionManager.query(`CREATE TABLE ${fakeDummyModelModelFactory.table}(
id INT NOT NULL AUTO_INCREMENT, id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(256), name VARCHAR(256),
date DATETIME, date DATETIME,
@ -42,23 +142,255 @@ describe('Model', () => {
PRIMARY KEY(id) PRIMARY KEY(id)
)`); )`);
await MysqlConnectionManager.query(`CREATE TABLE ${authorFactory.table}(
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(64),
PRIMARY KEY(id)
)`);
await MysqlConnectionManager.query(`CREATE TABLE ${postFactory.table}(
id INT NOT NULL AUTO_INCREMENT,
author_id INT NOT NULL,
content VARCHAR(512),
PRIMARY KEY(id),
FOREIGN KEY post_author_fk (author_id) REFERENCES ${authorFactory.table} (id)
)`);
await MysqlConnectionManager.query(`CREATE TABLE ${roleFactory.table}(
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(64),
PRIMARY KEY(id)
)`);
await MysqlConnectionManager.query(`CREATE TABLE ${permissionFactory.table}(
id INT NOT NULL AUTO_INCREMENT,
name VARCHAR(64),
PRIMARY KEY(id)
)`);
await MysqlConnectionManager.query(`CREATE TABLE author_role(
id INT NOT NULL AUTO_INCREMENT,
author_id INT NOT NULL,
role_id INT NOT NULL,
PRIMARY KEY(id),
FOREIGN KEY author_role_author_fk (author_id) REFERENCES ${authorFactory.table} (id),
FOREIGN KEY author_role_role_fk (role_id) REFERENCES ${roleFactory.table} (id)
)`);
await MysqlConnectionManager.query(`CREATE TABLE role_permission(
id INT NOT NULL AUTO_INCREMENT,
role_id INT NOT NULL,
permission_id INT NOT NULL,
PRIMARY KEY(id),
FOREIGN KEY role_permission_role_fk (role_id) REFERENCES ${roleFactory.table} (id),
FOREIGN KEY role_permission_permission_fk (permission_id) REFERENCES ${permissionFactory.table} (id)
)`);
/// SEED ///
// permissions
createPostPermission = Permission.create({name: 'create-post'});
await createPostPermission.save();
moderatePostPermission = Permission.create({name: 'moderate-post'});
await moderatePostPermission.save();
viewLogsPermission = Permission.create({name: 'view-logs'});
await viewLogsPermission.save();
// roles
guestRole = Role.create({name: 'guest'});
await guestRole.save();
await RolePermission.create({role_id: guestRole.id, permission_id: createPostPermission.id}).save();
moderatorRole = Role.create({name: 'moderator'});
await moderatorRole.save();
await RolePermission.create({role_id: moderatorRole.id, permission_id: createPostPermission.id}).save();
await RolePermission.create({role_id: moderatorRole.id, permission_id: moderatePostPermission.id}).save();
adminRole = Role.create({name: 'admin'});
await adminRole.save();
await RolePermission.create({role_id: adminRole.id, permission_id: createPostPermission.id}).save();
await RolePermission.create({role_id: adminRole.id, permission_id: moderatePostPermission.id}).save();
await RolePermission.create({role_id: adminRole.id, permission_id: viewLogsPermission.id}).save();
// authors
glimmerAuthor = Author.create({name: 'glimmer'});
await glimmerAuthor.save();
await AuthorRole.create({author_id: glimmerAuthor.id, role_id: guestRole.id}).save();
bowAuthor = Author.create({name: 'bow'});
await bowAuthor.save();
await AuthorRole.create({author_id: bowAuthor.id, role_id: moderatorRole.id}).save();
adoraAuthor = Author.create({name: 'adora'});
await adoraAuthor.save();
await AuthorRole.create({author_id: adoraAuthor.id, role_id: adminRole.id}).save();
// posts
post1 = Post.create({author_id: glimmerAuthor.id, content: 'I\'m the queen now and you\'ll do as I order.'});
await post1.save();
post2 = Post.create({author_id: adoraAuthor.id, content: 'But you\'re wrong!'});
await post2.save();
post3 = Post.create({author_id: bowAuthor.id, content: 'Come on guys, let\'s talk this through.'});
await post3.save();
});
afterAll(async () => {
await MysqlConnectionManager.endPool();
});
describe('Model', () => {
it('should construct properly', () => {
const date = new Date(888);
const model = fakeDummyModelModelFactory.create({
name: 'a_name',
date: date,
non_existing_property: 'dropped_value',
}, true);
expect(model.id).toBeUndefined();
expect(model.name).toBe('a_name');
expect(model.date).toBe(date);
expect(model.date_default).toBeUndefined();
expect(model.non_existing_property).toBeUndefined();
});
it('should have a proper table name', () => {
expect(fakeDummyModelModelFactory.table).toBe('fake_dummy_models');
expect(FakeDummyModel.table).toBe('fake_dummy_models');
expect(FakeDummyModel.create({}).table).toBe('fake_dummy_models');
});
it('should insert properly', async () => {
const date = new Date(569985); const date = new Date(569985);
let instance: FakeDummyModel | null = new FakeDummyModel({ const insertInstance: FakeDummyModel | null = fakeDummyModelModelFactory.create({
name: 'name1', name: 'name1',
date: date, date: date,
}); }, true);
await instance.save();
expect(instance.id).toBe(1);
expect(instance.name).toBe('name1');
expect(instance.date?.getTime()).toBeCloseTo(date.getTime(), -4);
expect(instance.date_default).toBeDefined();
instance = await FakeDummyModel.getById(1); // Insert
expect(instance).toBeDefined(); expect(insertInstance.exists()).toBeFalsy();
expect(instance!.id).toBe(1); await insertInstance.save();
expect(instance!.name).toBe('name1'); expect(insertInstance.exists()).toBeTruthy();
expect(instance!.date?.getTime()).toBeCloseTo(date.getTime(), -4);
expect(instance!.date_default).toBeDefined(); expect(insertInstance.id).toBe(1); // Auto id from insert
}, 15000); expect(insertInstance.name).toBe('name1');
expect(insertInstance.date?.getTime()).toBeCloseTo(date.getTime(), -4);
expect(insertInstance.date_default).toBeDefined();
// Check that row exists in DB
const retrievedInstance = await FakeDummyModel.getById(1);
expect(retrievedInstance).toBeDefined();
expect(retrievedInstance?.id).toBe(1);
expect(retrievedInstance?.name).toBe('name1');
expect(retrievedInstance?.date?.getTime()).toBeCloseTo(date.getTime(), -4);
expect(retrievedInstance?.date_default).toBeDefined();
const failingInsertModel = fakeDummyModelModelFactory.create({
name: 'a',
}, true);
await expect(failingInsertModel.save()).rejects.toBeInstanceOf(ValidationBag);
});
it('should update properly', async () => {
const insertModel = fakeDummyModelModelFactory.create({
name: 'update',
}, true);
await insertModel.save();
const preUpdatedModel = await FakeDummyModel.getById(insertModel.id);
expect(preUpdatedModel).not.toBeNull();
expect(preUpdatedModel?.name).toBe(insertModel.name);
// Update model
if (preUpdatedModel) {
preUpdatedModel.name = 'updated_name';
await preUpdatedModel.save();
}
const postUpdatedModel = await FakeDummyModel.getById(insertModel.id);
expect(postUpdatedModel).not.toBeNull();
expect(postUpdatedModel?.id).toBe(insertModel.id);
expect(postUpdatedModel?.name).not.toBe(insertModel.name);
expect(postUpdatedModel?.name).toBe(preUpdatedModel?.name);
});
it('should delete properly', async () => {
const insertModel = fakeDummyModelModelFactory.create({
name: 'delete',
}, true);
await insertModel.save();
const preDeleteModel = await FakeDummyModel.getById(insertModel.id);
expect(preDeleteModel).not.toBeNull();
await preDeleteModel?.delete();
const postDeleteModel = await FakeDummyModel.getById(insertModel.id);
expect(postDeleteModel).toBeNull();
});
});
let createPostPermission: Permission;
let moderatePostPermission: Permission;
let viewLogsPermission: Permission;
let guestRole: Role;
let moderatorRole: Role;
let adminRole: Role;
let glimmerAuthor: Author;
let bowAuthor: Author;
let adoraAuthor: Author;
let post1: Post;
let post2: Post;
let post3: Post;
describe('ModelRelation', () => {
test('Query and check relations', async () => {
const posts = await Post.select()
.with('author.roles.permissions')
.sortBy('id', 'ASC')
.get();
expect(posts.length).toBe(3);
async function testPost(
post: Post,
originalPost: Post,
expectedAuthor: Author,
expectedRoles: Role[],
expectedPermissions: Permission[],
) {
console.log('Testing post', post);
expect(post.id).toBe(originalPost.id);
expect(post.content).toBe(originalPost.content);
const actualAuthor = await post.author.get();
expect(actualAuthor).not.toBeNull();
expect(await post.author.has(expectedAuthor)).toBeTruthy();
expect(actualAuthor?.equals(expectedAuthor)).toBe(true);
const authorRoles = await actualAuthor?.roles.get() || [];
console.log('Roles:');
expect(authorRoles.map(r => r.id)).toStrictEqual(expectedRoles.map(r => r.id));
const authorPermissions = (await Promise.all(authorRoles.map(async r => await r.permissions.get())))
.flatMap(p => p);
console.log('Permissions:');
expect(authorPermissions.map(p => p.id)).toStrictEqual(expectedPermissions.map(p => p.id));
}
await testPost(posts[0], post1, glimmerAuthor,
[guestRole],
[createPostPermission]);
await testPost(posts[1], post2, adoraAuthor,
[adminRole],
[createPostPermission, moderatePostPermission, viewLogsPermission]);
await testPost(posts[2], post3, bowAuthor,
[moderatorRole],
[createPostPermission, moderatePostPermission]);
});
}); });

126
test/ModelQuery.test.ts Normal file
View File

@ -0,0 +1,126 @@
import ModelQuery, {SelectFieldValue, WhereOperator} from "../src/db/ModelQuery";
import ModelFactory from "../src/db/ModelFactory";
import Model from "../src/db/Model";
describe('Test ModelQuery', () => {
test('select', () => {
const query = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>, 'f1', '"Test" as f2')
.where('f4', 'v4')
.where('f5', true)
.where('f6', null)
.where('f7', undefined);
expect(query.toString(true)).toBe('SELECT `model`.`f1`,"Test" as f2 FROM `model` WHERE `f4`=? AND `f5`=true AND `f6` IS null AND `f7` IS null');
expect(query.variables).toStrictEqual(['v4']);
});
test('order by', () => {
const query = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>)
.sortBy('model.f2', 'ASC');
expect(query.toString(true)).toBe('SELECT `model`.* FROM `model` ORDER BY `model`.`f2` ASC');
const queryRaw = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>)
.sortBy('coalesce(model.f1, model.f2)', 'ASC', true);
expect(queryRaw.toString(true)).toBe('SELECT `model`.* FROM `model` ORDER BY coalesce(model.f1, model.f2) ASC');
});
test('create (insert into)', () => {
const date = new Date();
const query = ModelQuery.insert(
{table: 'model'} as unknown as ModelFactory<Model>,
{
'boolean': true,
'null': null,
'undefined': undefined,
'string': 'string',
'date': date,
'sensitive': 'sensitive', // Reserved word
},
);
expect(query.toString(true)).toBe('INSERT INTO `model` (`boolean`,`null`,`undefined`,`string`,`date`,`sensitive`) VALUES(true,null,null,?,?,?)');
expect(query.variables).toStrictEqual([
'string',
date,
'sensitive',
]);
});
test('update', () => {
const date = new Date();
const query = ModelQuery.update({table: 'model'} as unknown as ModelFactory<Model>, {
'boolean': true,
'null': null,
'undefined': undefined,
'string': 'string',
'date': date,
'sensitive': 'sensitive', // Reserved word
}).where('f4', 'v4')
.where('f5', true)
.where('f6', null)
.where('f7', undefined);
expect(query.toString(true)).toBe('UPDATE `model` SET `model`.`boolean`=true,`model`.`null`=null,`model`.`undefined`=null,`model`.`string`=?,`model`.`date`=?,`model`.`sensitive`=? WHERE `f4`=? AND `f5`=true AND `f6` IS null AND `f7` IS null');
expect(query.variables).toStrictEqual([
'string',
date,
'sensitive',
'v4',
]);
});
test('function select', () => {
const query = ModelQuery.select(
{table: 'model'} as unknown as ModelFactory<Model>,
'f1',
new SelectFieldValue('_count', 'COUNT(*)', true),
);
expect(query.toString(true)).toBe('SELECT `model`.`f1`,(COUNT(*)) AS `_count` FROM `model`');
expect(query.variables).toStrictEqual([]);
});
test('pivot', () => {
const query = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>, 'f1');
query.pivot('pivot.f2', 'f3');
expect(query.toString(true)).toBe('SELECT `model`.`f1`,`pivot`.`f2`,`model`.`f3` FROM `model`');
expect(query.variables).toStrictEqual([]);
});
test('groupWhere generates proper query', () => {
const query = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>, '*');
query.where('f1', 'v1');
query.groupWhere(q => q.where('f2', 'v2').where('f3', 'v3')
.groupWhere(q => q.where('f4', 'v4'), WhereOperator.OR))
.where('f5', 'v5');
expect(query.toString(true)).toBe('SELECT `model`.* FROM `model` WHERE `f1`=? AND (`f2`=? AND `f3`=? OR (`f4`=?)) AND `f5`=?');
expect(query.variables).toStrictEqual(['v1', 'v2', 'v3', 'v4', 'v5']);
});
test('recursive queries', () => {
const query = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>, '*');
query.where('f1', 'v1');
query.leftJoin('test').on('model.j1', 'test.j2');
query.recursive({localKey: 'local', foreignKey: 'foreign'}, false);
query.limit(8);
expect(query.toString(true)).toBe("WITH RECURSIVE cte AS (SELECT `model`.*,1 AS __depth, CONCAT(`local`) AS __path FROM `model` WHERE `f1`=? UNION SELECT o.*,c.__depth + 1,CONCAT(c.__path,'/',o.`local`) AS __path FROM `model` AS o, cte AS c WHERE o.`foreign`=c.`local`) SELECT * FROM cte LEFT JOIN `test` ON `model`.`j1`=`test`.`j2` ORDER BY __path ASC LIMIT 8");
expect(query.variables).toStrictEqual(['v1']);
const reversedQuery = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>, '*');
reversedQuery.where('f1', 'v1');
reversedQuery.leftJoin('test').on('model.j1', 'test.j2');
reversedQuery.recursive({localKey: 'local', foreignKey: 'foreign'}, true);
expect(reversedQuery.toString(true)).toBe("WITH RECURSIVE cte AS (SELECT `model`.*,1 AS __depth, CONCAT(`foreign`) AS __path FROM `model` WHERE `f1`=? UNION SELECT o.*,c.__depth + 1,CONCAT(c.__path,'/',o.`foreign`) AS __path FROM `model` AS o, cte AS c WHERE o.`foreign`=c.`local`) SELECT * FROM cte LEFT JOIN `test` ON `model`.`j1`=`test`.`j2` ORDER BY __path DESC");
expect(reversedQuery.variables).toStrictEqual(['v1']);
});
test('union queries', () => {
const query = ModelQuery.select({table: 'model'} as unknown as ModelFactory<Model>, '*');
const query2 = ModelQuery.select({table: 'model2'} as unknown as ModelFactory<Model>, '*');
query2.where('f2', 'v2');
query.union(query2, 'model.f1', 'DESC', false, 8);
expect(query.toString(true)).toBe("(SELECT `model`.* FROM `model`) UNION (SELECT `model2`.* FROM `model2` WHERE `f2`=?) ORDER BY `model`.`f1` DESC LIMIT 8");
expect(query.variables).toStrictEqual(['v2']);
});
});

41
test/_app.ts Normal file
View File

@ -0,0 +1,41 @@
import Application from "../src/Application";
import {setupMailServer, teardownMailServer} from "./_mail_server";
import TestApp from "../src/TestApp";
import MysqlConnectionManager from "../src/db/MysqlConnectionManager";
import config from "config";
export default function useApp(appSupplier?: (addr: string, port: number) => Promise<TestApp>): void {
let app: Application;
beforeAll(async (done) => {
await MysqlConnectionManager.prepare();
await MysqlConnectionManager.query('DROP DATABASE IF EXISTS ' + config.get<string>('mysql.database'));
await MysqlConnectionManager.endPool();
await setupMailServer();
app = appSupplier ? await appSupplier('127.0.0.1', 8966) : new TestApp('127.0.0.1', 8966);
await app.start();
done();
});
afterAll(async (done) => {
const errors = [];
try {
await app.stop();
} catch (e) {
errors.push(e);
}
try {
await teardownMailServer();
} catch (e) {
errors.push(e);
}
if (errors.length > 0) throw errors;
done();
});
}

View File

@ -0,0 +1,39 @@
import {popEmail} from "./_mail_server";
import supertest from "supertest";
export async function followMagicLinkFromMail(
agent: supertest.SuperTest<supertest.Test>,
cookies: string[],
expectedRedirectUrl: string = '/',
): Promise<void> {
const mail: Record<string, unknown> | null = await popEmail();
expect(mail).not.toBeNull();
const query = (mail?.text as string).split('/magic/link?')[1].split('\n')[0];
expect(query).toBeDefined();
await agent.get('/magic/link?' + query)
.expect(200);
await agent.get('/magic/lobby')
.set('Cookie', cookies)
.expect(302)
.expect('Location', expectedRedirectUrl);
}
export async function testLogout(
agent: supertest.SuperTest<supertest.Test>,
cookies: string[],
csrf: string,
): Promise<void> {
// Authenticated
await agent.get('/is-auth').set('Cookie', cookies).expect(200);
// Logout
await agent.post('/auth/logout')
.set('Cookie', cookies)
.send({csrf: csrf})
.expect(302);
// Not authenticated
await agent.get('/is-auth').set('Cookie', cookies).expect(401);
}

38
test/_mail_server.ts Normal file
View File

@ -0,0 +1,38 @@
import MailDev, {Mail} from "maildev";
export const MAIL_SERVER = new MailDev({
ip: 'localhost',
});
export async function setupMailServer(): Promise<void> {
await new Promise<void>((resolve, reject) => MAIL_SERVER.listen((err?: Error) => {
if (err) reject(err);
else resolve();
}));
}
export async function teardownMailServer(): Promise<void> {
await new Promise<void>((resolve, reject) => MAIL_SERVER.close((err?: Error) => {
if (err) reject(err);
else resolve();
}));
}
export async function popEmail(): Promise<Record<string, unknown> | null> {
return await new Promise<Record<string, unknown> | null>((resolve, reject) => {
MAIL_SERVER.getAllEmail((err: Error | undefined, emails: Mail[]) => {
if (err) return reject(err);
if (emails.length === 0) return resolve(null);
const email = emails[0];
expect(email).toBeDefined();
expect(email.id).toBeDefined();
return resolve(new Promise<Record<string, unknown>>((resolve, reject) => {
MAIL_SERVER.deleteEmail(email.id as string, (err: Error | undefined) => {
if (err) return reject(err);
resolve(email as Record<string, unknown>);
});
}));
});
});
}

View File

@ -1,7 +0,0 @@
import CreateMigrationsTable from "../src/migrations/CreateMigrationsTable";
import CreateLogsTable from "../src/migrations/CreateLogsTable";
export const MIGRATIONS = [
CreateMigrationsTable,
CreateLogsTable,
];

216
test/types/maildev.d.ts vendored Normal file
View File

@ -0,0 +1,216 @@
// Type definitions for maildev 1.0.0-rc3
// Project: https://github.com/djfarrelly/maildev
// Definitions by: Cyril Schumacher <https://github.com/cyrilschumacher>
// Zak Barbuto <https://github.com/zbarbuto>
// Definitions: https://github.com/DefinitelyTyped/DefinitelyTyped
/// <reference types="node"/>
declare module 'maildev' {
import fs = require("fs");
/**
* Interface for {@link MailDev}.
*/
export default class MailDev {
/**
* Constructor.
*
* @public
* @param {MailDevOptions} options The options.
*/
public constructor(options: MailDevOptions);
/**
* Deletes a given email by identifier.
*
* @public
* @param {string} id The email identifier.
* @param {Function} callback The error callback.
*/
public deleteEmail(id: string, callback?: (error: Error) => void): void;
/**
* Deletes all email and their attachments.
*
* @public
* @param {Function} callback The error callback.
*/
public deleteAllEmail(callback?: (error: Error) => void): void;
/**
* Stops the SMTP server.
*
* @public
* @param {Function} callback The error callback.
*/
public close(callback?: (error: Error) => void): void;
/**
* Accepts e-mail identifier, returns email object.
*
* @public
* @param {string} id The e-mail identifier.
* @param {Function} callback The error callback.
*/
public getEmail(id: string, callback?: (error: Error) => void): void;
/**
* Returns a readable stream of the raw e-mail.
*
* @public
* @param {string} id The e-mail identifier.
*/
public getRawEmail(id: string, callback?: (error: Error, readStream: fs.ReadStream) => void): void;
/**
* Returns array of all e-mail.
* @public
*/
public getAllEmail(done: (error: Error, emails: Array<Record<string, unknown>>) => void): void;
/**
* Starts the SMTP server.
*
* @public
* @param {Function} callback The error callback.
*/
public listen(callback?: (error: Error) => void): void;
/**
* Event called when a new e-mail is received. Callback receives single mail object.
*
* @public
* @param {string} eventName The event name.
* @param {Function} email The email.
*/
public on(eventName: string, callback: (email: Record<string, unknown>) => void): void;
/**
* Relay the e-mail.
*
* @param {string} idOrMailObject The identifier or mail object.
* @param {Function} done The callback.
*/
public relayMail(idOrMailObject: string, done: (error: Error) => void): void;
}
/**
* Interface for {@link MailDev} options.
*/
export interface MailDevOptions {
/**
* IP Address to bind SMTP service to', '0.0.0.0'
*
* @type {string}
*/
ip?: string;
/**
* SMTP host for outgoing emails
*
* @type {string}
*/
outgoingHost?: string;
/**
* SMTP password for outgoing emails
*
* @type {string}
*/
outgoingPass?: string;
/**
* SMTP port for outgoing emails.
*
* @type {number}
*/
outgoingPort?: number;
/**
* SMTP user for outgoing emails
*
* @type {string}
*/
outgoingUser?: string;
/**
* Use SMTP SSL for outgoing emails
*
* @type {boolean}
*/
outgoingSecure?: boolean;
/**
* SMTP port to catch emails.
*
* @type {number}
*/
smtp?: number;
/**
* Port to use for web UI
*
* @type {number}
*/
web?: number;
/**
* IP Address to bind HTTP service to
*
* @type {string}
*/
webIp?: string;
/**
* Do not start web UI
*
* @type {boolean}
*/
disableWeb?: boolean;
/**
* Do not output console.log messages
*
* @type {boolean}
*/
silent?: boolean;
/**
* HTTP user for GUI
*
* @type {string}
*/
webUser?: string;
/**
* HTTP password for GUI
*
* @type {string}
*/
webPass?: string;
/**
* Open the Web GUI after startup
*
* @type {boolean}
*/
open?: boolean;
}
/**
* Interface for mail.
*/
export interface Mail {
/**
* Identifier.
*/
id?: string;
/**
* Client.
*/
envelope?: Record<string, unknown>;
}
}

Some files were not shown because too many files have changed in this diff Show More