diff --git a/changelog.d/372.feature b/changelog.d/372.feature new file mode 100644 index 00000000..7cc2cd4d --- /dev/null +++ b/changelog.d/372.feature @@ -0,0 +1 @@ +FCM v1: use async version of google-auth and add HTTP proxy support. diff --git a/pyproject.toml b/pyproject.toml index a4351ffd..2b6955e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dependencies = [ "attrs>=19.2.0", "cryptography>=2.6.1", "idna>=2.8", - "google-auth>=2.27.0", + "google-auth[aiohttp]>=2.27.0", "jaeger-client>=4.0.0", "matrix-common==1.3.0", "opentracing>=2.2.0", @@ -104,6 +104,7 @@ dev = [ "mypy-zope==1.0.1", "towncrier", "tox", + "google-auth-stubs==0.2.0", "types-opentracing>=2.4.2", "types-pyOpenSSL", "types-PyYAML", diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py index 621c2fd7..010dd591 100644 --- a/sygnal/gcmpushkin.py +++ b/sygnal/gcmpushkin.py @@ -14,18 +14,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import logging +import os import time from enum import Enum from io import BytesIO from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple -import google.auth.transport.requests -from google.oauth2 import service_account +# We are using an unstable async google-auth API, but it's there since 3+ years +# https://github.com/googleapis/google-auth-library-python/issues/613 +import aiohttp +import google.auth.transport._aiohttp_requests +from google.auth._default_async import load_credentials_from_file +from google.oauth2._credentials_async import Credentials from opentracing import Span, logs, tags from prometheus_client import Counter, Gauge, Histogram -from twisted.internet.defer import DeferredSemaphore +from twisted.internet.defer import Deferred, DeferredSemaphore from twisted.web.client import FileBodyProducer, HTTPConnectionPool, readBody from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse @@ -180,10 +186,33 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None: "Must configure `project_id` when using FCM api v1", ) - self.service_account_file = self.get_config("service_account_file", str) - if self.api_version is APIVersion.V1 and not self.service_account_file: - raise PushkinSetupException( - "Must configure `service_account_file` when using FCM api v1", + self.credentials: Optional[Credentials] = None + + if self.api_version is APIVersion.V1: + self.service_account_file = self.get_config("service_account_file", str) + if not self.service_account_file: + raise PushkinSetupException( + "Must configure `service_account_file` when using FCM api v1", + ) + try: + self.credentials, _ = load_credentials_from_file( + str(self.service_account_file), + scopes=AUTH_SCOPES, + ) + except google.auth.exceptions.DefaultCredentialsError as e: + raise PushkinSetupException( + f"`service_account_file` must be valid: {str(e)}", + ) + + session = None + if proxy_url: + # `ClientSession` can't directly take the proxy URL, so we need to + # set the usual env var and use `trust_env=True` + os.environ["HTTPS_PROXY"] = proxy_url + session = aiohttp.ClientSession(trust_env=True, auto_decompress=False) + + self.google_auth_request = google.auth.transport._aiohttp_requests.Request( + session=session ) # Use the fcm_options config dictionary as a foundation for the body; @@ -464,21 +493,26 @@ def _handle_v1_response( f"Unknown GCM response code {response.code}" ) - def _get_access_token(self) -> str: - """Retrieve a valid access token that can be used to authorize requests. + async def _get_auth_header(self) -> str: + """Retrieve the auth header that can be used to authorize requests. - :return: Access token. + :return: Needed content of the `Authorization` header """ - # TODO: Should we use the environment variable approach instead? - # export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json - # credentials, project = google.auth.default(scopes=AUTH_SCOPES) - credentials = service_account.Credentials.from_service_account_file( - str(self.service_account_file), - scopes=AUTH_SCOPES, - ) - request = google.auth.transport.requests.Request() - credentials.refresh(request) - return credentials.token + if self.api_version is APIVersion.Legacy: + return "key=%s" % (self.api_key,) + else: + assert self.credentials is not None + await self._refresh_credentials() + return "Bearer %s" % self.credentials.token + + async def _refresh_credentials(self) -> None: + assert self.credentials is not None + if not self.credentials.valid: + await Deferred.fromFuture( + asyncio.ensure_future( + self.credentials.refresh(self.google_auth_request) + ) + ) async def _dispatch_notification_unlimited( self, n: Notification, device: Device, context: NotificationContext @@ -532,10 +566,7 @@ async def _dispatch_notification_unlimited( "Content-Type": ["application/json"], } - if self.api_version == APIVersion.Legacy: - headers["Authorization"] = ["key=%s" % (self.api_key,)] - elif self.api_version is APIVersion.V1: - headers["Authorization"] = ["Bearer %s" % (self._get_access_token(),)] + headers["Authorization"] = [await self._get_auth_header()] body = self.base_request_body.copy() body["data"] = data diff --git a/tests/test_gcm.py b/tests/test_gcm.py index a5454937..60c3d960 100644 --- a/tests/test_gcm.py +++ b/tests/test_gcm.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import tempfile from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple from unittest.mock import MagicMock -from sygnal.gcmpushkin import GcmPushkin +from sygnal.gcmpushkin import APIVersion, GcmPushkin from tests import testutils from tests.testutils import DummyResponse @@ -79,6 +80,21 @@ } +class TestCredentials: + def __init__(self) -> None: + self.valid = False + + @property + def token(self) -> str: + if self.valid: + return "myaccesstoken" + else: + raise Exception() + + async def refresh(self, request: Any) -> None: + self.valid = True + + class TestGcmPushkin(GcmPushkin): """ A GCM pushkin with the ability to make HTTP requests removed and instead @@ -92,6 +108,8 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]): self.last_request_body: Dict[str, Any] = {} self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type] self.num_requests = 0 + if self.api_version is APIVersion.V1: + self.credentials = TestCredentials() # type: ignore[assignment] def preload_with_response( self, code: int, response_payload: Dict[str, Any] @@ -110,8 +128,27 @@ async def _perform_http_request( # type: ignore[override] self.num_requests += 1 return self.preloaded_response, json.dumps(self.preloaded_response_payload) - def _get_access_token(self) -> str: - return "token" + async def _refresh_credentials(self) -> None: + assert self.credentials is not None + if not self.credentials.valid: + await self.credentials.refresh(self.google_auth_request) + + +FAKE_SERVICE_ACCOUNT_FILE = b""" +{ + "type": "service_account", + "project_id": "project_id", + "private_key_id": "private_key_id", + "private_key": "-----BEGIN PRIVATE KEY-----\\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC0PwE6TeTHjD5R\\nY2nOw1rsTgQZ38LCR2CLtx36n+LUkgej/9b+fwC88oKIqJKjUwn43JEOhf4rbA/a\\nqo4jVoLgv754G5+7Glfarr3/rqg+AVT75x6J5DRvhIYpDXwMIUqLAAbfk3TTFNJn\\n2ctrkBF2ZP9p3mzZ3NRjU63Wbf3LBpRqs8jdFEQu8JAecG8VKV1mboJIXG3hwqFN\\nJmcpC/+sWaxB5iMgSqy0w/rGFs6ZbZF6D10XYvf40lEEk9jQIovT+QD4+6GTlroT\\nbOk8uIwxFQcwMFpXj4MktqVNSNyiuuttptIvBWcMWHlaabXrR89vqUFe1g1Jx4GL\\nCF89RrcLAgMBAAECggEAPUYZ3b8zId78JGDeTEq+8wwGeuFFbRQkrvpeN5/41Xib\\nHlZPuQ5lqtXqKBjeWKVXA4G/0icc45gFv7kxPrQfI9YrItuJLmrjKNU0g+HVEdcU\\nE9pa2Fd6t9peXUBXRixfEee9bm3LTiKK8IDqlTNRrGTjKxNQ/7MBhI6izv1vRH/x\\n8i0o1xxNdqstHZ9wBFKYO9w8UQjtfzckkBNDLkaJ/WN0BoRubmUiV1+KwAyyBr6O\\nRnnZ9Tvy8VraSNSdJhX36ai36y18/sT6PWOp99zHYuDyz89KIz1la/fT9eSoR0Jy\\nYePmTEi+9pWhvtpAkqJkRxe5IDz71JVsQ07KoVfzaQKBgQDzKKUd/0ujhv/B9MQf\\nHcwSeWu/XnQ4hlcwz8dTWQjBV8gv9l4yBj9Pra62rg/tQ7b5XKMt6lv/tWs1IpdA\\neMsySY4972VPrmggKXgCnyKckDUYydNtHAIj9buo6AV8rONaneYnGv5wpSsf3q2c\\nOZrkamRgbBkI+B2mZ2obH1oVlQKBgQC9w9HkrDMvZ5L/ilZmpsvoHNFlQwmDgNlN\\n0ej5QGID5rljRM3CcLNHdyQiKqvLA9MCpPEXb2vVJPdmquD12A7a9s0OwxB/dtOD\\nykofcTY0ZHEM1HEyYJGmdK4FvZuNU4o2/D268dePjtj1Xw3c5fs0bcDiGQMtjWlz\\n5hjBzMsyHwKBgGjrIsPcwlBfEcAo0u7yNnnKNnmuUcuJ+9kt7j3Cbwqty80WKvK+\\ny1agBIECfhDMZQkXtbk8JFIjf4y/zi+db1/VaTDEORy2jmtCOWw4KgEQIDj/7OBp\\nc2r8vupUovl2x+rzsrkw5pTIT+FCffqoyHLCjWkle2/pTzHb8Waekoo5AoGAbELk\\nYy5uwTO45Hr60fOEzzZpq/iz28dNshz4agL2KD2gNGcTcEO1tCbfgXKQsfDLmG2b\\ncgBKJ77AOl1wnDEYQIme8TYOGnojL8Pfx9Jh10AaUvR8Y/49+hYFFhdXQCiR6M69\\nNQM2NJuNYWdKVGUMjJu0+AjHDFzp9YonQ6Ffp4cCgYEAmVALALCjU9GjJymgJ0lx\\nD9LccVHMwf9NmR/sMg0XNePRbCEcMDHKdtVJ1zPGS5txuxY3sRb/tDpv7TfuitrU\\nAw0/2ooMzunaoF/HXo+C/+t+pfuqPqLK4sCCyezUlMfCcaPdwXN2FmbgsaFHfe7I\\n7sGEnS/d8wEgydMiptJEf9s=\\n-----END PRIVATE KEY-----\\n", + "client_email": "firebase-adminsdk@project_id.iam.gserviceaccount.com", + "client_id": "client_id", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk%40project_id.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +} +""" class GcmTestCase(testutils.TestCase): @@ -128,11 +165,14 @@ def config_setup(self, config: Dict[str, Any]) -> None: "api_key": "kii", "fcm_options": {"content_available": True, "mutable_content": True}, } + self.service_account_file = tempfile.NamedTemporaryFile() + self.service_account_file.write(FAKE_SERVICE_ACCOUNT_FILE) + self.service_account_file.flush() config["apps"]["com.example.gcm.apiv1"] = { "type": "tests.test_gcm.TestGcmPushkin", "api_version": "v1", "project_id": "example_project", - "service_account_file": "/path/to/file.json", + "service_account_file": self.service_account_file.name, "fcm_options": { "apns": { "payload": { @@ -146,6 +186,9 @@ def config_setup(self, config: Dict[str, Any]) -> None: }, } + def tearDown(self) -> None: + self.service_account_file.close() + def get_test_pushkin(self, name: str) -> TestGcmPushkin: pushkin = self.sygnal.pushkins[name] assert isinstance(pushkin, TestGcmPushkin) @@ -260,6 +303,10 @@ def test_expected_api_v1(self) -> None: ) self.assertEqual(resp, {"rejected": []}) + assert notification_req[3] is not None + self.assertEqual( + notification_req[3].get("Authorization"), ["Bearer myaccesstoken"] + ) def test_expected_with_default_payload(self) -> None: """