Skip to content

Commit

Permalink
Adding new features and examples
Browse files Browse the repository at this point in the history
Adding support and examples for
JSON mode
Parallel Function Calling
GPT4 with Vision
DALL·E
  • Loading branch information
toshiakit committed Jan 5, 2024
1 parent a4baf34 commit 3b83a23
Show file tree
Hide file tree
Showing 13 changed files with 497 additions and 29 deletions.
32 changes: 25 additions & 7 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% Currently, the supported NVP are, including the equivalent name in the API:
% - FunctionCall (function_call)
% - ToolChoice (tool_choice)
% - ModelName (model)
% - Temperature (temperature)
% - TopProbabilityMass (top_p)
Expand All @@ -17,6 +17,8 @@
% - MaxNumTokens (max_tokens)
% - PresencePenalty (presence_penalty)
% - FrequencyPenalty (frequence_penalty)
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - TimeOut
% - StreamFun
Expand Down Expand Up @@ -55,7 +57,7 @@
arguments
messages
functions
nvp.FunctionCall = []
nvp.ToolChoice = []
nvp.ModelName = "gpt-3.5-turbo"
nvp.Temperature = 1
nvp.TopProbabilityMass = 1
Expand All @@ -64,6 +66,8 @@
nvp.MaxNumTokens = inf
nvp.PresencePenalty = 0
nvp.FrequencyPenalty = 0
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
Expand All @@ -85,7 +89,7 @@
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "function_call")
if isfield(message, "tool_choice")
text = "";
else
text = string(message.content);
Expand All @@ -105,19 +109,33 @@

parameters.stream = ~isempty(nvp.StreamFun);

if ~isempty(functions)
parameters.functions = functions;
if ~isempty(functions) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
parameters.tools = functions;
end

if ~isempty(nvp.FunctionCall)
parameters.function_call = nvp.FunctionCall;
if ~isempty(nvp.ToolChoice) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
parameters.tool_choice = nvp.ToolChoice;
end

