Skip to content

Commit

Permalink
get basic Azure connection working
Browse files Browse the repository at this point in the history
  • Loading branch information
ccreutzi committed May 27, 2024
1 parent 26b1272 commit 8bd236b
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 38 deletions.
16 changes: 10 additions & 6 deletions +llms/+internal/callAzureChatAPI.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [text, message, response] = callAzureChatAPI(resourceName, deploymentID, messages, functions, nvp)
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp)
%callOpenAIChatAPI Calls the openAI chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
Expand Down Expand Up @@ -52,7 +52,7 @@
% Copyright 2023-2024 The MathWorks, Inc.

arguments
resourceName
endpoint
deploymentID
messages
functions
Expand All @@ -72,11 +72,11 @@
nvp.StreamFun = []
end

END_POINT = "https://" + resourceName + ".openai.azure.com/openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;
URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, URL, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
Expand Down Expand Up @@ -108,9 +108,13 @@

parameters.stream = ~isempty(nvp.StreamFun);

parameters.tools = functions;
if ~isempty(functions)
parameters.tools = functions;
end

parameters.tool_choice = nvp.ToolChoice;
if ~isempty(nvp.ToolChoice)
parameters.tool_choice = nvp.ToolChoice;
end

if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
Expand Down
5 changes: 3 additions & 2 deletions +llms/+internal/sendRequest.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function [response, streamedText] = sendRequest(parameters, token, endpoint, timeout, streamFun)
%sendRequest Sends a request to an ENDPOINT using PARAMETERS and
% api key TOKEN. TIMEOUT is the nubmer of seconds to wait for initial
% api key TOKEN. TIMEOUT is the number of seconds to wait for initial
% server connection. STREAMFUN is an optional callback function.

% Copyright 2023 The MathWorks, Inc.
Expand All @@ -16,7 +16,8 @@
% Define the headers for the API request

headers = [matlab.net.http.HeaderField('Content-Type', 'application/json')...
matlab.net.http.HeaderField('Authorization', "Bearer " + token)];
matlab.net.http.HeaderField('Authorization', "Bearer " + token)...
matlab.net.http.HeaderField('api-key',token)];

% Define the request message
request = matlab.net.http.RequestMessage('post',headers,parameters);
Expand Down
19 changes: 9 additions & 10 deletions azureChat.m
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
classdef(Sealed) azureChat < llms.internal.textGenerator
%azureChat Chat completion API from Azure.
%
% CHAT = azureChat(resourceName, deploymentID) creates an azureChat object with the
% resource name and deployment ID path parameters required by Azure to establish the connection.
% CHAT = azureChat(endpoint, deploymentID) creates an azureChat object with the
% endpoint and deployment ID path parameters required by Azure to establish the connection.
%
% CHAT = azureChat(systemPrompt) creates an azureChatobject with the
% specified system prompt.
Expand Down Expand Up @@ -74,16 +74,15 @@
% Copyright 2023-2024 The MathWorks, Inc.

properties(SetAccess=private)
ResourceName
DeploymentID
APIVersion
Endpoint (1,1) string
DeploymentID (1,1) string
APIVersion (1,1) string
end


methods
function this = azureChat(resourceName, deploymentID, systemPrompt, nvp)
function this = azureChat(endpoint, deploymentID, systemPrompt, nvp)
arguments
resourceName {mustBeTextScalar}
endpoint {mustBeTextScalar}
deploymentID {mustBeTextScalar}
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
Expand Down Expand Up @@ -123,7 +122,7 @@
end
end

this.ResourceName = resourceName;
this.Endpoint = endpoint;
this.DeploymentID = deploymentID;
this.APIVersion = nvp.APIVersion;
this.ResponseFormat = nvp.ResponseFormat;
Expand Down Expand Up @@ -181,7 +180,7 @@
end

toolChoice = convertToolChoice(this, nvp.ToolChoice);
[text, message, response] = llms.internal.callAzureChatAPI(this.ResourceName, ...
[text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ...
this.DeploymentID, messagesStruct, this.FunctionsStruct, ...
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ...
TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,...
Expand Down
39 changes: 19 additions & 20 deletions tests/tazureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,6 @@

% Copyright 2024 The MathWorks, Inc.

methods (TestClassSetup)
function saveEnvVar(testCase)
% Ensures key is not in environment variable for tests
azureKeyVar = "AZURE_OPENAI_API_KEY";
if isenv(azureKeyVar)
key = getenv(azureKeyVar);
unsetenv(azureKeyVar);
testCase.addTeardown(@(x) setenv(azureKeyVar, x), key);
end
end
end

properties(TestParameter)
InvalidConstructorInput = iGetInvalidConstructorInput;
InvalidGenerateInput = iGetInvalidGenerateInput;
Expand All @@ -24,11 +12,14 @@ function saveEnvVar(testCase)
methods(Test)
% Test methods
function keyNotFound(testCase)
testCase.verifyError(@()azureChat("My_resource", "Deployment"), "llms:keyMustBeSpecified");
import matlab.unittest.fixtures.EnvironmentVariableFixture
testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_API_KEY","dummy"));
unsetenv("AZURE_OPENAI_API_KEY");
testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT")), "llms:keyMustBeSpecified");
end

function constructChatWithAllNVP(testCase)
resourceName = "resource";
endpoint = getenv("AZURE_OPENAI_ENDPOINT");
deploymentID = "hello";
functions = openAIFunction("funName");
temperature = 0;
Expand All @@ -39,7 +30,7 @@ function constructChatWithAllNVP(testCase)
frequenceP = 2;
systemPrompt = "This is a system prompt";
timeout = 3;
chat = azureChat(resourceName, deploymentID, systemPrompt, Tools=functions, ...
chat = azureChat(endpoint, deploymentID, systemPrompt, Tools=functions, ...
Temperature=temperature, TopProbabilityMass=topP, StopSequences=stop, ApiKey=apiKey,...
FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout);
testCase.verifyEqual(chat.Temperature, temperature);
Expand All @@ -49,28 +40,36 @@ function constructChatWithAllNVP(testCase)
testCase.verifyEqual(chat.PresencePenalty, presenceP);
end

function doGenerate(testCase)
testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"));
response = testCase.verifyWarningFree(@() generate(chat,"hi"));
testCase.verifyClass(response,'string');
testCase.verifyGreaterThan(strlength(response),0);
end

function verySmallTimeOutErrors(testCase)
chat = azureChat("My_resource", "Deployment", TimeOut=0.0001, ApiKey="false-key");
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), TimeOut=0.0001, ApiKey="false-key");
testCase.verifyError(@()generate(chat, "hi"), "MATLAB:webservices:Timeout")
end

function errorsWhenPassingToolChoiceWithEmptyTools(testCase)
chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key");
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ApiKey="this-is-not-a-real-key");
testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall");
end

function invalidInputsConstructor(testCase, InvalidConstructorInput)
testCase.verifyError(@()azureChat("My_resource", "Deployment", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
end

function invalidInputsGenerate(testCase, InvalidGenerateInput)
f = openAIFunction("validfunction");
chat = azureChat("My_resource", "Deployment", Tools=f, ApiKey="this-is-not-a-real-key");
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), Tools=f, ApiKey="this-is-not-a-real-key");
testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error);
end

function invalidSetters(testCase, InvalidValuesSetters)
chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key");
chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ApiKey="this-is-not-a-real-key");
function assignValueToProperty(property, value)
chat.(property) = value;
end
Expand Down

0 comments on commit 8bd236b

Please sign in to comment.