From 79a0d5289cb9e15a0a34df646624e802e89dd20f Mon Sep 17 00:00:00 2001 From: Ajay Singh Date: Wed, 30 Oct 2024 21:01:13 +0000 Subject: [PATCH] Fix: Disable refresh token for inactive user. (#814) * Fix: Disable refresh token for inactive user. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add check in serailizer instead of blacklist mixin. * Update tests. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused import. * Correct error message * Add test for deleted users. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- rest_framework_simplejwt/serializers.py | 18 ++++++++- tests/test_serializers.py | 49 +++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index 3dc318687..51138f8bf 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -5,7 +5,7 @@ from django.contrib.auth.models import AbstractBaseUser, update_last_login from django.utils.translation import gettext_lazy as _ from rest_framework import exceptions, serializers -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import AuthenticationFailed, ValidationError from .models import TokenUser from .settings import api_settings @@ -104,9 +104,25 @@ class TokenRefreshSerializer(serializers.Serializer): access = serializers.CharField(read_only=True) token_class = RefreshToken + default_error_messages = { + "no_active_account": _("No active account found for the given token.") + } + def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]: refresh = self.token_class(attrs["refresh"]) + user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None) + if user_id and ( + user := get_user_model().objects.get( + **{api_settings.USER_ID_FIELD: user_id} + ) + ): + if not api_settings.USER_AUTHENTICATION_RULE(user): + raise AuthenticationFailed( + self.error_messages["no_active_account"], + "no_active_account", + ) + data = {"access": str(refresh.access_token)} if api_settings.ROTATE_REFRESH_TOKENS: diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 322d1cd9d..6ec6a8083 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -4,6 +4,7 @@ from django.conf import settings from django.contrib.auth import get_user_model +from django.core import exceptions as django_exceptions from django.test import TestCase from rest_framework import exceptions as drf_exceptions @@ -247,6 +248,54 @@ def test_it_should_update_token_exp_claim_if_everything_ok(self): class TestTokenRefreshSerializer(TestCase): + def setUp(self): + self.username = "test_user" + self.password = "test_password" + + self.user = User.objects.create_user( + username=self.username, + password=self.password, + ) + + def test_it_should_raise_error_for_deleted_users(self): + refresh = RefreshToken.for_user(self.user) + self.user.delete() + + s = TokenRefreshSerializer(data={"refresh": str(refresh)}) + + with self.assertRaises(django_exceptions.ObjectDoesNotExist) as e: + s.is_valid() + + self.assertIn("does not exist", str(e.exception)) + + def test_it_should_raise_error_for_inactive_users(self): + refresh = RefreshToken.for_user(self.user) + self.user.is_active = False + self.user.save() + + s = TokenRefreshSerializer(data={"refresh": str(refresh)}) + + with self.assertRaises(drf_exceptions.AuthenticationFailed) as e: + s.is_valid() + + self.assertIn("No active account", e.exception.args[0]) + + def test_it_should_return_access_token_for_active_users(self): + refresh = RefreshToken.for_user(self.user) + + s = TokenRefreshSerializer(data={"refresh": str(refresh)}) + + now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2 + with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow: + fake_aware_utcnow.return_value = now + s.is_valid() + + access = AccessToken(s.validated_data["access"]) + + self.assertEqual( + access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME) + ) + def test_it_should_raise_token_error_if_token_invalid(self): token = RefreshToken() del token["exp"]