From a7e617064b80a7a9aa42715c3220b1b87f5c588a Mon Sep 17 00:00:00 2001 From: Christopher Creutzig <89011131+ccreutzi@users.noreply.github.com> Date: Wed, 14 Aug 2024 10:00:49 +0200 Subject: [PATCH 1/3] Implement `MinP` for `ollamaChat` --- +llms/+internal/callOllamaChatAPI.m | 2 ++ ollamaChat.m | 11 ++++++++- tests/tollamaChat.m | 38 +++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/+llms/+internal/callOllamaChatAPI.m b/+llms/+internal/callOllamaChatAPI.m index 0bad15f..ce81780 100644 --- a/+llms/+internal/callOllamaChatAPI.m +++ b/+llms/+internal/callOllamaChatAPI.m @@ -29,6 +29,7 @@ messages nvp.Temperature nvp.TopP + nvp.MinP nvp.TopK nvp.TailFreeSamplingZ nvp.StopSequences @@ -103,6 +104,7 @@ dict = dictionary(); dict("Temperature") = "temperature"; dict("TopP") = "top_p"; +dict("MinP") = "min_p"; dict("TopK") = "top_k"; dict("TailFreeSamplingZ") = "tfs_z"; dict("StopSequences") = "stop"; diff --git a/ollamaChat.m b/ollamaChat.m index 6d9e5a0..df736d9 100644 --- a/ollamaChat.m +++ b/ollamaChat.m @@ -23,6 +23,12 @@ % words can appear in any particular place. % This is also known as top-p sampling. % +% MinP - Minimum probability ratio for controlling the +% diversity of the output. Default value is 0; +% higher values imply that only the more likely +% words can appear in any particular place. +% This is also known as min-p sampling. +% % TopK - Maximum number of most likely tokens that are % considered for output. Default is Inf, allowing % all tokens. Smaller values reduce diversity in @@ -67,6 +73,7 @@ Model (1,1) string Endpoint (1,1) string TopK (1,1) {mustBeReal,mustBePositive} = Inf + MinP (1,1) {llms.utils.mustBeValidTopP} = 0 TailFreeSamplingZ (1,1) {mustBeReal} = 1 end @@ -77,6 +84,7 @@ systemPrompt {llms.utils.mustBeTextOrEmpty} = [] nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 nvp.TopP {llms.utils.mustBeValidTopP} = 1 + nvp.MinP {llms.utils.mustBeValidTopP} = 0 nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" @@ -103,6 +111,7 @@ this.ResponseFormat = nvp.ResponseFormat; this.Temperature = nvp.Temperature; this.TopP = nvp.TopP; + this.MinP = nvp.MinP; this.TopK = nvp.TopK; this.TailFreeSamplingZ = nvp.TailFreeSamplingZ; this.StopSequences = nvp.StopSequences; @@ -146,7 +155,7 @@ [text, message, response] = llms.internal.callOllamaChatAPI(... this.Model, messagesStruct, ... Temperature=this.Temperature, ... - TopP=this.TopP, TopK=this.TopK,... + TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,... TailFreeSamplingZ=this.TailFreeSamplingZ,... StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... diff --git a/tests/tollamaChat.m b/tests/tollamaChat.m index 4320774..342e7df 100644 --- a/tests/tollamaChat.m +++ b/tests/tollamaChat.m @@ -50,7 +50,7 @@ function extremeTopK(testCase) %% This should work, and it does on some computers. On others, Ollama %% receives the parameter, but either Ollama or llama.cpp fails to %% honor it correctly. - testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + % testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); % setting top-k to k=1 leaves no random choice, % so we expect to get a fixed response. @@ -61,11 +61,27 @@ function extremeTopK(testCase) testCase.verifyEqual(response1,response2); end + function extremeMinP(testCase) + %% This should work, and it does on some computers. On others, Ollama + %% receives the parameter, but either Ollama or llama.cpp fails to + %% honor it correctly. + % testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + + % setting min-p to p=1 means only tokens with the same logit as + % the most likely one can be chosen, which will almost certainly + % only ever be one, so we expect to get a fixed response. + chat = ollamaChat("mistral",MinP=1); + prompt = "Min-p sampling with p=1 returns a definite answer."; + response1 = generate(chat,prompt); + response2 = generate(chat,prompt); + testCase.verifyEqual(response1,response2); + end + function extremeTfsZ(testCase) %% This should work, and it does on some computers. On others, Ollama %% receives the parameter, but either Ollama or llama.cpp fails to %% honor it correctly. - testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + % testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); % setting tfs_z to z=0 leaves no random choice, but degrades to % greedy sampling, so we expect to get a fixed response. @@ -235,6 +251,16 @@ function queryModels(testCase) "Value", -20, ... "Error", "MATLAB:expectedNonnegative"), ... ... + "MinPTooLarge", struct( ... + "Property", "MinP", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "MinPTooSmall", struct( ... + "Property", "MinP", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... "WrongTypeStopSequences", struct( ... "Property", "StopSequences", ... "Value", 123, ... @@ -329,6 +355,14 @@ function queryModels(testCase) "Input",{{ "TopP" -20 }},... "Error","MATLAB:expectedNonnegative"),...I ... + "MinPTooLarge",struct( ... + "Input",{{ "MinP" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "MinPTooSmall",struct( ... + "Input",{{ "MinP" -20 }},... + "Error","MATLAB:expectedNonnegative"),...I + ... "WrongTypeStopSequences",struct( ... "Input",{{ "StopSequences" 123}},... "Error","MATLAB:validators:mustBeNonzeroLengthText"),... From 08160b7eeb369d4c645947d665b087a4aaf6635b Mon Sep 17 00:00:00 2001 From: Christopher Creutzig <89011131+ccreutzi@users.noreply.github.com> Date: Wed, 14 Aug 2024 11:17:45 +0200 Subject: [PATCH 2/3] Rename `mustBeValidTopP` to `mustBeValidProbability` We're now using this validator for more than just `TopP`, and a new name is in order. --- +llms/+internal/textGenerator.m | 2 +- .../+utils/{mustBeValidTopP.m => mustBeValidProbability.m} | 2 +- azureChat.m | 2 +- ollamaChat.m | 6 +++--- openAIChat.m | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) rename +llms/+utils/{mustBeValidTopP.m => mustBeValidProbability.m} (84%) diff --git a/+llms/+internal/textGenerator.m b/+llms/+internal/textGenerator.m index 204e516..b80589b 100644 --- a/+llms/+internal/textGenerator.m +++ b/+llms/+internal/textGenerator.m @@ -8,7 +8,7 @@ Temperature {llms.utils.mustBeValidTemperature} = 1 %TopP Top probability mass to consider for generation. - TopP {llms.utils.mustBeValidTopP} = 1 + TopP {llms.utils.mustBeValidProbability} = 1 %StopSequences Sequences to stop the generation of tokens. StopSequences {llms.utils.mustBeValidStop} = {} diff --git a/+llms/+utils/mustBeValidTopP.m b/+llms/+utils/mustBeValidProbability.m similarity index 84% rename from +llms/+utils/mustBeValidTopP.m rename to +llms/+utils/mustBeValidProbability.m index ed2bbd6..13e9a14 100644 --- a/+llms/+utils/mustBeValidTopP.m +++ b/+llms/+utils/mustBeValidProbability.m @@ -1,4 +1,4 @@ -function mustBeValidTopP(value) +function mustBeValidProbability(value) % This function is undocumented and will change in a future release % Copyright 2024 The MathWorks, Inc. diff --git a/azureChat.m b/azureChat.m index b4b8df5..313c0f4 100644 --- a/azureChat.m +++ b/azureChat.m @@ -109,7 +109,7 @@ nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-02-01" nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 - nvp.TopP {llms.utils.mustBeValidTopP} = 1 + nvp.TopP {llms.utils.mustBeValidProbability} = 1 nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0 diff --git a/ollamaChat.m b/ollamaChat.m index df736d9..031b7f3 100644 --- a/ollamaChat.m +++ b/ollamaChat.m @@ -73,7 +73,7 @@ Model (1,1) string Endpoint (1,1) string TopK (1,1) {mustBeReal,mustBePositive} = Inf - MinP (1,1) {llms.utils.mustBeValidTopP} = 0 + MinP (1,1) {llms.utils.mustBeValidProbability} = 0 TailFreeSamplingZ (1,1) {mustBeReal} = 1 end @@ -83,8 +83,8 @@ modelName {mustBeTextScalar} systemPrompt {llms.utils.mustBeTextOrEmpty} = [] nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 - nvp.TopP {llms.utils.mustBeValidTopP} = 1 - nvp.MinP {llms.utils.mustBeValidTopP} = 0 + nvp.TopP {llms.utils.mustBeValidProbability} = 1 + nvp.MinP {llms.utils.mustBeValidProbability} = 0 nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" diff --git a/openAIChat.m b/openAIChat.m index cbd2440..46bf0d6 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -94,7 +94,7 @@ nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty nvp.ModelName (1,1) string {mustBeModel} = "gpt-4o-mini" nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 - nvp.TopP {llms.utils.mustBeValidTopP} = 1 + nvp.TopP {llms.utils.mustBeValidProbability} = 1 nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" nvp.APIKey {mustBeNonzeroLengthTextScalar} From b0023dc858d9e2091b3ad29d2c5abc084ac4552f Mon Sep 17 00:00:00 2001 From: Christopher Creutzig <89011131+ccreutzi@users.noreply.github.com> Date: Tue, 20 Aug 2024 10:23:37 +0200 Subject: [PATCH 3/3] tests still unreliable with Ollama version in GitHub CI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These tests should work and do work locally. But they fail in GitHub CI – for an unknown reason that almost certainly is in Ollama, not in our code. --- tests/tollamaChat.m | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tollamaChat.m b/tests/tollamaChat.m index 342e7df..bdde8b9 100644 --- a/tests/tollamaChat.m +++ b/tests/tollamaChat.m @@ -50,7 +50,7 @@ function extremeTopK(testCase) %% This should work, and it does on some computers. On others, Ollama %% receives the parameter, but either Ollama or llama.cpp fails to %% honor it correctly. - % testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); % setting top-k to k=1 leaves no random choice, % so we expect to get a fixed response. @@ -65,7 +65,7 @@ function extremeMinP(testCase) %% This should work, and it does on some computers. On others, Ollama %% receives the parameter, but either Ollama or llama.cpp fails to %% honor it correctly. - % testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); % setting min-p to p=1 means only tokens with the same logit as % the most likely one can be chosen, which will almost certainly @@ -81,7 +81,7 @@ function extremeTfsZ(testCase) %% This should work, and it does on some computers. On others, Ollama %% receives the parameter, but either Ollama or llama.cpp fails to %% honor it correctly. - % testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); % setting tfs_z to z=0 leaves no random choice, but degrades to % greedy sampling, so we expect to get a fixed response.