Skip to content

Commit

Permalink
Fix: Disable refresh token for inactive user. (#814)
Browse files Browse the repository at this point in the history
* Fix: Disable refresh token for inactive user.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add check in serailizer instead of blacklist mixin.

* Update tests.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused import.

* Correct error message

* Add test for deleted users.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ajay09 and pre-commit-ci[bot] authored Oct 30, 2024
1 parent d66d246 commit 79a0d52
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
18 changes: 17 additions & 1 deletion rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.contrib.auth.models import AbstractBaseUser, update_last_login
from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, serializers
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import AuthenticationFailed, ValidationError

from .models import TokenUser
from .settings import api_settings
Expand Down Expand Up @@ -104,9 +104,25 @@ class TokenRefreshSerializer(serializers.Serializer):
access = serializers.CharField(read_only=True)
token_class = RefreshToken

default_error_messages = {
"no_active_account": _("No active account found for the given token.")
}

def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
refresh = self.token_class(attrs["refresh"])

user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
if user_id and (
user := get_user_model().objects.get(
**{api_settings.USER_ID_FIELD: user_id}
)
):
if not api_settings.USER_AUTHENTICATION_RULE(user):
raise AuthenticationFailed(
self.error_messages["no_active_account"],
"no_active_account",
)

data = {"access": str(refresh.access_token)}

if api_settings.ROTATE_REFRESH_TOKENS:
Expand Down
49 changes: 49 additions & 0 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from django.conf import settings
from django.contrib.auth import get_user_model
from django.core import exceptions as django_exceptions
from django.test import TestCase
from rest_framework import exceptions as drf_exceptions

Expand Down Expand Up @@ -247,6 +248,54 @@ def test_it_should_update_token_exp_claim_if_everything_ok(self):


class TestTokenRefreshSerializer(TestCase):
def setUp(self):
self.username = "test_user"
self.password = "test_password"

self.user = User.objects.create_user(
username=self.username,
password=self.password,
)

def test_it_should_raise_error_for_deleted_users(self):
refresh = RefreshToken.for_user(self.user)
self.user.delete()

s = TokenRefreshSerializer(data={"refresh": str(refresh)})

with self.assertRaises(django_exceptions.ObjectDoesNotExist) as e:
s.is_valid()

self.assertIn("does not exist", str(e.exception))

def test_it_should_raise_error_for_inactive_users(self):
refresh = RefreshToken.for_user(self.user)
self.user.is_active = False
self.user.save()

s = TokenRefreshSerializer(data={"refresh": str(refresh)})

with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
s.is_valid()

self.assertIn("No active account", e.exception.args[0])

def test_it_should_return_access_token_for_active_users(self):
refresh = RefreshToken.for_user(self.user)

s = TokenRefreshSerializer(data={"refresh": str(refresh)})

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
s.is_valid()

access = AccessToken(s.validated_data["access"])

self.assertEqual(
access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
)

def test_it_should_raise_token_error_if_token_invalid(self):
token = RefreshToken()
del token["exp"]
Expand Down

0 comments on commit 79a0d52

Please sign in to comment.