Skip to content

Commit

Permalink
Merge pull request #49 from nebari-dev/misc-fixes
Browse files Browse the repository at this point in the history
Misc fixes for Nebari Integration
  • Loading branch information
aktech authored Jan 8, 2024
2 parents 4f4d9f9 + ed6bab4 commit 7c666aa
Show file tree
Hide file tree
Showing 21 changed files with 178 additions and 119 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pip install -e .
Set the following environment variable:

```bash
export JWT_SECRET_KEY=$(openssl rand -hex 32)
export JHUB_APP_JWT_SECRET_KEY=$(openssl rand -hex 32)
```


Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ dependencies:
- streamlit
- tornado>=5.1
- traitlets
- python-slugify
- pip:
- gradio
15 changes: 12 additions & 3 deletions jhub_apps/configuration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from base64 import b64encode
from secrets import token_bytes

Expand Down Expand Up @@ -62,7 +61,11 @@ def install_jhub_apps(c, spawner_to_subclass):
"JHUB_APP_TITLE": c.JAppsConfig.app_title,
"JHUB_APP_ICON": c.JAppsConfig.app_icon,
"JHUB_JUPYTERHUB_CONFIG": c.JAppsConfig.jupyterhub_config_path,
"JWT_SECRET_KEY": os.environ["JWT_SECRET_KEY"],
"JHUB_APP_JWT_SECRET_KEY": _create_token_for_service(),

# Temp environment variables for Nebari Deployment
"PROXY_API_SERVICE_PORT": "*",
"HUB_SERVICE_PORT": "*",
},
"oauth_redirect_uri": oauth_redirect_uri,
"display": False,
Expand All @@ -84,14 +87,20 @@ def install_jhub_apps(c, spawner_to_subclass):
"admin:servers", # start/stop servers
"admin:server_state", # start/stop servers
"admin:server_state", # start/stop servers
"admin:auth_state",
"access:services",
"list:services",
"read:services", # read service models
],
},
{
"name": "user",
# grant all users access to services
"scopes": ["self", "access:services"],
"scopes": [
"self",
"access:services",
"admin:auth_state"
],
},
]

Expand Down
5 changes: 5 additions & 0 deletions jhub_apps/hub_client/hub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,8 @@ def delete_server(self, username, server_name, remove=False):
r = requests.delete(API_URL + url, headers=self._headers(), json=params)
r.raise_for_status()
return r.status_code

def get_services(self):
r = requests.get(API_URL + "/services", headers=self._headers())
r.raise_for_status()
return r.json()
2 changes: 1 addition & 1 deletion jhub_apps/service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

from .service import router
from .routes import router

### When managed by Jupyterhub, the actual endpoints
### will be served out prefixed by /services/:name.
Expand Down
18 changes: 14 additions & 4 deletions jhub_apps/service/auth.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,44 @@
import logging
import os
import typing
from datetime import timedelta, datetime

import jwt
from fastapi import HTTPException, status

logger = logging.getLogger(__name__)


def create_access_token(data: dict, expires_delta: typing.Optional[timedelta] = None):
logger.info("Creating access token")
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
secret_key = os.environ["JWT_SECRET_KEY"]
secret_key = os.environ["JHUB_APP_JWT_SECRET_KEY"]
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm="HS256")
return encoded_jwt


def get_jhub_token_from_jwt_token(token):
logger.info("Trying to get JHub Apps token from JWT Token")
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
detail={
"msg": "Could not validate credentials"
},
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, os.environ["JWT_SECRET_KEY"], algorithms=["HS256"])
payload = jwt.decode(token, os.environ["JHUB_APP_JWT_SECRET_KEY"], algorithms=["HS256"])
access_token_data: dict = payload.get("sub")
if access_token_data is None:
raise credentials_exception
except jwt.PyJWTError:
except jwt.PyJWTError as e:
logger.warning("Authentication failed for token")
logger.exception(e)
raise credentials_exception
logger.info("Fetched access token from JWT Token")
return access_token_data["access_token"]
1 change: 1 addition & 0 deletions jhub_apps/service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class User(BaseModel):
last_activity: Optional[datetime] = None
servers: Optional[Dict[str, Server]] = None
scopes: List[str]
auth_state: Optional[Dict] = None