if ismember(nvp.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
if strcmp(nvp.ResponseFormat,"json")
parameters.response_format = struct('type','json_object');
end
end

if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
end

parameters.model = nvp.ModelName;

dict = mapNVPToParameters;

nvpOptions = keys(dict);
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
end

for i=1:length(nvpOptions)
if isfield(nvp, nvpOptions(i))
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
Expand Down
28 changes: 28 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
*.fig binary
*.mat binary
*.mdl binary diff merge=mlAutoMerge
*.mdlp binary
*.mexa64 binary
*.mexw64 binary
*.mexmaci64 binary
*.mlapp binary
*.mldatx binary
*.mlproj binary
*.mlx binary
*.p binary
*.sfx binary
*.sldd binary
*.slreqx binary merge=mlAutoMerge
*.slmx binary merge=mlAutoMerge
*.sltx binary
*.slxc binary
*.slx binary merge=mlAutoMerge
*.slxp binary

## Other common binary file types
*.docx binary
*.exe binary
*.jpg binary
*.pdf binary
*.png binary
*.xlsx binary
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.env
*.asv
Binary file added examples/Example_DALL·E.mlx
Binary file not shown.
Binary file added examples/Example_GPT4_Vision.mlx
Binary file not shown.
Binary file added examples/Example_JSON_Mode.mlx
Binary file not shown.
Binary file added examples/Example_Parallel_Function_Calls.mlx
Binary file not shown.
Binary file added examples/Example_Streaming.mlx
Binary file not shown.
Binary file added examples/images/bear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/bear_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 72 additions & 22 deletions openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
% CHAT = openAIChat(systemPrompt,Name=Value) specifies additional options
% using one or more name-value arguments:
%
% Functions - Array of openAIFunction objects representing
% Tools - Array of openAIFunction objects representing
% custom functions to be used during chat completions.
%
% ModelName - Name of the model to use for chat completions.
Expand All @@ -28,12 +28,15 @@
% PresencePenalty - Penalty value for using a token in the response
% that has already been used. Default value is 0.
%
% FrequencyPenalty - Penalty value for using a token that is frequent
% FrequencyPenalty - Penalty value for using a token that is frequent
% in the training data. Default value is 0.
%
% StreamFun - Function to callback when streaming the
% result
%
% ResponseFormat - The format of response the model returns.
% Default is text, or json.
%
% openAIChat Functions:
% openAIChat - Chat completion API from OpenAI.
% generate - Generate a response using the openAIChat instance.
Expand All @@ -58,6 +61,8 @@
% FunctionNames - Names of the functions that the model can
% request calls.
%
% ResponseFormat - Specifies the response format, text or json
%
% TimeOut - Connection Timeout in seconds (default: 10 secs)
%

Expand All @@ -78,6 +83,10 @@

%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
FrequencyPenalty

%RESPONSEFORMAT Response format, text or json
ResponseFormat

end

properties(SetAccess=private)
Expand All @@ -95,7 +104,7 @@
end

properties(Access=private)
Functions
Tools
FunctionsStruct
ApiKey
StreamFun
Expand All @@ -105,14 +114,16 @@
function this = openAIChat(systemPrompt, nvp)
arguments
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Functions (1,:) {mustBeA(nvp.Functions, "openAIFunction")} = openAIFunction.empty
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ...
"gpt-3.5-turbo", "gpt-3.5-turbo-16k",...
"gpt-4-1106-preview","gpt-3.5-turbo-1106"])} = "gpt-3.5-turbo"
"gpt-4-1106-preview","gpt-3.5-turbo-1106", ...
"gpt-4-vision-preview"])} = "gpt-3.5-turbo"
nvp.Temperature {mustBeValidTemperature} = 1
nvp.TopProbabilityMass {mustBeValidTopP} = 1
nvp.StopSequences {mustBeValidStop} = {}
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
nvp.ResponseFormat (1,1) {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text"
nvp.ApiKey {mustBeNonzeroLengthTextScalar}
nvp.PresencePenalty {mustBeValidPenalty} = 0
nvp.FrequencyPenalty {mustBeValidPenalty} = 0
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
Expand All @@ -125,11 +136,11 @@
this.StreamFun = [];
end

if ~isempty(nvp.Functions)
this.Functions = nvp.Functions;
[this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Functions);
if ~isempty(nvp.Tools)
this.Tools = nvp.Tools;
[this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools);
else
this.Functions = [];
this.Tools = [];
this.FunctionsStruct = [];
this.FunctionNames = [];
end
Expand All @@ -145,6 +156,18 @@
this.Temperature = nvp.Temperature;
this.TopProbabilityMass = nvp.TopProbabilityMass;
this.StopSequences = nvp.StopSequences;
% Response Format is supported in the latest models only
if strcmp(nvp.ResponseFormat,"json")
if ismember(this.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
if contains(this.SystemPrompt{1}.content,"designed to output to JSON")
this.ResponseFormat = nvp.ResponseFormat;
else
error("To get JSON output, add 'designed to output to JSON' to the system prompt.")
end
else
mustBeMember(this.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
end
end
this.PresencePenalty = nvp.PresencePenalty;
this.FrequencyPenalty = nvp.FrequencyPenalty;
this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp);
Expand All @@ -166,18 +189,26 @@
% MaxNumTokens - Maximum number of tokens in the generated response.
% Default value is inf.
%
% FunctionCall - Function call to execute before generating the
% response, specified as a string array. Default value is empty.

% ToolChoice - Function to execute. 'none', 'auto',
% or specify the function to call.
%
% Seed - An integer value to use to obtain
% reproducible responses
%
% Currently, GPT-4 Turbo with vision does not support the message.name
% parameter, functions/tools, response_format parameter, stop
% sequences, and max_tokens

arguments
this (1,1) openAIChat
messages (1,1) {mustBeValidMsgs}
nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1
nvp.MaxNumTokens (1,1) {mustBePositive} = inf
nvp.FunctionCall {mustBeValidFunctionCall(this, nvp.FunctionCall)} = []
nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = []
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
end

functionCall = convertFunctionCall(this, nvp.FunctionCall);
toolChoice = convertToolChoice(this, nvp.ToolChoice);
if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
Expand All @@ -189,10 +220,11 @@
end

[text, message, response] = llms.internal.callOpenAIChatAPI(messagesStruct, this.FunctionsStruct,...
ModelName=this.ModelName, FunctionCall=functionCall, Temperature=this.Temperature, ...
ModelName=this.ModelName, ToolChoice=toolChoice, Temperature=this.Temperature, ...
TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ...
ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ...
ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun);
end

Expand Down Expand Up @@ -254,24 +286,36 @@ function mustBeValidFunctionCall(this, functionCall)
end
end

function functionCall = convertFunctionCall(this, functionCall)
% If functionCall is not empty, then it must be in
% the format {"name", functionCall}
if ~isempty(functionCall)&&ismember(functionCall, this.FunctionNames)
functionCall = struct("name", functionCall);
function toolChoice = convertToolChoice(this, toolChoice)
% if toolChoice is empty
if isempty(toolChoice)
% if Tools is not empty, the default is 'auto'.
if ~isempty(this.Tools)
toolChoice = "auto";
end
else
% if toolChoice is not empty, then it must be in the format
% {"type": "function", "function": {"name": "my_function"}}
toolChoice = struct("type","function","function",struct("name",toolChoice));
end

end
end
end

function mustBeNonzeroLengthTextScalar(content)
mustBeNonzeroLengthText(content)
mustBeTextScalar(content)
end

function [functionsStruct, functionNames] = functionAsStruct(functions)
numFunctions = numel(functions);
functionsStruct = cell(1, numFunctions);
functionNames = strings(1, numFunctions);

for i = 1:numFunctions
functionsStruct{i} = encodeStruct(functions(i));
functionsStruct{i} = struct('type','function', ...
'function',encodeStruct(functions(i))) ;
functionNames(i) = functions(i).FunctionName;
end
end
Expand Down Expand Up @@ -311,4 +355,10 @@ function mustBeValidStop(value)
error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements"));
end
end
end

function mustBeIntegerOrEmpty(value)
if ~isempty(value)
mustBeInteger(value)
end
end
37 changes: 37 additions & 0 deletions openAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
% openAIMessages functions:
% addSystemMessage - Add system message.
% addUserMessage - Add user message.
% addUserMessageWithImages - Add user message with images for
% GPT-4 Vision.
% addFunctionMessage - Add a function message.
% addResponseMessage - Add a response message.
% removeMessage - Remove message from history.
Expand Down Expand Up @@ -69,6 +71,41 @@
this.Messages{end+1} = newMessage;
end

function this = addUserMessageWithImages(this, prompt, images)
%addUserMessageWithImages Add user message with images.

arguments
this (1,1) openAIMessages
prompt {mustBeNonzeroLengthTextScalar}
images (1,:) cell {mustBeNonempty}
end

newMessage = struct("role", "user", "content", []);
newMessage.content = {struct("type","text","text",string(prompt))};
for ii = 1:numel(images)
if startsWith(images{ii},("https://"|"http://"))
s = struct( ...
"type","image_url", ...
"image_url",struct("url",images{ii}));
newMessage.content{end+1} = s;
else
[~,~,ext] = fileparts(images{ii});
MIMEType = "data:image/" + erase(ext,".") + ";base64,";
% Base64 encode the image using the given MIME type
fid = fopen(images{ii});
V = fread(fid,'*uint8');
fclose(fid);
b64 = matlab.net.base64encode(V);
s = struct( ...
"type","image_url", ...
"image_url",struct("url",MIMEType + b64));
newMessage.content{end+1} = s;
end
this.Messages{end+1} = newMessage;
end

end

function this = addFunctionMessage(this, name, content)
%addFunctionMessage Add function message.
%
Expand Down
Loading

0 comments on commit 3b83a23

Please sign in to comment.