swaf/src/components/CsrfProtectionComponent.ts

71 lines
2.5 KiB
TypeScript

import ApplicationComponent from "../ApplicationComponent";
import {Request, Router} from "express";
import crypto from "crypto";
import {BadRequestError} from "../HttpError";
import {AuthMiddleware} from "../auth/AuthComponent";
import {Session, SessionData} from "express-session";
export default class CsrfProtectionComponent extends ApplicationComponent {
private static readonly excluders: ((req: Request) => boolean)[] = [];
public static getCsrfToken(session: Session & Partial<SessionData>): string {
if (typeof session.csrf !== 'string') {
session.csrf = crypto.randomBytes(64).toString('base64');
}
return session.csrf;
}
public static addExcluder(excluder: (req: Request) => boolean): void {
this.excluders.push(excluder);
}
public async handle(router: Router): Promise<void> {
router.use(async (req, res, next) => {
for (const excluder of CsrfProtectionComponent.excluders) {
if (excluder(req)) return next();
}
const session = req.getSession();
res.locals.getCsrfToken = () => {
return CsrfProtectionComponent.getCsrfToken(session);
};
if (!['GET', 'HEAD', 'OPTIONS'].find(s => s === req.method)) {
try {
if ((await req.as(AuthMiddleware).getAuthGuard().getProofsForRequest(req)).length === 0) {
if (session.csrf === undefined) {
return next(new InvalidCsrfTokenError(req.baseUrl, `You weren't assigned any CSRF token.`));
} else if (req.body.csrf === undefined) {
return next(new InvalidCsrfTokenError(req.baseUrl, `You didn't provide any CSRF token.`));
} else if (session.csrf !== req.body.csrf) {
return next(new InvalidCsrfTokenError(req.baseUrl, `Tokens don't match.`));
}
}
} catch (e) {
return next(e);
}
}
next();
});
}
}
class InvalidCsrfTokenError extends BadRequestError {
public constructor(url: string, details: string, cause?: Error) {
super(
`Invalid CSRF token`,
`${details} We can't process this request. Please try again.`,
url,
cause,
);
}
public get name(): string {
return 'Invalid CSRF Token';
}
public get errorCode(): number {
return 401;
}
}