diff --git a/src/auth/AuthController.ts b/src/auth/AuthController.ts index 6cad640..113a646 100644 --- a/src/auth/AuthController.ts +++ b/src/auth/AuthController.ts @@ -7,6 +7,8 @@ import User from "./models/User"; import UserPasswordComponent from "./password/UserPasswordComponent"; import UserNameComponent from "./models/UserNameComponent"; import {UnknownRelationValidationError} from "../db/Validator"; +import AuthMethod from "./AuthMethod"; +import AuthProof from "./AuthProof"; export default class AuthController extends Controller { public getRoutesPrefix(): string { @@ -83,8 +85,23 @@ export default class AuthController extends Controller { // Redirect to registration if user not found if (methods.length === 0) return await this.redirectToRegistration(req, res, identifier); + // Choose best matching method + let user: User | null = null; + let method: AuthMethod> | null = null; + let weight = -1; + + for (const entry of methods) { + const methodWeight = entry.method.getWeightForRequest(req); + if (methodWeight > weight) { + user = entry.user; + method = entry.method; + weight = methodWeight; + } + } + + if (!method || !user) ({method, user} = methods[0]); // Default to first method + // Login - const {user, method} = methods[0]; return await method.attemptLogin(req, res, user); } diff --git a/src/auth/AuthMethod.ts b/src/auth/AuthMethod.ts index cc25101..a283f5b 100644 --- a/src/auth/AuthMethod.ts +++ b/src/auth/AuthMethod.ts @@ -9,6 +9,14 @@ export default interface AuthMethod

> { */ getName(): string; + /** + * Used for automatic auth method detection. Won't affect forced auth method. + * + * @return {@code 0} if the request is not conform to this auth method, otherwise the exact count of matching + * fields. + */ + getWeightForRequest(req: Request): number; + findUserByIdentifier(identifier: string): Promise; getProofsForSession?(session: Express.Session): Promise; diff --git a/src/auth/magic_link/MagicLinkAuthMethod.ts b/src/auth/magic_link/MagicLinkAuthMethod.ts index d321c7a..7580abb 100644 --- a/src/auth/magic_link/MagicLinkAuthMethod.ts +++ b/src/auth/magic_link/MagicLinkAuthMethod.ts @@ -11,7 +11,7 @@ import RedirectBackComponent from "../../components/RedirectBackComponent"; import Application from "../../Application"; import {MailTemplate} from "../../mail/Mail"; import AuthMagicLinkActionType from "./AuthMagicLinkActionType"; -import Validator from "../../db/Validator"; +import Validator, {EMAIL_REGEX} from "../../db/Validator"; import ModelFactory from "../../db/ModelFactory"; import UserNameComponent from "../models/UserNameComponent"; @@ -26,6 +26,12 @@ export default class MagicLinkAuthMethod implements AuthMethod { return 'magic_link'; } + public getWeightForRequest(req: Request): number { + return !req.body.identifier || !EMAIL_REGEX.test(req.body.identifier) ? + 0 : + 1; + } + public async findUserByIdentifier(identifier: string): Promise { return (await UserEmail.select() .with('user.mainEmail') diff --git a/src/auth/password/PasswordAuthMethod.ts b/src/auth/password/PasswordAuthMethod.ts index 03ff0cf..78c55fb 100644 --- a/src/auth/password/PasswordAuthMethod.ts +++ b/src/auth/password/PasswordAuthMethod.ts @@ -24,6 +24,12 @@ export default class PasswordAuthMethod implements AuthMethod return 'password'; } + public getWeightForRequest(req: Request): number { + return !req.body.identifier || !req.body.password || req.body.password.length === 0 ? + 0 : + 2; + } + public async findUserByIdentifier(identifier: string): Promise { const query = UserEmail.select() .with('user')