From bbd4437f7567f4de3966d0babd6e2f7e49beb4d9 Mon Sep 17 00:00:00 2001 From: Yuhuai Liu Date: Mon, 15 Apr 2024 13:20:17 -0400 Subject: [PATCH] conditional callback --- addons/base/tests/views.py | 25 +++++++++++++++++++ website/oauth/views.py | 45 ++++++++++++++++++++++++---------- website/settings/defaults.py | 1 + website/settings/local-dist.py | 1 + 4 files changed, 59 insertions(+), 13 deletions(-) diff --git a/addons/base/tests/views.py b/addons/base/tests/views.py index 9a78ca49cdf..f1e88d0dbd5 100644 --- a/addons/base/tests/views.py +++ b/addons/base/tests/views.py @@ -3,13 +3,21 @@ from nose.tools import * # noqa import responses from rest_framework import status as http_status +from waffle.testutils import override_flag +from urllib.parse import ( + urlencode, + urlparse, + urlunparse, +) from addons.base.tests.base import OAuthAddonTestCaseMixin from framework.auth import Auth from framework.exceptions import HTTPError from osf_tests.factories import AuthUserFactory, ProjectFactory from osf.utils import permissions +from osf.features import ENABLE_GV from website.util import api_url_for, web_url_for +from website.settings import GRAVYVALET_URL class OAuthAddonAuthViewsTestCaseMixin(OAuthAddonTestCaseMixin): @@ -54,6 +62,23 @@ def test_oauth_finish(self): name, args, kwargs = mock_callback.mock_calls[0] assert_equal(kwargs['user']._id, self.user._id) + @mock.patch('website.oauth.views.requests.get') + def test_oauth_finish_enable_gv(self, mock_requests_get): + url = web_url_for( + 'oauth_callback', + service_name=self.ADDON_SHORT_NAME + ) + query_params = { + 'code': 'somecode', + 'state': 'somestatetoken', + } + with override_flag(ENABLE_GV, active=True): + request_url = urlunparse(urlparse(url)._replace(query=urlencode(query_params))) + res = self.app.get(request_url, auth=self.user.auth) + mock_requests_get.assert_called_with( + urlunparse(urlparse(GRAVYVALET_URL)._replace(path='callback', query=urlencode(query_params))) + ) + def test_delete_external_account(self): url = api_url_for( 'oauth_disconnect', diff --git a/website/oauth/views.py b/website/oauth/views.py index 401b05ab3c4..39fcd19fc46 100644 --- a/website/oauth/views.py +++ b/website/oauth/views.py @@ -1,14 +1,23 @@ # -*- coding: utf-8 -*- from rest_framework import status as http_status +import waffle +import requests +from urllib.parse import ( + urlencode, + urlparse, + urlunparse, +) -from flask import redirect +from flask import redirect, request from framework.auth.decorators import must_be_logged_in from framework.exceptions import HTTPError from osf.models import ExternalAccount +from osf import features from website.oauth.utils import get_service from website.oauth.signals import oauth_complete +from website.settings import GRAVYVALET_URL @must_be_logged_in def oauth_disconnect(external_account_id, auth): @@ -41,17 +50,27 @@ def oauth_connect(service_name, auth): @must_be_logged_in def oauth_callback(service_name, auth): - user = auth.user - provider = get_service(service_name) - - # Retrieve permanent credentials from provider - if not provider.auth_callback(user=user): - return {} - - if provider.account and not user.external_accounts.filter(id=provider.account.id).exists(): - user.external_accounts.add(provider.account) - user.save() - - oauth_complete.send(provider, account=provider.account, user=user) + if waffle.flag_is_active(request, features.ENABLE_GV): + code = request.args.get('code') + state = request.args.get('state') + query_params = { + 'code': code, + 'state': state, + } + gv_url = urlunparse(urlparse(GRAVYVALET_URL)._replace(path='callback', query=urlencode(query_params))) + requests.get(gv_url) + else: + user = auth.user + provider = get_service(service_name) + + # Retrieve permanent credentials from provider + if not provider.auth_callback(user=user): + return {} + + if provider.account and not user.external_accounts.filter(id=provider.account.id).exists(): + user.external_accounts.add(provider.account) + user.save() + + oauth_complete.send(provider, account=provider.account, user=user) return {} diff --git a/website/settings/defaults.py b/website/settings/defaults.py index 882a5713592..ce44070357b 100644 --- a/website/settings/defaults.py +++ b/website/settings/defaults.py @@ -320,6 +320,7 @@ def parent_dir(path): DEFAULT_HMAC_ALGORITHM = hashlib.sha256 WATERBUTLER_URL = 'http://localhost:7777' WATERBUTLER_INTERNAL_URL = WATERBUTLER_URL +GRAVYVALET_URL = 'https://localhost:8004' #################### # Identifiers # diff --git a/website/settings/local-dist.py b/website/settings/local-dist.py index 85a9b420db5..1c24eba81b1 100644 --- a/website/settings/local-dist.py +++ b/website/settings/local-dist.py @@ -23,6 +23,7 @@ #WATERBUTLER_URL = 'http://localhost:7777' #WATERBUTLER_INTERNAL_URL = WATERBUTLER_URL +#GRAVYVALET_URL = 'https://localhost:8004' PREPRINT_PROVIDER_DOMAINS = { 'enabled': False,