From 906fa572163a94fccaaf451577bf8dbd32c1af44 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 24 Apr 2023 13:05:52 +0300 Subject: [PATCH] Add JWT support for user-interactive flows --- synapse/api/constants.py | 1 + synapse/handlers/auth.py | 4 ++ synapse/handlers/ui_auth/checkers.py | 83 ++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 5a943b02f34a..85811ad30cca 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -84,6 +84,7 @@ class LoginType: SSO: Final = "m.login.sso" DUMMY: Final = "m.login.dummy" REGISTRATION_TOKEN: Final = "m.login.registration_token" + JWT: Final = "org.matrix.login.jwt" # This is used in the `type` parameter for /register when called by diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 308e38edea2d..c11c717c2633 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -413,6 +413,10 @@ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: ): ui_auth_types.add(LoginType.SSO) + # if JWT is enabled, allow user to re-authenticate with one + if self.hs.config.jwt.jwt_enabled: + ui_auth_types.add(LoginType.JWT) + return ui_auth_types def get_enabled_auth_types(self) -> Iterable[str]: diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 78a75bfed6ee..c8e10b8090bc 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -20,6 +20,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.types import UserID from synapse.util import json_decoder if TYPE_CHECKING: @@ -314,6 +315,87 @@ async def check_auth(self, authdict: dict, clientip: str) -> Any: ) +class JwtAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.JWT + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.hs = hs + + def is_enabled(self) -> bool: + return bool(self.hs.config.jwt.jwt_enabled) + + async def check_auth(self, authdict: dict, clientip: str) -> Any: + token = authdict.get("token", None) + if token is None: + raise LoginError( + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN + ) + + from authlib.jose import JsonWebToken, JWTClaims + from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError + + jwt = JsonWebToken([self.hs.config.jwt.jwt_algorithm]) + claim_options = {} + if self.hs.config.jwt.jwt_issuer is not None: + claim_options["iss"] = { + "value": self.hs.config.jwt.jwt_issuer, + "essential": True, + } + if self.hs.config.jwt.jwt_audiences is not None: + claim_options["aud"] = { + "values": self.hs.config.jwt.jwt_audiences, + "essential": True, + } + + try: + claims = jwt.decode( + token, + key=self.hs.config.jwt.jwt_secret, + claims_cls=JWTClaims, + claims_options=claim_options, + ) + except BadSignatureError: + # We handle this case separately to provide a better error message + raise LoginError( + 403, + "JWT validation failed: Signature verification failed", + errcode=Codes.FORBIDDEN, + ) + except JoseError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 403, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.FORBIDDEN, + ) + + try: + claims.validate(leeway=120) # allows 2 min of clock skew + + # Enforce the old behavior which is rolled out in productive + # servers: if the JWT contains an 'aud' claim but none is + # configured, the login attempt will fail + if claims.get("aud") is not None: + if ( + self.hs.config.jwt.jwt_audiences is None + or len(self.hs.config.jwt.jwt_audiences) == 0 + ): + raise InvalidClaimError("aud") + except JoseError as e: + raise LoginError( + 403, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.FORBIDDEN, + ) + + user = claims.get(self.hs.config.jwt.jwt_subject_claim, None) + if user is None: + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) + + return UserID(user, self.hs.hostname).to_string() + + INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [ DummyAuthChecker, TermsAuthChecker, @@ -321,5 +403,6 @@ async def check_auth(self, authdict: dict, clientip: str) -> Any: EmailIdentityAuthChecker, MsisdnAuthChecker, RegistrationTokenAuthChecker, + JwtAuthChecker, ] """A list of UserInteractiveAuthChecker classes"""