Skip to content

Commit

Permalink
Add Azure Search user-agent, test OpenAI client header
Browse files Browse the repository at this point in the history
* Adding test to verify the user-agent header in the Azure OpenAi chat client
* Adding the user-agent header to the search client in Azure vector store
  • Loading branch information
sobychacko authored and Mark Pollack committed Sep 19, 2024
1 parent 035036c commit e644bf7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.autoconfigure.vectorstore.azure;

import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.ClientOptions;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;

Expand Down Expand Up @@ -47,11 +48,16 @@
@ConditionalOnProperty(prefix = "spring.ai.vectorstore.azure", value = { "url", "api-key", "index-name" })
public class AzureVectorStoreAutoConfiguration {

private final static String APPLICATION_ID = "spring-ai";

@Bean
@ConditionalOnMissingBean
public SearchIndexClient searchIndexClient(AzureVectorStoreProperties properties) {
ClientOptions clientOptions = new ClientOptions();
clientOptions.setApplicationId(APPLICATION_ID);
return new SearchIndexClientBuilder().endpoint(properties.getUrl())
.credential(new AzureKeyCredential(properties.getApiKey()))
.clientOptions(clientOptions)
.buildClient();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,8 +33,20 @@
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.ReflectionUtils;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.implementation.OpenAIClientImpl;
import com.azure.core.http.HttpHeader;
import com.azure.core.http.HttpHeaderName;
import com.azure.core.http.HttpMethod;
import com.azure.core.http.HttpPipeline;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import reactor.core.publisher.Flux;

import java.lang.reflect.Field;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -44,11 +56,12 @@
/**
* @author Christian Tzolov
* @author Piotr Olaszewski
* @author Soby Chacko
* @since 0.8.0
*/
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
public class AzureOpenAiAutoConfigurationIT {
class AzureOpenAiAutoConfigurationIT {

private static String CHAT_MODEL_NAME = "gpt-4o";

Expand Down Expand Up @@ -79,7 +92,7 @@ public class AzureOpenAiAutoConfigurationIT {
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");

@Test
public void chatCompletion() {
void chatCompletion() {
contextRunner.run(context -> {
AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage, systemMessage)));
Expand All @@ -88,7 +101,26 @@ public void chatCompletion() {
}

@Test
public void chatCompletionStreaming() {
void httpRequestContainsUserAgentHeader() {
contextRunner.run(context -> {
OpenAIClient openAIClient = context.getBean(OpenAIClient.class);
Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient");
assertThat(serviceClientField).isNotNull();
ReflectionUtils.makeAccessible(serviceClientField);
OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient);
assertThat(oaci).isNotNull();
HttpPipeline httpPipeline = oaci.getHttpPipeline();
HttpResponse httpResponse = httpPipeline
.send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL()))
.block();
assertThat(httpResponse).isNotNull();
HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT);
assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue();
});
}

@Test
void chatCompletionStreaming() {
contextRunner.run(context -> {

AzureOpenAiChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);
Expand Down Expand Up @@ -140,7 +172,7 @@ void transcribe() {
}

@Test
public void chatActivation() {
void chatActivation() {

// Disable the chat auto-configuration.
contextRunner.withPropertyValues("spring.ai.azure.openai.chat.enabled=false").run(context -> {
Expand All @@ -159,7 +191,7 @@ public void chatActivation() {
}

@Test
public void embeddingActivation() {
void embeddingActivation() {

// Disable the embedding auto-configuration.
contextRunner.withPropertyValues("spring.ai.azure.openai.embedding.enabled=false").run(context -> {
Expand All @@ -178,7 +210,7 @@ public void embeddingActivation() {
}

@Test
public void audioTranscriptionActivation() {
void audioTranscriptionActivation() {

// Disable the transcription auto-configuration.
contextRunner.withPropertyValues("spring.ai.azure.openai.audio.transcription.enabled=false").run(context -> {
Expand Down

0 comments on commit e644bf7

Please sign in to comment.