# https://stackoverflow.com/questions/64501193/fastapi-how-to-use-httpexception-in-responses
Expand Down
22 changes: 18 additions & 4 deletions jhub_apps/service/service.py → jhub_apps/service/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def get_token(code: str):
"Callback function for OAuth2AuthorizationCodeBearer scheme"
# The only thing we need in this form post is the code
# Everything else we can hardcode / pull from env
logger.info(f"Getting token for code {code}")
async with get_client() as client:
redirect_uri = (
os.environ["PUBLIC_HOST"] + os.environ["JUPYTERHUB_OAUTH_CALLBACK_URL"],
Expand All @@ -60,12 +61,13 @@ async def get_token(code: str):
)
### resp.json() is {'access_token': <token>, 'token_type': 'Bearer'}
response = RedirectResponse(os.environ["PUBLIC_HOST"] + "/hub/home", status_code=302)
response.set_cookie(key="access_token",value=access_token, httponly=True)
response.set_cookie(key="access_token", value=access_token, httponly=True)
return response


@router.get("/jhub-login", description="Login via OAuth2")
async def login(request: Request):
logger.info(f"Logging in: {request}")
authorization_url = os.environ["PUBLIC_HOST"] + "/hub/api/oauth2/authorize?response_type=code&client_id=service-japps"
return RedirectResponse(authorization_url, status_code=302)

Expand Down Expand Up @@ -226,20 +228,32 @@ async def get_frameworks(user: User = Depends(get_current_user)):

@router.get("/conda-environments/", description="Get all conda environments")
async def conda_environments(user: User = Depends(get_current_user)):
logging.info("Getting conda environments")
logging.info(f"Getting conda environments for user: {user}")
config = get_jupyterhub_config()
conda_envs = get_conda_envs(config)
logger.info(f"Found conda environments: {conda_envs}")
return conda_envs


@router.get("/spawner-profiles/", description="Get all spawner profiles")
async def spawner_profiles(user: User = Depends(get_current_user)):
logging.info("Getting spawner profiles")
hclient = HubClient()
user_from_service = hclient.get_user(user.name)
auth_state = user_from_service.get("auth_state")
logging.info(f"Getting spawner profiles for user: {user.name}")
config = get_jupyterhub_config()
spawner_profiles_ = get_spawner_profiles(config)
spawner_profiles_ = await get_spawner_profiles(config, auth_state=auth_state)
logger.info(f"Loaded spawner profiles: {spawner_profiles_}")
return spawner_profiles_


@router.get("/services/", description="Get all services")
async def hub_services(user: User = Depends(get_current_user)):
logging.info(f"Getting hub services for user: {user}")
hub_client = HubClient()
return hub_client.get_services()


@router.get(
"/status",
)
Expand Down
51 changes: 45 additions & 6 deletions jhub_apps/service/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import base64
import logging
import os
from unittest.mock import Mock

import requests
from jupyterhub.app import JupyterHub
from traitlets.config import LazyConfigValue

from jhub_apps.spawner.types import FrameworkConf, FRAMEWORKS_MAPPING
from slugify import slugify


logger = logging.getLogger(__name__)
Expand All @@ -15,11 +17,10 @@
def get_jupyterhub_config():
hub = JupyterHub()
jhub_config_file_path = os.environ["JHUB_JUPYTERHUB_CONFIG"]
print(f"Getting JHub config from file: {jhub_config_file_path}")
logger.info(f"Getting JHub config from file: {jhub_config_file_path}")
hub.load_config_file(jhub_config_file_path)
config = hub.config
print(f"JHub config from file: {config}")
print(f"JApps config: {config.JAppsConfig}")
logger.info(f"JApps config: {config.JAppsConfig}")
return config


