Skip to content

Commit

Permalink
Merge pull request #77 from matlab-deep-learning/minp
Browse files Browse the repository at this point in the history
Add min-p sampling for `ollamaChat`
  • Loading branch information
ccreutzi committed Aug 20, 2024
2 parents 150d9c1 + b0023dc commit d127953
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 6 deletions.
2 changes: 2 additions & 0 deletions +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
messages
nvp.Temperature
nvp.TopP
nvp.MinP
nvp.TopK
nvp.TailFreeSamplingZ
nvp.StopSequences
Expand Down Expand Up @@ -103,6 +104,7 @@
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("MinP") = "min_p";
dict("TopK") = "top_k";
dict("TailFreeSamplingZ") = "tfs_z";
dict("StopSequences") = "stop";
Expand Down
2 changes: 1 addition & 1 deletion +llms/+internal/textGenerator.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Temperature {llms.utils.mustBeValidTemperature} = 1

%TopP Top probability mass to consider for generation.
TopP {llms.utils.mustBeValidTopP} = 1
TopP {llms.utils.mustBeValidProbability} = 1

%StopSequences Sequences to stop the generation of tokens.
StopSequences {llms.utils.mustBeValidStop} = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function mustBeValidTopP(value)
function mustBeValidProbability(value)
% This function is undocumented and will change in a future release

% Copyright 2024 The MathWorks, Inc.
Expand Down
2 changes: 1 addition & 1 deletion azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.APIVersion (1,1) string {mustBeAPIVersion} = "2024-02-01"
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.TopP {llms.utils.mustBeValidProbability} = 1
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0
Expand Down
13 changes: 11 additions & 2 deletions ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
% words can appear in any particular place.
% This is also known as top-p sampling.
%
% MinP - Minimum probability ratio for controlling the
% diversity of the output. Default value is 0;
% higher values imply that only the more likely
% words can appear in any particular place.
% This is also known as min-p sampling.
%
% TopK - Maximum number of most likely tokens that are
% considered for output. Default is Inf, allowing
% all tokens. Smaller values reduce diversity in
Expand Down Expand Up @@ -67,6 +73,7 @@
Model (1,1) string
Endpoint (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
MinP (1,1) {llms.utils.mustBeValidProbability} = 0
TailFreeSamplingZ (1,1) {mustBeReal} = 1
end

Expand All @@ -76,7 +83,8 @@
modelName {mustBeTextScalar}
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.TopP {llms.utils.mustBeValidProbability} = 1
nvp.MinP {llms.utils.mustBeValidProbability} = 0
nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
Expand All @@ -103,6 +111,7 @@
this.ResponseFormat = nvp.ResponseFormat;
this.Temperature = nvp.Temperature;
this.TopP = nvp.TopP;
this.MinP = nvp.MinP;
this.TopK = nvp.TopK;
this.TailFreeSamplingZ = nvp.TailFreeSamplingZ;
this.StopSequences = nvp.StopSequences;
Expand Down Expand Up @@ -146,7 +155,7 @@
[text, message, response] = llms.internal.callOllamaChatAPI(...
this.Model, messagesStruct, ...
Temperature=this.Temperature, ...
TopP=this.TopP, TopK=this.TopK,...
TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,...
TailFreeSamplingZ=this.TailFreeSamplingZ,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
Expand Down
2 changes: 1 addition & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.ModelName (1,1) string {mustBeModel} = "gpt-4o-mini"
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.TopP {llms.utils.mustBeValidProbability} = 1
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
nvp.APIKey {mustBeNonzeroLengthTextScalar}
Expand Down
34 changes: 34 additions & 0 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ function extremeTopK(testCase)
testCase.verifyEqual(response1,response2);
end

function extremeMinP(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
%% honor it correctly.
testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");

% setting min-p to p=1 means only tokens with the same logit as
% the most likely one can be chosen, which will almost certainly
% only ever be one, so we expect to get a fixed response.
chat = ollamaChat("mistral",MinP=1);
prompt = "Min-p sampling with p=1 returns a definite answer.";
response1 = generate(chat,prompt);
response2 = generate(chat,prompt);
testCase.verifyEqual(response1,response2);
end

function extremeTfsZ(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
Expand Down Expand Up @@ -235,6 +251,16 @@ function queryModels(testCase)
"Value", -20, ...
"Error", "MATLAB:expectedNonnegative"), ...
...
"MinPTooLarge", struct( ...
"Property", "MinP", ...
"Value", 20, ...
"Error", "MATLAB:notLessEqual"), ...
...
"MinPTooSmall", struct( ...
"Property", "MinP", ...
"Value", -20, ...
"Error", "MATLAB:expectedNonnegative"), ...
...
"WrongTypeStopSequences", struct( ...
"Property", "StopSequences", ...
"Value", 123, ...
Expand Down Expand Up @@ -329,6 +355,14 @@ function queryModels(testCase)
"Input",{{ "TopP" -20 }},...
"Error","MATLAB:expectedNonnegative"),...I
...
"MinPTooLarge",struct( ...
"Input",{{ "MinP" 20 }},...
"Error","MATLAB:notLessEqual"),...
...
"MinPTooSmall",struct( ...
"Input",{{ "MinP" -20 }},...
"Error","MATLAB:expectedNonnegative"),...I
...
"WrongTypeStopSequences",struct( ...
"Input",{{ "StopSequences" 123}},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
Expand Down

0 comments on commit d127953

Please sign in to comment.