import AuthProof from "./AuthProof"; import MysqlConnectionManager from "../db/MysqlConnectionManager"; import User from "./models/User"; import {Connection} from "mysql"; import {Request, Response} from "express"; import {PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE} from "../Mails"; import Mail from "../mail/Mail"; import Controller from "../Controller"; import config from "config"; import Application from "../Application"; import NunjucksComponent from "../components/NunjucksComponent"; import AuthMethod from "./AuthMethod"; import {Session, SessionData} from "express-session"; import UserNameComponent from "./models/UserNameComponent"; export default class AuthGuard { private readonly authMethods: AuthMethod>[]; public constructor( private readonly app: Application, ...authMethods: AuthMethod>[] ) { this.authMethods = authMethods; } public async interruptAuth(req: Request, res: Response): Promise { for (const method of this.authMethods) { if (method.interruptAuth && await method.interruptAuth(req, res)) return true; } return false; } public getAuthMethodByName(authMethodName: string): AuthMethod> | null { return this.authMethods.find(m => m.getName() === authMethodName) || null; } public getAuthMethodNames(): string[] { return this.authMethods.map(m => m.getName()); } public getRegistrationMethod(): AuthMethod> { return this.authMethods[0]; } public async getAuthMethodsByIdentifier( identifier: string, ): Promise<{ user: User, method: AuthMethod> }[]> { const methods = []; for (const method of this.authMethods) { const user = await method.findUserByIdentifier(identifier); if (user) methods.push({user, method}); } return methods; } public async getProofs(req: Request): Promise[]> { const proofs = []; if (req.getSessionOptional()) { proofs.push(...await this.getProofsForSession(req.session)); } proofs.push(...await this.getProofsForRequest(req)); return proofs; } public async getProofsForSession(session: Session & Partial): Promise[]> { if (!session.isAuthenticated) return []; const proofs = []; for (const method of this.authMethods) { if (method.getProofsForSession) { const methodProofs = await method.getProofsForSession(session); for (const proof of methodProofs) { if (!await proof.isValid() || !await proof.isAuthorized()) { await proof.revoke(); } else { proofs.push(proof); } } } } if (proofs.length === 0) { session.isAuthenticated = false; session.persistent = false; } return proofs; } public async getProofsForRequest(req: Request): Promise[]> { const proofs = []; for (const method of this.authMethods) { if (method.getProofsForRequest) { const methodProofs = await method.getProofsForRequest(req); for (const proof of methodProofs) { if (!await proof.isValid() || !await proof.isAuthorized()) { await proof.revoke(); } else { proofs.push(proof); } } } } return proofs; } public async authenticateOrRegister( session: Session & Partial, proof: AuthProof, persistSession: boolean, onLogin?: (user: User) => Promise, beforeRegister?: (connection: Connection, user: User) => Promise, afterRegister?: (connection: Connection, user: User) => Promise, ): Promise { if (!await proof.isValid()) throw new InvalidAuthProofError(); if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError(); let user = await proof.getResource(); // Revoke proof early if user is not approved if (user && !user.isApproved() || !user && User.isApprovalMode()) { await proof.revoke(); } // Register if user doesn't exist if (!user) { const callbacks: RegisterCallback[] = []; user = await MysqlConnectionManager.wrapTransaction(async connection => { const user = User.create({}); if (beforeRegister) { (await beforeRegister(connection, user)).forEach(c => callbacks.push(c)); } await user.save(connection, c => callbacks.push(c)); if (afterRegister) { (await afterRegister(connection, user)).forEach(c => callbacks.push(c)); } return user; }); for (const callback of callbacks) { await callback(); } if (User.isApprovalMode()) { await new Mail(this.app.as(NunjucksComponent).getEnvironment(), PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE, { username: user.asOptional(UserNameComponent)?.getName() || (await user.mainEmail.get())?.getOrFail('email') || 'Could not find an identifier', link: config.get('public_url') + Controller.route('accounts-approval'), }).send(config.get('app.contact_email')); } } // Don't login if user isn't approved if (!user.isApproved()) { throw new PendingApprovalAuthError(); } // Login session.isAuthenticated = true; session.persistent = persistSession; if (onLogin) await onLogin(user); return user; } } export class AuthError extends Error { } export class AuthProofError extends AuthError { } export class InvalidAuthProofError extends AuthProofError { public constructor() { super('Invalid auth proof.'); } } export class UnauthorizedAuthProofError extends AuthProofError { public constructor() { super('Unauthorized auth proof.'); } } export class PendingApprovalAuthError extends AuthError { public constructor() { super(`User is not approved.`); } } export type RegisterCallback = () => Promise;