From 626ddc64af7b2f368c2d35827116e52117d387b9 Mon Sep 17 00:00:00 2001 From: Christopher Creutzig Date: Fri, 17 May 2024 14:28:17 +0100 Subject: [PATCH] Move model capabilitiy verification out of openAIChat.m, for maintainability --- +llms/+openai/models.m | 12 ++ +llms/+openai/validateMessageSupported.m | 13 ++ +llms/+openai/validateResponseFormat.m | 16 +++ openAIChat.m | 31 ++--- tests/topenAIChat.m | 169 ++++++++++++++++++++++- 5 files changed, 217 insertions(+), 24 deletions(-) create mode 100644 +llms/+openai/models.m create mode 100644 +llms/+openai/validateMessageSupported.m create mode 100644 +llms/+openai/validateResponseFormat.m diff --git a/+llms/+openai/models.m b/+llms/+openai/models.m new file mode 100644 index 0000000..8bee9b2 --- /dev/null +++ b/+llms/+openai/models.m @@ -0,0 +1,12 @@ +function models = models +%MODELS - supported OpenAI models + +% Copyright 2024 The MathWorks, Inc. + models = [... + "gpt-4o","gpt-4o-2024-05-13",... + "gpt-4-turbo","gpt-4-turbo-2024-04-09",... + "gpt-4","gpt-4-0613", ... + "gpt-3.5-turbo","gpt-3.5-turbo-0125", ... + "gpt-3.5-turbo-1106",... + ]; +end diff --git a/+llms/+openai/validateMessageSupported.m b/+llms/+openai/validateMessageSupported.m new file mode 100644 index 0000000..bc091c9 --- /dev/null +++ b/+llms/+openai/validateMessageSupported.m @@ -0,0 +1,13 @@ +function validateMessageSupported(message, model); +%validateMessageSupported - check that message is supported by model + +% Copyright 2024 The MathWorks, Inc. + + % only certain models support image generation + if iscell(message.content) && any(cellfun(@(x) isfield(x,"image_url"), message.content)) + if ~ismember(model,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"]) + error("llms:invalidContentTypeForModel", ... + llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", model)); + end + end +end diff --git a/+llms/+openai/validateResponseFormat.m b/+llms/+openai/validateResponseFormat.m new file mode 100644 index 0000000..814fe96 --- /dev/null +++ b/+llms/+openai/validateResponseFormat.m @@ -0,0 +1,16 @@ +function validateResponseFormat(format,model) +%validateResponseFormat - validate requested response format is available for selected model +% Not all OpenAI models support JSON output + +% Copyright 2024 The MathWorks, Inc. + + if format == "json" + if ismember(model,["gpt-4","gpt-4-0613"]) + error("llms:invalidOptionAndValueForModel", ... + llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", model)); + else + warning("llms:warningJsonInstruction", ... + llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction")) + end + end +end diff --git a/openAIChat.m b/openAIChat.m index c547fb9..23b4add 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -114,13 +114,7 @@ arguments systemPrompt {llms.utils.mustBeTextOrEmpty} = [] nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty - nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,[... - "gpt-4o","gpt-4o-2024-05-13",... - "gpt-4-turbo","gpt-4-turbo-2024-04-09",... - "gpt-4","gpt-4-0613", ... - "gpt-3.5-turbo","gpt-3.5-turbo-0125", ... - "gpt-3.5-turbo-1106",... - ])} = "gpt-3.5-turbo" + nvp.ModelName (1,1) string {mustBeModel} = "gpt-3.5-turbo" nvp.Temperature {mustBeValidTemperature} = 1 nvp.TopProbabilityMass {mustBeValidTopP} = 1 nvp.StopSequences {mustBeValidStop} = {} @@ -160,16 +154,8 @@ this.StopSequences = nvp.StopSequences; % ResponseFormat is only supported in the latest models only - if nvp.ResponseFormat == "json" - if ismember(this.ModelName,["gpt-4","gpt-4-0613"]) - error("llms:invalidOptionAndValueForModel", ... - llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", this.ModelName)); - else - warning("llms:warningJsonInstruction", ... - llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction")) - end - - end + llms.openai.validateResponseFormat(nvp.ResponseFormat, this.ModelName); + this.ResponseFormat = nvp.ResponseFormat; this.PresencePenalty = nvp.PresencePenalty; this.FrequencyPenalty = nvp.FrequencyPenalty; @@ -219,12 +205,7 @@ messagesStruct = messages.Messages; end - if iscell(messagesStruct{end}.content) && any(cellfun(@(x) isfield(x,"image_url"), messagesStruct{end}.content)) - if ~ismember(this.ModelName,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"]) - error("llms:invalidContentTypeForModel", ... - llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", this.ModelName)); - end - end + llms.openai.validateMessageSupported(messagesStruct{end}, model); if ~isempty(this.SystemPrompt) messagesStruct = horzcat(this.SystemPrompt, messagesStruct); @@ -334,3 +315,7 @@ function mustBeIntegerOrEmpty(value) mustBeInteger(value) end end + +function mustBeModel(model) + mustBeMember(model,llms.openai.models); +end diff --git a/tests/topenAIChat.m b/tests/topenAIChat.m index 1d43df8..22fa9bf 100644 --- a/tests/topenAIChat.m +++ b/tests/topenAIChat.m @@ -16,6 +16,7 @@ function saveEnvVar(testCase) end properties(TestParameter) + ValidConstructorInput = iGetValidConstructorInput(); InvalidConstructorInput = iGetInvalidConstructorInput(); InvalidGenerateInput = iGetInvalidGenerateInput(); InvalidValuesSetters = iGetInvalidValuesSetters(); @@ -65,6 +66,21 @@ function constructChatWithAllNVP(testCase) testCase.verifyEqual(chat.PresencePenalty, presenceP); end + function validConstructorCalls(testCase,ValidConstructorInput) + if isempty(ValidConstructorInput.ExpectedWarning) + chat = testCase.verifyWarningFree(... + @() openAIChat(ValidConstructorInput.Input{:})); + else + chat = testCase.verifyWarning(... + @() openAIChat(ValidConstructorInput.Input{:}), ... + ValidConstructorInput.ExpectedWarning); + end + properties = ValidConstructorInput.VerifyProperties; + for prop=string(fieldnames(properties)).' + testCase.verifyEqual(chat.(prop),properties.(prop),"Property " + prop); + end + end + function verySmallTimeOutErrors(testCase) chat = openAIChat(TimeOut=0.0001, ApiKey="false-key"); @@ -126,7 +142,6 @@ function noStopSequencesNoMaxNumTokens(testCase) end function createOpenAIChatWithStreamFunc(testCase) - function seen = sf(str) persistent data; if isempty(data) @@ -275,6 +290,158 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) "Error", "MATLAB:notGreaterEqual")); end +function validConstructorInput = iGetValidConstructorInput() +% while it is valid to provide the key via an environment variable, +% this test set does not use that, for easier setup +validFunction = openAIFunction("funName"); +validConstructorInput = struct( ... + "JustKey", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key"}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "SystemPrompt", struct( ... + "Input",{{"system prompt","ApiKey","this-is-not-a-real-key"}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {{struct("role","system","content","system prompt")}}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "Temperature", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","Temperature",2}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {2}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "TopProbabilityMass", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","TopProbabilityMass",0.2}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {0.2}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "StopSequences", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","StopSequences",["foo","bar"]}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {["foo","bar"]}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "PresencePenalty", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","PresencePenalty",0.1}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0.1}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "FrequencyPenalty", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","FrequencyPenalty",0.1}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0.1}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "TimeOut", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","TimeOut",0.1}}, ... + "ExpectedWarning", '', ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {0.1}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"text"} ... + ) ... + ), ... + "ResponseFormat", struct( ... + "Input",{{"ApiKey","this-is-not-a-real-key","ResponseFormat","json"}}, ... + "ExpectedWarning", "llms:warningJsonInstruction", ... + "VerifyProperties", struct( ... + "Temperature", {1}, ... + "TopProbabilityMass", {1}, ... + "StopSequences", {{}}, ... + "PresencePenalty", {0}, ... + "FrequencyPenalty", {0}, ... + "TimeOut", {10}, ... + "FunctionNames", {[]}, ... + "ModelName", {"gpt-3.5-turbo"}, ... + "SystemPrompt", {[]}, ... + "ResponseFormat", {"json"} ... + ) ... + ) ... + ); +end + function invalidConstructorInput = iGetInvalidConstructorInput() validFunction = openAIFunction("funName"); invalidConstructorInput = struct( ...