Skip to content

Commit

Permalink
feat: tests working with new permission
Browse files Browse the repository at this point in the history
Signed-off-by: Léo-Paul HAUET <[email protected]>
  • Loading branch information
IC-1101asterisk committed Jun 30, 2023
1 parent f4633d6 commit 712bd38
Show file tree
Hide file tree
Showing 20 changed files with 217 additions and 215 deletions.
8 changes: 3 additions & 5 deletions backend/api/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,11 @@ def create_user(self):

if created:
user.set_password(self.password)
# for testing purpose most authentication are done without channel allowing to mock passing channel in
# header, this check is necessary to not break previous tests but irl a user cannot be created
# without a channel
user.save()
if self.channel:
UserChannel.objects.create(user=user, channel_name=self.channel, role=self.role)
user.save()
self.user = user

self.user = user

def request(self, **kwargs):
# create user
Expand Down
21 changes: 6 additions & 15 deletions backend/api/tests/views/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import functools
import pathlib
from typing import Final

import pytest
from django import conf
from rest_framework import test

from api.models import ComputeTask
Expand All @@ -12,22 +10,15 @@
from api.tests.common import AuthenticatedClient

_CHANNEL_NAME: Final[str] = "mychannel"
_EXTRA_HTTP_HEADERS: Final[dict[str, str]] = {"HTTP_SUBSTRA_CHANNEL_NAME": _CHANNEL_NAME}


@pytest.fixture(autouse=True)
def _set_settings(settings: conf.Settings, tmp_path: pathlib.Path):
settings.MEDIA_ROOT = tmp_path.resolve()
settings.LEDGER_CHANNELS = {_CHANNEL_NAME: {"chaincode": {"name": "mycc"}, "model_export_enabled": True}}


@pytest.fixture
def authenticated_client() -> test.APIClient:
client = AuthenticatedClient()

client.get = functools.partial(client.get, **_EXTRA_HTTP_HEADERS)
client.post = functools.partial(client.post, **_EXTRA_HTTP_HEADERS)
client.delete = functools.partial(client.delete, **_EXTRA_HTTP_HEADERS)
client.get = functools.partial(client.get)
client.post = functools.partial(client.post)
client.delete = functools.partial(client.delete)

return client

Expand All @@ -36,9 +27,9 @@ def authenticated_client() -> test.APIClient:
def authenticated_backend_client() -> test.APIClient:
client = AuthenticatedBackendClient()

client.get = functools.partial(client.get, **_EXTRA_HTTP_HEADERS)
client.post = functools.partial(client.post, **_EXTRA_HTTP_HEADERS)
client.put = functools.partial(client.put, **_EXTRA_HTTP_HEADERS)
client.get = functools.partial(client.get)
client.post = functools.partial(client.post)
client.put = functools.partial(client.put)

return client

Expand Down
6 changes: 3 additions & 3 deletions backend/api/tests/views/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_download_file_local_allowed(self):
self.assertIn("local-organization", metadata.permissions_process_authorized_ids)

with mock.patch("api.views.utils.get_owner", return_value="local-organization"):
response = self.client.get(self.function_url, **self.extra)
response = self.client.get(self.function_url)

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.headers["Content-Disposition"], f'attachment; filename="{self.function_filename}"')
Expand All @@ -72,7 +72,7 @@ def test_download_file_local_denied(self):
metadata.save()

with mock.patch("api.views.utils.get_owner", return_value="local-organization"):
response = self.client.get(self.function_url, **self.extra)
response = self.client.get(self.function_url)

self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

Expand All @@ -92,7 +92,7 @@ def test_download_file_remote_allowed(self):
body=self.function_content,
content_type="text/plain; charset=utf-8",
)
response = self.client.get(self.function_url, **self.extra)
response = self.client.get(self.function_url)
mocked_responses.assert_call_count(metadata.function_address, 1)

self.assertEqual(response.status_code, status.HTTP_200_OK)
Expand Down
26 changes: 11 additions & 15 deletions backend/api/tests/views/test_views_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from substrapp.tests.common import get_description_function
from substrapp.tests.common import get_sample_function
from users.models.token import ImplicitBearerToken
from users.models.user_channel import UserChannel

MEDIA_ROOT = tempfile.mkdtemp()

