diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml index ae5263d73..580a2594d 100644 --- a/.github/workflows/helm.yml +++ b/.github/workflows/helm.yml @@ -12,7 +12,7 @@ on: paths: - "charts/**" -concurrency: +concurrency: group: "${{ github.workflow }}-${{ github.ref }}" cancel-in-progress: true @@ -57,8 +57,8 @@ jobs: - name: Checkout bitnami-labs/readme-generator-for-helm uses: actions/checkout@v3 with: - repository: 'bitnami-labs/readme-generator-for-helm' - ref: '2.5.0' + repository: "bitnami-labs/readme-generator-for-helm" + ref: "2.5.0" path: readme-generator-for-helm - name: Install readme-generator-for-helm dependencies @@ -70,18 +70,18 @@ jobs: - name: Execute readme-generator-for-helm run: readme-generator-for-helm/bin/index.js -r backend/charts/substra-backend/README.md -v backend/charts/substra-backend/values.yaml - + - name: Check diff run: | cd backend/ if [ -z "$(git status --porcelain)" ]; then exit 0 else - echo "There should be no change generated, please run 'make chart-doc' to update the chart README.md" + echo "There should be no change generated, please run 'make doc' in backend/charts/ to update the chart README.md" git diff exit 1 fi - + publish: name: Publish runs-on: ubuntu-latest @@ -91,12 +91,12 @@ jobs: - generate-chart-readme steps: - uses: actions/checkout@v3 - + - uses: azure/setup-helm@v3.5 with: version: "v3.12.0" id: install - + - name: Add dependencies repo run: | helm repo add bitnami https://charts.bitnami.com/bitnami @@ -112,7 +112,7 @@ jobs: uses: actions/checkout@v3 with: repository: Substra/charts - ref: 'main' + ref: "main" token: ${{ secrets.CHARTS_GITHUB_TOKEN }} path: substra-charts diff --git a/CHANGELOG.md b/CHANGELOG.md index 532da5821..2896e681d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove pagination on `get_performances` to remove limitation on 1000 first points ([#690](https://github.com/Substra/substra-backend/pull/690)) +### Added + +- New UserAwaitingApproval (base user with no channel) ([#680](https://github.com/Substra/substra-backend/pull/680)) + ## [0.39.0](https://github.com/Substra/substra-backend/releases/tag/0.39.0) 2023-06-27 ### Added @@ -30,7 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- references to `substra` cli commands in `localdev.md` ([#667](https://github.com/Substra/substra-backend/pull/667)) +- references to `substra` cli commands in `localdev.md` ([#667](https://github.com/Substra/substra-backend/pull/667)) ## [0.37.0](https://github.com/Substra/substra-backend/releases/tag/0.37.0) 2023-05-11 diff --git a/backend/api/tests/common.py b/backend/api/tests/common.py index dd2154c7c..f4334f9ff 100644 --- a/backend/api/tests/common.py +++ b/backend/api/tests/common.py @@ -25,9 +25,10 @@ def generate_jwt_auth_header(jwt): class AuthenticatedClient(APIClient): def __init__( self, + *, + channel="mychannel", enforce_csrf_checks=False, role=UserChannel.Role.USER, - channel=None, username="substra", password="p@sswr0d44", **defaults, @@ -41,14 +42,13 @@ def __init__( def create_user(self): user, created = User.objects.get_or_create(username=self.username) + if created: user.set_password(self.password) user.save() - # for testing purpose most authentication are done without channel allowing to mock passing channel in - # header, this check is necessary to not break previous tests but irl a user cannot be created - # without a channel if self.channel: UserChannel.objects.create(user=user, channel_name=self.channel, role=self.role) + self.user = user def request(self, **kwargs): diff --git a/backend/api/tests/views/conftest.py b/backend/api/tests/views/conftest.py index 662502c4b..22dcd7beb 100644 --- a/backend/api/tests/views/conftest.py +++ b/backend/api/tests/views/conftest.py @@ -1,9 +1,4 @@ -import functools -import pathlib -from typing import Final - import pytest -from django import conf from rest_framework import test from api.models import ComputeTask @@ -11,24 +6,11 @@ from api.tests.common import AuthenticatedBackendClient from api.tests.common import AuthenticatedClient -_CHANNEL_NAME: Final[str] = "mychannel" -_EXTRA_HTTP_HEADERS: Final[dict[str, str]] = {"HTTP_SUBSTRA_CHANNEL_NAME": _CHANNEL_NAME} - - -@pytest.fixture(autouse=True) -def _set_settings(settings: conf.Settings, tmp_path: pathlib.Path): - settings.MEDIA_ROOT = tmp_path.resolve() - settings.LEDGER_CHANNELS = {_CHANNEL_NAME: {"chaincode": {"name": "mycc"}, "model_export_enabled": True}} - @pytest.fixture def authenticated_client() -> test.APIClient: client = AuthenticatedClient() - client.get = functools.partial(client.get, **_EXTRA_HTTP_HEADERS) - client.post = functools.partial(client.post, **_EXTRA_HTTP_HEADERS) - client.delete = functools.partial(client.delete, **_EXTRA_HTTP_HEADERS) - return client @@ -36,10 +18,6 @@ def authenticated_client() -> test.APIClient: def authenticated_backend_client() -> test.APIClient: client = AuthenticatedBackendClient() - client.get = functools.partial(client.get, **_EXTRA_HTTP_HEADERS) - client.post = functools.partial(client.post, **_EXTRA_HTTP_HEADERS) - client.put = functools.partial(client.put, **_EXTRA_HTTP_HEADERS) - return client diff --git a/backend/api/tests/views/test_utils.py b/backend/api/tests/views/test_utils.py index 487b9afd2..2472b3770 100644 --- a/backend/api/tests/views/test_utils.py +++ b/backend/api/tests/views/test_utils.py @@ -56,7 +56,7 @@ def test_download_file_local_allowed(self): self.assertIn("local-organization", metadata.permissions_process_authorized_ids) with mock.patch("api.views.utils.get_owner", return_value="local-organization"): - response = self.client.get(self.function_url, **self.extra) + response = self.client.get(self.function_url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.headers["Content-Disposition"], f'attachment; filename="{self.function_filename}"') @@ -72,7 +72,7 @@ def test_download_file_local_denied(self): metadata.save() with mock.patch("api.views.utils.get_owner", return_value="local-organization"): - response = self.client.get(self.function_url, **self.extra) + response = self.client.get(self.function_url) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -92,7 +92,7 @@ def test_download_file_remote_allowed(self): body=self.function_content, content_type="text/plain; charset=utf-8", ) - response = self.client.get(self.function_url, **self.extra) + response = self.client.get(self.function_url) mocked_responses.assert_call_count(metadata.function_address, 1) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/backend/api/tests/views/test_views_authentication.py b/backend/api/tests/views/test_views_authentication.py index 65f29a931..95b15691a 100644 --- a/backend/api/tests/views/test_views_authentication.py +++ b/backend/api/tests/views/test_views_authentication.py @@ -17,6 +17,7 @@ from substrapp.tests.common import get_description_function from substrapp.tests.common import get_sample_function from users.models.token import ImplicitBearerToken +from users.models.user_channel import UserChannel MEDIA_ROOT = tempfile.mkdtemp() @@ -56,11 +57,12 @@ def setUpTestData(cls): user, created = User.objects.get_or_create(username="foo") if created: user.set_password("bar") + UserChannel.objects.create(user=user, channel_name="mychannel", role=UserChannel.Role.USER) user.save() cls.user = user def test_authentication_fail(self): - response = self.client.get(self.function_url, **self.extra) + response = self.client.get(self.function_url) self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code) @@ -80,7 +82,7 @@ def test_authentication_with_bad_settings_credentials_fail(self): authorization_header = generate_basic_auth_header("unauthorized_username", "unauthorized_password") self.client.credentials(HTTP_AUTHORIZATION=authorization_header) - response = self.client.get(self.function_url, **self.extra) + response = self.client.get(self.function_url) self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code) @@ -103,23 +105,23 @@ def test_authentication_with_organization_fail(self): for header in bad_authorization_headers: self.client.credentials(HTTP_AUTHORIZATION=header) - response = self.client.get(self.function_url, **self.extra) + response = self.client.get(self.function_url) self.assertEqual(status.HTTP_401_UNAUTHORIZED, response.status_code) def test_obtain_token(self): endpoint = "/api-token-auth/" # clean use - response = self.client.post(endpoint, {"username": "foo", "password": "baz"}, **self.extra) + response = self.client.post(endpoint, {"username": "foo", "password": "baz"}) self.assertEqual(response.status_code, 400) - response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra) + response = self.client.post(endpoint, {"username": "foo", "password": "bar"}) self.assertEqual(response.status_code, 200) token_old = response.json()["token"] self.assertTrue(token_old) # token should be updated after a second post - response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra) + response = self.client.post(endpoint, {"username": "foo", "password": "bar"}) self.assertEqual(response.status_code, 200) token = response.json()["token"] self.assertTrue(token) @@ -137,20 +139,14 @@ def test_obtain_token(self): self.client.credentials(HTTP_AUTHORIZATION=valid_auth_token_header) with mock.patch("api.views.utils.get_owner", return_value="foo"): - response = self.client.get(self.function_url, **self.extra) - + response = self.client.get(self.function_url) self.assertEqual(status.HTTP_200_OK, response.status_code) # usage with an existing token # the token should be ignored since the purpose of the view is to authenticate via user/password valid_auth_token_header = f"Token {token}" self.client.credentials(HTTP_AUTHORIZATION=valid_auth_token_header) - response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra) - self.assertEqual(response.status_code, 200) - - invalid_auth_token_header = "Token nope" - self.client.credentials(HTTP_AUTHORIZATION=invalid_auth_token_header) - response = self.client.post(endpoint, {"username": "foo", "password": "bar"}, **self.extra) + response = self.client.post(endpoint, {"username": "foo", "password": "bar"}) self.assertEqual(response.status_code, 200) @@ -175,7 +171,7 @@ def setUpTestData(cls): def _login(self): data = {"username": self.username, "password": self.password} - r = self.client.post(self.login_url, data, **self.extra) + r = self.client.post(self.login_url, data) return r.status_code, r diff --git a/backend/api/tests/views/test_views_compute_plan_graph.py b/backend/api/tests/views/test_views_compute_plan_graph.py index e865f43d9..e1cc58804 100644 --- a/backend/api/tests/views/test_views_compute_plan_graph.py +++ b/backend/api/tests/views/test_views_compute_plan_graph.py @@ -106,7 +106,7 @@ def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} + self.extra = {"HTTP_ACCEPT": "application/json;version=0.0"} self.base_url = "api:workflow_graph" def tearDown(self): @@ -236,9 +236,7 @@ def test_n_plus_one_queries_compute_graph(authenticated_client): url = reverse("api:workflow_graph", args=[compute_plan.key]) with utils.CaptureQueriesContext(connection) as queries: - authenticated_client.get( - url, {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} - ) + authenticated_client.get(url, {"HTTP_ACCEPT": "application/json;version=0.0"}) queries = len(queries.captured_queries) assert queries < 15 diff --git a/backend/api/tests/views/test_views_computeplan.py b/backend/api/tests/views/test_views_computeplan.py index 1e21a2ad3..bd1fb44c2 100644 --- a/backend/api/tests/views/test_views_computeplan.py +++ b/backend/api/tests/views/test_views_computeplan.py @@ -325,8 +325,8 @@ def test_computeplan_list_success(self): ) def test_computeplan_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(self.url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(self.url, **self.extra) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) @internal_server_error_on_exception() diff --git a/backend/api/tests/views/test_views_computetask.py b/backend/api/tests/views/test_views_computetask.py index 493c7e472..bb19e4992 100644 --- a/backend/api/tests/views/test_views_computetask.py +++ b/backend/api/tests/views/test_views_computetask.py @@ -35,7 +35,6 @@ class ComputeTaskViewTests(APITestCase): def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} self.logger = logging.getLogger("django.request") self.previous_level = self.logger.getEffectiveLevel() @@ -245,7 +244,7 @@ def mock_register_compute_task(orc_request): url = reverse("api:task-bulk_create") with mock.patch.object(OrchestratorClient, "register_tasks", side_effect=mock_register_compute_task): - response = self.client.post(url, data=data, format="json", **self.extra) + response = self.client.post(url, data=data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK, response.data) assert response.json()[0] == expected_response[0] @@ -430,7 +429,7 @@ def setUp(self): self.done_task_key = done_task.key def test_task_list_success(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) # manually overriding duration for doing tasks as "now" is taken from db and not timezone.now(), # couldn't be properly mocked for task in response.json().get("results"): @@ -447,21 +446,21 @@ def test_task_list_success(self): ) def test_task_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(self.url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(self.url) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) @internal_server_error_on_exception() @mock.patch("api.views.computetask.ComputeTaskViewSet.list", side_effect=Exception("Unexpected error")) def test_task_list_fail(self, _): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) def test_task_list_filter(self): """Filter task on key.""" key = self.list_expected_results[0]["key"] params = urlencode({"key": key}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.list_expected_results[:1]} ) @@ -470,7 +469,7 @@ def test_task_list_filter_and(self): """Filter task on key and owner.""" key, owner = self.list_expected_results[0]["key"], self.list_expected_results[0]["owner"] params = urlencode({"key": key, "owner": owner}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.list_expected_results[:1]} ) @@ -480,7 +479,7 @@ def test_task_list_filter_in(self): key_0 = self.list_expected_results[0]["key"] key_1 = self.list_expected_results[1]["key"] params = urlencode({"key": ",".join([key_0, key_1])}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 2, "next": None, "previous": None, "results": self.list_expected_results[:2]} ) @@ -500,7 +499,7 @@ def test_task_list_filter_by_status(self, t_status): """Filter task on status.""" filtered_train_tasks = [task for task in self.list_expected_results if task["status"] == t_status] params = urlencode({"status": t_status}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") if t_status != "STATUS_XXX": if t_status == ComputeTask.Status.STATUS_DOING: @@ -527,7 +526,7 @@ def test_task_list_filter_by_status_in(self, t_statuses): """Filter task on status.""" filtered_train_tasks = [task for task in self.list_expected_results if task["status"] in t_statuses] params = urlencode({"status": ",".join(t_statuses)}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") if "STATUS_XXX" not in t_statuses: if ComputeTask.Status.STATUS_DOING in t_statuses: @@ -555,7 +554,7 @@ def test_task_match(self): # this will be handled as 2 tokens, so items matching both XXXX and YYYYYYYYYYYY will be returned # this should be enough to guarantee that there will only be one matching task params = urlencode({"match": key[19:]}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertDictEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.list_expected_results[:1]} ) @@ -570,7 +569,7 @@ def test_task_match_and_filter(self): "match": key[19:], } ) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.list_expected_results[:1]} ) @@ -584,7 +583,7 @@ def test_task_match_and_filter(self): ) def test_task_list_pagination_success(self, _, page_size, page): params = urlencode({"page_size": page_size, "page": page}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") r = response.json() # manually overriding duration for doing tasks as "now" is taken from db and not timezone.now(), # couldn't be properly mocked @@ -598,7 +597,7 @@ def test_task_list_pagination_success(self, _, page_size, page): def test_task_cp_list_success(self): """List tasks for a specific compute plan (CPTaskViewSet).""" url = reverse("api:compute_plan_task-list", args=[self.compute_plan.key]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) # manually overriding duration for doing tasks as "now" is taken from db and not timezone.now(), # couldn't be properly mocked for task in response.json().get("results"): @@ -621,7 +620,7 @@ def test_task_list_cross_assets_filters(self): ] for params in params_list: - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") # manually overriding duration for doing tasks as "now" is taken from db and not timezone.now(), # couldn't be properly mocked for task in response.json().get("results"): @@ -631,12 +630,12 @@ def test_task_list_cross_assets_filters(self): # filter on wrong key params = urlencode({"function_key": self.data_manager.key}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(len(response.json().get("results")), 0) def test_task_list_ordering(self): params = urlencode({"ordering": "creation_date"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") # manually overriding duration for doing tasks as "now" is taken from db and not timezone.now(), # couldn't be properly mocked for task in response.json().get("results"): @@ -645,7 +644,7 @@ def test_task_list_ordering(self): self.assertEqual(response.json().get("results"), self.list_expected_results), params = urlencode({"ordering": "-creation_date"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") # manually overriding duration for doing tasks as "now" is taken from db and not timezone.now(), # couldn't be properly mocked for task in response.json().get("results"): @@ -655,26 +654,26 @@ def test_task_list_ordering(self): def test_task_retrieve(self): url = reverse("api:task-detail", args=[self.detail_expected_results["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) self.assertEqual(response.json(), self.detail_expected_results) def test_task_retrieve_wrong_channel(self): url = reverse("api:task-detail", args=[self.detail_expected_results["key"]]) - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @internal_server_error_on_exception() @mock.patch("api.views.computetask.ComputeTaskViewSet.retrieve", side_effect=Exception("Unexpected error")) def test_task_retrieve_fail(self, _): url = reverse("api:task-detail", args=[self.detail_expected_results["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) def test_task_list_input_assets(self): url = reverse("api:task-input_assets", args=[self.done_task_key]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) expected_results = [ { "identifier": "datasamples", @@ -696,33 +695,33 @@ def test_task_list_input_assets_filter(self): url = reverse("api:task-input_assets", args=[self.done_task_key]) # base response should contain a datamanager and a datasample - response = self.client.get(url, **self.extra) + response = self.client.get(url) data = response.json() assert data["count"] == 2 # single filter - response = self.client.get(url, data={"kind": "ASSET_DATA_MANAGER"}, **self.extra) + response = self.client.get(url, data={"kind": "ASSET_DATA_MANAGER"}) data = response.json() assert data["count"] == 1 # multi filter - response = self.client.get(url, data={"kind": "ASSET_DATA_MANAGER,ASSET_MODEL"}, **self.extra) + response = self.client.get(url, data={"kind": "ASSET_DATA_MANAGER,ASSET_MODEL"}) data = response.json() assert data["count"] == 1 # invalid filter - response = self.client.get(url, data={"kind": "foo"}, **self.extra) + response = self.client.get(url, data={"kind": "foo"}) data = response.json() assert data["count"] == 2 # invalid multi filter - response = self.client.get(url, data={"kind": "ASSET_DATA_MANAGER,foo"}, **self.extra) + response = self.client.get(url, data={"kind": "ASSET_DATA_MANAGER,foo"}) data = response.json() assert data["count"] == 2 def test_task_list_output_assets(self): url = reverse("api:task-output_assets", args=[self.done_task_key]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) expected_results = [ { "identifier": "model", @@ -739,27 +738,27 @@ def test_task_list_output_assets_filter(self): url = reverse("api:task-output_assets", args=[self.done_task_key]) # base response should contain a model - response = self.client.get(url, **self.extra) + response = self.client.get(url) data = response.json() assert data["count"] == 1 # single filter - response = self.client.get(url, data={"kind": "ASSET_PERFORMANCE"}, **self.extra) + response = self.client.get(url, data={"kind": "ASSET_PERFORMANCE"}) data = response.json() assert data["count"] == 0 # multi filter - response = self.client.get(url, data={"kind": "ASSET_PERFORMANCE,ASSET_MODEL"}, **self.extra) + response = self.client.get(url, data={"kind": "ASSET_PERFORMANCE,ASSET_MODEL"}) data = response.json() assert data["count"] == 1 # invalid filter - response = self.client.get(url, data={"kind": "foo"}, **self.extra) + response = self.client.get(url, data={"kind": "foo"}) data = response.json() assert data["count"] == 1 # invalid multi filter - response = self.client.get(url, data={"kind": "ASSET_PERFORMANCE,foo"}, **self.extra) + response = self.client.get(url, data={"kind": "ASSET_PERFORMANCE,foo"}) data = response.json() assert data["count"] == 1 @@ -788,8 +787,8 @@ def test_n_plus_one_queries_compute_task_in_compute_plan(authenticated_client, c queries_for_10_tasks = len(queries_10.captured_queries) - assert abs(queries_for_60_tasks - queries_for_10_tasks) < 5 - assert queries_for_60_tasks < 17 + assert abs(queries_for_60_tasks - queries_for_10_tasks) < 6 + assert queries_for_60_tasks < 19 @pytest.mark.django_db @@ -816,7 +815,7 @@ def test_n_plus_one_queries_compute_task_detail(authenticated_client, create_com authenticated_client.get(url_10) queries_for_10_samples = len(queries_10.captured_queries) - assert abs(queries_for_4_samples - queries_for_10_samples) < 5 + assert abs(queries_for_4_samples - queries_for_10_samples) < 6 assert queries_for_4_samples < 20 @@ -844,5 +843,5 @@ def test_n_plus_one_queries_compute_task_list(authenticated_client, create_compu authenticated_client.get(url) queries_for_60_tasks = len(queries_60.captured_queries) - assert abs(queries_for_60_tasks - queries_for_10_tasks) < 5 + assert abs(queries_for_60_tasks - queries_for_10_tasks) < 6 assert queries_for_60_tasks < 15 diff --git a/backend/api/tests/views/test_views_datamanager.py b/backend/api/tests/views/test_views_datamanager.py index 9445270a9..5ac64d173 100644 --- a/backend/api/tests/views/test_views_datamanager.py +++ b/backend/api/tests/views/test_views_datamanager.py @@ -39,7 +39,7 @@ def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) self.url = reverse("api:data_manager-list") - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} + self.extra = {"HTTP_ACCEPT": "application/json;version=0.0"} self.logger = logging.getLogger("django.request") self.previous_level = self.logger.getEffectiveLevel() @@ -195,8 +195,8 @@ def test_datamanager_list_success(self): ) def test_datamanager_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(self.url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(self.url, **self.extra) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) @internal_server_error_on_exception() @@ -502,8 +502,8 @@ def test_datamanager_retrieve_with_tasks(self): def test_datamanager_retrieve_wrong_channel(self): url = reverse("api:data_manager-detail", args=[self.expected_results[0]["key"]]) - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(url, **self.extra) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_datamanager_retrieve_storage_addresses_update(self): diff --git a/backend/api/tests/views/test_views_datasample.py b/backend/api/tests/views/test_views_datasample.py index c39dd1295..2f4e3bd63 100644 --- a/backend/api/tests/views/test_views_datasample.py +++ b/backend/api/tests/views/test_views_datasample.py @@ -61,7 +61,7 @@ def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) self.url = reverse("api:data_sample-list") - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} + self.extra = {"HTTP_ACCEPT": "application/json;version=0.0"} self.logger = logging.getLogger("django.request") self.previous_level = self.logger.getEffectiveLevel() @@ -110,8 +110,8 @@ def test_datasample_retrieve(self): def test_datasample_retrieve_wrong_channel(self): url = reverse("api:data_sample-detail", args=[self.expected_results[0]["key"]]) - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(url, **self.extra) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @internal_server_error_on_exception() @@ -134,8 +134,8 @@ def test_datasample_list_success(self): ) def test_datasample_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(self.url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(self.url, **self.extra) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) @internal_server_error_on_exception() diff --git a/backend/api/tests/views/test_views_function.py b/backend/api/tests/views/test_views_function.py index 57be631b2..2ef4b8029 100644 --- a/backend/api/tests/views/test_views_function.py +++ b/backend/api/tests/views/test_views_function.py @@ -38,7 +38,7 @@ class FunctionViewTests(APITestCase): def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} + self.logger = logging.getLogger("django.request") self.previous_level = self.logger.getEffectiveLevel() self.logger.setLevel(logging.ERROR) @@ -246,11 +246,11 @@ def tearDown(self): def test_function_list_empty(self): Function.objects.all().delete() - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) def test_function_list_success(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), { @@ -262,14 +262,14 @@ def test_function_list_success(self): ) def test_function_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(self.url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(self.url) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) @internal_server_error_on_exception() @mock.patch("api.views.function.FunctionViewSet.list", side_effect=Exception("Unexpected error")) def test_function_list_fail(self, _): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) def test_function_list_storage_addresses_update(self): @@ -278,7 +278,7 @@ def test_function_list_storage_addresses_update(self): function.function_address.replace("http://testserver", "http://remotetestserver") function.save() - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.data["count"], len(self.expected_functions)) for result, function in zip(response.data["results"], self.expected_functions): for field in ("description", "function"): @@ -288,7 +288,7 @@ def test_function_list_filter(self): """Filter function on key.""" key = self.expected_functions[0]["key"] params = urlencode({"key": key}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.expected_functions[:1]} ) @@ -297,7 +297,7 @@ def test_function_list_filter_and(self): """Filter function on key and owner.""" key, owner = self.expected_functions[0]["key"], self.expected_functions[0]["owner"] params = urlencode({"key": key, "owner": owner}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.expected_functions[:1]} ) @@ -307,7 +307,7 @@ def test_function_list_filter_in(self): key_0 = self.expected_functions[0]["key"] key_1 = self.expected_functions[1]["key"] params = urlencode({"key": ",".join([key_0, key_1])}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 2, "next": None, "previous": None, "results": self.expected_functions[:2]} ) @@ -315,7 +315,7 @@ def test_function_list_filter_in(self): def test_function_match(self): """Match function on part of the name.""" params = urlencode({"match": "le fu"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.expected_functions[:1]} ) @@ -328,7 +328,7 @@ def test_function_match_and_filter(self): "match": "le fu", } ) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.expected_functions[:1]} ) @@ -342,16 +342,16 @@ def test_function_list_compute_plan_key_filter(self): # filter on compute_plan_key params = urlencode({"compute_plan_key": compute_plan.key}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), self.expected_functions[:2]) def test_function_list_ordering(self): params = urlencode({"ordering": "creation_date"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), self.expected_functions), params = urlencode({"ordering": "-creation_date"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), self.expected_functions[::-1]), @parameterized.expand( @@ -363,7 +363,7 @@ def test_function_list_ordering(self): ) def test_function_list_pagination_success(self, _, page_size, page): params = urlencode({"page_size": page_size, "page": page}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") r = response.json() self.assertEqual(r["count"], len(self.expected_functions)) offset = (page - 1) * page_size @@ -377,7 +377,7 @@ def test_function_cp_list_success(self): factory.create_computetask(compute_plan, self.functions[1]) url = reverse("api:compute_plan_function-list", args=[compute_plan.key]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) self.assertEqual( response.json(), { @@ -400,19 +400,19 @@ def test_function_list_can_process(self): self.expected_functions[1]["permissions"]["process"]["authorized_ids"] = ["MyOrg1MSP", "MyOrg2MSP"] params = urlencode({"can_process": "MyOrg1MSP"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), self.expected_functions), params = urlencode({"can_process": "MyOrg2MSP"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), self.expected_functions[:2]), params = urlencode({"can_process": "MyOrg3MSP"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), [self.expected_functions[0]]), params = urlencode({"can_process": "MyOrg1MSP,MyOrg2MSP"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual(response.json().get("results"), self.expected_functions[:2]), @parameterized.expand( @@ -476,7 +476,7 @@ def mock_orc_response(data): } with mock.patch.object(OrchestratorClient, "register_function", side_effect=mock_orc_response): - response = self.client.post(self.url, data=data, format="multipart", **self.extra) + response = self.client.post(self.url, data=data, format="multipart") self.assertIsNotNone(response.data["key"]) self.assertEqual(response.status_code, status.HTTP_201_CREATED) # asset created in local db @@ -513,7 +513,7 @@ def test_file_size_limit(self): "description": open(description_path, "rb"), } - response = self.client.post(self.url, data=data, format="multipart", **self.extra) + response = self.client.post(self.url, data=data, format="multipart") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("File too large", response.data["message"][0]["file"]) @@ -555,7 +555,7 @@ def __init__(self) -> None: } with mock.patch.object(OrchestratorClient, "register_function", side_effect=MockOrcError()): - response = self.client.post(self.url, data=data, format="multipart", **self.extra) + response = self.client.post(self.url, data=data, format="multipart") # asset not created in local db self.assertEqual(Function.objects.count(), len(self.expected_functions)) # orc error code should be propagated @@ -569,13 +569,13 @@ def test_function_create_fail(self, _): def test_function_retrieve(self): url = reverse("api:function-detail", args=[self.expected_functions[0]["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) self.assertEqual(response.json(), self.expected_functions[0]) def test_function_retrieve_wrong_channel(self): url = reverse("api:function-detail", args=[self.expected_functions[0]["key"]]) - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_function_retrieve_storage_addresses_update(self): @@ -585,7 +585,7 @@ def test_function_retrieve_storage_addresses_update(self): function.save() url = reverse("api:function-detail", args=[self.expected_functions[0]["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) for field in ("description", "function"): self.assertEqual( response.data[field]["storage_address"], self.expected_functions[0][field]["storage_address"] @@ -595,7 +595,7 @@ def test_function_retrieve_storage_addresses_update(self): @mock.patch("api.views.function.FunctionViewSet.retrieve", side_effect=Exception("Unexpected error")) def test_function_retrieve_fail(self, _): url = reverse("api:function-detail", args=[self.expected_functions[0]["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) def test_function_download_file(self): @@ -603,7 +603,7 @@ def test_function_download_file(self): function = factory.create_function(key=function_files.key) url = reverse("api:function-file", args=[function.key]) with mock.patch("api.views.utils.get_owner", return_value=function.owner): - response = self.client.get(url, **self.extra) + response = self.client.get(url) content = response.getvalue() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(content, function_files.file.read()) @@ -614,7 +614,7 @@ def test_function_download_description(self): function = factory.create_function(key=function_files.key) url = reverse("api:function-description", args=[function.key]) with mock.patch("api.views.utils.get_owner", return_value=function.owner): - response = self.client.get(url, **self.extra) + response = self.client.get(url) content = response.getvalue() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(content, function_files.description.read()) @@ -630,7 +630,7 @@ def test_function_update(self): function["name"] = data["name"] with mock.patch.object(OrchestratorClient, "update_function", side_effect=function): - response = self.client.put(url, data=data, format="json", **self.extra) + response = self.client.put(url, data=data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -638,5 +638,5 @@ def test_function_update(self): error.code = StatusCode.INTERNAL with mock.patch.object(OrchestratorClient, "update_function", side_effect=error): - response = self.client.put(url, data=data, format="json", **self.extra) + response = self.client.put(url, data=data, format="json") self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/backend/api/tests/views/test_views_info.py b/backend/api/tests/views/test_views_info.py index 58b7efeef..bc36efaa9 100644 --- a/backend/api/tests/views/test_views_info.py +++ b/backend/api/tests/views/test_views_info.py @@ -13,9 +13,6 @@ @override_settings(LEDGER_CHANNELS={"mychannel": {"chaincode": {"name": "mycc"}, "model_export_enabled": True}}) class InfoViewTests(APITestCase): url = "/info/" - extra = { - "HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", - } def test_anonymous(self): client = APIClient() @@ -35,7 +32,7 @@ def test_authenticated(self): with mock.patch.object( OrchestratorClient, "query_version", return_value=OrchestratorVersion(server="foo", chaincode="bar") ): - response = client.get(self.url, **self.extra) + response = client.get(self.url) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/backend/api/tests/views/test_views_metadata.py b/backend/api/tests/views/test_views_metadata.py index a29a38177..fb409753c 100644 --- a/backend/api/tests/views/test_views_metadata.py +++ b/backend/api/tests/views/test_views_metadata.py @@ -28,7 +28,6 @@ def setUp(self): metadata=dict(One="case_sensitive", three="duplicate_three", four="test") ) - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} self.url = reverse("api:compute_plan_metadata-list") # alphabetically ordered list @@ -39,9 +38,9 @@ def tearDown(self): def test_metadata_list_empty(self): ComputePlan.objects.all().delete() - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(list(response.data), []) def test_metadata_list(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(list(response.data), self.expected_results) diff --git a/backend/api/tests/views/test_views_model.py b/backend/api/tests/views/test_views_model.py index 56910f167..cdf929feb 100644 --- a/backend/api/tests/views/test_views_model.py +++ b/backend/api/tests/views/test_views_model.py @@ -40,7 +40,6 @@ def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) self.model, self.model_filename = get_sample_model() - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": CHANNEL, "HTTP_ACCEPT": "application/json;version=0.0"} self.logger = logging.getLogger("django.request") self.previous_level = self.logger.getEffectiveLevel() self.logger.setLevel(logging.ERROR) @@ -138,25 +137,25 @@ def tearDown(self): def test_model_list_empty(self): Model.objects.all().delete() - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) def test_model_list_success(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), {"count": len(self.expected_results), "next": None, "previous": None, "results": self.expected_results}, ) def test_model_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(self.url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(self.url) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) @internal_server_error_on_exception() @mock.patch("api.views.model.ModelViewSet.list", side_effect=Exception("Unexpected error")) def test_model_list_fail(self, _): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) def test_model_list_storage_addresses_update(self): @@ -164,7 +163,7 @@ def test_model_list_storage_addresses_update(self): model.model_address.replace("http://testserver", "http://remotetestserver") model.save() - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.data["count"], len(self.expected_results)) for result, model in zip(response.data["results"], self.expected_results): self.assertEqual(result["address"]["storage_address"], model["address"]["storage_address"]) @@ -173,7 +172,9 @@ def test_model_list_filter(self): """Filter model on key.""" key = self.expected_results[0]["key"] params = urlencode({"key": key}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get( + f"{self.url}?{params}", + ) self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.expected_results[:1]} ) @@ -182,7 +183,9 @@ def test_model_list_filter_and(self): """Filter model on key and owner.""" key, owner = self.expected_results[0]["key"], self.expected_results[0]["owner"] params = urlencode({"key": key, "owner": owner}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get( + f"{self.url}?{params}", + ) self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": self.expected_results[:1]} ) @@ -192,7 +195,9 @@ def test_model_list_filter_in(self): key_0 = self.expected_results[0]["key"] key_1 = self.expected_results[1]["key"] params = urlencode({"key": ",".join([key_0, key_1])}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get( + f"{self.url}?{params}", + ) self.assertEqual( response.json(), {"count": 2, "next": None, "previous": None, "results": self.expected_results[:2]} ) @@ -206,7 +211,9 @@ def test_model_list_filter_in(self): ) def test_model_list_pagination_success(self, _, page_size, page): params = urlencode({"page_size": page_size, "page": page}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get( + f"{self.url}?{params}", + ) r = response.json() self.assertEqual(r["count"], len(self.expected_results)) offset = (page - 1) * page_size @@ -214,22 +221,28 @@ def test_model_list_pagination_success(self, _, page_size, page): def test_model_list_ordering(self): params = urlencode({"ordering": "creation_date"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get( + f"{self.url}?{params}", + ) self.assertEqual(response.json().get("results"), self.expected_results), params = urlencode({"ordering": "-creation_date"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get( + f"{self.url}?{params}", + ) self.assertEqual(response.json().get("results"), self.expected_results[::-1]) def test_model_retrieve(self): url = reverse("api:model-detail", args=[self.expected_results[0]["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) self.assertEqual(response.json(), self.expected_results[0]) def test_model_retrieve_wrong_channel(self): url = reverse("api:model-detail", args=[self.expected_results[0]["key"]]) - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(url, **extra) + self.client.channel = "yourchannel" + response = self.client.get(url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_model_retrieve_storage_addresses_update(self): @@ -238,7 +251,9 @@ def test_model_retrieve_storage_addresses_update(self): model.save() url = reverse("api:model-detail", args=[self.expected_results[0]["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) self.assertEqual( response.data["address"]["storage_address"], self.expected_results[0]["address"]["storage_address"] ) @@ -247,7 +262,9 @@ def test_model_retrieve_storage_addresses_update(self): @mock.patch("api.views.model.ModelViewSet.retrieve", side_effect=Exception("Unexpected error")) def test_model_retrieve_fail(self, _): url = reverse("api:model-detail", args=[self.expected_results[0]["key"]]) - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) def test_model_download_by_organization_for_worker(self): @@ -370,7 +387,9 @@ def test_model_download_file(self): with mock.patch("api.views.utils.get_owner", return_value=model.owner), mock.patch( "api.views.model.type", return_value=OrganizationUser ): - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) content = response.getvalue() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(content, model_files.file.read()) @@ -381,7 +400,9 @@ def test_model_download_file_wrong_user(self): model = factory.create_model(self.train_task, key=model_files.key, owner="substra") url = reverse("api:model-file", args=[model.key]) with mock.patch("api.views.utils.get_owner", return_value=model.owner): - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_model_download_file_wrong_owner(self): @@ -391,7 +412,9 @@ def test_model_download_file_wrong_owner(self): with mock.patch("api.views.utils.get_owner", return_value=model.owner), mock.patch( "api.views.model.type", return_value=OrganizationUser ): - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_model_download_file_deleted(self): @@ -404,5 +427,7 @@ def test_model_download_file_deleted(self): with mock.patch("api.views.utils.get_owner", return_value=model.owner), mock.patch( "api.views.model.type", return_value=OrganizationUser ): - response = self.client.get(url, **self.extra) + response = self.client.get( + url, + ) self.assertEqual(response.status_code, status.HTTP_410_GONE) diff --git a/backend/api/tests/views/test_views_newsfeed.py b/backend/api/tests/views/test_views_newsfeed.py index ce7fe8858..e4ec13d87 100644 --- a/backend/api/tests/views/test_views_newsfeed.py +++ b/backend/api/tests/views/test_views_newsfeed.py @@ -24,15 +24,13 @@ class NewsFeedViewTests(APITestCase): def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) - - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} self.url = reverse("api:news_feed-list") def tearDown(self): shutil.rmtree(MEDIA_ROOT, ignore_errors=True) def test_newsfeed_list_empty(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) def test_newsfeed_list(self): @@ -164,7 +162,7 @@ def test_newsfeed_list(self): }, ] - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), {"count": len(expected_results), "next": None, "previous": None, "results": expected_results}, @@ -202,21 +200,21 @@ def test_newsfeed_filter_creation_date(self): }, ] - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), {"count": len(expected_results), "next": None, "previous": None, "results": expected_results}, ) params = urlencode({"timestamp_after": expected_results[2]["timestamp"]}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 3, "next": None, "previous": None, "results": expected_results[:3]}, ) params = urlencode({"timestamp_before": expected_results[0]["timestamp"]}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 2, "next": None, "previous": None, "results": expected_results[1:]}, @@ -225,7 +223,7 @@ def test_newsfeed_filter_creation_date(self): params = urlencode( {"timestamp_before": expected_results[1]["timestamp"], "timestamp_after": expected_results[2]["timestamp"]} ) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": [expected_results[2]]}, @@ -261,21 +259,21 @@ def test_newsfeed_filter_start_end_date(self): }, ] - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), {"count": len(expected_results), "next": None, "previous": None, "results": expected_results}, ) params = urlencode({"timestamp_after": expected_results[1]["timestamp"]}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 2, "next": None, "previous": None, "results": expected_results[:2]}, ) params = urlencode({"timestamp_before": expected_results[0]["timestamp"]}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 2, "next": None, "previous": None, "results": expected_results[1:]}, @@ -284,7 +282,7 @@ def test_newsfeed_filter_start_end_date(self): params = urlencode( {"timestamp_before": expected_results[0]["timestamp"], "timestamp_after": expected_results[1]["timestamp"]} ) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": [expected_results[1]]}, @@ -329,28 +327,28 @@ def test_newsfeed_filter_important_news_only(self): }, ] - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), {"count": len(expected_results), "next": None, "previous": None, "results": expected_results}, ) params = urlencode({"important_news_only": "true"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": 1, "next": None, "previous": None, "results": [expected_results[0]]}, ) params = urlencode({"important_news_only": "false"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": len(expected_results), "next": None, "previous": None, "results": expected_results}, ) params = urlencode({"important_news_only": "else"}) - response = self.client.get(f"{self.url}?{params}", **self.extra) + response = self.client.get(f"{self.url}?{params}") self.assertEqual( response.json(), {"count": len(expected_results), "next": None, "previous": None, "results": expected_results}, diff --git a/backend/api/tests/views/test_views_performance.py b/backend/api/tests/views/test_views_performance.py index 622f17172..5756a77b9 100644 --- a/backend/api/tests/views/test_views_performance.py +++ b/backend/api/tests/views/test_views_performance.py @@ -35,7 +35,6 @@ def setUp(self): self.data_sample = factory.create_datasample([self.data_manager]) self.compute_plan = factory.create_computeplan() - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} self.url = reverse("api:compute_plan_perf-list", args=[self.compute_plan.key]) self.metric = factory.create_function( @@ -114,7 +113,7 @@ def tearDown(self): def test_performance_list_empty(self): Performance.objects.all().delete() - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), { @@ -127,7 +126,7 @@ def test_performance_list_empty(self): ) def test_performance_list(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual( response.json(), { @@ -151,9 +150,8 @@ def setUp(self): factory.create_computeplan(status=ComputePlan.Status.PLAN_STATUS_DONE), ] - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} self.url = reverse("api:performance-list") - self.export_extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "*/*"} + self.export_url = reverse("api:performance-export") self.metrics = [ @@ -287,18 +285,18 @@ def tearDown(self): shutil.rmtree(MEDIA_ROOT, ignore_errors=True) def test_performance_view(self): - response = self.client.get(self.url, **self.extra) + response = self.client.get(self.url) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_performance_export(self): - response = self.client.get(self.export_url, **self.export_extra) + response = self.client.get(self.export_url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(list(response.streaming_content)), len(self.expected_results) + 1) def test_performance_export_with_metadata(self): metadata = "epochs,hidden_sizes,last_hidden_sizes" params = urlencode({"metadata_columns": metadata}) - response = self.client.get(f"{self.export_url}?{params}", **self.export_extra) + response = self.client.get(f"{self.export_url}?{params}") content_list = list(response.streaming_content) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(metadata in str(content_list[0])) @@ -308,7 +306,7 @@ def test_performance_export_filter(self): """Filter performance on cp key.""" key = self.compute_plans[0].key params = urlencode({"key": key}) - response = self.client.get(f"{self.export_url}?{params}", **self.export_extra) + response = self.client.get(f"{self.export_url}?{params}") content_list = list(response.streaming_content) self.assertEqual(len(content_list), 4) self.assertTrue(str(self.compute_plans[0].key) in str(content_list[1])) @@ -318,7 +316,7 @@ def test_performance_export_filter_in(self): key_0 = self.compute_plans[0].key key_1 = self.compute_plans[1].key params = urlencode({"key": ",".join([str(key_0), str(key_1)])}) - response = self.client.get(f"{self.export_url}?{params}", **self.export_extra) + response = self.client.get(f"{self.export_url}?{params}") content_list = list(response.streaming_content) self.assertEqual(len(content_list), len(self.expected_results) + 1) @@ -328,7 +326,7 @@ def test_performance_export_filter_and(self): key_1 = self.compute_plans[1].key status = ComputePlan.Status.PLAN_STATUS_DOING params = urlencode({"key": ",".join([str(key_0), str(key_1)]), "status": status}) - response = self.client.get(f"{self.export_url}?{params}", **self.export_extra) + response = self.client.get(f"{self.export_url}?{params}") content_list = list(response.streaming_content) self.assertEqual(len(content_list), 4) self.assertTrue(status in str(content_list[1])) @@ -352,5 +350,5 @@ def test_n_plus_one_queries_performance_list(authenticated_client, create_comput with utils.CaptureQueriesContext(connection) as query: authenticated_client.get(url) query_task_with_perf = len(query.captured_queries) - assert query_task_with_perf < 11 + assert query_task_with_perf < 12 assert query_task_with_perf - query_tasks_empty < 3 diff --git a/backend/api/tests/views/test_views_task_profiling.py b/backend/api/tests/views/test_views_task_profiling.py index 36230ed44..64fc94f14 100644 --- a/backend/api/tests/views/test_views_task_profiling.py +++ b/backend/api/tests/views/test_views_task_profiling.py @@ -43,21 +43,20 @@ def setUp(self) -> None: ] def test_task_profiling_list_success(self): - response = self.client.get(TASK_PROFILING_LIST_URL, **EXTRA) + response = self.client.get(TASK_PROFILING_LIST_URL) self.assertEqual( response.json(), {"count": len(self.expected_results), "next": None, "previous": None, "results": self.expected_results}, ) def test_task_profiling_list_wrong_channel(self): - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "yourchannel", "HTTP_ACCEPT": "application/json;version=0.0"} - response = self.client.get(TASK_PROFILING_LIST_URL, **extra) + self.client.channel = "yourchannel" + response = self.client.get(TASK_PROFILING_LIST_URL) self.assertEqual(response.json(), {"count": 0, "next": None, "previous": None, "results": []}) def test_task_profiling_retrieve_success(self): response = self.client.get( - reverse("api:task_profiling-detail", args=[self.expected_results[0]["compute_task_key"]]), - **EXTRA, + reverse("api:task_profiling-detail", args=[self.expected_results[0]["compute_task_key"]]) ) self.assertEqual(response.json(), self.expected_results[0]) @@ -66,9 +65,7 @@ def test_task_profiling_create_bad_client(self): cp = factory.create_computeplan() task = factory.create_computetask(compute_plan=cp, function=function) - response = self.client.post( - TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key), "channel": CHANNEL}, **EXTRA - ) + response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key), "channel": CHANNEL}) self.assertEqual(response.status_code, 403) @@ -81,13 +78,11 @@ def test_task_profiling_create_success(self): cp = factory.create_computeplan() task = factory.create_computetask(compute_plan=cp, function=function) - response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key)}, **EXTRA) + response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key)}) self.assertEqual(response.status_code, 201) step_url = reverse("api:step-list", args=[str(task.key)]) - response = self.client.post( - step_url, {"step": "custom_step", "duration": datetime.timedelta(seconds=20)}, **EXTRA - ) + response = self.client.post(step_url, {"step": "custom_step", "duration": datetime.timedelta(seconds=20)}) self.assertEqual(response.status_code, 200) expected_result = [ @@ -109,10 +104,10 @@ def test_already_exist_task_profiling(self): cp = factory.create_computeplan() task = factory.create_computetask(compute_plan=cp, function=function) - response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key)}, **EXTRA) + response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key)}) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key)}, **EXTRA) + response = self.client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(task.key)}) self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) @@ -121,14 +116,10 @@ def test_already_exist_task_profiling(self): def test_task_profiling_post_duplicate(authenticated_backend_client, create_compute_plan, create_compute_task): compute_plan = create_compute_plan() compute_task = create_compute_task(compute_plan) - response = authenticated_backend_client.post( - TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}, **EXTRA - ) + response = authenticated_backend_client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}) assert response.status_code == status.HTTP_201_CREATED - response = authenticated_backend_client.post( - TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}, **EXTRA - ) + response = authenticated_backend_client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}) assert response.status_code == status.HTTP_409_CONFLICT @@ -138,7 +129,7 @@ def test_task_profiling_update_datetime(authenticated_backend_client, create_com compute_plan = create_compute_plan() compute_task = create_compute_task(compute_plan) - authenticated_backend_client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}, **EXTRA) + authenticated_backend_client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}) task_profiling = compute_task.task_profiling task_profiling.refresh_from_db() previous_datetime = task_profiling.creation_date @@ -159,13 +150,13 @@ def test_task_profiling_add_step_no_datetime_change( compute_plan = create_compute_plan() compute_task = create_compute_task(compute_plan) - authenticated_backend_client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}, **EXTRA) + authenticated_backend_client.post(TASK_PROFILING_LIST_URL, {"compute_task_key": str(compute_task.key)}) task_profiling = compute_task.task_profiling task_profiling.refresh_from_db() previous_datetime = task_profiling.creation_date step_url = reverse("api:step-list", args=[str(compute_task.key)]) - authenticated_backend_client.post(step_url, {"compute_task_key": str(compute_task.key)}, **EXTRA) + authenticated_backend_client.post(step_url, {"compute_task_key": str(compute_task.key)}) task_profiling.refresh_from_db() new_datetime = task_profiling.creation_date assert new_datetime == previous_datetime @@ -179,7 +170,6 @@ class TaskProfilingViewTestsOtherBackend(APITestCase): client_class = AuthenticatedBackendClient def setUp(self) -> None: - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": CHANNEL, "HTTP_ACCEPT": "application/json;version=0.0"} self.url = reverse("api:task_profiling-list") def test_task_profiling_create_fail_other_backend(self): @@ -187,5 +177,5 @@ def test_task_profiling_create_fail_other_backend(self): cp = factory.create_computeplan() task = factory.create_computetask(compute_plan=cp, function=function) - response = self.client.post(self.url, {"compute_task_key": str(task.key)}, **self.extra) + response = self.client.post(self.url, {"compute_task_key": str(task.key)}) self.assertEqual(response.status_code, 403) diff --git a/backend/api/tests/views/test_views_token.py b/backend/api/tests/views/test_views_token.py index c60132f62..cf14f19cb 100644 --- a/backend/api/tests/views/test_views_token.py +++ b/backend/api/tests/views/test_views_token.py @@ -104,10 +104,9 @@ def test_delete_token_other_user(authenticated_client): @pytest.mark.django_db def test_token_creation_post(authenticated_client): authenticated_client.create_user() - extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} payload = {"expires_at": "2023-07-14T11:55:36.509Z", "note": "gfyqgbs"} url = "/api-token/" - response = authenticated_client.post(url, payload, **extra) + response = authenticated_client.post(url, payload) assert response.status_code == status.HTTP_200_OK tokens_count = BearerToken.objects.count() diff --git a/backend/api/views/model.py b/backend/api/views/model.py index 02d02a778..6a051da5a 100644 --- a/backend/api/views/model.py +++ b/backend/api/views/model.py @@ -13,7 +13,6 @@ from rest_framework.authentication import BasicAuthentication from rest_framework.decorators import action from rest_framework.filters import OrderingFilter -from rest_framework.permissions import IsAuthenticated from rest_framework.settings import api_settings from rest_framework.viewsets import GenericViewSet @@ -27,6 +26,7 @@ from api.views.utils import get_channel_name from api.views.utils import if_true from libs.pagination import DefaultPageNumberPagination +from libs.permissions import IsAuthorized from organization.authentication import OrganizationUser from substrapp.models import Model as ModelFiles from substrapp.utils import get_owner @@ -93,7 +93,7 @@ class ModelViewSet(mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.Li filterset_class = ModelFilter authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES + [BasicAuthentication] - permission_classes = [IsAuthenticated, IsCurrentBackendOrReadOnly] + permission_classes = [IsAuthorized, IsCurrentBackendOrReadOnly] def get_queryset(self): return Model.objects.filter(channel=get_channel_name(self.request)) diff --git a/backend/api/views/task_profiling.py b/backend/api/views/task_profiling.py index 27992d22f..72fe2a15a 100644 --- a/backend/api/views/task_profiling.py +++ b/backend/api/views/task_profiling.py @@ -8,7 +8,6 @@ from rest_framework import mixins from rest_framework import status from rest_framework.authentication import BasicAuthentication -from rest_framework.permissions import IsAuthenticated from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -20,6 +19,7 @@ from api.views.utils import IsCurrentBackendOrReadOnly from api.views.utils import get_channel_name from libs.pagination import LargePageNumberPagination +from libs.permissions import IsAuthorized logger = structlog.get_logger(__name__) @@ -35,7 +35,7 @@ class TaskProfilingViewSet( serializer_class = TaskProfilingSerializer pagination_class = LargePageNumberPagination authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES + [BasicAuthentication] - permission_classes = [IsAuthenticated, IsCurrentBackendOrReadOnly] + permission_classes = [IsAuthorized, IsCurrentBackendOrReadOnly] def get_queryset(self) -> QuerySet[TaskProfiling]: return TaskProfiling.objects.filter(compute_task__channel=get_channel_name(self.request)) @@ -56,7 +56,7 @@ def perform_update(self, serializer): class TaskProfilingStepViewSet(mixins.CreateModelMixin, GenericViewSet): serializer_class = ProfilingStepSerializer authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES + [BasicAuthentication] - permission_classes = [IsAuthenticated, IsCurrentBackendOrReadOnly] + permission_classes = [IsAuthorized, IsCurrentBackendOrReadOnly] def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: task_profile_pk = kwargs["task_profiling_pk"] diff --git a/backend/api/views/utils.py b/backend/api/views/utils.py index 935f40f66..a5cf266d5 100644 --- a/backend/api/views/utils.py +++ b/backend/api/views/utils.py @@ -9,7 +9,6 @@ from rest_framework.authentication import BasicAuthentication from rest_framework.permissions import SAFE_METHODS from rest_framework.permissions import BasePermission -from rest_framework.permissions import IsAuthenticated from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings @@ -17,6 +16,7 @@ from api.errors import AssetPermissionError from api.errors import BadRequestError +from libs.permissions import IsAuthorized from organization.authentication import OrganizationUser from substrapp.clients import organization as organization_client from substrapp.storages.minio import MinioStorage @@ -60,7 +60,7 @@ def set_headers(self, filelike): class PermissionMixin(object): authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES + [BasicAuthentication] - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthorized] def check_access(self, channel_name: str, user, asset, is_proxied_request: bool) -> None: """Returns true if API consumer is allowed to access data. diff --git a/backend/backend/settings/deps/restframework.py b/backend/backend/settings/deps/restframework.py index 3425c11d0..542d7f686 100644 --- a/backend/backend/settings/deps/restframework.py +++ b/backend/backend/settings/deps/restframework.py @@ -13,7 +13,7 @@ "users.authentication.BearerTokenAuthentication", # Bearer token for SDK ], "DEFAULT_PERMISSION_CLASSES": [ - "rest_framework.permissions.IsAuthenticated", + "libs.permissions.IsAuthorized", ], "UNICODE_JSON": False, "ALLOWED_VERSIONS": ("0.0",), @@ -36,5 +36,5 @@ } SPECTACULAR_SETTINGS = { - "SERVE_PERMISSIONS": ["rest_framework.permissions.IsAuthenticated"], + "SERVE_PERMISSIONS": ["libs.permissions.IsAuthorized"], } diff --git a/backend/backend/settings/mods/oidc.py b/backend/backend/settings/mods/oidc.py index eaee681ba..f1e31a8dc 100644 --- a/backend/backend/settings/mods/oidc.py +++ b/backend/backend/settings/mods/oidc.py @@ -34,10 +34,13 @@ OIDC["USERS"]["APPEND_DOMAIN"] = to_bool(os.environ.get("OIDC_USERS_APPEND_DOMAIN", "false")) OIDC["USERS"]["DEFAULT_CHANNEL"] = os.environ.get("OIDC_USERS_DEFAULT_CHANNEL") - if not OIDC["USERS"]["DEFAULT_CHANNEL"]: - raise Exception("No default channel provided for OIDC users") - if OIDC["USERS"]["DEFAULT_CHANNEL"] not in ledger.LEDGER_CHANNELS: - raise Exception(f"Channel {OIDC['USERS']['DEFAULT_CHANNEL']} does not exist") + OIDC["USERS"]["MUST_BE_APPROVED"] = to_bool(os.environ.get("OIDC_USERS_MUST_BE_APPROVED", "false")) + if OIDC["USERS"]["DEFAULT_CHANNEL"] and OIDC["USERS"]["MUST_BE_APPROVED"]: + raise Exception("Both 'default channel' and 'user must be approved' options are activated") + if not (OIDC["USERS"]["DEFAULT_CHANNEL"] or OIDC["USERS"]["MUST_BE_APPROVED"]): + raise Exception( + "At least one option between 'default channel' and 'user must be approved' needs to be activated" + ) OIDC["USERS"]["LOGIN_VALIDITY_DURATION"] = int( os.environ.get("OIDC_USERS_LOGIN_VALIDITY_DURATION", 60 * 60) ) # seconds diff --git a/backend/backend/views.py b/backend/backend/views.py index db0bb7f69..3c3971358 100644 --- a/backend/backend/views.py +++ b/backend/backend/views.py @@ -5,12 +5,12 @@ from rest_framework import status from rest_framework.authtoken.views import ObtainAuthToken as DRFObtainAuthToken from rest_framework.permissions import AllowAny -from rest_framework.permissions import IsAuthenticated from rest_framework.throttling import AnonRateThrottle from rest_framework.views import APIView from api.views.utils import ApiResponse from api.views.utils import get_channel_name +from libs.permissions import IsAuthorized from libs.user_login_throttle import UserLoginThrottle from substrapp.orchestrator import get_orchestrator_client from substrapp.utils import get_owner @@ -46,7 +46,7 @@ class AuthenticatedBearerToken(DRFObtainAuthToken): get a Bearer token if you're already authenticated somehow """ - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthorized] def post(self, request, *args, **kwargs): s = BearerTokenSerializer(data=request.data) @@ -60,7 +60,7 @@ class ActiveBearerTokens(APIView): list Bearer tokens for a user """ - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthorized] def get(self, request, *args, **kwargs): tokens = [ diff --git a/backend/libs/permissions.py b/backend/libs/permissions.py new file mode 100644 index 000000000..c0e9a223a --- /dev/null +++ b/backend/libs/permissions.py @@ -0,0 +1,14 @@ +from rest_framework import permissions + +from organization.authentication import OrganizationUser + + +class IsAuthorized(permissions.BasePermission): + def has_permission(self, request, view): + return bool( + ( + request.user + and request.user.is_authenticated + and (hasattr(request.user, "channel") or isinstance(request.user, OrganizationUser)) + ) + ) diff --git a/backend/organization/tests/views/test_views_organization.py b/backend/organization/tests/views/test_views_organization.py index 8a07824c7..3dc077b83 100644 --- a/backend/organization/tests/views/test_views_organization.py +++ b/backend/organization/tests/views/test_views_organization.py @@ -22,8 +22,6 @@ def setUp(self): if not os.path.exists(MEDIA_ROOT): os.makedirs(MEDIA_ROOT) - self.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} - self.logger = logging.getLogger("django.request") self.previous_level = self.logger.getEffectiveLevel() self.logger.setLevel(logging.ERROR) @@ -41,7 +39,7 @@ def tearDown(self): def test_organization_list_success(self): url = reverse("organization:organization-list") with mock.patch("api.serializers.organization.get_owner", return_value="foo"): - response = self.client.get(url, **self.extra) + response = self.client.get(url) self.assertEqual( response.json(), [ diff --git a/backend/users/authentication.py b/backend/users/authentication.py index ac8495701..5e844e1f4 100644 --- a/backend/users/authentication.py +++ b/backend/users/authentication.py @@ -137,8 +137,8 @@ def create_user(self, claims): username = utils.oidc.generate_username(email, issuer, subject) user = self.UserModel.objects.create_user(username, email=email) - - UserChannel.objects.create(user=user, channel_name=settings.OIDC["USERS"]["DEFAULT_CHANNEL"]) + if settings.OIDC["USERS"]["DEFAULT_CHANNEL"]: + UserChannel.objects.create(user=user, channel_name=settings.OIDC["USERS"]["DEFAULT_CHANNEL"]) UserOidcInfo.objects.create( user=user, openid_issuer=issuer, openid_subject=subject, valid_until=_get_user_valid_until() ) diff --git a/backend/users/serializers/user.py b/backend/users/serializers/user.py index 3c43cd32d..951c5db03 100644 --- a/backend/users/serializers/user.py +++ b/backend/users/serializers/user.py @@ -56,3 +56,9 @@ def get_role(self, instance): def get_is_external_user(self, instance): return hasattr(instance, "oidc_info") + + +class UserAwaitingApprovalSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ["username", "email"] diff --git a/backend/users/tests/test_user.py b/backend/users/tests/test_user.py index 115a38439..5dd46b358 100644 --- a/backend/users/tests/test_user.py +++ b/backend/users/tests/test_user.py @@ -105,7 +105,6 @@ def test_oidc_username_generation(self): class TestUserEndpoints: url = None - extra = None @pytest.fixture(autouse=True) def use_dummy_channels(self, settings): @@ -117,7 +116,6 @@ def setup_class(cls): usually contains tests). """ cls.url = reverse("user:users-list") - cls.extra = {"HTTP_SUBSTRA_CHANNEL_NAME": "mychannel", "HTTP_ACCEPT": "application/json;version=0.0"} cls.channel = "mychannel" @pytest.mark.django_db @@ -193,7 +191,9 @@ def test_user_create_role_unknown(self): @pytest.mark.django_db def test_retrieve_user(self): url = reverse("user:users-detail", args=["substra"]) - response = AuthenticatedClient(channel=self.channel).get(url, **self.extra) + response = AuthenticatedClient(channel=self.channel).get( + url, + ) assert response.status_code == status.HTTP_200_OK @pytest.mark.django_db @@ -322,7 +322,9 @@ def test_delete_user(self): @pytest.mark.django_db def test_list_users(self): - response = AuthenticatedClient(channel=self.channel).get(self.url, **self.extra) + response = AuthenticatedClient(channel=self.channel).get( + self.url, + ) assert response.status_code == status.HTTP_200_OK assert "substra" == response.json()["results"][0]["username"] @@ -338,21 +340,29 @@ def test_filter_user_role(self): assert response.status_code == status.HTTP_201_CREATED # list all users (no filter) - response = AuthenticatedClient(channel=self.channel).get(self.url, **self.extra) + response = AuthenticatedClient(channel=self.channel).get( + self.url, + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["count"] == 2 assert {user["role"] for user in data["results"]} == {"USER", "ADMIN"} # list only admin users - response = AuthenticatedClient(channel=self.channel).get(self.url, data={"role": "ADMIN"}, **self.extra) + response = AuthenticatedClient(channel=self.channel).get( + self.url, + data={"role": "ADMIN"}, + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["count"] == 1 assert data["results"][0]["role"] == "ADMIN" # list only non admin users - response = AuthenticatedClient(channel=self.channel).get(self.url, data={"role": "USER"}, **self.extra) + response = AuthenticatedClient(channel=self.channel).get( + self.url, + data={"role": "USER"}, + ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["count"] == 1 diff --git a/backend/users/urls.py b/backend/users/urls.py index da8106bdc..2877ef548 100644 --- a/backend/users/urls.py +++ b/backend/users/urls.py @@ -9,10 +9,12 @@ # Create a router and register our viewsets with it. from users.views import AuthenticationViewSet from users.views import UserViewSet +from users.views.user import UserAwaitingApprovalViewSet router = DefaultRouter() router.register(r"me", AuthenticationViewSet, basename="me") router.register(r"users", UserViewSet, basename="users") +router.register(r"users-awaiting-approval", UserAwaitingApprovalViewSet, basename="users-awaiting-approval") urlpatterns = [ path("", include(router.urls)), diff --git a/backend/users/views/user.py b/backend/users/views/user.py index 78a355bfa..0d26aeced 100644 --- a/backend/users/views/user.py +++ b/backend/users/views/user.py @@ -1,9 +1,11 @@ import datetime +import json from urllib.parse import unquote import jwt from django.conf import settings from django.contrib.auth import get_user_model +from django.contrib.auth.models import User from django.contrib.auth.password_validation import validate_password from django.core.exceptions import ValidationError as djangoValidationError from django.utils.encoding import force_str @@ -27,7 +29,9 @@ from api.views.utils import ApiResponse from api.views.utils import get_channel_name from libs.pagination import DefaultPageNumberPagination +from libs.permissions import IsAuthorized from users.models.user_channel import UserChannel +from users.serializers.user import UserAwaitingApprovalSerializer from users.serializers.user import UserSerializer @@ -88,6 +92,11 @@ def has_permission(self, request, view): return request.user.channel.role == UserChannel.Role.ADMIN +class IsAdmin(permissions.BasePermission): + def has_permission(self, request, view): + return request.user.channel.role == UserChannel.Role.ADMIN + + class IsSelf(permissions.BasePermission): def has_permission(self, request, view): user = view.get_object() @@ -119,7 +128,7 @@ class UserViewSet( ordering = ["username"] filter_backends = [OrderingFilter, MatchFilter, DjangoFilterBackend] lookup_field = "username" - permission_classes = [permissions.IsAuthenticated, IsAdminOrReadOnly] + permission_classes = [IsAuthorized, IsAdminOrReadOnly] search_fields = ["username"] filterset_class = UserFilter @@ -232,3 +241,49 @@ def generate_reset_password_token(self, request, *args, **kwargs): data = {"reset_password_token": jwt_token} return ApiResponse(data=data, status=status.HTTP_200_OK, headers=self.get_success_headers({})) + + +class UserAwaitingApprovalViewSet( + GenericViewSet, + mixins.ListModelMixin, +): + user_model = get_user_model() + permission_classes = [IsAdmin] + pagination_class = DefaultPageNumberPagination + serializer_class = UserAwaitingApprovalSerializer + ordering_fields = ["username"] + ordering = ["username"] + filter_backends = [OrderingFilter, MatchFilter, DjangoFilterBackend] + lookup_field = "username" + search_fields = ["username"] + filterset_class = UserFilter + + def get_queryset(self): + return self.user_model.objects.filter(channel=None).exclude(username="deleted") + + def delete(self, request, *args, **kwargs): + try: + user = User.objects.get(username=request.GET.get("username")) + user.delete() + return ApiResponse(data={"message": "User removed"}, status=status.HTTP_200_OK) + except User.DoesNotExist or User.MultipleObjectsReturned: + pass + return ApiResponse(data={"message": "User not found"}, status=status.HTTP_404_NOT_FOUND) + + def put(self, request, *args, **kwargs): + d = json.loads(request.body) + try: + user = User.objects.get(username=request.GET.get("username")) + except User.DoesNotExist: + return ApiResponse(data={"message": "User not found"}, status=status.HTTP_404_NOT_FOUND) + except User.MultipleObjectsReturned: + return ApiResponse( + data={"message": "Multiple instance of the same user found"}, status=status.HTTP_409_CONFLICT + ) + + channel_name = get_channel_name(request) + channel_name = get_channel_name(request) + role = _validate_role(d.get("role")) + UserChannel.objects.create(channel_name=channel_name, role=role, user=user) + data = UserSerializer(instance=user).data + return ApiResponse(data=data, status=status.HTTP_200_OK) diff --git a/charts/substra-backend/CHANGELOG.md b/charts/substra-backend/CHANGELOG.md index debd7a68e..aefbf89af 100644 --- a/charts/substra-backend/CHANGELOG.md +++ b/charts/substra-backend/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [Unreleased] + +### Added + +- New `requireApproval` field, that triggers the User Awaiting Approval functionality ([#680](https://github.com/Substra/substra-backend/pull/680)) + ## [22.5.2] - 2023-06-27 ### Changed diff --git a/charts/substra-backend/Chart.yaml b/charts/substra-backend/Chart.yaml index e3d976288..34aaa4d36 100644 --- a/charts/substra-backend/Chart.yaml +++ b/charts/substra-backend/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: substra-backend home: https://github.com/Substra -version: 22.5.2 +version: 22.6.0 appVersion: 0.39.0 kubeVersion: ">= 1.19.0-0" description: Main package for Substra diff --git a/charts/substra-backend/README.md b/charts/substra-backend/README.md index 2038725fe..9f82b6fa1 100644 --- a/charts/substra-backend/README.md +++ b/charts/substra-backend/README.md @@ -284,6 +284,7 @@ Else, you must strike a balance: longer durations are more convenient, but risk | `oidc.users.useRefreshToken` | Attempt to refresh user info in the background. | `true` | | `oidc.users.loginValidityDuration` | How long a user account is valid after an OIDC login, in seconds | `3600` | | `oidc.users.channel` | The channel to assign OIDC users to (mandatory) | `nil` | +| `oidc.users.requireApproval` | Activate the user approval. A user using the OIDC login for the first time will need approval from an admin. It is not compatible with default channel | `false` | | `oidc.users.appendDomain` | As usernames are assigned based on e-mail address, whether to suffix user names with the email domain (john.doe@example.com would then be `john-doe-example`) | `false` | ### Database connection settings diff --git a/charts/substra-backend/templates/configmap-oidc.yaml b/charts/substra-backend/templates/configmap-oidc.yaml index a45420fcf..53f3cd916 100644 --- a/charts/substra-backend/templates/configmap-oidc.yaml +++ b/charts/substra-backend/templates/configmap-oidc.yaml @@ -8,6 +8,7 @@ data: OIDC_ENABLED: {{ .Values.oidc.enabled | quote }} OIDC_USERS_APPEND_DOMAIN: {{ .Values.oidc.users.appendDomain | quote }} OIDC_USERS_DEFAULT_CHANNEL: {{ .Values.oidc.users.channel | default "" | quote }} + OIDC_USERS_MUST_BE_APPROVED: {{ .Values.oidc.users.requireApproval | default "" | quote }} OIDC_USERS_LOGIN_VALIDITY_DURATION: {{ .Values.oidc.users.loginValidityDuration | default "" | quote }} OIDC_USERS_USE_REFRESH_TOKEN: {{ .Values.oidc.users.useRefreshToken | quote }} OIDC_RP_SIGN_ALGO: {{ .Values.oidc.signAlgo | default "" | quote }} diff --git a/charts/substra-backend/values.yaml b/charts/substra-backend/values.yaml index 7b6ccb4b2..3db4dde82 100644 --- a/charts/substra-backend/values.yaml +++ b/charts/substra-backend/values.yaml @@ -12,7 +12,6 @@ organizationName: owkin ## DataSampleStorageInServerMedia: false - privateCa: ## @param privateCa.enabled Run the init container injecting the private CA certificate ## @@ -427,7 +426,6 @@ schedulerWorker: runAsGroup: 1001 fsGroup: 1001 - ## @section Celery task scheduler settings scheduler: ## @param scheduler.enabled Enable scheduler service @@ -654,7 +652,7 @@ addAccountOperator: ## @descriptionStart Uses the authorization code flow. ## ## By default, `oidc.users.useRefreshToken` is enabled. This makes sure the user still has an account at the identity provider, without damaging user experience. -## +## ## The way it works is that a OIDC user that spent more than `oidc.users.loginValidityDuration` since their last login must undergo a refresh to keep using their access tokens -- but these refreshes are done in the background if `oidc.users.useRefreshToken` is enabled (otherwise a new manual authorization is necessary). The identity provider must support `offline_access` and configuration discovery. ## ## With this option active, you can set `oidc.users.loginValidityDuration` to low values (minutes). @@ -666,10 +664,10 @@ oidc: ## @param oidc.enabled Whether to enable OIDC authentication ## enabled: false - + ## @param oidc.clientSecretName The name of a secret containing the keys `OIDC_RP_CLIENT_ID` and `OIDC_RP_CLIENT_SECRET` (client ID and secret, typically issued by the provider) clientSecretName: null - + provider: ## @param oidc.provider.url The identity provider URL (with scheme). url: null @@ -683,10 +681,10 @@ oidc: token: null ## @param oidc.provider.endpoints.user Typically https://provider/me user: null - + ## @param oidc.provider.jwksUri Typically https://provider/jwks. Only required for public-key-based signing algorithms. If not given, read from `/.well-known/openid-configuration` at startup. jwksUri: null - + ## @param oidc.signAlgo Either RS256 or HS256 signAlgo: RS256 users: @@ -696,6 +694,8 @@ oidc: loginValidityDuration: 3600 ## @param oidc.users.channel The channel to assign OIDC users to (mandatory) channel: null + ## @param oidc.users.requireApproval Activate the user approval. A user using the OIDC login for the first time will need approval from an admin. It is not compatible with default channel + requireApproval: false ## @param oidc.users.appendDomain As usernames are assigned based on e-mail address, whether to suffix user names with the email domain (john.doe@example.com would then be `john-doe-example`) appendDomain: false @@ -708,17 +708,16 @@ database: username: &psql-username postgres ## @param database.auth.password what password to use for connecting password: &psql-password postgres - + ## @param database.auth.credentialsSecretName An alternative to giving username and password; must have `DATABASE_USERNAME` and `DATABASE_PASSWORD` keys. ## credentialsSecretName: null - + ## @param database.host Hostname of the database to connect to (defaults to local) host: null ## @param database.port Port of an external database to connect to port: 5432 - ## @section PostgreSQL settings ## @descriptionStart ## Database included as a subchart used by default. diff --git a/docs/settings.md b/docs/settings.md index 8926ecbb6..f173b36cf 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -132,6 +132,7 @@ Accepted true values for `bool` are: `1`, `ON`, `On`, `on`, `T`, `t`, `TRUE`, `T | bool | `OIDC_USERS_APPEND_DOMAIN` | `false` | | | string | `OIDC_USERS_DEFAULT_CHANNEL` | nil | | | int | `OIDC_USERS_LOGIN_VALIDITY_DURATION` | `3600` (`60 * 60`) | seconds | +| bool | `OIDC_USERS_MUST_BE_APPROVED` | `false` | | | bool | `OIDC_USERS_USE_REFRESH_TOKEN` | `false` | | ## CORS settings