From 03d5520b67e34e5267dac603d4fcc8846e172452 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Thu, 26 Sep 2024 18:35:34 +0200 Subject: [PATCH] Take an explicit application instance on the OAuth1 client --- lms/services/__init__.py | 2 +- lms/services/lti_grading/_v11.py | 10 ++++-- lms/services/lti_grading/factory.py | 1 + lms/services/oauth1.py | 19 ++++++------ lms/views/lti/deep_linking.py | 9 +++--- tests/unit/lms/services/__init___test.py | 2 -- .../lms/services/lti_grading/_v11_test.py | 6 ++-- .../lms/services/lti_grading/factory_test.py | 5 ++- tests/unit/lms/services/oauth1_test.py | 31 ++++++++++++------- tests/unit/lms/views/lti/deep_linking_test.py | 3 +- 10 files changed, 53 insertions(+), 35 deletions(-) diff --git a/lms/services/__init__.py b/lms/services/__init__.py index 3bae65e4cf..6f735b4450 100644 --- a/lms/services/__init__.py +++ b/lms/services/__init__.py @@ -82,7 +82,7 @@ def includeme(config): "lms.services.group_info.GroupInfoService", name="group_info" ) config.register_service_factory("lms.services.lti_h.LTIHService", name="lti_h") - config.register_service_factory("lms.services.oauth1.OAuth1Service", name="oauth1") + config.register_service_factory("lms.services.oauth1.factory", name="oauth1") config.register_service_factory( "lms.services.course.course_service_factory", name="course" ) diff --git a/lms/services/lti_grading/_v11.py b/lms/services/lti_grading/_v11.py index 226827490a..eb74c0fc8b 100644 --- a/lms/services/lti_grading/_v11.py +++ b/lms/services/lti_grading/_v11.py @@ -2,6 +2,7 @@ import xmltodict +from lms.models import ApplicationInstance from lms.services.exceptions import ExternalRequestError, StudentNotInCourse from lms.services.http import HTTPService from lms.services.lti_grading.interface import GradingResult, LTIGradingService @@ -11,11 +12,16 @@ class LTI11GradingService(LTIGradingService): # See: LTI1.1 Outcomes https://www.imsglobal.org/specs/ltiomv1p0/specification def __init__( - self, line_item_url, http_service: HTTPService, oauth1_service: OAuth1Service + self, + line_item_url, + http_service: HTTPService, + oauth1_service: OAuth1Service, + application_instance: ApplicationInstance, ): super().__init__(line_item_url, None) self.http_service = http_service self.oauth1_service = oauth1_service + self.application_instance = application_instance def read_result(self, grading_id) -> GradingResult: result = GradingResult(score=None, comment=None) @@ -73,7 +79,7 @@ def _send_request(self, request_body) -> dict: url=self.line_item_url, data=xml_body, headers={"Content-Type": "application/xml"}, - auth=self.oauth1_service.get_client(), + auth=self.oauth1_service.get_client(self.application_instance), ) except ExternalRequestError as err: err.message = "Error calling LTI Outcomes service" diff --git a/lms/services/lti_grading/factory.py b/lms/services/lti_grading/factory.py index 2ff5e538b9..4373751929 100644 --- a/lms/services/lti_grading/factory.py +++ b/lms/services/lti_grading/factory.py @@ -23,4 +23,5 @@ def service_factory(_context, request): line_item_url=request.parsed_params.get("lis_outcome_service_url"), http_service=request.find_service(name="http"), oauth1_service=request.find_service(name="oauth1"), + application_instance=request.lti_user.application_instance, ) diff --git a/lms/services/oauth1.py b/lms/services/oauth1.py index 3e2c754b84..3c50a3afc2 100644 --- a/lms/services/oauth1.py +++ b/lms/services/oauth1.py @@ -8,22 +8,19 @@ from oauthlib.oauth1.rfc5849 import signature from requests_oauthlib import OAuth1 +from lms.models import ApplicationInstance + class OAuth1Service: """Provides OAuth1 convenience functions.""" - def __init__(self, _context, request): - self._request = request - - def get_client(self) -> OAuth1: + def get_client(self, application_instance: ApplicationInstance) -> OAuth1: """ - Get an OAUth1 client that can be used to sign HTTP requests. + Get an OAUth1 client that can be used to sign HTTP requests for `application_instance`. To sign a request with the client pass it as the `auth` parameter to `requests.post()`. """ - application_instance = self._request.lti_user.application_instance - return OAuth1( client_key=application_instance.consumer_key, client_secret=application_instance.shared_secret, @@ -34,14 +31,12 @@ def get_client(self) -> OAuth1: force_include_body=True, ) - def sign(self, url: str, method: str, data: dict) -> dict: + def sign(self, application_instance, url: str, method: str, data: dict) -> dict: """ Sign data following the oauth1 spec. Useful when not using these values for a HTTP requests with the client from get_client. """ - application_instance = self._request.lti_user.application_instance - client_key = application_instance.consumer_key # Secret and token need to joined by "&". # We don't have a token but the trailing `&` is required @@ -81,3 +76,7 @@ def sign(self, url: str, method: str, data: dict) -> dict: payload["oauth_signature"] = digest return payload + + +def factory(_context, _request): + return OAuth1Service() diff --git a/lms/views/lti/deep_linking.py b/lms/views/lti/deep_linking.py index 73d00551d6..f8f77e91fe 100644 --- a/lms/views/lti/deep_linking.py +++ b/lms/views/lti/deep_linking.py @@ -239,8 +239,6 @@ def file_picker_to_form_fields_v11(self): if title := assignment_configuration.get("title"): content_item["title"] = title - oauth1_service = self.request.find_service(name="oauth1") - payload = { "content_items": json.dumps( { @@ -256,8 +254,11 @@ def file_picker_to_form_fields_v11(self): # An opaque value which should be returned by the TP in its response. payload["data"] = data - return oauth1_service.sign( - self.request.parsed_params["content_item_return_url"], "post", payload + return self.request.find_service(name="oauth1").sign( + self.request.lti_user.application_instance, + self.request.parsed_params["content_item_return_url"], + "post", + payload, ) @staticmethod diff --git a/tests/unit/lms/services/__init___test.py b/tests/unit/lms/services/__init___test.py index 87045ea3ae..9bd12f3fee 100644 --- a/tests/unit/lms/services/__init___test.py +++ b/tests/unit/lms/services/__init___test.py @@ -8,7 +8,6 @@ from lms.services.h_api import HAPI from lms.services.launch_verifier import LaunchVerifier from lms.services.lti_h import LTIHService -from lms.services.oauth1 import OAuth1Service class TestIncludeme: @@ -20,7 +19,6 @@ class TestIncludeme: ("grading_info", GradingInfoService), ("group_info", GroupInfoService), ("lti_h", LTIHService), - ("oauth1", OAuth1Service), ), ) def test_it_has_the_expected_service_by_name( diff --git a/tests/unit/lms/services/lti_grading/_v11_test.py b/tests/unit/lms/services/lti_grading/_v11_test.py index c15648b191..5261645bec 100644 --- a/tests/unit/lms/services/lti_grading/_v11_test.py +++ b/tests/unit/lms/services/lti_grading/_v11_test.py @@ -187,8 +187,10 @@ def svc_method(self, svc, request): return getattr(svc, request.param) @pytest.fixture - def svc(self, oauth1_service, http_service): - return LTI11GradingService(sentinel.service_url, http_service, oauth1_service) + def svc(self, oauth1_service, http_service, application_instance): + return LTI11GradingService( + sentinel.service_url, http_service, oauth1_service, application_instance + ) class GradingResponse(dict): diff --git a/tests/unit/lms/services/lti_grading/factory_test.py b/tests/unit/lms/services/lti_grading/factory_test.py index 97a7ef0954..5690e82873 100644 --- a/tests/unit/lms/services/lti_grading/factory_test.py +++ b/tests/unit/lms/services/lti_grading/factory_test.py @@ -14,7 +14,10 @@ def test_v11( svc = service_factory(sentinel.context, pyramid_request) LTI11GradingService.assert_called_once_with( - sentinel.grading_url, http_service, oauth1_service + sentinel.grading_url, + http_service, + oauth1_service, + pyramid_request.lti_user.application_instance, ) assert svc == LTI11GradingService.return_value diff --git a/tests/unit/lms/services/oauth1_test.py b/tests/unit/lms/services/oauth1_test.py index 7fefab9efe..39c947ca34 100644 --- a/tests/unit/lms/services/oauth1_test.py +++ b/tests/unit/lms/services/oauth1_test.py @@ -1,17 +1,17 @@ import json -from unittest import mock +from unittest.mock import sentinel import pytest from requests import Request -from lms.services.oauth1 import OAuth1Service +from lms.services.oauth1 import OAuth1Service, factory pytestmark = pytest.mark.usefixtures("application_instance_service") class TestOAuth1Service: def test_we_configure_OAuth1_correctly(self, service, OAuth1, application_instance): - service.get_client() + service.get_client(application_instance) OAuth1.assert_called_once_with( client_key=application_instance.consumer_key, @@ -26,7 +26,7 @@ def test_we_can_be_used_to_sign_a_request(self, service, application_instance): "POST", url="http://example.com", data={"param": "value"}, - auth=service.get_client(), + auth=service.get_client(application_instance), ) prepared_request = request.prepare() @@ -133,7 +133,7 @@ def test_sign( uuid.uuid4.return_value.hex = nonce datetime.now.return_value.timestamp.return_value = timestamp - result = service.sign(url, method, data) + result = service.sign(application_instance, url, method, data) assert result["oauth_signature_method"] == "HMAC-SHA1" assert result["oauth_nonce"] == nonce @@ -142,13 +142,8 @@ def test_sign( assert result["oauth_signature"] == signature @pytest.fixture - def service(self, context, pyramid_request): - return OAuth1Service(context, pyramid_request) - - @pytest.fixture - def context(self): - # We don't use context, so it doesn't matter what it is - return mock.sentinel.context + def service(self): + return OAuth1Service() @pytest.fixture def uuid(self, patch): @@ -161,3 +156,15 @@ def datetime(self, patch): @pytest.fixture def OAuth1(self, patch): return patch("lms.services.oauth1.OAuth1") + + +class TestFactory: + def test_it(self, pyramid_request, OAuth1Service): + service = factory(sentinel.context, pyramid_request) + + OAuth1Service.assert_called_once_with() + assert service == OAuth1Service.return_value + + @pytest.fixture + def OAuth1Service(self, patch): + return patch("lms.services.oauth1.OAuth1Service") diff --git a/tests/unit/lms/views/lti/deep_linking_test.py b/tests/unit/lms/views/lti/deep_linking_test.py index b6fd5b5c7b..53286dba2e 100644 --- a/tests/unit/lms/views/lti/deep_linking_test.py +++ b/tests/unit/lms/views/lti/deep_linking_test.py @@ -209,6 +209,7 @@ def test_it_for_v11( title, opaque_data_lti11, oauth1_service, + application_instance, ): misc_plugin.get_deeplinking_launch_url.return_value = "LAUNCH_URL" pyramid_request.parsed_params["opaque_data_lti11"] = opaque_data_lti11 @@ -252,7 +253,7 @@ def test_it_for_v11( expected_fields["data"] = opaque_data_lti11 oauth1_service.sign.assert_called_once_with( - sentinel.return_url, "post", expected_fields + application_instance, sentinel.return_url, "post", expected_fields ) assert fields == oauth1_service.sign.return_value