swaf/src/auth/AuthGuard.ts
2021-04-22 18:01:13 +02:00

196 lines
6.6 KiB
TypeScript

import AuthProof from "./AuthProof";
import MysqlConnectionManager from "../db/MysqlConnectionManager";
import User from "./models/User";
import {Connection} from "mysql";
import {Request, Response} from "express";
import {PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE} from "../Mails";
import Mail from "../mail/Mail";
import Controller from "../Controller";
import config from "config";
import Application from "../Application";
import NunjucksComponent from "../components/NunjucksComponent";
import AuthMethod from "./AuthMethod";
import {Session, SessionData} from "express-session";
import UserNameComponent from "./models/UserNameComponent";
export default class AuthGuard {
private readonly authMethods: AuthMethod<AuthProof<User>>[];
public constructor(
private readonly app: Application,
...authMethods: AuthMethod<AuthProof<User>>[]
) {
this.authMethods = authMethods;
}
public async interruptAuth(req: Request, res: Response): Promise<boolean> {
for (const method of this.authMethods) {
if (method.interruptAuth && await method.interruptAuth(req, res)) return true;
}
return false;
}
public getAuthMethodByName(authMethodName: string): AuthMethod<AuthProof<User>> | null {
return this.authMethods.find(m => m.getName() === authMethodName) || null;
}
public getAuthMethodNames(): string[] {
return this.authMethods.map(m => m.getName());
}
public getRegistrationMethod(): AuthMethod<AuthProof<User>> {
return this.authMethods[0];
}
public async getAuthMethodsByIdentifier(
identifier: string,
): Promise<{ user: User, method: AuthMethod<AuthProof<User>> }[]> {
const methods = [];
for (const method of this.authMethods) {
const user = await method.findUserByIdentifier(identifier);
if (user) methods.push({user, method});
}
return methods;
}
public async getProofs(req: Request): Promise<AuthProof<User>[]> {
const proofs = [];
if (req.getSessionOptional()) {
proofs.push(...await this.getProofsForSession(req.session));
}
proofs.push(...await this.getProofsForRequest(req));
return proofs;
}
public async getProofsForSession(session: Session & Partial<SessionData>): Promise<AuthProof<User>[]> {
if (!session.isAuthenticated) return [];
const proofs = [];
for (const method of this.authMethods) {
if (method.getProofsForSession) {
const methodProofs = await method.getProofsForSession(session);
for (const proof of methodProofs) {
if (!await proof.isValid() || !await proof.isAuthorized()) {
await proof.revoke();
} else {
proofs.push(proof);
}
}
}
}
if (proofs.length === 0) {
session.isAuthenticated = false;
session.persistent = false;
}
return proofs;
}
public async getProofsForRequest(req: Request): Promise<AuthProof<User>[]> {
const proofs = [];
for (const method of this.authMethods) {
if (method.getProofsForRequest) {
const methodProofs = await method.getProofsForRequest(req);
for (const proof of methodProofs) {
if (!await proof.isValid() || !await proof.isAuthorized()) {
await proof.revoke();
} else {
proofs.push(proof);
}
}
}
}
return proofs;
}
public async authenticateOrRegister(
session: Session & Partial<SessionData>,
proof: AuthProof<User>,
persistSession: boolean,
onLogin?: (user: User) => Promise<void>,
beforeRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]>,
afterRegister?: (connection: Connection, user: User) => Promise<RegisterCallback[]>,
): Promise<User> {
if (!await proof.isValid()) throw new InvalidAuthProofError();
if (!await proof.isAuthorized()) throw new UnauthorizedAuthProofError();
let user = await proof.getResource();
// Revoke proof early if user is not approved
if (user && !user.isApproved() || !user && User.isApprovalMode()) {
await proof.revoke();
}
// Register if user doesn't exist
if (!user) {
const callbacks: RegisterCallback[] = [];
user = await MysqlConnectionManager.wrapTransaction(async connection => {
const user = User.create({});
if (beforeRegister) {
(await beforeRegister(connection, user)).forEach(c => callbacks.push(c));
}
await user.save(connection, c => callbacks.push(c));
if (afterRegister) {
(await afterRegister(connection, user)).forEach(c => callbacks.push(c));
}
return user;
});
for (const callback of callbacks) {
await callback();
}
if (User.isApprovalMode()) {
await new Mail(this.app.as(NunjucksComponent).getEnvironment(), PENDING_ACCOUNT_REVIEW_MAIL_TEMPLATE, {
username: user.asOptional(UserNameComponent)?.getName() ||
(await user.mainEmail.get())?.getOrFail('email') ||
'Could not find an identifier',
link: config.get<string>('public_url') + Controller.route('accounts-approval'),
}).send(config.get<string>('app.contact_email'));
}
}
// Don't login if user isn't approved
if (!user.isApproved()) {
throw new PendingApprovalAuthError();
}
// Login
session.isAuthenticated = true;
session.persistent = persistSession;
if (onLogin) await onLogin(user);
return user;
}
}
export class AuthError extends Error {
}
export class AuthProofError extends AuthError {
}
export class InvalidAuthProofError extends AuthProofError {
public constructor() {
super('Invalid auth proof.');
}
}
export class UnauthorizedAuthProofError extends AuthProofError {
public constructor() {
super('Unauthorized auth proof.');
}
}
export class PendingApprovalAuthError extends AuthError {
public constructor() {
super(`User is not approved.`);
}
}
export type RegisterCallback = () => Promise<void>;