Skip to content

Commit

Permalink
Implement MinP for ollamaChat
Browse files Browse the repository at this point in the history
  • Loading branch information
ccreutzi committed Aug 14, 2024
1 parent 150d9c1 commit a7e6170
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
2 changes: 2 additions & 0 deletions +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
messages
nvp.Temperature
nvp.TopP
nvp.MinP
nvp.TopK
nvp.TailFreeSamplingZ
nvp.StopSequences
Expand Down Expand Up @@ -103,6 +104,7 @@
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("MinP") = "min_p";
dict("TopK") = "top_k";
dict("TailFreeSamplingZ") = "tfs_z";
dict("StopSequences") = "stop";
Expand Down
11 changes: 10 additions & 1 deletion ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
% words can appear in any particular place.
% This is also known as top-p sampling.
%
% MinP - Minimum probability ratio for controlling the
% diversity of the output. Default value is 0;
% higher values imply that only the more likely
% words can appear in any particular place.
% This is also known as min-p sampling.
%
% TopK - Maximum number of most likely tokens that are
% considered for output. Default is Inf, allowing
% all tokens. Smaller values reduce diversity in
Expand Down Expand Up @@ -67,6 +73,7 @@
Model (1,1) string
Endpoint (1,1) string
TopK (1,1) {mustBeReal,mustBePositive} = Inf
MinP (1,1) {llms.utils.mustBeValidTopP} = 0
TailFreeSamplingZ (1,1) {mustBeReal} = 1
end

Expand All @@ -77,6 +84,7 @@
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Temperature {llms.utils.mustBeValidTemperature} = 1
nvp.TopP {llms.utils.mustBeValidTopP} = 1
nvp.MinP {llms.utils.mustBeValidTopP} = 0
nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf
nvp.StopSequences {llms.utils.mustBeValidStop} = {}
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
Expand All @@ -103,6 +111,7 @@
this.ResponseFormat = nvp.ResponseFormat;
this.Temperature = nvp.Temperature;
this.TopP = nvp.TopP;
this.MinP = nvp.MinP;
this.TopK = nvp.TopK;
this.TailFreeSamplingZ = nvp.TailFreeSamplingZ;
this.StopSequences = nvp.StopSequences;
Expand Down Expand Up @@ -146,7 +155,7 @@
[text, message, response] = llms.internal.callOllamaChatAPI(...
this.Model, messagesStruct, ...
Temperature=this.Temperature, ...
TopP=this.TopP, TopK=this.TopK,...
TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,...
TailFreeSamplingZ=this.TailFreeSamplingZ,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
Expand Down
38 changes: 36 additions & 2 deletions tests/tollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function extremeTopK(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
%% honor it correctly.
testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");
% testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");

% setting top-k to k=1 leaves no random choice,
% so we expect to get a fixed response.
Expand All @@ -61,11 +61,27 @@ function extremeTopK(testCase)
testCase.verifyEqual(response1,response2);
end

function extremeMinP(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
%% honor it correctly.
% testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");

% setting min-p to p=1 means only tokens with the same logit as
% the most likely one can be chosen, which will almost certainly
% only ever be one, so we expect to get a fixed response.
chat = ollamaChat("mistral",MinP=1);
prompt = "Min-p sampling with p=1 returns a definite answer.";
response1 = generate(chat,prompt);
response2 = generate(chat,prompt);
testCase.verifyEqual(response1,response2);
end

function extremeTfsZ(testCase)
%% This should work, and it does on some computers. On others, Ollama
%% receives the parameter, but either Ollama or llama.cpp fails to
%% honor it correctly.
testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");
% testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably");

% setting tfs_z to z=0 leaves no random choice, but degrades to
% greedy sampling, so we expect to get a fixed response.
Expand Down Expand Up @@ -235,6 +251,16 @@ function queryModels(testCase)
"Value", -20, ...
"Error", "MATLAB:expectedNonnegative"), ...
...
"MinPTooLarge", struct( ...
"Property", "MinP", ...
"Value", 20, ...
"Error", "MATLAB:notLessEqual"), ...
...
"MinPTooSmall", struct( ...
"Property", "MinP", ...
"Value", -20, ...
"Error", "MATLAB:expectedNonnegative"), ...
...
"WrongTypeStopSequences", struct( ...
"Property", "StopSequences", ...
"Value", 123, ...
Expand Down Expand Up @@ -329,6 +355,14 @@ function queryModels(testCase)
"Input",{{ "TopP" -20 }},...
"Error","MATLAB:expectedNonnegative"),...I
...
"MinPTooLarge",struct( ...
"Input",{{ "MinP" 20 }},...
"Error","MATLAB:notLessEqual"),...
...
"MinPTooSmall",struct( ...
"Input",{{ "MinP" -20 }},...
"Error","MATLAB:expectedNonnegative"),...I
...
"WrongTypeStopSequences",struct( ...
"Input",{{ "StopSequences" 123}},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
Expand Down

0 comments on commit a7e6170

Please sign in to comment.