Skip to content

Commit

Permalink
Add support for passing admin_emails via configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
hellais committed Mar 15, 2024
1 parent de32ea5 commit 31eb88a
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 92 deletions.
5 changes: 5 additions & 0 deletions ooniapi/common/src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class Settings(BaseSettings):
session_expiry_days: int = 10
login_expiry_days: int = 10

admin_emails: List[str] = [
"[email protected]",
"[email protected]",
]

aws_region: str = ""
aws_access_key_id: str = ""
aws_secret_access_key: str = ""
Expand Down
1 change: 0 additions & 1 deletion ooniapi/services/ooniauth/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers = [
]
dependencies = [
"fastapi ~= 0.108.0",
"clickhouse-driver ~= 0.2.6",
"sqlalchemy ~= 2.0.27",
"ujson ~= 5.9.0",
"python-dateutil ~= 2.8.2",
Expand Down
7 changes: 0 additions & 7 deletions ooniapi/services/ooniauth/src/ooniauth/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Annotated

from clickhouse_driver import Client as ClickhouseClient
import boto3

from fastapi import Depends
Expand All @@ -9,12 +8,6 @@
from .common.config import Settings


def get_clickhouse_client(
settings: Annotated[Settings, Depends(get_settings)]
) -> ClickhouseClient:
return ClickhouseClient.from_url(settings.clickhouse_url)


def get_ses_client(settings: Annotated[Settings, Depends(get_settings)]):
return boto3.client(
"ses",
Expand Down
21 changes: 8 additions & 13 deletions ooniapi/services/ooniauth/src/ooniauth/routers/v1.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
"""
OONIRun link management
https://github.com/ooni/spec/blob/master/backends/bk-005-ooni-run-v2.md
"""

from datetime import datetime, timedelta, timezone
from typing import Optional
from urllib.parse import urlparse, urlencode, urlunsplit
Expand All @@ -16,7 +10,7 @@
from pydantic import EmailStr
from typing_extensions import Annotated

from ..dependencies import get_clickhouse_client, get_ses_client
from ..dependencies import get_ses_client

from ..utils import (
create_session_token,
Expand Down Expand Up @@ -127,7 +121,6 @@ async def user_login(
Query(alias="k", description="JWT token with aud=register"),
],
settings: Settings = Depends(get_settings),
db: Settings = Depends(get_clickhouse_client),
):
"""Auth Services: login using a registration/login link"""
try:
Expand All @@ -146,21 +139,23 @@ async def user_login(
log.info("user login successful")

# Store account role in token to prevent frequent DB lookups
role = get_account_role(db=db, account_id=dec["account_id"]) or "user"
email_address = dec["email_address"]
role = get_account_role(
admin_emails=settings.admin_emails, email_address=email_address
)
redirect_to = dec.get("redirect_to", "")
email = dec["email_address"]

token = create_session_token(
key=settings.jwt_encryption_key,
account_id=dec["account_id"],
email_address=email_address,
role=role,
session_expiry_days=settings.session_expiry_days,
login_expiry_days=settings.login_expiry_days,
)
return SessionTokenCreate(
bearer=token,
redirect_to=redirect_to,
email_address=email,
email_address=email_address,
)


Expand All @@ -187,7 +182,7 @@ async def user_refresh_token(

newtoken = create_session_token(
key=settings.jwt_encryption_key,
account_id=tok["account_id"],
email_address=tok["email_address"],
role=tok["role"],
session_expiry_days=settings.session_expiry_days,
login_expiry_days=settings.login_expiry_days,
Expand Down
38 changes: 12 additions & 26 deletions ooniapi/services/ooniauth/src/ooniauth/routers/v2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
"""
OONIRun link management
https://github.com/ooni/spec/blob/master/backends/bk-005-ooni-run-v2.md
"""

from datetime import datetime, timedelta, timezone
from typing import Optional
from typing import List, Optional
from urllib.parse import urlparse
import logging

Expand All @@ -16,7 +10,7 @@
from pydantic import EmailStr
from typing_extensions import Annotated

from ..dependencies import get_clickhouse_client, get_ses_client
from ..dependencies import get_ses_client

from ..utils import (
create_session_token,
Expand All @@ -26,7 +20,7 @@
format_login_url,
VALID_REDIRECT_TO_FQDN,
)
from ..common.dependencies import get_settings, role_required
from ..common.dependencies import get_settings
from ..common.config import Settings
from ..common.routers import BaseModel
from ..common.utils import (
Expand Down Expand Up @@ -113,7 +107,6 @@ async def create_user_login(


class UserSession(BaseModel):
account_id: str
session_token: str
redirect_to: str
email_address: str
Expand All @@ -123,18 +116,17 @@ class UserSession(BaseModel):


def maybe_get_user_session_from_header(
db, authorization_header: str, jwt_encryption_key: str
authorization_header: str, jwt_encryption_key: str, admin_emails: List[str]
) -> Optional[UserSession]:
token = get_client_token(
authorization=authorization_header, jwt_encryption_key=jwt_encryption_key
)
if token is None:
return None

account_id = token["account_id"]
role = get_account_role(account_id=account_id, db=db) or "user"
login_time = datetime.fromtimestamp(token["login_time"])
email_address = token["email_address"]
role = get_account_role(admin_emails=admin_emails, email_address=email_address)
login_time = datetime.fromtimestamp(token["login_time"])
redirect_to = ""

return UserSession(
Expand All @@ -143,29 +135,27 @@ def maybe_get_user_session_from_header(
email_address=email_address,
role=role,
login_time=login_time,
account_id=account_id,
is_logged_in=True,
)


def get_user_session_from_login_token(
login_token: str, db, jwt_encryption_key: str
login_token: str, jwt_encryption_key: str, admin_emails: List[str]
) -> UserSession:
try:
d = decode_jwt(
token=login_token,
key=jwt_encryption_key,
audience="register",
)
account_id = d["account_id"]
role = get_account_role(db=db, account_id=account_id) or "user"
email_address = d["email_address"]
role = get_account_role(admin_emails=admin_emails, email_address=email_address)
return UserSession(
session_token="",
redirect_to=d["redirect_to"],
email_address=d["email_address"],
role=role,
login_time=datetime.now(timezone.utc),
account_id=account_id,
)
except (
jwt.exceptions.MissingRequiredClaimError,
Expand All @@ -188,19 +178,18 @@ async def create_user_session(
req: Optional[CreateUserSession] = None,
authorization: str = Header("authorization"),
settings: Settings = Depends(get_settings),
db: Settings = Depends(get_clickhouse_client),
):
"""Auth Services: login using a registration/login link"""
if req and req.login_token:
user_session = get_user_session_from_login_token(
login_token=req.login_token,
db=db,
admin_emails=settings.admin_emails,
jwt_encryption_key=settings.jwt_encryption_key,
)
else:
user_session = maybe_get_user_session_from_header(
authorization_header=authorization,
db=db,
admin_emails=settings.admin_emails,
jwt_encryption_key=settings.jwt_encryption_key,
)

Expand All @@ -210,7 +199,6 @@ async def create_user_session(
assert user_session.login_time
user_session.session_token = create_session_token(
key=settings.jwt_encryption_key,
account_id=user_session.account_id,
role=user_session.role,
session_expiry_days=settings.session_expiry_days,
login_expiry_days=settings.login_expiry_days,
Expand All @@ -225,11 +213,10 @@ async def create_user_session(
async def get_user_session(
authorization: str = Header("authorization"),
settings: Settings = Depends(get_settings),
db: Settings = Depends(get_clickhouse_client),
):
user_session = maybe_get_user_session_from_header(
authorization_header=authorization,
db=db,
admin_emails=settings.admin_emails,
jwt_encryption_key=settings.jwt_encryption_key,
)
if not user_session:
Expand All @@ -239,7 +226,6 @@ async def get_user_session(
email_address="",
role="",
login_time=None,
account_id="",
is_logged_in=False,
)
return user_session
18 changes: 7 additions & 11 deletions ooniapi/services/ooniauth/src/ooniauth/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import hashlib
import time
from typing import Optional
from typing import List, Optional
from textwrap import dedent
from urllib.parse import urlencode, urlparse, urlunsplit

import sqlalchemy as sa

from .common.utils import create_jwt, query_click_one_row
from .common.utils import create_jwt

VALID_REDIRECT_TO_FQDN = (
"explorer.ooni.org",
Expand All @@ -26,12 +26,11 @@ def format_login_url(redirect_to: str, registration_token: str) -> str:

def create_session_token(
key: str,
account_id: str,
email_address: str,
role: str,
session_expiry_days: int,
login_expiry_days: int,
login_time: Optional[int] = None,
email_address: Optional[str] = None,
) -> str:
now = int(time.time())
session_exp = now + session_expiry_days * 86400
Expand All @@ -44,20 +43,17 @@ def create_session_token(
"iat": now,
"exp": exp,
"aud": "user_auth",
"account_id": account_id,
"login_time": login_time,
"role": role,
"email_address": email_address,
}
return create_jwt(payload=payload, key=key)


def get_account_role(db, account_id: str) -> Optional[str]:
"""Get account role from database, or None"""
query = "SELECT role FROM accounts WHERE account_id = :account_id"
query_params = dict(account_id=account_id)
r = query_click_one_row(db, sa.text(query), query_params)
return r["role"] if r else None
def get_account_role(admin_emails: List[str], email_address: str) -> str:
if email_address in admin_emails:
return "admin"
return "user"


def hash_email_address(email_address: str, key: str) -> str:
Expand Down
16 changes: 6 additions & 10 deletions ooniapi/services/ooniauth/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from unittest.mock import MagicMock
import pytest

import time
import jwt

from fastapi.testclient import TestClient

from ooniauth.common.config import Settings
from ooniauth.common.dependencies import get_settings
from ooniauth.dependencies import get_ses_client, get_clickhouse_client
from ooniauth.dependencies import get_ses_client
from ooniauth.utils import hash_email_address
from ooniauth.main import app

Expand All @@ -23,7 +20,7 @@ def override_get_settings():
@pytest.fixture
def client_with_bad_settings():
app.dependency_overrides[get_settings] = make_override_get_settings(
postgresql_url="postgresql://bad:bad@localhost/bad"
postgresql_url="postgresql://bad:bad@localhost/bad",
)

client = TestClient(app)
Expand All @@ -32,12 +29,13 @@ def client_with_bad_settings():

@pytest.fixture
def user_email():
return "[email protected]"
# NSA shall never be an admin user, lol
return "[email protected]"


@pytest.fixture
def admin_email():
return "dev+adminaccount@ooni.org"
return "admin@ooni.org"


@pytest.fixture
Expand Down Expand Up @@ -95,6 +93,7 @@ def client(
email_source_address=email_source_address,
account_id_hashing_key=account_id_hashing_key,
aws_access_key_id="ITSCHANGED",
admin_emails=[admin_email],
aws_secret_access_key="ITSCHANGED",
)
mock_clickhouse = MagicMock()
Expand All @@ -112,8 +111,5 @@ def mock_execute(query, query_params, with_column_types, settings):

return [("user",)], [("role", "String")]

mock_clickhouse.execute = mock_execute
app.dependency_overrides[get_clickhouse_client] = lambda: mock_clickhouse

client = TestClient(app)
yield client
12 changes: 0 additions & 12 deletions ooniapi/services/ooniauth/tests/test_auth_v1.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
"""
Integration test for Auth API
Warning: this test runs against a real database and SMTP
Lint using:
black -t py37 -l 100 --fast ooniapi/tests/integ/test_probe_services.py
Test using:
pytest-3 -s --show-capture=no ooniapi/tests/integ/test_integration_auth.py
"""

from urllib.parse import parse_qs, urlparse
from ooniauth.common.utils import decode_jwt
from ooniauth.main import app
Expand Down
12 changes: 0 additions & 12 deletions ooniapi/services/ooniauth/tests/test_auth_v2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
"""
Integration test for Auth API
Warning: this test runs against a real database and SMTP
Lint using:
black -t py37 -l 100 --fast ooniapi/tests/integ/test_probe_services.py
Test using:
pytest-3 -s --show-capture=no ooniapi/tests/integ/test_integration_auth.py
"""

from urllib.parse import parse_qs, urlparse
from ooniauth.common.utils import decode_jwt
from ooniauth.main import app
Expand Down

0 comments on commit 31eb88a

Please sign in to comment.