Skip to content

Commit

Permalink
Bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
debymf committed Oct 21, 2023
1 parent da1c112 commit d6fbeef
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 32 deletions.
2 changes: 1 addition & 1 deletion +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPEN_API_KEY and not specified via ApiKey parameter.";
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter.";
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified.";
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
Expand Down
19 changes: 7 additions & 12 deletions openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
%
% SystemPrompt - System prompt.
%
% AvailableModels - List of available models.
%
% FunctionNames - Names of the functions that the model can
% request calls.

Expand Down Expand Up @@ -93,25 +91,18 @@
ApiKey
end

properties(Constant)
%AVAILABLEMODELS List of available models.
AvailableModels = ["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",...
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",...
"gpt-3.5-turbo-16k-0613"]
end

methods
function this = openAIChat(systemPrompt, nvp)
arguments
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Functions (1,:) {mustBeA(nvp.Functions, "openAIFunction")} = openAIFunction.empty
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",...
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ...
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",...
"gpt-3.5-turbo-16k-0613"])} = "gpt-3.5-turbo"
nvp.Temperature (1,1) {mustBeValidTemperature} = 1
nvp.TopProbabilityMass (1,1) {mustBeValidTopP} = 1
nvp.StopSequences (1,:) {mustBeValidStop} = {}
nvp.ApiKey (1,1) {mustBeNonzeroLengthText}
nvp.ApiKey {mustBeNonzeroLengthTextScalar}
nvp.PresencePenalty (1,1) {mustBeValidPenalty} = 0
nvp.FrequencyPenalty (1,1) {mustBeValidPenalty} = 0
end
Expand Down Expand Up @@ -249,6 +240,10 @@ function mustBeValidFunctionCall(this, functionCall)
end
end

function mustBeNonzeroLengthTextScalar(content)
mustBeNonzeroLengthText(content)
mustBeTextScalar(content)
end

function [functionsStruct, functionNames] = functionAsStruct(functions)
numFunctions = numel(functions);
Expand All @@ -268,7 +263,7 @@ function mustBeValidMsgs(value)
end
else
try
mustBeNonzeroLengthText(value);
mustBeNonzeroLengthTextScalar(value);
catch ME
error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt"));
end
Expand Down
23 changes: 14 additions & 9 deletions openAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@

arguments
this (1,1) openAIMessages
name (1,1) {mustBeNonzeroLengthText}
content (1,1) {mustBeNonzeroLengthText}
name {mustBeNonzeroLengthTextScalar}
content {mustBeNonzeroLengthTextScalar}
end

newMessage = struct("role", "system", "name", name, "content", content);
newMessage = struct("role", "system", "name", string(name), "content", string(content));
this.Messages{end+1} = newMessage;
end

Expand All @@ -62,10 +62,10 @@

arguments
this (1,1) openAIMessages
content (1,1) {mustBeNonzeroLengthText}
content {mustBeNonzeroLengthTextScalar}
end

newMessage = struct("role", "user", "content", content);
newMessage = struct("role", "user", "content", string(content));
this.Messages{end+1} = newMessage;
end

Expand All @@ -86,11 +86,11 @@

arguments
this (1,1) openAIMessages
name (1,1) {mustBeNonzeroLengthText}
content (1,1) {mustBeNonzeroLengthText}
name {mustBeNonzeroLengthTextScalar}
content {mustBeNonzeroLengthTextScalar}
end

newMessage = struct("role", "function", "name", name, "content", content);
newMessage = struct("role", "function", "name", string(name), "content", string(content));
this.Messages{end+1} = newMessage;
end

Expand Down Expand Up @@ -133,7 +133,7 @@
if isfield(messageStruct, "function_call")
funCall = messageStruct.function_call;
validateAssistantWithFunctionCall(funCall)
this = addAssistantMessage(this,funCall.name, funCall.arguments);
this = addAssistantMessage(this, funCall.name, funCall.arguments);
else
% Simple assistant response
validateRegularAssistant(messageStruct.content);
Expand Down Expand Up @@ -197,6 +197,11 @@
end
end

function mustBeNonzeroLengthTextScalar(content)
mustBeNonzeroLengthText(content)
mustBeTextScalar(content)
end

