Skip to content

Commit

Permalink
conditional callback
Browse files Browse the repository at this point in the history
  • Loading branch information
adlius committed Apr 15, 2024
1 parent 3448bba commit bbd4437
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 13 deletions.
25 changes: 25 additions & 0 deletions addons/base/tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from nose.tools import * # noqa
import responses
from rest_framework import status as http_status
from waffle.testutils import override_flag
from urllib.parse import (
urlencode,
urlparse,
urlunparse,
)

from addons.base.tests.base import OAuthAddonTestCaseMixin
from framework.auth import Auth
from framework.exceptions import HTTPError
from osf_tests.factories import AuthUserFactory, ProjectFactory
from osf.utils import permissions
from osf.features import ENABLE_GV
from website.util import api_url_for, web_url_for
from website.settings import GRAVYVALET_URL


class OAuthAddonAuthViewsTestCaseMixin(OAuthAddonTestCaseMixin):
Expand Down Expand Up @@ -54,6 +62,23 @@ def test_oauth_finish(self):
name, args, kwargs = mock_callback.mock_calls[0]
assert_equal(kwargs['user']._id, self.user._id)

@mock.patch('website.oauth.views.requests.get')
def test_oauth_finish_enable_gv(self, mock_requests_get):
url = web_url_for(
'oauth_callback',
service_name=self.ADDON_SHORT_NAME
)
query_params = {
'code': 'somecode',
'state': 'somestatetoken',
}
with override_flag(ENABLE_GV, active=True):
request_url = urlunparse(urlparse(url)._replace(query=urlencode(query_params)))
res = self.app.get(request_url, auth=self.user.auth)
mock_requests_get.assert_called_with(
urlunparse(urlparse(GRAVYVALET_URL)._replace(path='callback', query=urlencode(query_params)))
)

def test_delete_external_account(self):
url = api_url_for(
'oauth_disconnect',
Expand Down
45 changes: 32 additions & 13 deletions website/oauth/views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
# -*- coding: utf-8 -*-

from rest_framework import status as http_status
import waffle
import requests
from urllib.parse import (
urlencode,
urlparse,
urlunparse,
)

from flask import redirect
from flask import redirect, request

from framework.auth.decorators import must_be_logged_in
from framework.exceptions import HTTPError
from osf.models import ExternalAccount
from osf import features
from website.oauth.utils import get_service
from website.oauth.signals import oauth_complete
from website.settings import GRAVYVALET_URL

@must_be_logged_in
def oauth_disconnect(external_account_id, auth):
Expand Down Expand Up @@ -41,17 +50,27 @@ def oauth_connect(service_name, auth):

@must_be_logged_in
def oauth_callback(service_name, auth):
user = auth.user
provider = get_service(service_name)

# Retrieve permanent credentials from provider
if not provider.auth_callback(user=user):
return {}

if provider.account and not user.external_accounts.filter(id=provider.account.id).exists():
user.external_accounts.add(provider.account)
user.save()

oauth_complete.send(provider, account=provider.account, user=user)
if waffle.flag_is_active(request, features.ENABLE_GV):
code = request.args.get('code')
state = request.args.get('state')
query_params = {
'code': code,
'state': state,
}
gv_url = urlunparse(urlparse(GRAVYVALET_URL)._replace(path='callback', query=urlencode(query_params)))
requests.get(gv_url)
else:
user = auth.user
provider = get_service(service_name)

# Retrieve permanent credentials from provider
if not provider.auth_callback(user=user):
return {}

if provider.account and not user.external_accounts.filter(id=provider.account.id).exists():
user.external_accounts.add(provider.account)
user.save()

oauth_complete.send(provider, account=provider.account, user=user)

return {}
1 change: 1 addition & 0 deletions website/settings/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def parent_dir(path):
DEFAULT_HMAC_ALGORITHM = hashlib.sha256
WATERBUTLER_URL = 'http://localhost:7777'
WATERBUTLER_INTERNAL_URL = WATERBUTLER_URL
GRAVYVALET_URL = 'https://localhost:8004'

####################
# Identifiers #
Expand Down
1 change: 1 addition & 0 deletions website/settings/local-dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#WATERBUTLER_URL = 'http://localhost:7777'
#WATERBUTLER_INTERNAL_URL = WATERBUTLER_URL
#GRAVYVALET_URL = 'https://localhost:8004'

PREPRINT_PROVIDER_DOMAINS = {
'enabled': False,
Expand Down

0 comments on commit bbd4437

Please sign in to comment.