Skip to content

Commit

Permalink
Take an explicit application instance on the OAuth1 client
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Sep 30, 2024
1 parent 649972a commit 03d5520
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 35 deletions.
2 changes: 1 addition & 1 deletion lms/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def includeme(config):
"lms.services.group_info.GroupInfoService", name="group_info"
)
config.register_service_factory("lms.services.lti_h.LTIHService", name="lti_h")
config.register_service_factory("lms.services.oauth1.OAuth1Service", name="oauth1")
config.register_service_factory("lms.services.oauth1.factory", name="oauth1")
config.register_service_factory(
"lms.services.course.course_service_factory", name="course"
)
Expand Down
10 changes: 8 additions & 2 deletions lms/services/lti_grading/_v11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import xmltodict

from lms.models import ApplicationInstance
from lms.services.exceptions import ExternalRequestError, StudentNotInCourse
from lms.services.http import HTTPService
from lms.services.lti_grading.interface import GradingResult, LTIGradingService
Expand All @@ -11,11 +12,16 @@
class LTI11GradingService(LTIGradingService):
# See: LTI1.1 Outcomes https://www.imsglobal.org/specs/ltiomv1p0/specification
def __init__(
self, line_item_url, http_service: HTTPService, oauth1_service: OAuth1Service
self,
line_item_url,
http_service: HTTPService,
oauth1_service: OAuth1Service,
application_instance: ApplicationInstance,
):
super().__init__(line_item_url, None)
self.http_service = http_service
self.oauth1_service = oauth1_service
self.application_instance = application_instance

def read_result(self, grading_id) -> GradingResult:
result = GradingResult(score=None, comment=None)
Expand Down Expand Up @@ -73,7 +79,7 @@ def _send_request(self, request_body) -> dict:
url=self.line_item_url,
data=xml_body,
headers={"Content-Type": "application/xml"},
auth=self.oauth1_service.get_client(),
auth=self.oauth1_service.get_client(self.application_instance),
)
except ExternalRequestError as err:
err.message = "Error calling LTI Outcomes service"
Expand Down
1 change: 1 addition & 0 deletions lms/services/lti_grading/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ def service_factory(_context, request):
line_item_url=request.parsed_params.get("lis_outcome_service_url"),
http_service=request.find_service(name="http"),
oauth1_service=request.find_service(name="oauth1"),
application_instance=request.lti_user.application_instance,
)
19 changes: 9 additions & 10 deletions lms/services/oauth1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,19 @@
from oauthlib.oauth1.rfc5849 import signature
from requests_oauthlib import OAuth1

from lms.models import ApplicationInstance


class OAuth1Service:
"""Provides OAuth1 convenience functions."""

def __init__(self, _context, request):
self._request = request

def get_client(self) -> OAuth1:
def get_client(self, application_instance: ApplicationInstance) -> OAuth1:
"""
Get an OAUth1 client that can be used to sign HTTP requests.
Get an OAUth1 client that can be used to sign HTTP requests for `application_instance`.
To sign a request with the client pass it as the `auth` parameter to
`requests.post()`.
"""
application_instance = self._request.lti_user.application_instance

return OAuth1(
client_key=application_instance.consumer_key,
client_secret=application_instance.shared_secret,
Expand All @@ -34,14 +31,12 @@ def get_client(self) -> OAuth1:
force_include_body=True,
)

def sign(self, url: str, method: str, data: dict) -> dict:
def sign(self, application_instance, url: str, method: str, data: dict) -> dict:
"""
Sign data following the oauth1 spec.
Useful when not using these values for a HTTP requests with the client from get_client.
"""
application_instance = self._request.lti_user.application_instance

client_key = application_instance.consumer_key
# Secret and token need to joined by "&".
# We don't have a token but the trailing `&` is required
Expand Down Expand Up @@ -81,3 +76,7 @@ def sign(self, url: str, method: str, data: dict) -> dict:

payload["oauth_signature"] = digest
return payload


def factory(_context, _request):
return OAuth1Service()
9 changes: 5 additions & 4 deletions lms/views/lti/deep_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ def file_picker_to_form_fields_v11(self):
if title := assignment_configuration.get("title"):
content_item["title"] = title

oauth1_service = self.request.find_service(name="oauth1")

