From 8bd236b5f48b00df79edf1b3bf7bfa17db08e095 Mon Sep 17 00:00:00 2001 From: Christopher Creutzig Date: Mon, 27 May 2024 08:22:11 +0100 Subject: [PATCH] get basic Azure connection working --- +llms/+internal/callAzureChatAPI.m | 16 +++++++----- +llms/+internal/sendRequest.m | 5 ++-- azureChat.m | 19 +++++++-------- tests/tazureChat.m | 39 +++++++++++++++--------------- 4 files changed, 41 insertions(+), 38 deletions(-) diff --git a/+llms/+internal/callAzureChatAPI.m b/+llms/+internal/callAzureChatAPI.m index 856615a..e2aef9c 100644 --- a/+llms/+internal/callAzureChatAPI.m +++ b/+llms/+internal/callAzureChatAPI.m @@ -1,4 +1,4 @@ -function [text, message, response] = callAzureChatAPI(resourceName, deploymentID, messages, functions, nvp) +function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp) %callOpenAIChatAPI Calls the openAI chat completions API. % % MESSAGES and FUNCTIONS should be structs matching the json format @@ -52,7 +52,7 @@ % Copyright 2023-2024 The MathWorks, Inc. arguments - resourceName + endpoint deploymentID messages functions @@ -72,11 +72,11 @@ nvp.StreamFun = [] end -END_POINT = "https://" + resourceName + ".openai.azure.com/openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion; +URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion; parameters = buildParametersCall(messages, functions, nvp); -[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun); +[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, URL, nvp.TimeOut, nvp.StreamFun); % If call errors, "choices" will not be part of response.Body.Data, instead % we get response.Body.Data.error @@ -108,9 +108,13 @@ parameters.stream = ~isempty(nvp.StreamFun); -parameters.tools = functions; +if ~isempty(functions) + parameters.tools = functions; +end -parameters.tool_choice = nvp.ToolChoice; +if ~isempty(nvp.ToolChoice) + parameters.tool_choice = nvp.ToolChoice; +end if ~isempty(nvp.Seed) parameters.seed = nvp.Seed; diff --git a/+llms/+internal/sendRequest.m b/+llms/+internal/sendRequest.m index 631c2dc..5832bbb 100644 --- a/+llms/+internal/sendRequest.m +++ b/+llms/+internal/sendRequest.m @@ -1,6 +1,6 @@ function [response, streamedText] = sendRequest(parameters, token, endpoint, timeout, streamFun) %sendRequest Sends a request to an ENDPOINT using PARAMETERS and -% api key TOKEN. TIMEOUT is the nubmer of seconds to wait for initial +% api key TOKEN. TIMEOUT is the number of seconds to wait for initial % server connection. STREAMFUN is an optional callback function. % Copyright 2023 The MathWorks, Inc. @@ -16,7 +16,8 @@ % Define the headers for the API request headers = [matlab.net.http.HeaderField('Content-Type', 'application/json')... - matlab.net.http.HeaderField('Authorization', "Bearer " + token)]; + matlab.net.http.HeaderField('Authorization', "Bearer " + token)... + matlab.net.http.HeaderField('api-key',token)]; % Define the request message request = matlab.net.http.RequestMessage('post',headers,parameters); diff --git a/azureChat.m b/azureChat.m index e0374ec..8c6acc1 100644 --- a/azureChat.m +++ b/azureChat.m @@ -1,8 +1,8 @@ classdef(Sealed) azureChat < llms.internal.textGenerator %azureChat Chat completion API from Azure. % -% CHAT = azureChat(resourceName, deploymentID) creates an azureChat object with the -% resource name and deployment ID path parameters required by Azure to establish the connection. +% CHAT = azureChat(endpoint, deploymentID) creates an azureChat object with the +% endpoint and deployment ID path parameters required by Azure to establish the connection. % % CHAT = azureChat(systemPrompt) creates an azureChatobject with the % specified system prompt. @@ -74,16 +74,15 @@ % Copyright 2023-2024 The MathWorks, Inc. properties(SetAccess=private) - ResourceName - DeploymentID - APIVersion + Endpoint (1,1) string + DeploymentID (1,1) string + APIVersion (1,1) string end - methods - function this = azureChat(resourceName, deploymentID, systemPrompt, nvp) + function this = azureChat(endpoint, deploymentID, systemPrompt, nvp) arguments - resourceName {mustBeTextScalar} + endpoint {mustBeTextScalar} deploymentID {mustBeTextScalar} systemPrompt {llms.utils.mustBeTextOrEmpty} = [] nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty @@ -123,7 +122,7 @@ end end - this.ResourceName = resourceName; + this.Endpoint = endpoint; this.DeploymentID = deploymentID; this.APIVersion = nvp.APIVersion; this.ResponseFormat = nvp.ResponseFormat; @@ -181,7 +180,7 @@ end toolChoice = convertToolChoice(this, nvp.ToolChoice); - [text, message, response] = llms.internal.callAzureChatAPI(this.ResourceName, ... + [text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ... this.DeploymentID, messagesStruct, this.FunctionsStruct, ... ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ... TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,... diff --git a/tests/tazureChat.m b/tests/tazureChat.m index ea7d050..f597ec6 100644 --- a/tests/tazureChat.m +++ b/tests/tazureChat.m @@ -3,18 +3,6 @@ % Copyright 2024 The MathWorks, Inc. - methods (TestClassSetup) - function saveEnvVar(testCase) - % Ensures key is not in environment variable for tests - azureKeyVar = "AZURE_OPENAI_API_KEY"; - if isenv(azureKeyVar) - key = getenv(azureKeyVar); - unsetenv(azureKeyVar); - testCase.addTeardown(@(x) setenv(azureKeyVar, x), key); - end - end - end - properties(TestParameter) InvalidConstructorInput = iGetInvalidConstructorInput; InvalidGenerateInput = iGetInvalidGenerateInput; @@ -24,11 +12,14 @@ function saveEnvVar(testCase) methods(Test) % Test methods function keyNotFound(testCase) - testCase.verifyError(@()azureChat("My_resource", "Deployment"), "llms:keyMustBeSpecified"); + import matlab.unittest.fixtures.EnvironmentVariableFixture + testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_API_KEY","dummy")); + unsetenv("AZURE_OPENAI_API_KEY"); + testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT")), "llms:keyMustBeSpecified"); end function constructChatWithAllNVP(testCase) - resourceName = "resource"; + endpoint = getenv("AZURE_OPENAI_ENDPOINT"); deploymentID = "hello"; functions = openAIFunction("funName"); temperature = 0; @@ -39,7 +30,7 @@ function constructChatWithAllNVP(testCase) frequenceP = 2; systemPrompt = "This is a system prompt"; timeout = 3; - chat = azureChat(resourceName, deploymentID, systemPrompt, Tools=functions, ... + chat = azureChat(endpoint, deploymentID, systemPrompt, Tools=functions, ... Temperature=temperature, TopProbabilityMass=topP, StopSequences=stop, ApiKey=apiKey,... FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout); testCase.verifyEqual(chat.Temperature, temperature); @@ -49,28 +40,36 @@ function constructChatWithAllNVP(testCase) testCase.verifyEqual(chat.PresencePenalty, presenceP); end + function doGenerate(testCase) + testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); + chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT")); + response = testCase.verifyWarningFree(@() generate(chat,"hi")); + testCase.verifyClass(response,'string'); + testCase.verifyGreaterThan(strlength(response),0); + end + function verySmallTimeOutErrors(testCase) - chat = azureChat("My_resource", "Deployment", TimeOut=0.0001, ApiKey="false-key"); + chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), TimeOut=0.0001, ApiKey="false-key"); testCase.verifyError(@()generate(chat, "hi"), "MATLAB:webservices:Timeout") end function errorsWhenPassingToolChoiceWithEmptyTools(testCase) - chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key"); + chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ApiKey="this-is-not-a-real-key"); testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall"); end function invalidInputsConstructor(testCase, InvalidConstructorInput) - testCase.verifyError(@()azureChat("My_resource", "Deployment", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); + testCase.verifyError(@()azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); end function invalidInputsGenerate(testCase, InvalidGenerateInput) f = openAIFunction("validfunction"); - chat = azureChat("My_resource", "Deployment", Tools=f, ApiKey="this-is-not-a-real-key"); + chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), Tools=f, ApiKey="this-is-not-a-real-key"); testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error); end function invalidSetters(testCase, InvalidValuesSetters) - chat = azureChat("My_resource", "Deployment", ApiKey="this-is-not-a-real-key"); + chat = azureChat(getenv("AZURE_OPENAI_ENDPOINT"), getenv("AZURE_OPENAI_DEPLOYMENT"), ApiKey="this-is-not-a-real-key"); function assignValueToProperty(property, value) chat.(property) = value; end