Skip to content

Commit

Permalink
feat: implement method to return a user's transit agency
Browse files Browse the repository at this point in the history
  • Loading branch information
angela-tran committed Aug 16, 2024
1 parent e00240c commit d4f04fc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
12 changes: 11 additions & 1 deletion benefits/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from django.conf import settings
from django.core.exceptions import ValidationError
from django.contrib.auth.models import Group
from django.contrib.auth.models import Group, User
from django.db import models
from django.urls import reverse

Expand Down Expand Up @@ -435,3 +435,13 @@ def all_active():
"""Get all TransitAgency instances marked active."""
logger.debug(f"Get all active {TransitAgency.__name__}")
return TransitAgency.objects.filter(active=True)

@staticmethod
def for_user(user: User):
group = user.groups.first()

if group is not None:
# TransitAgency to Group is one-to-one, so there will be either 0 or 1 returned
return TransitAgency.objects.filter(group=group).first()
else:
return None
16 changes: 16 additions & 0 deletions tests/pytest/core/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.exceptions import ValidationError

import pytest
Expand Down Expand Up @@ -464,3 +465,18 @@ def test_TransitAgency_all_active(model_TransitAgency):
assert len(result) > 0
assert model_TransitAgency in result
assert inactive_agency not in result


@pytest.mark.django_db
def test_TransitAgency_for_user(model_TransitAgency):
group = Group.objects.create(name="test_group")

agency_for_user = TransitAgency.by_id(model_TransitAgency.id)
agency_for_user.pk = None
agency_for_user.group = group
agency_for_user.save()

user = User.objects.create_user(username="test_user", email="[email protected]", password="test", is_staff=True)
user.groups.add(group)

assert TransitAgency.for_user(user) == agency_for_user

0 comments on commit d4f04fc

Please sign in to comment.