function validateRegularAssistant(content)
try
mustBeNonzeroLengthText(content)
Expand Down
7 changes: 4 additions & 3 deletions tests/topenAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ function saveEnvVar(testCase)
function generateAcceptsSingleStringAsInput(testCase)
chat = openAIChat(ApiKey="this-is-not-a-real-key");
testCase.verifyWarningFree(@()generate(chat,"This is okay"));
chat = openAIChat(ApiKey='this-is-not-a-real-key');
testCase.verifyWarningFree(@()generate(chat,"This is okay"));
end

function generateAcceptsMessagesAsInput(testCase)
Expand Down Expand Up @@ -307,7 +309,7 @@ function assignValueToProperty(property, value)
...
"InvalidApiKeySize",struct( ...
"Input",{{ "ApiKey" ["abc" "abc"] }},...
"Error","MATLAB:validation:IncompatibleSize"));
"Error","MATLAB:validators:mustBeTextScalar"));
end

function invalidGenerateInput = iGetInvalidGenerateInput
Expand Down Expand Up @@ -354,5 +356,4 @@ function assignValueToProperty(property, value)
"InvalidFunctionCallSize",struct( ...
"Input",{{ validMessages "FunctionCall" ["validfunction", "validfunction"] }},...
"Error","MATLAB:validators:mustBeTextScalar"));
end

end
24 changes: 17 additions & 7 deletions tests/topenAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
InvalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt;
InvalidInputsResponseMessage = iGetInvalidInputsResponseMessage;
InvalidRemoveMessage = iGetInvalidRemoveMessage;
ValidTextInput = {"This is okay"; 'this is ok'};
end

methods(Test)
Expand All @@ -17,6 +18,15 @@ function constructorStartsWithEmptyMessages(testCase)
testCase.verifyTrue(isempty(msgs.Messages));
end

function differentInputTextAccepted(testCase, ValidTextInput)
msgs = openAIMessages;
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput));
testCase.verifyWarningFree(@()addFunctionMessage(msgs, ValidTextInput, ValidTextInput));
end


function systemMessageIsAdded(testCase)
prompt = "Here is a system prompt";
name = "example";
Expand Down Expand Up @@ -56,7 +66,7 @@ function assistantFunctionCallMessageIsAdded(testCase)
msgs = openAIMessages;
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
funCall = struct("name", functionName, "arguments", args);
functionCallPrompt = struct("role", "assistant", "content", [], "function_call", funCall);
functionCallPrompt = struct("role", "assistant", "content", "", "function_call", funCall);
msgs = addResponseMessage(msgs, functionCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
end
Expand All @@ -65,7 +75,7 @@ function assistantFunctionCallMessageWithoutArgsIsAdded(testCase)
functionName = "functionName";
msgs = openAIMessages;
funCall = struct("name", functionName, "arguments", "{}");
functionCallPrompt = struct("role", "assistant", "content", [], "function_call", funCall);
functionCallPrompt = struct("role", "assistant", "content", "", "function_call", funCall);
msgs = addResponseMessage(msgs, functionCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
end
Expand Down Expand Up @@ -145,11 +155,11 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
...
"NonScalarInputName", ...
struct("Input", {{["name1" "name2"], "content"}}, ...
"Error", "MATLAB:validation:IncompatibleSize"),...
"Error", "MATLAB:validators:mustBeTextScalar"),...
...
"NonScalarInputContent", ...
struct("Input", {{"name", ["content1", "content2"]}}, ...
"Error", "MATLAB:validation:IncompatibleSize"));
"Error", "MATLAB:validators:mustBeTextScalar"));
end

function invalidInputsUserPrompt = iGetInvalidInputsUserPrompt
Expand All @@ -160,7 +170,7 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
...
"NonScalarInput", ...
struct("Input", {{["prompt1" "prompt2"]}}, ...
"Error", "MATLAB:validation:IncompatibleSize"), ...
"Error", "MATLAB:validators:mustBeTextScalar"), ...
...
"EmptyInput", ...
struct("Input", {{""}}, ...
Expand All @@ -187,11 +197,11 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
...
"NonScalarInputName", ...
struct("Input", {{["name1" "name2"], "content"}}, ...
"Error", "MATLAB:validation:IncompatibleSize"),...
"Error", "MATLAB:validators:mustBeTextScalar"),...
...
"NonScalarInputContent", ...
struct("Input", {{"name", ["content1", "content2"]}}, ...
"Error", "MATLAB:validation:IncompatibleSize"));
"Error", "MATLAB:validators:mustBeTextScalar"));
end

function invalidRemoveMessage = iGetInvalidRemoveMessage
Expand Down

0 comments on commit d6fbeef

Please sign in to comment.