Expand All @@ -30,24 +31,62 @@ def get_conda_envs(config):
elif isinstance(config.JAppsConfig.conda_envs, LazyConfigValue):
return []
elif callable(config.JAppsConfig.conda_envs):
return config.JAppsConfig.conda_envs()
try:
logger.info("JAppsConfig.conda_envs is a callable, calling now..")
return config.JAppsConfig.conda_envs()
except Exception as e:
logger.exception(e)
return []
else:
raise ValueError(
f"Invalid value for config.JAppsConfig.conda_envs: {config.JAppsConfig.conda_envs}"
)


def get_spawner_profiles(config):
def get_fake_spawner_object(auth_state):
fake_spawner = Mock()

async def get_auth_state():
return auth_state

fake_spawner.user.get_auth_state = get_auth_state
fake_spawner.log = logger
return fake_spawner


def _slugify_profile_list(profile_list):
# This is replicating the following:
# https://github.com/jupyterhub/kubespawner/blob/a4b9b190f0335406c33c6de11b5d1b687842dd89/kubespawner/spawner.py#L3279
# Since we are not inside spawner yet, the profiles might not be slugified yet
if not profile_list:
# empty profile lists are just returned
return profile_list

for profile in profile_list:
# generate missing slug fields from display_name
if 'slug' not in profile:
profile['slug'] = slugify(profile['display_name'])
return profile_list


async def get_spawner_profiles(config, auth_state=None):
"""This will extract spawner profiles from the JupyterHub config
If the Spawner is KubeSpawner
# See: https://jupyterhub-kubespawner.readthedocs.io/en/latest/spawner.html#kubespawner.KubeSpawner.profile_list
"""
profile_list = config.KubeSpawner.profile_list
if isinstance(profile_list, list):
return config.KubeSpawner.profile_list
elif isinstance(profile_list, LazyConfigValue):
return []
elif callable(profile_list):
return profile_list()
try:
logger.info("config.KubeSpawner.profile_list is a callable, calling now..")
profile_list = await profile_list(get_fake_spawner_object(auth_state))
return _slugify_profile_list(profile_list)
except Exception as e:
logger.exception(e)
return []
else:
raise ValueError(
f"Invalid value for config.KubeSpawner.profile_list: {profile_list}"
Expand Down
21 changes: 20 additions & 1 deletion jhub_apps/spawner/spawner_creation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from jhub_apps.spawner.utils import get_origin_host
from jhub_apps.spawner.command import (
EXAMPLES_DIR,
Expand All @@ -10,11 +12,24 @@
from jhub_apps.spawner.types import Framework


logger = logging.getLogger(__name__)


def subclass_spawner(base_spawner):
# TODO: Find a better way to do this
class JHubSpawner(base_spawner):

async def _get_user_auth_state(self):
try:
auth_state = await self.user.get_auth_state()
return auth_state
except Exception as e:
logger.exception(e)

def get_args(self):
"""Return arguments to pass to the notebook server"""
# logger.info("Getting Spawner args")
# self._get_user_auth_state()
argv = super().get_args()
if self.user_options.get("argv"):
argv.extend(self.user_options["argv"])
Expand Down Expand Up @@ -56,6 +71,8 @@ def get_args(self):
return argv

def get_env(self):
# logger.info("Getting spawner environments")
# await self._get_user_auth_state()
env = super().get_env()
if self.user_options.get("env"):
env.update(self.user_options["env"])
Expand All @@ -75,6 +92,8 @@ def get_env(self):
return env

async def start(self):
logger.info("Starting spawner process")
await self._get_user_auth_state()
framework = self.user_options.get("framework")
if (
self.user_options.get("jhub_app")
Expand All @@ -93,7 +112,7 @@ async def start(self):
"-m",
"jupyterhub.singleuser",
]
print(f"Final Spawner Command: {self.cmd}")
logger.info(f"Final Spawner Command: {self.cmd}")
return await super().start()

def _expand_user_vars(self, string):
Expand Down
Loading

0 comments on commit 7c666aa

Please sign in to comment.