diff --git a/+llms/+utils/errorMessageCatalog.m b/+llms/+utils/errorMessageCatalog.m index ad7ddf9..0908106 100644 --- a/+llms/+utils/errorMessageCatalog.m +++ b/+llms/+utils/errorMessageCatalog.m @@ -45,7 +45,7 @@ catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters."; catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1})."; catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4."; -catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPEN_API_KEY and not specified via ApiKey parameter."; +catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter."; catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages."; catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified."; catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects."; diff --git a/openAIChat.m b/openAIChat.m index f62f85e..000fa62 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -52,8 +52,6 @@ % % SystemPrompt - System prompt. % -% AvailableModels - List of available models. -% % FunctionNames - Names of the functions that the model can % request calls. @@ -93,25 +91,18 @@ ApiKey end - properties(Constant) - %AVAILABLEMODELS List of available models. - AvailableModels = ["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",... - "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",... - "gpt-3.5-turbo-16k-0613"] - end - methods function this = openAIChat(systemPrompt, nvp) arguments systemPrompt {llms.utils.mustBeTextOrEmpty} = [] nvp.Functions (1,:) {mustBeA(nvp.Functions, "openAIFunction")} = openAIFunction.empty - nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",... + nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ... "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",... "gpt-3.5-turbo-16k-0613"])} = "gpt-3.5-turbo" nvp.Temperature (1,1) {mustBeValidTemperature} = 1 nvp.TopProbabilityMass (1,1) {mustBeValidTopP} = 1 nvp.StopSequences (1,:) {mustBeValidStop} = {} - nvp.ApiKey (1,1) {mustBeNonzeroLengthText} + nvp.ApiKey {mustBeNonzeroLengthTextScalar} nvp.PresencePenalty (1,1) {mustBeValidPenalty} = 0 nvp.FrequencyPenalty (1,1) {mustBeValidPenalty} = 0 end @@ -249,6 +240,10 @@ function mustBeValidFunctionCall(this, functionCall) end end +function mustBeNonzeroLengthTextScalar(content) +mustBeNonzeroLengthText(content) +mustBeTextScalar(content) +end function [functionsStruct, functionNames] = functionAsStruct(functions) numFunctions = numel(functions); @@ -268,7 +263,7 @@ function mustBeValidMsgs(value) end else try - mustBeNonzeroLengthText(value); + mustBeNonzeroLengthTextScalar(value); catch ME error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt")); end diff --git a/openAIMessages.m b/openAIMessages.m index 8662c9d..e897f8a 100644 --- a/openAIMessages.m +++ b/openAIMessages.m @@ -39,11 +39,11 @@ arguments this (1,1) openAIMessages - name (1,1) {mustBeNonzeroLengthText} - content (1,1) {mustBeNonzeroLengthText} + name {mustBeNonzeroLengthTextScalar} + content {mustBeNonzeroLengthTextScalar} end - newMessage = struct("role", "system", "name", name, "content", content); + newMessage = struct("role", "system", "name", string(name), "content", string(content)); this.Messages{end+1} = newMessage; end @@ -62,10 +62,10 @@ arguments this (1,1) openAIMessages - content (1,1) {mustBeNonzeroLengthText} + content {mustBeNonzeroLengthTextScalar} end - newMessage = struct("role", "user", "content", content); + newMessage = struct("role", "user", "content", string(content)); this.Messages{end+1} = newMessage; end @@ -86,11 +86,11 @@ arguments this (1,1) openAIMessages - name (1,1) {mustBeNonzeroLengthText} - content (1,1) {mustBeNonzeroLengthText} + name {mustBeNonzeroLengthTextScalar} + content {mustBeNonzeroLengthTextScalar} end - newMessage = struct("role", "function", "name", name, "content", content); + newMessage = struct("role", "function", "name", string(name), "content", string(content)); this.Messages{end+1} = newMessage; end @@ -133,7 +133,7 @@ if isfield(messageStruct, "function_call") funCall = messageStruct.function_call; validateAssistantWithFunctionCall(funCall) - this = addAssistantMessage(this,funCall.name, funCall.arguments); + this = addAssistantMessage(this, funCall.name, funCall.arguments); else % Simple assistant response validateRegularAssistant(messageStruct.content); @@ -197,6 +197,11 @@ end end +function mustBeNonzeroLengthTextScalar(content) +mustBeNonzeroLengthText(content) +mustBeTextScalar(content) +end + function validateRegularAssistant(content) try mustBeNonzeroLengthText(content) diff --git a/tests/topenAIChat.m b/tests/topenAIChat.m index 2f17215..5d0732a 100644 --- a/tests/topenAIChat.m +++ b/tests/topenAIChat.m @@ -26,6 +26,8 @@ function saveEnvVar(testCase) function generateAcceptsSingleStringAsInput(testCase) chat = openAIChat(ApiKey="this-is-not-a-real-key"); testCase.verifyWarningFree(@()generate(chat,"This is okay")); + chat = openAIChat(ApiKey='this-is-not-a-real-key'); + testCase.verifyWarningFree(@()generate(chat,"This is okay")); end function generateAcceptsMessagesAsInput(testCase) @@ -307,7 +309,7 @@ function assignValueToProperty(property, value) ... "InvalidApiKeySize",struct( ... "Input",{{ "ApiKey" ["abc" "abc"] }},... - "Error","MATLAB:validation:IncompatibleSize")); + "Error","MATLAB:validators:mustBeTextScalar")); end function invalidGenerateInput = iGetInvalidGenerateInput @@ -354,5 +356,4 @@ function assignValueToProperty(property, value) "InvalidFunctionCallSize",struct( ... "Input",{{ validMessages "FunctionCall" ["validfunction", "validfunction"] }},... "Error","MATLAB:validators:mustBeTextScalar")); -end - +end \ No newline at end of file diff --git a/tests/topenAIMessages.m b/tests/topenAIMessages.m index e465b03..39db370 100644 --- a/tests/topenAIMessages.m +++ b/tests/topenAIMessages.m @@ -9,6 +9,7 @@ InvalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt; InvalidInputsResponseMessage = iGetInvalidInputsResponseMessage; InvalidRemoveMessage = iGetInvalidRemoveMessage; + ValidTextInput = {"This is okay"; 'this is ok'}; end methods(Test) @@ -17,6 +18,15 @@ function constructorStartsWithEmptyMessages(testCase) testCase.verifyTrue(isempty(msgs.Messages)); end + function differentInputTextAccepted(testCase, ValidTextInput) + msgs = openAIMessages; + testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); + testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); + testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput)); + testCase.verifyWarningFree(@()addFunctionMessage(msgs, ValidTextInput, ValidTextInput)); + end + + function systemMessageIsAdded(testCase) prompt = "Here is a system prompt"; name = "example"; @@ -56,7 +66,7 @@ function assistantFunctionCallMessageIsAdded(testCase) msgs = openAIMessages; args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; funCall = struct("name", functionName, "arguments", args); - functionCallPrompt = struct("role", "assistant", "content", [], "function_call", funCall); + functionCallPrompt = struct("role", "assistant", "content", "", "function_call", funCall); msgs = addResponseMessage(msgs, functionCallPrompt); testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt); end @@ -65,7 +75,7 @@ function assistantFunctionCallMessageWithoutArgsIsAdded(testCase) functionName = "functionName"; msgs = openAIMessages; funCall = struct("name", functionName, "arguments", "{}"); - functionCallPrompt = struct("role", "assistant", "content", [], "function_call", funCall); + functionCallPrompt = struct("role", "assistant", "content", "", "function_call", funCall); msgs = addResponseMessage(msgs, functionCallPrompt); testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt); end @@ -145,11 +155,11 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage) ... "NonScalarInputName", ... struct("Input", {{["name1" "name2"], "content"}}, ... - "Error", "MATLAB:validation:IncompatibleSize"),... + "Error", "MATLAB:validators:mustBeTextScalar"),... ... "NonScalarInputContent", ... struct("Input", {{"name", ["content1", "content2"]}}, ... - "Error", "MATLAB:validation:IncompatibleSize")); + "Error", "MATLAB:validators:mustBeTextScalar")); end function invalidInputsUserPrompt = iGetInvalidInputsUserPrompt @@ -160,7 +170,7 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage) ... "NonScalarInput", ... struct("Input", {{["prompt1" "prompt2"]}}, ... - "Error", "MATLAB:validation:IncompatibleSize"), ... + "Error", "MATLAB:validators:mustBeTextScalar"), ... ... "EmptyInput", ... struct("Input", {{""}}, ... @@ -187,11 +197,11 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage) ... "NonScalarInputName", ... struct("Input", {{["name1" "name2"], "content"}}, ... - "Error", "MATLAB:validation:IncompatibleSize"),... + "Error", "MATLAB:validators:mustBeTextScalar"),... ... "NonScalarInputContent", ... struct("Input", {{"name", ["content1", "content2"]}}, ... - "Error", "MATLAB:validation:IncompatibleSize")); + "Error", "MATLAB:validators:mustBeTextScalar")); end function invalidRemoveMessage = iGetInvalidRemoveMessage