Expand Down Expand Up @@ -56,11 +57,12 @@ def setUpTestData(cls):
user, created = User.objects.get_or_create(username="foo")
if created:
user.set_password("bar")
UserChannel.objects.create(user=user, channel_name="mychannel", role=UserChannel.Role.USER)
user.save()
cls.user = user

def test_authentication_fail(self):
response = self.client.get(self.function_url, **self.extra)
response = self.client.get(self.function_url)

self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code)

Expand All @@ -80,7 +82,7 @@ def test_authentication_with_bad_settings_credentials_fail(self):
authorization_header = generate_basic_auth_header("unauthorized_username", "unauthorized_password")

self.client.credentials(HTTP_AUTHORIZATION=authorization_header)
response = self.client.get(self.function_url, **self.extra)
response = self.client.get(self.function_url)

self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code)

Expand All @@ -103,23 +105,23 @@ def test_authentication_with_organization_fail(self):

for header in bad_authorization_headers:
self.client.credentials(HTTP_AUTHORIZATION=header)
response = self.client.get(self.function_url, **self.extra)
response = self.client.get(self.function_url)

self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code)

def test_obtain_token(self):
endpoint = "/api-token-auth/"
# clean use
response = self.client.post(endpoint, {"username": "foo", "password": "baz"}, **self.extra)
response = self.client.post(endpoint, {"username": "foo", "password": "baz"})
self.assertEqual(response.status_code, 400)

response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra)
response = self.client.post(endpoint, {"username": "foo", "password": "bar"})
self.assertEqual(response.status_code, 200)
token_old = response.json()["token"]
self.assertTrue(token_old)

# token should be updated after a second post
response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra)
response = self.client.post(endpoint, {"username": "foo", "password": "bar"})
self.assertEqual(response.status_code, 200)
token = response.json()["token"]
self.assertTrue(token)
Expand All @@ -137,20 +139,14 @@ def test_obtain_token(self):
self.client.credentials(HTTP_AUTHORIZATION=valid_auth_token_header)

with mock.patch("api.views.utils.get_owner", return_value="foo"):
response = self.client.get(self.function_url, **self.extra)

response = self.client.get(self.function_url)
self.assertEqual(status.HTTP_200_OK, response.status_code)

# usage with an existing token
# the token should be ignored since the purpose of the view is to authenticate via user/password
valid_auth_token_header = f"Token {token}"
self.client.credentials(HTTP_AUTHORIZATION=valid_auth_token_header)
response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra)
self.assertEqual(response.status_code, 200)

invalid_auth_token_header = "Token nope"
self.client.credentials(HTTP_AUTHORIZATION=invalid_auth_token_header)
response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra)
response = self.client.post(endpoint, {"username": "foo", "password": "bar"})
self.assertEqual(response.status_code, 200)


Expand All @@ -175,7 +171,7 @@ def setUpTestData(cls):

def _login(self):
data = {"username": self.username, "password": self.password}
r = self.client.post(self.login_url, data, **self.extra)
r = self.client.post(self.login_url, data)

return r.status_code, r

Expand Down
6 changes: 2 additions & 4 deletions backend/api/tests/views/test_views_compute_plan_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def setUp(self):
if not os.path.exists(MEDIA_ROOT):
os.makedirs(MEDIA_ROOT)

self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"}
self.extra = {"HTTP_ACCEPT": "application/json;version=0.0"}
self.base_url = "api:workflow_graph"

def tearDown(self):
Expand Down Expand Up @@ -236,9 +236,7 @@ def test_n_plus_one_queries_compute_graph(authenticated_client):
url = reverse("api:workflow_graph", args=[compute_plan.key])

with utils.CaptureQueriesContext(connection) as queries:
authenticated_client.get(
url, {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"}
)
authenticated_client.get(url, {"HTTP_ACCEPT": "application/json;version=0.0"})
queries = len(queries.captured_queries)

assert queries < 15
3 changes: 2 additions & 1 deletion backend/api/tests/views/test_views_computeplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def test_computeplan_list_success(self):
)

def test_computeplan_list_wrong_channel(self):
extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"}
extra = {"HTTP_ACCEPT": "application/json;version=0.0"}
self.client.channel = "yourchannel"
response = self.client.get(self.url, **extra)
self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []})

Expand Down
Loading

0 comments on commit 712bd38

Please sign in to comment.