Skip to content

Commit

Permalink
support workload identity auth for Azure
Browse files Browse the repository at this point in the history
  • Loading branch information
wcy-fdu committed Mar 18, 2024
1 parent 3f825ee commit de46128
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
18 changes: 18 additions & 0 deletions src/azure/storage/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,22 @@ pub struct Config {
///
/// This is part of use AAD(Azure Active Directory) authenticate on Azure VM
pub endpoint: Option<String>,
/// `azure_federated_token` value will be loaded from:
///
/// - this field if it's `is_some`
/// - env value: [`AZURE_FEDERATED_TOKEN`]
/// - profile config: `azure_federated_token_file`
pub azure_federated_token: Option<String>,
/// `azure_federated_token_file` value will be loaded from:
///
/// - this field if it's `is_some`
/// - env value: [`AZURE_FEDERATED_TOKEN_FILE`]
/// - profile config: `azure_federated_token_file`
pub azure_federated_token_file: Option<String>,
/// `azure_tenant_id_env_key` value will be loaded from:
///
/// - this field if it's `is_some`
/// - env value: [`AZURE_TENANT_ID_ENV_KEY`]
/// - profile config: `azure_tenant_id_env_key`
pub azure_tenant_id_env_key: Option<String>,
}
18 changes: 17 additions & 1 deletion src/azure/storage/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::sync::Mutex;

use anyhow::Result;

use super::config::Config;
use super::credential::Credential;
use super::imds_credential;
use super::{config::Config, workload_identity_credential};

/// Loader will load credential from different methods.
#[cfg_attr(test, derive(Debug))]
Expand Down Expand Up @@ -45,6 +45,10 @@ impl Loader {
return Ok(Some(cred));
}

if let Some(cred) = self.load_via_workload_identity().await? {
return Ok(Some(cred));
}

// try to load credential using AAD(Azure Active Directory) authenticate on Azure VM
// we may get an error if not running on Azure VM
// see https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal,http#using-the-rest-protocol
Expand Down Expand Up @@ -72,4 +76,16 @@ impl Loader {

Ok(cred)
}

async fn load_via_workload_identity(&self) -> Result<Option<Credential>> {
let workload_identity_token = workload_identity_credential::get_workload_identity_token(
"https://storage.azure.com/",
&self.config,
)
.await?;
match workload_identity_token {
Some(token) => Ok(Some(Credential::BearerToken(token.access_token))),
None => Ok(None),
}
}
}
2 changes: 2 additions & 0 deletions src/azure/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub use credential::Credential as AzureStorageCredential;

mod imds_credential;

mod workload_identity_credential;

mod loader;

pub use loader::Loader as AzureStorageLoader;
Expand Down
85 changes: 85 additions & 0 deletions src/azure/storage/workload_identity_credential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::str;

use http::HeaderValue;
use http::Method;
use http::Request;
use reqwest::Client;
use reqwest::Url;
use serde::Deserialize;
use std::fs;

use super::config::Config;

const MSI_API_VERSION: &str = "2019-08-01";
const MSI_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";

/// Gets an access token for the specified resource and configuration.
///
/// See <https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal,http#using-the-rest-protocol>
pub async fn get_workload_identity_token(
resource: &str,
config: &Config,
) -> anyhow::Result<Option<AccessToken>> {
let token = match (
&config.azure_federated_token,
&config.azure_federated_token_file,
) {
(Some(token), Some(_)) | (Some(token), None) => token.clone(),
(None, Some(token_file)) => {
let token = fs::read_to_string(token_file)?;
token
}
_ => return Ok(None),
};
let tenant_id = if let Some(tenant_id) = &config.azure_tenant_id_env_key {
tenant_id
} else {
return Ok(None);
};
let client_id = if let Some(client_id) = &config.client_id {
client_id
} else {
return Ok(None);
};

let endpoint = config.endpoint.as_deref().unwrap_or(MSI_ENDPOINT);

let mut query_items = vec![("api-version", MSI_API_VERSION), ("resource", resource)];
query_items.push(("token", &token));
query_items.push(("tenant_id", &tenant_id));
query_items.push(("client_id", &client_id));

let url = Url::parse_with_params(endpoint, &query_items)?;
let mut req = Request::builder()
.method(Method::GET)
.uri(url.to_string())
.body("")?;

req.headers_mut()
.insert("metadata", HeaderValue::from_static("true"));

if let Some(secret) = &config.msi_secret {
req.headers_mut()
.insert("x-identity-header", HeaderValue::from_str(secret)?);
};

let res = Client::new().execute(req.try_into()?).await?;
let rsp_status = res.status();
let rsp_body = res.text().await?;
;

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build_single_feature (services-azblob)

unnecessary trailing semicolon

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build (ubuntu-latest)

unnecessary trailing semicolon

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build_all_features (ubuntu-latest)

unnecessary trailing semicolon

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build (macos-11)

unnecessary trailing semicolon

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build (windows-latest)

unnecessary trailing semicolon

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build_all_features (macos-11)

unnecessary trailing semicolon

Check warning on line 69 in src/azure/storage/workload_identity_credential.rs

View workflow job for this annotation

GitHub Actions / build_all_features (windows-latest)

unnecessary trailing semicolon

if !rsp_status.is_success() {
return Err(anyhow::anyhow!("Failed to get token from working identity credential"));
}

let token: AccessToken = serde_json::from_str(&rsp_body)?;
Ok(Some(token))
}

#[derive(Debug, Clone, Deserialize)]
#[allow(unused)]
pub struct AccessToken {
pub access_token: String,
pub expires_on: String,

}

0 comments on commit de46128

Please sign in to comment.