Skip to content

Commit

Permalink
Let generate temporarily override model settings (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccreutzi authored and GitHub Enterprise committed Sep 2, 2024
1 parent 9bc71a5 commit bfef6e3
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 35 deletions.
65 changes: 59 additions & 6 deletions azureChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,59 @@
% Seed - An integer value to use to obtain
% reproducible responses
%
% Temperature - Temperature value for controlling the randomness
% of the output. The default value is CHAT.Temperature;
% higher values increase the randomness (in some sense,
% the “creativity”) of outputs, lower values
% reduce it. Setting Temperature=0 removes
% randomness from the output altogether.
%
% TopP - Top probability mass value for controlling the
% diversity of the output. Default value is CHAT.TopP;
% lower values imply that only the more likely
% words can appear in any particular place.
% This is also known as top-p sampling.
%
% StopSequences - Vector of strings that when encountered, will
% stop the generation of tokens. Default
% value is CHAT.StopSequences.
% Example: ["The end.", "And that's all she wrote."]
%
% ResponseFormat - The format of response the model returns.
% Default value is CHAT.ResponseFormat.
% "text" | "json"
%
% PresencePenalty - Penalty value for using a token in the response
% that has already been used. Default value is
% CHAT.PresencePenalty.
% Higher values reduce repetition of words in the output.
%
% FrequencyPenalty - Penalty value for using a token that is frequent
% in the output. Default value is CHAT.FrequencyPenalty.
% Higher values reduce repetition of words in the output.
%
% StreamFun - Function to callback when streaming the result.
% Default value is CHAT.StreamFun.
%
% TimeOut - Connection Timeout in seconds. Default value is CHAT.TimeOut.
%
%
% Currently, GPT-4 Turbo with vision does not support the message.name
% parameter, functions/tools, response_format parameter, stop
% sequences, and max_tokens

arguments
this (1,1) azureChat
messages {mustBeValidMsgs}
nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature
nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP
nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat
nvp.APIKey {mustBeNonzeroLengthTextScalar} = this.APIKey
nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = this.PresencePenalty
nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = this.FrequencyPenalty
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
Expand All @@ -199,15 +245,22 @@
end

toolChoice = convertToolChoice(this, nvp.ToolChoice);

if isfield(nvp,"StreamFun")
streamFun = nvp.StreamFun;
else
streamFun = this.StreamFun;
end

try
[text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ...
this.DeploymentID, messagesStruct, this.FunctionsStruct, ...
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ...
TopP=this.TopP, NumCompletions=nvp.NumCompletions,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
APIKey=this.APIKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun);
ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=nvp.Temperature, ...
TopP=nvp.TopP, NumCompletions=nvp.NumCompletions,...
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
PresencePenalty=nvp.PresencePenalty, FrequencyPenalty=nvp.FrequencyPenalty, ...
ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ...
APIKey=nvp.APIKey,TimeOut=nvp.TimeOut,StreamFun=streamFun);
catch ME
if ismember(ME.identifier,...
["MATLAB:webservices:UnknownHost","MATLAB:webservices:Timeout"])
Expand Down
40 changes: 36 additions & 4 deletions functionSignatures.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,17 @@
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"},
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
{"name":"ModelName","kind":"namevalue","type":"choices=llms.openai.models"},
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
],
"outputs":
[
Expand Down Expand Up @@ -73,7 +83,16 @@
{"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]},
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
{"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"},
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
{"name":"APIKey","kind":"namevalue","type":["string","scalar"]},
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
{"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
{"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]},
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
],
"outputs":
[
Expand All @@ -90,12 +109,14 @@
{"name":"systemPrompt","kind":"ordered","type":["string","scalar"]},
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
{"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
{"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]},
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
{"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]},
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
{"name":"StreamFun","kind":"namevalue","type":"function_handle"}
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]}
],
"outputs":
[
Expand All @@ -109,7 +130,18 @@
{"name":"this","kind":"required","type":["ollamaChat","scalar"]},
{"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]},
{"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]},
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]}
{"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]},
{"name":"Model","kind":"namevalue","type":"choices=ollamaChat.models"},
{"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]},
{"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
{"name":"MinP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]},
{"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]},
{"name":"StopSequences","kind":"namevalue","type":["string","vector"]},
{"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"},
{"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]},
{"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]},
{"name":"StreamFun","kind":"namevalue","type":"function_handle"},
{"name":"Endpoint","kind":"namevalue","type":["string","scalar"]}
],
"outputs":
[
Expand Down
97 changes: 82 additions & 15 deletions ollamaChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
% value is empty.
% Example: ["The end.", "And that's all she wrote."]
%
%
% ResponseFormat - The format of response the model returns.
% "text" (default) | "json"
%
Expand Down Expand Up @@ -128,17 +127,79 @@
% [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options
% using one or more name-value arguments:
%
% MaxNumTokens - Maximum number of tokens in the generated response.
% Default value is inf.
% MaxNumTokens - Maximum number of tokens in the generated response.
% Default value is inf.
%
% Seed - An integer value to use to obtain
% reproducible responses
%
% Model - Model name (as expected by Ollama server).
% Default value is CHAT.Model.
%
% Temperature - Temperature value for controlling the randomness
% of the output. Default value is CHAT.Temperature.
% Higher values increase the randomness (in some
% sense, the “creativity”) of outputs, lower
% values reduce it. Setting Temperature=0 removes
% randomness from the output altogether.
%
% TopP - Top probability mass value for controlling the
% diversity of the output. Default value is CHAT.TopP;
% lower values imply that only the more likely
% words can appear in any particular place.
% This is also known as top-p sampling.
%
% MinP - Minimum probability ratio for controlling the
% diversity of the output. Default value is CHAT.MinP;
% higher values imply that only the more likely
% words can appear in any particular place.
% This is also known as min-p sampling.
%
% TopK - Maximum number of most likely tokens that are
% considered for output. Default is CHAT.TopK.
% Smaller values reduce diversity in the output.
%
% TailFreeSamplingZ - Reduce the use of less probable tokens, based on
% the second-order differences of ordered
% probabilities.
% Default value is CHAT.TailFreeSamplingZ.
% Lower values reduce diversity, with
% some authors recommending values
% around 0.95. Tail-free sampling is
% slower than using TopP or TopK.
%
% StopSequences - Vector of strings that when encountered, will
% stop the generation of tokens. Default
% value is CHAT.StopSequences.
% Example: ["The end.", "And that's all she wrote."]
%
%
% ResponseFormat - The format of response the model returns.
% The default value is CHAT.ResponseFormat.
% "text" (default) | "json"
%
% StreamFun - Function to callback when streaming the
% result. The default value is CHAT.StreamFun.
%
% TimeOut - Connection Timeout in seconds. Default is CHAT.TimeOut.
%
% Seed - An integer value to use to obtain
% reproducible responses

arguments
this (1,1) ollamaChat
messages {mustBeValidMsgs}
messages {mustBeValidMsgs}
nvp.Model {mustBeTextScalar} = this.Model
nvp.Temperature {llms.utils.mustBeValidTemperature} = this.Temperature
nvp.TopP {llms.utils.mustBeValidProbability} = this.TopP
nvp.MinP {llms.utils.mustBeValidProbability} = this.MinP
nvp.TopK (1,1) {mustBeReal,mustBePositive} = this.TopK
nvp.StopSequences {llms.utils.mustBeValidStop} = this.StopSequences
nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = this.ResponseFormat
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = this.TimeOut
nvp.TailFreeSamplingZ (1,1) {mustBeReal} = this.TailFreeSamplingZ
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
nvp.Endpoint (1,1) string = this.Endpoint
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
end

messages = convertCharsToStrings(messages);
Expand All @@ -152,15 +213,21 @@
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
end

if isfield(nvp,"StreamFun")
streamFun = nvp.StreamFun;
else
streamFun = this.StreamFun;
end

[text, message, response] = llms.internal.callOllamaChatAPI(...
this.Model, messagesStruct, ...
Temperature=this.Temperature, ...
TopP=this.TopP, MinP=this.MinP, TopK=this.TopK,...
TailFreeSamplingZ=this.TailFreeSamplingZ,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
TimeOut=this.TimeOut, StreamFun=this.StreamFun, ...
Endpoint=this.Endpoint);
nvp.Model, messagesStruct, ...
Temperature=nvp.Temperature, ...
TopP=nvp.TopP, MinP=nvp.MinP, TopK=nvp.TopK,...
TailFreeSamplingZ=nvp.TailFreeSamplingZ,...
StopSequences=nvp.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
ResponseFormat=nvp.ResponseFormat,Seed=nvp.Seed, ...
TimeOut=nvp.TimeOut, StreamFun=streamFun, ...
Endpoint=nvp.Endpoint);

if isfield(response.Body.Data,"error")
err = response.Body.Data.error;
Expand Down
Loading

0 comments on commit bfef6e3

Please sign in to comment.