payload = {
"content_items": json.dumps(
{
Expand All @@ -256,8 +254,11 @@ def file_picker_to_form_fields_v11(self):
# An opaque value which should be returned by the TP in its response.
payload["data"] = data

return oauth1_service.sign(
self.request.parsed_params["content_item_return_url"], "post", payload
return self.request.find_service(name="oauth1").sign(
self.request.lti_user.application_instance,
self.request.parsed_params["content_item_return_url"],
"post",
payload,
)

@staticmethod
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/lms/services/__init___test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from lms.services.h_api import HAPI
from lms.services.launch_verifier import LaunchVerifier
from lms.services.lti_h import LTIHService
from lms.services.oauth1 import OAuth1Service


class TestIncludeme:
Expand All @@ -20,7 +19,6 @@ class TestIncludeme:
("grading_info", GradingInfoService),
("group_info", GroupInfoService),
("lti_h", LTIHService),
("oauth1", OAuth1Service),
),
)
def test_it_has_the_expected_service_by_name(
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/lms/services/lti_grading/_v11_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,10 @@ def svc_method(self, svc, request):
return getattr(svc, request.param)

@pytest.fixture
def svc(self, oauth1_service, http_service):
return LTI11GradingService(sentinel.service_url, http_service, oauth1_service)
def svc(self, oauth1_service, http_service, application_instance):
return LTI11GradingService(
sentinel.service_url, http_service, oauth1_service, application_instance
)


class GradingResponse(dict):
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/lms/services/lti_grading/factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def test_v11(
svc = service_factory(sentinel.context, pyramid_request)

LTI11GradingService.assert_called_once_with(
sentinel.grading_url, http_service, oauth1_service
sentinel.grading_url,
http_service,
oauth1_service,
pyramid_request.lti_user.application_instance,
)
assert svc == LTI11GradingService.return_value

Expand Down
31 changes: 19 additions & 12 deletions tests/unit/lms/services/oauth1_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import json
from unittest import mock
from unittest.mock import sentinel

import pytest
from requests import Request

from lms.services.oauth1 import OAuth1Service
from lms.services.oauth1 import OAuth1Service, factory

pytestmark = pytest.mark.usefixtures("application_instance_service")


class TestOAuth1Service:
def test_we_configure_OAuth1_correctly(self, service, OAuth1, application_instance):
service.get_client()
service.get_client(application_instance)

OAuth1.assert_called_once_with(
client_key=application_instance.consumer_key,
Expand All @@ -26,7 +26,7 @@ def test_we_can_be_used_to_sign_a_request(self, service, application_instance):
"POST",
url="http://example.com",
data={"param": "value"},
auth=service.get_client(),
auth=service.get_client(application_instance),
)

prepared_request = request.prepare()
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_sign(
uuid.uuid4.return_value.hex = nonce
datetime.now.return_value.timestamp.return_value = timestamp

result = service.sign(url, method, data)
result = service.sign(application_instance, url, method, data)

assert result["oauth_signature_method"] == "HMAC-SHA1"
assert result["oauth_nonce"] == nonce
Expand All @@ -142,13 +142,8 @@ def test_sign(
assert result["oauth_signature"] == signature

@pytest.fixture
def service(self, context, pyramid_request):
return OAuth1Service(context, pyramid_request)

@pytest.fixture
def context(self):
# We don't use context, so it doesn't matter what it is
return mock.sentinel.context
def service(self):
return OAuth1Service()

@pytest.fixture
def uuid(self, patch):
Expand All @@ -161,3 +156,15 @@ def datetime(self, patch):
@pytest.fixture
def OAuth1(self, patch):
return patch("lms.services.oauth1.OAuth1")


class TestFactory:
def test_it(self, pyramid_request, OAuth1Service):
service = factory(sentinel.context, pyramid_request)

OAuth1Service.assert_called_once_with()
assert service == OAuth1Service.return_value

@pytest.fixture
def OAuth1Service(self, patch):
return patch("lms.services.oauth1.OAuth1Service")
3 changes: 2 additions & 1 deletion tests/unit/lms/views/lti/deep_linking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def test_it_for_v11(
title,
opaque_data_lti11,
oauth1_service,
application_instance,
):
misc_plugin.get_deeplinking_launch_url.return_value = "LAUNCH_URL"
pyramid_request.parsed_params["opaque_data_lti11"] = opaque_data_lti11
Expand Down Expand Up @@ -252,7 +253,7 @@ def test_it_for_v11(
expected_fields["data"] = opaque_data_lti11

oauth1_service.sign.assert_called_once_with(
sentinel.return_url, "post", expected_fields
application_instance, sentinel.return_url, "post", expected_fields
)
assert fields == oauth1_service.sign.return_value

Expand Down

0 comments on commit 03d5520

Please sign in to comment.