Skip to content

Commit

Permalink
Merge pull request #46 from matlab-deep-learning/Azure_APIVersion_tests
Browse files Browse the repository at this point in the history
Azure api version tests
  • Loading branch information
vpapanasta authored Jun 27, 2024
2 parents 05c861b + fbf1372 commit 5882528
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
6 changes: 3 additions & 3 deletions azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@
function this = azureChat(systemPrompt, nvp)
arguments
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Endpoint {mustBeNonzeroLengthTextScalar}
nvp.Deployment {mustBeNonzeroLengthTextScalar}
nvp.Endpoint (1,1) string {mustBeNonzeroLengthTextScalar}
nvp.Deployment (1,1) string {mustBeNonzeroLengthTextScalar}
nvp.APIKey {mustBeNonzeroLengthTextScalar}
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.APIVersion (1,1) {mustBeAPIVersion} = "2024-02-01"
nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-02-01"
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
Expand Down
18 changes: 18 additions & 0 deletions tests/tazureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InvalidGenerateInput = iGetInvalidGenerateInput;
InvalidValuesSetters = iGetInvalidValuesSetters;
StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}});
APIVersions = iGetAPIVersions();
end

methods(Test)
Expand Down Expand Up @@ -150,6 +151,19 @@ function keyNotFound(testCase)
unsetenv("AZURE_OPENAI_API_KEY");
testCase.verifyError(@()azureChat, "llms:keyMustBeSpecified");
end

function canUseAPIVersions(testCase, APIVersions)
% Test that we can use different APIVersion value to call
% azureChat.generate

testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT.");
chat = azureChat("APIVersion", APIVersions);

response = testCase.verifyWarningFree(@() generate(chat,"How similar is the DNA of a cat and a tiger?"));
testCase.verifyClass(response,'string');
testCase.verifyGreaterThan(strlength(response),0);

end
end
end

Expand Down Expand Up @@ -446,3 +460,7 @@ function keyNotFound(testCase)
"Input",{{ validMessages "ToolChoice" ["validfunction", "validfunction"] }},...
"Error","MATLAB:validators:mustBeTextScalar"));
end

function apiVersions = iGetAPIVersions()
apiVersions = cellstr(llms.azure.apiVersions);
end

0 comments on commit 5882528

Please sign in to comment.