Skip to content

Commit

Permalink
Add custom header support for Azure OpenAI
Browse files Browse the repository at this point in the history
- Adds configuration properties to allow custom header specification
- Implements mechanism to apply custom headers to Azure OpenAI requests
- Enhances flexibility for users to customize API interactions

These changes allow users to add necessary headers for authentication,
tracking, or other purposes when interacting with Azure OpenAI services.

Resolves spring-projects#1284
  • Loading branch information
sobychacko authored and Mark Pollack committed Sep 19, 2024
1 parent e644bf7 commit c67442d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package org.springframework.ai.autoconfigure.azure.openai;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
Expand All @@ -40,6 +42,7 @@
import com.azure.core.credential.KeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.ClientOptions;
import com.azure.core.util.Header;

/**
* @author Piotr Olaszewski
Expand All @@ -57,14 +60,19 @@ public class AzureOpenAiAutoConfiguration {
@Bean
@ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class })
public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) {

if (StringUtils.hasText(connectionProperties.getApiKey())) {

Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty");

Map<String, String> customHeaders = connectionProperties.getCustomHeaders();
List<Header> headers = customHeaders.entrySet()
.stream()
.map(entry -> new Header(entry.getKey(), entry.getValue()))
.collect(Collectors.toList());
ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers);
return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint())
.credential(new AzureKeyCredential(connectionProperties.getApiKey()))
.clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID))
.clientOptions(clientOptions)
.buildClient();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,8 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.azure.openai;

import java.util.HashMap;
import java.util.Map;

import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties(AzureOpenAiConnectionProperties.CONFIG_PREFIX)
Expand All @@ -40,6 +44,8 @@ public class AzureOpenAiConnectionProperties {
*/
private String endpoint;

private Map<String, String> customHeaders = new HashMap<>();

public String getEndpoint() {
return this.endpoint;
}
Expand All @@ -64,4 +70,12 @@ public void setOpenAiApiKey(String openAiApiKey) {
this.openAiApiKey = openAiApiKey;
}

public Map<String, String> getCustomHeaders() {
return customHeaders;
}

public void setCustomHeaders(Map<String, String> customHeaders) {
this.customHeaders = customHeaders;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package org.springframework.ai.autoconfigure.azure;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.implementation.OpenAIClientImpl;
import com.azure.core.http.*;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
Expand All @@ -34,15 +37,6 @@
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.ReflectionUtils;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.implementation.OpenAIClientImpl;
import com.azure.core.http.HttpHeader;
import com.azure.core.http.HttpHeaderName;
import com.azure.core.http.HttpMethod;
import com.azure.core.http.HttpPipeline;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.HttpResponse;
import reactor.core.publisher.Flux;

import java.lang.reflect.Field;
Expand Down

0 comments on commit c67442d

Please sign in to comment.