Skip to content

Commit

Permalink
parameterize getApiKeyFromNvpOrEnv, allowing different env variables …
Browse files Browse the repository at this point in the history
…for API keys
  • Loading branch information
ccreutzi committed May 27, 2024
1 parent ccd6961 commit 26b1272
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 22 deletions.
22 changes: 11 additions & 11 deletions +llms/+internal/getApiKeyFromNvpOrEnv.m
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
function key = getApiKeyFromNvpOrEnv(nvp)
function key = getApiKeyFromNvpOrEnv(nvp,envVarName)
% This function is undocumented and will change in a future release

%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable.
%
% This function takes a struct nvp containing name-value pairs and checks
% if it contains a field called "ApiKey". If the field is not found,
% the function attempts to retrieve the API key from an environment
% variable called "OPENAI_API_KEY". If both methods fail, the function
% throws an error.
% This function takes a struct nvp containing name-value pairs and checks if
% it contains a field called "ApiKey". If the field is not found, the
% function attempts to retrieve the API key from an environment variable
% whose name is given as the second argument. If both methods fail, the
% function throws an error.

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

if isfield(nvp, "ApiKey")
key = nvp.ApiKey;
else
if isenv("OPENAI_API_KEY")
key = getenv("OPENAI_API_KEY");
if isenv(envVarName)
key = getenv(envVarName);
else
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified"));
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName));
end
end
end
end
4 changes: 2 additions & 2 deletions +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter.";
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable {1} and not specified via ApiKey parameter.";
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
Expand All @@ -58,4 +58,4 @@
catalog("llms:invalidOptionsForAzureBackEnd") = "The parameter Model Name is not compatible with Azure.";
catalog("llms:apiReturnedError") = "OpenAI API Error: {1}";
catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}.";
end
end
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ jobs:
- name: Run tests and generate artifacts
env:
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
AZURE_OPENAI_DEPLOYMENT: ${{ secrets.AZURE_DEPLOYMENT }}
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_KEY }}
uses: matlab-actions/run-tests@v2
with:
test-results-junit: test-results/results.xml
Expand Down
2 changes: 1 addition & 1 deletion azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
this.StopSequences = nvp.StopSequences;
this.PresencePenalty = nvp.PresencePenalty;
this.FrequencyPenalty = nvp.FrequencyPenalty;
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp);
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"AZURE_OPENAI_API_KEY");
this.TimeOut = nvp.TimeOut;
end

Expand Down
2 changes: 1 addition & 1 deletion extractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

END_POINT = "https://api.openai.com/v1/embeddings";

key = llms.internal.getApiKeyFromNvpOrEnv(nvp);
key = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY");

parameters = struct("input",text,"model",nvp.ModelName);

Expand Down
2 changes: 1 addition & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@

this.PresencePenalty = nvp.PresencePenalty;
this.FrequencyPenalty = nvp.FrequencyPenalty;
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp);
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY");
this.TimeOut = nvp.TimeOut;
end

Expand Down
2 changes: 1 addition & 1 deletion openAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
end

this.ModelName = nvp.ModelName;
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp);
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY");
this.TimeOut = nvp.TimeOut;
end

Expand Down
10 changes: 5 additions & 5 deletions tests/tazureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
methods (TestClassSetup)
function saveEnvVar(testCase)
% Ensures key is not in environment variable for tests
openAIEnvVar = "OPENAI_API_KEY";
if isenv(openAIEnvVar)
key = getenv(openAIEnvVar);
unsetenv(openAIEnvVar);
testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key);
azureKeyVar = "AZURE_OPENAI_API_KEY";
if isenv(azureKeyVar)
key = getenv(azureKeyVar);
unsetenv(azureKeyVar);
testCase.addTeardown(@(x) setenv(azureKeyVar, x), key);
end
end
end
Expand Down

0 comments on commit 26b1272

Please sign in to comment.