Skip to content

Commit

Permalink
Update: support query params in equest_get_endpoint; Add: Azure custo…
Browse files Browse the repository at this point in the history
…m speech model query filters
  • Loading branch information
zhongdongy committed Mar 26, 2023
1 parent dac9857 commit d62dfc8
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 67 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion rust-ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ log = "0.4.17"
log4rs = "1.2.0"
serde_with = "2.3.1"
isolang = { version = "2.2.0", features = ["serde"] }
lazy_static = "1.4.0"
lazy_static = "1.4.0"
urlencoding = "2.1.2"
223 changes: 216 additions & 7 deletions rust-ai/src/azure/apis/speech.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use reqwest::header::HeaderMap;
use crate::azure::{
endpoint::{request_get_endpoint, request_post_endpoint_ssml, SpeechServiceEndpoint},
types::{
common::{MicrosoftOutputFormat, ResponseExpectation, ResponseType, ServiceHealthResponse},
common::{MicrosoftOutputFormat, ResponseExpectation, ResponseType},
speech::ServiceHealthResponse,
tts::Voice,
SSML,
},
Expand Down Expand Up @@ -90,7 +91,7 @@ impl Speech {
///
/// Source: <https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/rest-text-to-speech>
pub async fn voice_list() -> Result<Vec<Voice>, Box<dyn std::error::Error>> {
let text = request_get_endpoint(&SpeechServiceEndpoint::Get_List_of_Voices).await?;
let text = request_get_endpoint(&SpeechServiceEndpoint::Get_List_of_Voices, None).await?;
match serde_json::from_str::<Vec<Voice>>(&text) {
Ok(voices) => Ok(voices),
Err(e) => {
Expand All @@ -100,13 +101,16 @@ impl Speech {
}
}

/// Health status provides insights about the overall health of the service
/// Health status provides insights about the overall health of the service
/// and sub-components.
///
///
/// V3.1 API supported only.
pub async fn health_check() -> Result<ServiceHealthResponse, Box<dyn std::error::Error>> {
let text =
request_get_endpoint(&SpeechServiceEndpoint::Get_Speech_to_Text_Health_Status_v3_1).await?;
let text = request_get_endpoint(
&SpeechServiceEndpoint::Get_Speech_to_Text_Health_Status_v3_1,
None,
)
.await?;

match serde_json::from_str::<ServiceHealthResponse>(&text) {
Ok(status) => Ok(status),
Expand All @@ -127,7 +131,7 @@ impl Speech {
let mut headers = HeaderMap::new();
headers.insert("X-Microsoft-OutputFormat", self.output_format.into());
match request_post_endpoint_ssml(
&SpeechServiceEndpoint::Convert_Text_to_Speech_v1,
&SpeechServiceEndpoint::Post_Text_to_Speech_v1,
self.ssml,
ResponseExpectation::Bytes,
headers,
Expand All @@ -145,3 +149,208 @@ impl Speech {
Ok(self.text_to_speech().await?)
}
}

/// TODO: remove `allow(dead_code)` when `models()` implemented.
#[allow(dead_code)]
pub struct SpeechModel {
model_id: Option<String>,

skip: Option<usize>,
top: Option<usize>,
filter: Option<FilterOperator>,
}

impl Default for SpeechModel {
fn default() -> Self {
Self {
model_id: None,
skip: None,
top: None,
filter: None,
}
}
}

impl SpeechModel {
pub fn skip(self, skip: usize) -> Self {
Self {
skip: Some(skip),
..self
}
}
pub fn top(self, top: usize) -> Self {
Self {
top: Some(top),
..self
}
}
pub fn filter(self, filter: FilterOperator) -> Self {
Self {
filter: Some(filter),
..self
}
}

pub fn id(self, id: String) -> Self {
Self {
model_id: Some(id),
..self
}
}

/// [Custom Speech]
/// Gets the list of custom models for the authenticated subscription.
///
/// TODO: implement this.
pub async fn models(self) -> Result<(), Box<dyn std::error::Error>> {
todo!("Test with custom models");
// let mut params = HashMap::<String, String>::new();

// if let Some(skip) = self.skip {
// params.insert("skip".into(), skip.to_string());
// }
// if let Some(top) = self.top {
// params.insert("top".into(), top.to_string());
// }
// if let Some(filter) = self.filter {
// params.insert("filter".into(), filter.to_string());
// }

// let text = request_get_endpoint(
// &SpeechServiceEndpoint::Get_List_of_Models_v3_1,
// Some(params),
// )
// .await?;

// println!("{}", text);

// match serde_json::from_str::<ServiceHealthResponse>(&text) {
// Ok(status) => Ok(status),
// Err(e) => {
// warn!(target: "azure", "Error parsing response: {:?}", e);
// Err("Unable to parse health status of speech cognitive services, check log for details".into())
// }
// }

// Ok(())
}
}

#[derive(Clone, Debug)]
pub enum FilterField {
DisplayName,
Description,
CreatedDateTime,
LastActionDateTime,
Status,
Locale,
}

impl Into<String> for FilterField {
fn into(self) -> String {
(match self {
Self::DisplayName => "displayName",
Self::Description => "description",
Self::CreatedDateTime => "createdDateTime",
Self::LastActionDateTime => "lastActionDateTime",
Self::Status => "status",
Self::Locale => "locale",
})
.into()
}
}

#[derive(Clone, Debug)]
pub enum FilterOperator {
Eq(FilterField, String),
Ne(FilterField, String),
Gt(FilterField, String),
Ge(FilterField, String),
Lt(FilterField, String),
Le(FilterField, String),
And(Box<FilterOperator>, Box<FilterOperator>),
Or(Box<FilterOperator>, Box<FilterOperator>),
Not(Box<FilterOperator>),
}
impl FilterOperator {
pub fn and(self, op: FilterOperator) -> Self {
Self::And(Box::new(self), Box::new(op))
}
pub fn or(self, op: FilterOperator) -> Self {
Self::Or(Box::new(self), Box::new(op))
}
pub fn not(self) -> Self {
Self::Not(Box::new(self))
}

fn str(self, not: bool) -> String {
match self {
Self::And(a, b) => {
if not {
format!("{} or {}", a.str(true), b.str(true))
} else {
format!("{} and {}", a.str(false), b.str(false))
}
}

Self::Or(a, b) => {
if not {
format!("{} and {}", a.str(true), b.str(true))
} else {
format!("{} or {}", a.str(false), b.str(false))
}
}

Self::Not(a) => format!("{}", a.str(!not)),

Self::Eq(field, value) => format!(
"{} {} '{}'",
Into::<String>::into(field),
if not { "ne" } else { "eq" },
Into::<String>::into(value)
),
Self::Ne(field, value) => format!(
"{} {} '{}'",
Into::<String>::into(field),
if not { "eq" } else { "ne" },
Into::<String>::into(value)
),
Self::Gt(field, value) => format!(
"{} {} '{}'",
Into::<String>::into(field),
if not { "le" } else { "gt" },
Into::<String>::into(value)
),
Self::Ge(field, value) => format!(
"{} {} '{}'",
Into::<String>::into(field),
if not { "lt" } else { "ge" },
Into::<String>::into(value)
),
Self::Lt(field, value) => format!(
"{} {} '{}'",
Into::<String>::into(field),
if not { "ge" } else { "lt" },
Into::<String>::into(value)
),
Self::Le(field, value) => format!(
"{} {} '{}'",
Into::<String>::into(field),
if not { "gt" } else { "le" },
Into::<String>::into(value)
),
}
}
}

impl Into<String> for FilterOperator {
fn into(self) -> String {
self.str(false)
}
}

impl ToString for FilterOperator {
fn to_string(&self) -> String {
Into::<String>::into(self.clone())
}
}
29 changes: 24 additions & 5 deletions rust-ai/src/azure/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use log::{debug, error};
use reqwest::{header::HeaderMap, Client};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use urlencoding::encode;

use crate::utils::config::Config;

Expand All @@ -13,8 +15,9 @@ use super::{
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum SpeechServiceEndpoint {
Get_List_of_Voices,
Convert_Text_to_Speech_v1,
Get_Speech_to_Text_Health_Status_v3_1
Post_Text_to_Speech_v1,
Get_Speech_to_Text_Health_Status_v3_1,
Get_List_of_Models_v3_1,
}

impl SpeechServiceEndpoint {
Expand All @@ -25,27 +28,43 @@ impl SpeechServiceEndpoint {
region
),

Self::Convert_Text_to_Speech_v1 => format!(
Self::Post_Text_to_Speech_v1 => format!(
"https://{}.tts.speech.microsoft.com/cognitiveservices/v1",
region
),


Self::Get_Speech_to_Text_Health_Status_v3_1 => format!(
"https://{}.cognitiveservices.azure.com/speechtotext/v3.1/healthstatus",
region
),

Self::Get_List_of_Models_v3_1 => format!(
"https://{}.cognitiveservices.azure.com/speechtotext/v3.1/models",
region
),
}
}
}

pub async fn request_get_endpoint(
endpoint: &SpeechServiceEndpoint,
params: Option<HashMap<String, String>>,
) -> Result<String, Box<dyn std::error::Error>> {
let config = Config::load().unwrap();
let region = config.azure.speech.region;

let url = endpoint.build(&region);
let mut url = endpoint.build(&region);

if let Some(params) = params {
let combined = params
.iter()
.map(|(k, v)| format!("{}={}", encode(k), encode(v)))
.collect::<Vec<String>>()
.join("&");
url.push_str(&format!("?{}", combined));
}

println!("URL={}", url);

let client = Client::new();
let mut req = client.get(url);
Expand Down
6 changes: 3 additions & 3 deletions rust-ai/src/azure/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ pub mod apis;
/// Azure types definition
pub mod types;

pub use apis::speech::Speech;
pub use types::MicrosoftOutputFormat;
pub use apis::speech::{FilterField, FilterOperator, Speech, SpeechModel};
pub use types::ssml;
pub use types::Gender;
pub use types::Locale;
pub use types::MicrosoftOutputFormat;
pub use types::VoiceName;
pub use types::SSML;
pub use types::ssml;
Loading

0 comments on commit d62dfc8

Please sign in to comment.