Skip to content

Commit

Permalink
Merge branch 'development_fl_server' of github.com:threefoldtech/rfs …
Browse files Browse the repository at this point in the history
…into development_support_conversion_progress
  • Loading branch information
rawdaGastan committed Sep 18, 2024
2 parents 2523b5c + 2efda5e commit 225fb8a
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 70 deletions.
39 changes: 21 additions & 18 deletions fl-server/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::sync::Arc;

use axum::{
extract::{Json, Request, State},
http::{self, StatusCode},
middleware::Next,
response::IntoResponse,
Extension,
};
use axum_macros::debug_handler;
use chrono::{Duration, Utc};
Expand All @@ -16,12 +17,6 @@ use crate::{
response::{ResponseError, ResponseResult},
};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub username: String,
pub password: String,
}

#[derive(Serialize, Deserialize)]
pub struct Claims {
pub exp: usize, // Expiry time of the token
Expand Down Expand Up @@ -52,10 +47,10 @@ pub struct SignInResponse {
)]
#[debug_handler]
pub async fn sign_in_handler(
Extension(cfg): Extension<config::Config>,
State(state): State<Arc<config::AppState>>,
Json(user_data): Json<SignInBody>,
) -> impl IntoResponse {
let user = match get_user_by_username(&cfg.users, &user_data.username) {
let user = match state.db.get_user_by_username(&user_data.username) {
Some(user) => user,
None => {
return Err(ResponseError::Unauthorized(
Expand All @@ -70,18 +65,18 @@ pub async fn sign_in_handler(
));
}

let token = encode_jwt(user.username.clone(), cfg.jwt_secret, cfg.jwt_expire_hours)
.map_err(|_| ResponseError::InternalServerError)?;
let token = encode_jwt(
user.username.clone(),
state.config.jwt_secret.clone(),
state.config.jwt_expire_hours,
)
.map_err(|_| ResponseError::InternalServerError)?;

Ok(ResponseResult::SignedIn(SignInResponse {
access_token: token,
}))
}

pub fn get_user_by_username<'a>(users: &'a [User], username: &str) -> Option<&'a User> {
users.iter().find(|u| u.username == username)
}

pub fn encode_jwt(
username: String,
jwt_secret: String,
Expand Down Expand Up @@ -111,7 +106,7 @@ pub fn decode_jwt(jwt_token: String, jwt_secret: String) -> Result<TokenData<Cla
}

pub async fn authorize(
State(cfg): State<config::Config>,
State(state): State<Arc<config::AppState>>,
mut req: Request,
next: Next,
) -> impl IntoResponse {
Expand All @@ -128,7 +123,15 @@ pub async fn authorize(

let mut header = auth_header.split_whitespace();
let (_, token) = (header.next(), header.next());
let token_data = match decode_jwt(token.unwrap().to_string(), cfg.jwt_secret) {
let token_str = match token {
Some(t) => t.to_string(),
None => {
log::error!("failed to get token string");
return Err(ResponseError::InternalServerError);
}
};

let token_data = match decode_jwt(token_str, state.config.jwt_secret.clone()) {
Ok(data) => data,
Err(_) => {
return Err(ResponseError::Forbidden(
Expand All @@ -137,7 +140,7 @@ pub async fn authorize(
}
};

let current_user = match get_user_by_username(&cfg.users, &token_data.claims.username) {
let current_user = match state.db.get_user_by_username(&token_data.claims.username) {
Some(user) => user,
None => {
return Err(ResponseError::Unauthorized(
Expand Down
19 changes: 14 additions & 5 deletions fl-server/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fs, path::PathBuf, sync::Mutex};
use std::{
collections::HashMap,
fs,
path::PathBuf,
sync::{Arc, Mutex},
};
use utoipa::ToSchema;

use crate::{auth, handlers};
use crate::{
db::{User, DB},
handlers,
};

#[derive(Debug, ToSchema, Serialize, Clone)]
pub struct Job {
pub id: String,
}

#[derive(Debug, ToSchema)]
#[derive(ToSchema)]
pub struct AppState {
pub jobs_state: Mutex<HashMap<String, handlers::FlistState>>,
pub flists_progress: Mutex<HashMap<PathBuf, f32>>,
pub db: Arc<dyn DB>,
pub config: Config,
}

#[derive(Debug, Default, Clone, Deserialize)]
Expand All @@ -25,8 +35,7 @@ pub struct Config {

pub jwt_secret: String,
pub jwt_expire_hours: i64,

pub users: Vec<auth::User>,
pub users: Vec<User>,
}

/// Parse the config file into Config struct.
Expand Down
31 changes: 31 additions & 0 deletions fl-server/src/db.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub username: String,
pub password: String,
}

pub trait DB: Send + Sync {
fn get_user_by_username(&self, username: &str) -> Option<&User>;
}

#[derive(Debug, ToSchema)]
pub struct VecDB {
users: Vec<User>,
}

impl VecDB {
pub fn new(users: &[User]) -> Self {
Self {
users: users.to_vec(),
}
}
}

impl DB for VecDB {
fn get_user_by_username(&self, username: &str) -> Option<&User> {
self.users.iter().find(|u| u.username == username)
}
}
72 changes: 41 additions & 31 deletions fl-server/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ pub async fn health_check_handler() -> ResponseResult {
#[debug_handler]
pub async fn create_flist_handler(
State(state): State<Arc<config::AppState>>,
Extension(cfg): Extension<config::Config>,
Extension(username): Extension<String>,
Json(body): Json<FlistBody>,
) -> impl IntoResponse {
let cfg = state.config.clone();
let credentials = Some(DockerCredentials {
username: body.username,
password: body.password,
Expand Down Expand Up @@ -151,22 +151,30 @@ pub async fn create_flist_handler(
};
let current_job = job.clone();

state.jobs_state.lock().unwrap().insert(
job.id.clone(),
FlistState::Accepted(format!("flist '{}' is accepted", fl_name)),
);
state
.flists_progress
.jobs_state
.lock()
.unwrap()
.insert(fl_path.clone(), 0.0);

tokio::spawn(async move {
state.jobs_state.lock().unwrap().insert(
.expect("failed to lock state")
.insert(
job.id.clone(),
FlistState::Started(format!("flist '{}' is started", fl_name)),
FlistState::Accepted(format!("flist '{}' is accepted", &fl_name)),
);

let flist_download_url = std::path::Path::new(&format!("{}:{}", cfg.host, cfg.port))
.join(cfg.flist_dir)
.join(username)
.join(&fl_name);

tokio::spawn(async move {
state
.jobs_state
.lock()
.expect("failed to lock state")
.insert(
job.id.clone(),
FlistState::Started(format!("flist '{}' is started", fl_name)),
);

let container_name = Uuid::new_v4().to_string();
let docker_tmp_dir = tempdir::TempDir::new(&container_name).unwrap();
let docker_tmp_dir_path = docker_tmp_dir.path().to_owned();
Expand Down Expand Up @@ -224,18 +232,22 @@ pub async fn create_flist_handler(
state
.jobs_state
.lock()
.unwrap()
.expect("failed to lock state")
.insert(job.id.clone(), FlistState::Failed);
return;
}

state.jobs_state.lock().unwrap().insert(
job.id.clone(),
FlistState::Created(format!(
"flist {}:{}/{:?} is created successfully",
cfg.host, cfg.port, fl_path
)),
);
state
.jobs_state
.lock()
.expect("failed to lock state")
.insert(
job.id.clone(),
FlistState::Created(format!(
"flist {:?} is created successfully",
flist_download_url
)),
);
state.flists_progress.lock().unwrap().insert(fl_path, 100.0);
});

Expand Down Expand Up @@ -264,7 +276,7 @@ pub async fn get_flist_state_handler(
if !&state
.jobs_state
.lock()
.unwrap()
.expect("failed to lock state")
.contains_key(&flist_job_id.clone())
{
return Err(ResponseError::NotFound("flist doesn't exist".to_string()));
Expand All @@ -273,9 +285,9 @@ pub async fn get_flist_state_handler(
let res_state = state
.jobs_state
.lock()
.unwrap()
.expect("failed to lock state")
.get(&flist_job_id.clone())
.unwrap()
.expect("failed to get from state")
.to_owned();

match res_state {
Expand All @@ -286,7 +298,7 @@ pub async fn get_flist_state_handler(
state
.jobs_state
.lock()
.unwrap()
.expect("failed to lock state")
.remove(&flist_job_id.clone());

Ok(ResponseResult::FlistState(res_state))
Expand All @@ -295,10 +307,10 @@ pub async fn get_flist_state_handler(
state
.jobs_state
.lock()
.unwrap()
.expect("failed to lock state")
.remove(&flist_job_id.clone());

return Err(ResponseError::InternalServerError);
Err(ResponseError::InternalServerError)
}
}
}
Expand All @@ -314,13 +326,11 @@ pub async fn get_flist_state_handler(
)
)]
#[debug_handler]
pub async fn list_flists_handler(
Extension(cfg): Extension<config::Config>,
State(state): State<Arc<config::AppState>>,
) -> impl IntoResponse {
pub async fn list_flists_handler(State(state): State<Arc<config::AppState>>) -> impl IntoResponse {
let mut flists: HashMap<String, Vec<FileInfo>> = HashMap::new();

let rs = visit_dir_one_level(&cfg.flist_dir, &state).await;
let rs: Result<Vec<FileInfo>, std::io::Error> =
visit_dir_one_level(&state.config.flist_dir, &state).await;
match rs {
Ok(files) => {
for file in files {
Expand Down
18 changes: 11 additions & 7 deletions fl-server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod auth;
mod config;
mod db;
mod handlers;
mod response;
mod serve_flists;
Expand All @@ -26,7 +27,7 @@ use std::{
};
use tokio::{runtime::Builder, signal};
use tower::ServiceBuilder;
use tower_http::{add_extension::AddExtensionLayer, cors::CorsLayer};
use tower_http::cors::CorsLayer;
use tower_http::{cors::Any, trace::TraceLayer};

use utoipa::OpenApi;
Expand Down Expand Up @@ -71,9 +72,13 @@ async fn app() -> Result<()> {
.await
.context("failed to parse config file")?;

let db = Arc::new(db::VecDB::new(&config.users));

let app_state = Arc::new(config::AppState {
jobs_state: Mutex::new(HashMap::new()),
flists_progress: Mutex::new(HashMap::new()),
db,
config,
});

let cors = CorsLayer::new()
Expand All @@ -87,14 +92,14 @@ async fn app() -> Result<()> {
.route(
"/v1/api/fl",
post(handlers::create_flist_handler).layer(middleware::from_fn_with_state(
config.clone(),
app_state.clone(),
auth::authorize,
)),
)
.route(
"/v1/api/fl/:job_id",
get(handlers::get_flist_state_handler).layer(middleware::from_fn_with_state(
config.clone(),
app_state.clone(),
auth::authorize,
)),
)
Expand All @@ -115,19 +120,18 @@ async fn app() -> Result<()> {
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http()),
)
.layer(AddExtensionLayer::new(config.clone()))
.with_state(Arc::clone(&app_state))
.layer(cors);

let address = format!("{}:{}", config.host, config.port);
let address = format!("{}:{}", app_state.config.host, app_state.config.port);
let listener = tokio::net::TcpListener::bind(address)
.await
.context("failed to bind address")?;

log::info!(
"🚀 Server started successfully at {}:{}",
config.host,
config.port
app_state.config.host,
app_state.config.port
);

axum::serve(listener, app)
Expand Down
Loading

0 comments on commit 225fb8a

Please sign in to comment.