Skip to content

Commit

Permalink
Addressing comments by ccreutzi
Browse files Browse the repository at this point in the history
  • Loading branch information
toshiakit committed Jan 26, 2024
1 parent b3de243 commit 3ec2414
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 11 deletions.
6 changes: 3 additions & 3 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
end

for i=1:length(nvpOptions)
if isfield(nvp, nvpOptions(i))
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
for opt = nvpOptions.'
if isfield(nvp, opt)
parameters.(dict(opt)) = nvp.(opt);
end
end
end
Expand Down
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Large Language Models (LLMs) with MATLAB® [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/llms-with-matlab)

This repository contains example code to demonstrate how to connect MATLAB to the OpenAI™ Chat Completions API (which powers ChatGPT™) as well as OpenAI Images API (which powers DALL-E™). This allows you to leverage the natural language processing capabilities of large language models directly within your MATLAB environment.
This repository contains example code to demonstrate how to connect MATLAB to the OpenAI™ Chat Completions API (which powers ChatGPT™) as well as OpenAI Images API (which powers DALL·E™). This allows you to leverage the natural language processing capabilities of large language models directly within your MATLAB environment.

The functionality shown here serves as an interface to the ChatGPT and DALL-E APIs. To start using the OpenAI APIs, you first need to obtain the OpenAI API keys. You are responsible for any fees OpenAI may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the OpenAI APIs.
The functionality shown here serves as an interface to the ChatGPT and DALL·E APIs. To start using the OpenAI APIs, you first need to obtain OpenAI API keys. You are responsible for any fees OpenAI may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the OpenAI APIs.

Some of the current LLMs supported are:
- gpt-3.5-turbo, gpt-3.5-turbo-1106
Expand Down Expand Up @@ -127,6 +127,7 @@ You can specifying the streaming function when you create the chat assistant. Th
sf = @(x)fprintf("%s",x);
chat = openAIChat(StreamFun=sf);
txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?")
% Should stream the response token by token
```
### Calling MATLAB functions with the API
Expand Down Expand Up @@ -280,6 +281,7 @@ image_path = "peppers.png";
messages = openAIMessages;
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
[txt,response] = generate(chat,messages);
% Should output the description of the image
```

### Obtaining embeddings
Expand Down Expand Up @@ -308,6 +310,7 @@ mdl = openAIImages(ModelName="dall-e-3");
images = generate(mdl,"Create a 3D avatar of a whimsical sushi on the beach. He is decorated with various sushi elements and is playfully interacting with the beach environment.");
figure
imshow(images{1})
% Should output an image based on the prompt
```

## Examples
Expand Down
23 changes: 23 additions & 0 deletions openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,21 @@

if isfield(nvp,"StreamFun")
this.StreamFun = nvp.StreamFun;
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StreamFun", nvp.ModelName));
end
else
this.StreamFun = [];
end

if ~isempty(nvp.Tools)
this.Tools = nvp.Tools;
[this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools);
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Tools", nvp.ModelName));
end
else
this.Tools = [];
this.FunctionsStruct = [];
Expand All @@ -155,6 +163,11 @@
this.Temperature = nvp.Temperature;
this.TopProbabilityMass = nvp.TopProbabilityMass;
this.StopSequences = nvp.StopSequences;
if ~isempty(nvp.StopSequences) && strcmp(nvp.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StopSequences", nvp.ModelName));
end


% ResponseFormat is only supported in the latest models only
if (nvp.ResponseFormat == "json")
Expand Down Expand Up @@ -208,7 +221,17 @@
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
end

if nvp.MaxNumTokens ~= Inf && strcmp(this.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "MaxNumTokens", this.ModelName));
end

toolChoice = convertToolChoice(this, nvp.ToolChoice);
if ~isempty(nvp.ToolChoice) && strcmp(this.ModelName,'gpt-4-vision-preview')
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "ToolChoice", this.ModelName));
end

if isstring(messages) && isscalar(messages)
messagesStruct = {struct("role", "user", "content", messages)};
else
Expand Down
2 changes: 1 addition & 1 deletion openAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ function validateRegularAssistant(content)
end

function validateAssistantWithToolCalls(toolCallStruct)
if ~isstruct(toolCallStruct)||~isfield(toolCallStruct, "id")||~isfield(toolCallStruct, "function")
if ~(isstruct(toolCallStruct) && isfield(toolCallStruct, "id") && isfield(toolCallStruct, "function"))
error("llms:mustBeAssistantWithIdAndFunction", ...
llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithIdAndFunction"))
else
Expand Down
4 changes: 2 additions & 2 deletions tests/topenAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ function saveEnvVar(testCase)
openAIEnvVar = "OPENAI_API_KEY";
if isenv(openAIEnvVar)
key = getenv(openAIEnvVar);
testCase.addTeardown(@() setenv(openAIEnvVar, key));
unsetenv(openAIEnvVar);
testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key);
end
end
end
Expand Down Expand Up @@ -223,7 +223,7 @@ function invalidInputsVariation(testCase, InvalidVariationInput)

function invalidEditInput = iGetInvalidEditInput
validImage = string(which("peppers.png"));
nonPNGImage = which("corn.tif");
nonPNGImage = string(which("corn.tif"));
invalidEditInput = struct( ...
"EmptyImage",struct( ...
"Input",{{ [], "prompt" }},...
Expand Down
5 changes: 2 additions & 3 deletions tests/topenAIMessages.m
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ function assistantToolCallMessageWithoutArgsIsAdded(testCase)
functionName = "functionName";
funCall = struct("name", functionName, "arguments", "{}");
toolCall = struct("id", "123", "type", "function", "function", funCall);
toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", []);
% tool_calls is an array of struct in API response
toolCallPrompt.tool_calls = toolCall;
toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall);
msgs = addResponseMessage(msgs, toolCallPrompt);
% to include in msgs, tool_calls must be a cell
testCase.verifyEqual(fieldnames(msgs.Messages{1}), fieldnames(toolCallPrompt));
Expand All @@ -112,8 +111,8 @@ function assistantParallelToolCallMessageIsAdded(testCase)
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
funCall = struct("name", functionName, "arguments", args);
toolCall = struct("id", "123", "type", "function", "function", funCall);
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []);
% tool_calls is an array of struct in API response
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", toolCall);
toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall];
msgs = addResponseMessage(msgs, toolCallPrompt);
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);
Expand Down

0 comments on commit 3ec2414

Please sign in to comment.