From 2673448000cf5e404529e4794eebffc00d9bc9fd Mon Sep 17 00:00:00 2001 From: Daven Quinn Date: Thu, 17 Oct 2024 23:12:50 -0500 Subject: [PATCH] Added localhost token generation for development --- api/routes/security.py | 137 +++++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/api/routes/security.py b/api/routes/security.py index 4abcac2..88513dd 100644 --- a/api/routes/security.py +++ b/api/routes/security.py @@ -30,7 +30,7 @@ ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours GROUP_TOKEN_LENGTH = 32 -GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent +GROUP_TOKEN_SALT = b"$2b$12$yQrslvQGWDFjwmDBMURAUe" # Hardcode salt so hashes are consistent class Token(BaseModel): @@ -59,10 +59,12 @@ class GroupTokenRequest(BaseModel): expiration: int group_id: int + access_token_key = "access_token" # Coming soon # refresh_token_key = "refresh_token" + class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer): """Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header""" @@ -78,9 +80,7 @@ async def __call__(self, request: Request) -> Optional[str]: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated", - headers={ - "WWW-Authenticate": "Bearer" - }, + headers={"WWW-Authenticate": "Bearer"}, ) else: return None # pragma: nocover @@ -88,9 +88,7 @@ async def __call__(self, request: Request) -> Optional[str]: oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie( - authorizationUrl='/security/login', - tokenUrl="/security/callback", - auto_error=False + authorizationUrl="/security/login", tokenUrl="/security/callback", auto_error=False ) http_bearer = HTTPBearer(auto_error=False) @@ -98,23 +96,20 @@ async def __call__(self, request: Request) -> Optional[str]: router = APIRouter( prefix="/security", tags=["security"], - responses={ - 404: { - "description": "Not found" - } - }, + responses={404: {"description": "Not found"}}, ) async def get_groups_from_header_token( - header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None: + header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)] +) -> int | None: """Get the groups from the bearer token in the header""" if header_token is None: return None token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT) - token_hash_string = token_hash.decode('utf-8') + token_hash_string = token_hash.decode("utf-8") engine = db.get_engine() async_session = db.get_async_session(engine) @@ -134,10 +129,7 @@ async def get_user(sub: str) -> schemas.User | None: async_session = db.get_async_session(engine) async with async_session() as session: - stmt = ( - select(schemas.User) - .where(schemas.User.sub == sub) - ) + stmt = select(schemas.User).where(schemas.User.sub == sub) user = await session.scalar(stmt) @@ -167,7 +159,9 @@ async def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2 return None try: - payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']]) + payload = jwt.decode( + token, os.environ["SECRET_KEY"], algorithms=[os.environ["JWT_ENCRYPTION_ALGORITHM"]] + ) sub: str = payload.get("sub") groups = payload.get("groups", []) token_data = TokenData(sub=sub, groups=groups) @@ -178,8 +172,8 @@ async def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2 async def get_groups( - user_token_data: TokenData | None = Depends(get_user_token_from_cookie), - header_token: int | None = Depends(get_groups_from_header_token) + user_token_data: TokenData | None = Depends(get_user_token_from_cookie), + header_token: int | None = Depends(get_groups_from_header_token), ) -> list[int]: """Get the groups from both the cookies and header""" @@ -196,7 +190,7 @@ async def get_groups( async def has_access(groups: list[int] = Depends(get_groups)) -> bool: """Check if the user has access to the group""" - if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development': + if "ENVIRONMENT" in os.environ and os.environ["ENVIRONMENT"] == "development": return True return 1 in groups @@ -210,10 +204,10 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - to_encode.update({ - "exp": expire - }) - encoded_jwt = jwt.encode(to_encode, os.environ['SECRET_KEY'], algorithm=os.environ['JWT_ENCRYPTION_ALGORITHM']) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode( + to_encode, os.environ["SECRET_KEY"], algorithm=os.environ["JWT_ENCRYPTION_ALGORITHM"] + ) return encoded_jwt @@ -222,63 +216,70 @@ async def redirect_authorization(return_url: str = None): """Redirect to the authorization URL with the appropriate parameters""" params = { - 'scope': "openid profile email", - 'client_id': os.environ['OAUTH_CLIENT_ID'], - 'response_type': "code", - 'redirect_uri': os.environ['REDIRECT_URI'] + "scope": "openid profile email", + "client_id": os.environ["OAUTH_CLIENT_ID"], + "response_type": "code", + "redirect_uri": os.environ["REDIRECT_URI"], } if return_url is not None: - params['state'] = return_url + params["state"] = return_url - return RedirectResponse(os.environ['OAUTH_AUTHORIZATION_URL'] + "?" + urllib.parse.urlencode(params)) + return RedirectResponse( + os.environ["OAUTH_AUTHORIZATION_URL"] + "?" + urllib.parse.urlencode(params) + ) @router.get("/callback") async def redirect_callback(code: str, state: Optional[str] = None): """Exchange the code for a token and redirect to the state URL""" - uri = os.environ['REDIRECT_URI'] + uri = os.environ["REDIRECT_URI"] data = { - 'grant_type': 'authorization_code', - 'client_id': os.environ['OAUTH_CLIENT_ID'], - 'client_secret': os.environ['OAUTH_CLIENT_SECRET'], - 'code': code, - 'redirect_uri': uri + "grant_type": "authorization_code", + "client_id": os.environ["OAUTH_CLIENT_ID"], + "client_secret": os.environ["OAUTH_CLIENT_SECRET"], + "code": code, + "redirect_uri": uri, } # Get the domain for the redirect URL parsed_url = urllib.parse.urlparse(uri) domain = parsed_url.netloc - async with aiohttp.ClientSession() as session: - async with session.post(os.environ['OAUTH_TOKEN_URL'], data=data) as token_response: + async with session.post(os.environ["OAUTH_TOKEN_URL"], data=data) as token_response: if token_response.status != 200: - raise HTTPException(status_code=400, detail=f"Invalid code: {await token_response.text()} ") + raise HTTPException( + status_code=400, detail=f"Invalid code: {await token_response.text()} " + ) response_data = await token_response.json() - async with session.post(os.environ['OAUTH_USERINFO_URL'], data=response_data) as user_response: + async with session.post( + os.environ["OAUTH_USERINFO_URL"], data=response_data + ) as user_response: if user_response.status != 200: - raise HTTPException(status_code=400, - detail=f"Couldn't get user information: {await user_response.text()} ") + raise HTTPException( + status_code=400, + detail=f"Couldn't get user information: {await user_response.text()} ", + ) user_data = await user_response.json() - user = await get_user(user_data['sub']) + user = await get_user(user_data["sub"]) if user is None: - given_name = user_data.get('given_name') if user_data.get('given_name') else "" - family_name = user_data.get('family_name') if user_data.get('family_name') else "" + given_name = user_data.get("given_name") if user_data.get("given_name") else "" + family_name = ( + user_data.get("family_name") if user_data.get("family_name") else "" + ) user = await create_user( - user_data['sub'], - f"{given_name} {family_name}", - user_data.get('email', '') + user_data["sub"], f"{given_name} {family_name}", user_data.get("email", "") ) names = [group.name for group in user.groups] @@ -293,40 +294,56 @@ async def redirect_callback(code: str, state: Optional[str] = None): "sub": user.sub, "role": role, # For PostgREST "groups": [group.id for group in user.groups], - "group_names": names + "group_names": names, } ) response = RedirectResponse(state if state else "/") redirect_domain = urllib.parse.urlparse(state).netloc + details = dict( + key=access_token_key, + value=f"Bearer {access_token}", + httponly=True, + samesite="lax", + ) + # Set a cookie for the API domain - response.set_cookie(key=access_token_key, value=f"Bearer {access_token}", httponly=True, samesite="lax", - domain=domain) + response.set_cookie(**details, domain=domain) + if "localhost" in redirect_domain: + # Set a cookie for the localhost redirect. + # We may want to limit this to the development environment in the future. + response.set_cookie(**details, domain=redirect_domain) return response @router.post("/token", response_model=AccessToken) -async def create_group_token(group_token_request: GroupTokenRequest, - user_token: TokenData = Depends(get_user_token_from_cookie)): +async def create_group_token( + group_token_request: GroupTokenRequest, + user_token: TokenData = Depends(get_user_token_from_cookie), +): """Get an access token for the current user""" if group_token_request.group_id not in user_token.groups: - raise HTTPException(status_code=401, - detail=f"User cannot create tokens for group {group_token_request.group_id}") + raise HTTPException( + status_code=401, + detail=f"User cannot create tokens for group {group_token_request.group_id}", + ) engine = db.get_engine() - token = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(GROUP_TOKEN_LENGTH)) + token = "".join( + secrets.choice(string.ascii_letters + string.digits) for i in range(GROUP_TOKEN_LENGTH) + ) token_hash = bcrypt.hashpw(token.encode("utf-8"), GROUP_TOKEN_SALT) - token_hash_string = token_hash.decode('utf-8') + token_hash_string = token_hash.decode("utf-8") await db.insert_access_token( engine=engine, token=token_hash_string, group_id=group_token_request.group_id, - expiration=datetime.fromtimestamp(group_token_request.expiration) + expiration=datetime.fromtimestamp(group_token_request.expiration), ) return AccessToken(group=group_token_request.group_id, token=token)