diff --git a/azureChat.m b/azureChat.m index 313c0f4..05fd467 100644 --- a/azureChat.m +++ b/azureChat.m @@ -174,6 +174,43 @@ % Seed - An integer value to use to obtain % reproducible responses % + % Temperature - Temperature value for controlling the randomness + % of the output. The default value is CHAT.Temperature; + % higher values increase the randomness (in some sense, + % the “creativity”) of outputs, lower values + % reduce it. Setting Temperature=0 removes + % randomness from the output altogether. + % + % TopP - Top probability mass value for controlling the + % diversity of the output. Default value is CHAT.TopP; + % lower values imply that only the more likely + % words can appear in any particular place. + % This is also known as top-p sampling. + % + % StopSequences - Vector of strings that when encountered, will + % stop the generation of tokens. Default + % value is CHAT.StopSequences. + % Example: ["The end.", "And that's all she wrote."] + % + % ResponseFormat - The format of response the model returns. + % Default value is CHAT.ResponseFormat. + % "text" | "json" + % + % PresencePenalty - Penalty value for using a token in the response + % that has already been used. Default value is + % CHAT.PresencePenalty. + % Higher values reduce repetition of words in the output. + % + % FrequencyPenalty - Penalty value for using a token that is frequent + % in the output. Default value is CHAT.FrequencyPenalty. + % Higher values reduce repetition of words in the output. + % + % StreamFun - Function to callback when streaming the result. + % Default value is CHAT.StreamFun. + % + % TimeOut - Connection Timeout in seconds. Default value is CHAT.TimeOut. + % + % % Currently, GPT-4 Turbo with vision does not support the message.name % parameter, functions/tools, response_format parameter, stop % sequences, and max_tokens @@ -181,6 +218,15 @@ arguments this (1,1) azureChat messages {mustBeValidMsgs} + nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature + nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP + nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences + nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat + nvp.APIKey {mustBeNonzeroLengthTextScalar} = this.APIKey + nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = this.PresencePenalty + nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = this.FrequencyPenalty + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut + nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1 nvp.MaxNumTokens (1,1) {mustBePositive} = inf nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = [] @@ -199,15 +245,22 @@ end toolChoice = convertToolChoice(this, nvp.ToolChoice); + + if isfield(nvp,"StreamFun") + streamFun = nvp.StreamFun; + else + streamFun = this.StreamFun; + end + try [text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ... this.DeploymentID, messagesStruct, this.FunctionsStruct, ... - ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ... - TopP=this.TopP, NumCompletions=nvp.NumCompletions,... - StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... - PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ... - ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... - APIKey=this.APIKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); + ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=nvp.Temperature, ... + TopP=nvp.TopP, NumCompletions=nvp.NumCompletions,... + StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... + PresencePenalty=nvp.PresencePenalty, FrequencyPenalty=nvp.FrequencyPenalty, ... + ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ... + APIKey=nvp.APIKey,TimeOut=nvp.TimeOut,StreamFun=streamFun); catch ME if ismember(ME.identifier,... ["MATLAB:webservices:UnknownHost","MATLAB:webservices:Timeout"]) diff --git a/functionSignatures.json b/functionSignatures.json index fcafc31..a380251 100644 --- a/functionSignatures.json +++ b/functionSignatures.json @@ -31,7 +31,17 @@ {"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]}, {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, {"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"}, - {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} + {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}, + {"name":"ModelName","kind":"namevalue","type":"choices=llms.openai.models"}, + {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, + {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, + {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, + {"name":"APIKey","kind":"namevalue","type":["string","scalar"]}, + {"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, + {"name":"StreamFun","kind":"namevalue","type":"function_handle"} ], "outputs": [ @@ -73,7 +83,16 @@ {"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]}, {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, {"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"}, - {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} + {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}, + {"name":"APIKey","kind":"namevalue","type":["string","scalar"]}, + {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, + {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, + {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, + {"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, + {"name":"StreamFun","kind":"namevalue","type":"function_handle"} ], "outputs": [ @@ -90,12 +109,14 @@ {"name":"systemPrompt","kind":"ordered","type":["string","scalar"]}, {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, {"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]}, {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, {"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]}, {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, - {"name":"StreamFun","kind":"namevalue","type":"function_handle"} + {"name":"StreamFun","kind":"namevalue","type":"function_handle"}, + {"name":"Endpoint","kind":"namevalue","type":["string","scalar"]} ], "outputs": [ @@ -109,7 +130,18 @@ {"name":"this","kind":"required","type":["ollamaChat","scalar"]}, {"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]}, {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, - {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} + {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}, + {"name":"Model","kind":"namevalue","type":"choices=ollamaChat.models"}, + {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, + {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]}, + {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, + {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, + {"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]}, + {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, + {"name":"StreamFun","kind":"namevalue","type":"function_handle"}, + {"name":"Endpoint","kind":"namevalue","type":["string","scalar"]} ], "outputs": [ diff --git a/ollamaChat.m b/ollamaChat.m index 031b7f3..a6d7419 100644 --- a/ollamaChat.m +++ b/ollamaChat.m @@ -47,7 +47,6 @@ % value is empty. % Example: ["The end.", "And that's all she wrote."] % -% % ResponseFormat - The format of response the model returns. % "text" (default) | "json" % @@ -128,17 +127,79 @@ % [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options % using one or more name-value arguments: % - % MaxNumTokens - Maximum number of tokens in the generated response. - % Default value is inf. + % MaxNumTokens - Maximum number of tokens in the generated response. + % Default value is inf. + % + % Seed - An integer value to use to obtain + % reproducible responses + % + % Model - Model name (as expected by Ollama server). + % Default value is CHAT.Model. + % + % Temperature - Temperature value for controlling the randomness + % of the output. Default value is CHAT.Temperature. + % Higher values increase the randomness (in some + % sense, the “creativity”) of outputs, lower + % values reduce it. Setting Temperature=0 removes + % randomness from the output altogether. + % + % TopP - Top probability mass value for controlling the + % diversity of the output. Default value is CHAT.TopP; + % lower values imply that only the more likely + % words can appear in any particular place. + % This is also known as top-p sampling. + % + % MinP - Minimum probability ratio for controlling the + % diversity of the output. Default value is CHAT.MinP; + % higher values imply that only the more likely + % words can appear in any particular place. + % This is also known as min-p sampling. + % + % TopK - Maximum number of most likely tokens that are + % considered for output. Default is CHAT.TopK. + % Smaller values reduce diversity in the output. + % + % TailFreeSamplingZ - Reduce the use of less probable tokens, based on + % the second-order differences of ordered + % probabilities. + % Default value is CHAT.TailFreeSamplingZ. + % Lower values reduce diversity, with + % some authors recommending values + % around 0.95. Tail-free sampling is + % slower than using TopP or TopK. + % + % StopSequences - Vector of strings that when encountered, will + % stop the generation of tokens. Default + % value is CHAT.StopSequences. + % Example: ["The end.", "And that's all she wrote."] + % + % + % ResponseFormat - The format of response the model returns. + % The default value is CHAT.ResponseFormat. + % "text" (default) | "json" + % + % StreamFun - Function to callback when streaming the + % result. The default value is CHAT.StreamFun. + % + % TimeOut - Connection Timeout in seconds. Default is CHAT.TimeOut. % - % Seed - An integer value to use to obtain - % reproducible responses arguments this (1,1) ollamaChat - messages {mustBeValidMsgs} + messages {mustBeValidMsgs} + nvp.Model {mustBeTextScalar} = this.Model + nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature + nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP + nvp.MinP {llms.utils.mustBeValidProbability} = this.MinP + nvp.TopK (1,1) {mustBeReal,mustBePositive} = this.TopK + nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences + nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut + nvp.TailFreeSamplingZ (1,1) {mustBeReal} = this.TailFreeSamplingZ + nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} + nvp.Endpoint (1,1) string = this.Endpoint nvp.MaxNumTokens (1,1) {mustBePositive} = inf - nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] + nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] end messages = convertCharsToStrings(messages); @@ -152,15 +213,21 @@ messagesStruct = horzcat(this.SystemPrompt, messagesStruct); end + if isfield(nvp,"StreamFun") + streamFun = nvp.StreamFun; + else + streamFun = this.StreamFun; + end + [text, message, response] = llms.internal.callOllamaChatAPI(... - this.Model, messagesStruct, ... - Temperature=this.Temperature, ... - TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,... - TailFreeSamplingZ=this.TailFreeSamplingZ,... - StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... - ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... - TimeOut=this.TimeOut, StreamFun=this.StreamFun, ... - Endpoint=this.Endpoint); + nvp.Model, messagesStruct, ... + Temperature=nvp.Temperature, ... + TopP=nvp.TopP, MinP=nvp.MinP, TopK=nvp.TopK,... + TailFreeSamplingZ=nvp.TailFreeSamplingZ,... + StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... + ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ... + TimeOut=nvp.TimeOut, StreamFun=streamFun, ... + Endpoint=nvp.Endpoint); if isfield(response.Body.Data,"error") err = response.Body.Data.error; diff --git a/openAIChat.m b/openAIChat.m index 46bf0d6..89cf932 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -162,17 +162,67 @@ % Seed - An integer value to use to obtain % reproducible responses % + % ModelName - Name of the model to use for chat completions. + % The default value is CHAT.ModelName. + % + % Temperature - Temperature value for controlling the randomness + % of the output. Default value is CHAT.Temperatur; + % higher values increase the randomness (in some sense, + % the “creativity”) of outputs, lower values + % reduce it. Setting Temperature=0 removes + % randomness from the output altogether. + % + % TopP - Top probability mass value for controlling the + % diversity of the output. Default value is CHAT.TopP; + % lower values imply that only the more likely + % words can appear in any particular place. + % This is also known as top-p sampling. + % + % StopSequences - Vector of strings that when encountered, will + % stop the generation of tokens. Default + % value is CHAT.StopSequences. + % Example: ["The end.", "And that's all she wrote."] + % + % PresencePenalty - Penalty value for using a token in the response + % that has already been used. Default value is + % CHAT.PresencePenalty. Higher values reduce repetition + % of words in the output. + % + % FrequencyPenalty - Penalty value for using a token that is frequent + % in the output. Default value is CHAT.FrequencyPenalty. + % Higher values reduce repetition of words in the output. + % + % TimeOut - Connection Timeout in seconds. + % Default value is CHAT.TimeOut. + % + % StreamFun - Function to callback when streaming the + % result. Default value is CHAT.StreamFun. + % + % ResponseFormat - The format of response the model returns. + % Default value is CHAT.ResponseFormat. + % "text" | "json" + % % Currently, GPT-4 Turbo with vision does not support the message.name % parameter, functions/tools, response_format parameter, and stop % sequences. It also has a low MaxNumTokens default, which can be overridden. arguments this (1,1) openAIChat - messages {mustBeValidMsgs} + messages {mustBeValidMsgs} + nvp.ModelName (1,1) string {mustBeModel} = "gpt-4o-mini" + nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature + nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP + nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences + nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat + nvp.APIKey {mustBeNonzeroLengthTextScalar} = this.APIKey + nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = this.PresencePenalty + nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = this.FrequencyPenalty + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut + nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1 nvp.MaxNumTokens (1,1) {mustBePositive} = inf - nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = [] - nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] + nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = [] + nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] end toolChoice = convertToolChoice(this, nvp.ToolChoice); @@ -190,13 +240,19 @@ messagesStruct = horzcat(this.SystemPrompt, messagesStruct); end + if isfield(nvp,"StreamFun") + streamFun = nvp.StreamFun; + else + streamFun = this.StreamFun; + end + [text, message, response] = llms.internal.callOpenAIChatAPI(messagesStruct, this.FunctionsStruct,... - ModelName=this.ModelName, ToolChoice=toolChoice, Temperature=this.Temperature, ... - TopP=this.TopP, NumCompletions=nvp.NumCompletions,... - StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... - PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ... - ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... - APIKey=this.APIKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); + ModelName=nvp.ModelName, ToolChoice=toolChoice, Temperature=nvp.Temperature, ... + TopP=nvp.TopP, NumCompletions=nvp.NumCompletions,... + StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... + PresencePenalty=nvp.PresencePenalty, FrequencyPenalty=nvp.FrequencyPenalty, ... + ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ... + APIKey=nvp.APIKey,TimeOut=nvp.TimeOut, StreamFun=streamFun); if isfield(response.Body.Data,"error") err = response.Body.Data.error.message; diff --git a/tests/tazureChat.m b/tests/tazureChat.m index a16a5da..704619d 100644 --- a/tests/tazureChat.m +++ b/tests/tazureChat.m @@ -76,6 +76,13 @@ function generateWithMultipleImages(testCase) testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical")); end + function generateOverridesProperties(testCase) + import matlab.unittest.constraints.EndsWithSubstring + chat = azureChat; + text = generate(chat, "Please count from 1 to 10.", Temperature = 0, StopSequences = "4"); + testCase.verifyThat(text, EndsWithSubstring("3, ")); + end + function doReturnErrors(testCase) testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); chat = azureChat; diff --git a/tests/tollamaChat.m b/tests/tollamaChat.m index bdde8b9..242a3b8 100644 --- a/tests/tollamaChat.m +++ b/tests/tollamaChat.m @@ -46,6 +46,13 @@ function doGenerateUsingSystemPrompt(testCase) testCase.verifyGreaterThan(strlength(response),0); end + function generateOverridesProperties(testCase) + import matlab.unittest.constraints.EndsWithSubstring + chat = ollamaChat("mistral"); + text = generate(chat, "Please count from 1 to 10.", Temperature = 0, StopSequences = "4"); + testCase.verifyThat(text, EndsWithSubstring("3, ")); + end + function extremeTopK(testCase) %% This should work, and it does on some computers. On others, Ollama %% receives the parameter, but either Ollama or llama.cpp fails to diff --git a/tests/topenAIChat.m b/tests/topenAIChat.m index ad3f69e..216c40d 100644 --- a/tests/topenAIChat.m +++ b/tests/topenAIChat.m @@ -85,7 +85,9 @@ function settingToolChoiceWithNone(testCase) function fixedSeedFixesResult(testCase) % Seed is "beta" in OpenAI documentation - % and not reliable in gpt-4o-mini at this time. + % and not reliable at this time. + testCase.assumeTrue(false,"disabled since the server is unreliable in honoring the Seed parameter"); + chat = openAIChat(ModelName="gpt-3.5-turbo"); result1 = generate(chat,"This is okay", "Seed", 2); @@ -202,6 +204,13 @@ function generateWithMultipleImages(testCase) testCase.verifyThat(text,ContainsSubstring("same") | ContainsSubstring("identical")); end + function generateOverridesProperties(testCase) + import matlab.unittest.constraints.EndsWithSubstring + chat = openAIChat; + text = generate(chat, "Please count from 1 to 10.", Temperature = 0, StopSequences = "4"); + testCase.verifyThat(text, EndsWithSubstring("3, ")); + end + function invalidInputsGenerate(testCase, InvalidGenerateInput) f = openAIFunction("validfunction"); chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key");