diff --git a/+llms/+azure/apiVersions.m b/+llms/+azure/apiVersions.m new file mode 100644 index 0000000..b32039a --- /dev/null +++ b/+llms/+azure/apiVersions.m @@ -0,0 +1,12 @@ +function versions = apiVersions +%VERSIONS - supported azure API versions + +% Copyright 2024 The MathWorks, Inc. + versions = [... + "2024-05-01-preview", ... + "2024-04-01-preview", ... + "2024-03-01-preview", ... + "2024-02-01", ... + "2023-05-15", ... + ]; +end diff --git a/+llms/+internal/callAzureChatAPI.m b/+llms/+internal/callAzureChatAPI.m new file mode 100644 index 0000000..bb73053 --- /dev/null +++ b/+llms/+internal/callAzureChatAPI.m @@ -0,0 +1,130 @@ +function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp) +% This function is undocumented and will change in a future release + +%callAzureChatAPI Calls the openAI chat completions API on Azure. +% +% MESSAGES and FUNCTIONS should be structs matching the json format +% required by the OpenAI Chat Completions API. +% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api +% +% More details on the parameters: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt +% +% Example +% +% % Create messages struct +% messages = {struct("role", "system",... +% "content", "You are a helpful assistant"); +% struct("role", "user", ... +% "content", "What is the edit distance between hi and hello?")}; +% +% % Create functions struct +% functions = {struct("name", "editDistance", ... +% "description", "Find edit distance between two strings or documents.", ... +% "parameters", struct( ... +% "type", "object", ... +% "properties", struct(... +% "str1", struct(... +% "description", "Source string.", ... +% "type", "string"),... +% "str2", struct(... +% "description", "Target string.", ... +% "type", "string")),... +% "required", ["str1", "str2"]))}; +% +% % Define your API key +% apiKey = "your-api-key-here" +% +% % Send a request +% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey) + +% Copyright 2023-2024 The MathWorks, Inc. + +arguments + endpoint + deploymentID + messages + functions + nvp.ToolChoice + nvp.APIVersion + nvp.Temperature + nvp.TopP + nvp.NumCompletions + nvp.StopSequences + nvp.MaxNumTokens + nvp.PresencePenalty + nvp.FrequencyPenalty + nvp.ResponseFormat + nvp.Seed + nvp.APIKey + nvp.TimeOut + nvp.StreamFun +end + +URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion; + +parameters = buildParametersCall(messages, functions, nvp); + +[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun); + +% If call errors, "choices" will not be part of response.Body.Data, instead +% we get response.Body.Data.error +if response.StatusCode=="OK" + % Outputs the first generation + if isempty(nvp.StreamFun) + message = response.Body.Data.choices(1).message; + else + message = struct("role", "assistant", ... + "content", streamedText); + end + if isfield(message, "tool_choice") + text = ""; + else + text = string(message.content); + end +else + text = ""; + message = struct(); +end +end + +function parameters = buildParametersCall(messages, functions, nvp) +% Builds a struct in the format that is expected by the API, combining +% MESSAGES, FUNCTIONS and parameters in NVP. + +parameters = struct(); +parameters.messages = messages; + +parameters.stream = ~isempty(nvp.StreamFun); + +if ~isempty(functions) + parameters.tools = functions; +end + +if ~isempty(nvp.ToolChoice) + parameters.tool_choice = nvp.ToolChoice; +end + +if ~isempty(nvp.Seed) + parameters.seed = nvp.Seed; +end + +dict = mapNVPToParameters; + +nvpOptions = keys(dict); +for opt = nvpOptions.' + if isfield(nvp, opt) + parameters.(dict(opt)) = nvp.(opt); + end +end +end + +function dict = mapNVPToParameters() +dict = dictionary(); +dict("Temperature") = "temperature"; +dict("TopP") = "top_p"; +dict("NumCompletions") = "n"; +dict("StopSequences") = "stop"; +dict("MaxNumTokens") = "max_tokens"; +dict("PresencePenalty") = "presence_penalty"; +dict("FrequencyPenalty") = "frequency_penalty"; +end \ No newline at end of file diff --git a/+llms/+internal/callOllamaChatAPI.m b/+llms/+internal/callOllamaChatAPI.m new file mode 100644 index 0000000..a7e6436 --- /dev/null +++ b/+llms/+internal/callOllamaChatAPI.m @@ -0,0 +1,106 @@ +function [text, message, response] = callOllamaChatAPI(model, messages, nvp) +% This function is undocumented and will change in a future release + +%callOllamaChatAPI Calls the Ollama® chat completions API. +% +% MESSAGES and FUNCTIONS should be structs matching the json format +% required by the Ollama Chat Completions API. +% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md +% +% More details on the parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values +% +% Example +% +% model = "mistral"; +% +% % Create messages struct +% messages = {struct("role", "system",... +% "content", "You are a helpful assistant"); +% struct("role", "user", ... +% "content", "What is the edit distance between hi and hello?")}; +% +% % Send a request +% [text, message] = llms.internal.callOllamaChatAPI(model, messages) + +% Copyright 2023-2024 The MathWorks, Inc. + +arguments + model + messages + nvp.Temperature + nvp.TopP + nvp.TopK + nvp.TailFreeSamplingZ + nvp.StopSequences + nvp.MaxNumTokens + nvp.ResponseFormat + nvp.Seed + nvp.TimeOut + nvp.StreamFun +end + +URL = "http://localhost:11434/api/chat"; + +% The JSON for StopSequences must have an array, and cannot say "stop": "foo". +% The easiest way to ensure that is to never pass in a scalar … +if isscalar(nvp.StopSequences) + nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences]; +end + +parameters = buildParametersCall(model, messages, nvp); + +[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun); + +% If call errors, "choices" will not be part of response.Body.Data, instead +% we get response.Body.Data.error +if response.StatusCode=="OK" + % Outputs the first generation + if isempty(nvp.StreamFun) + message = response.Body.Data.message; + else + message = struct("role", "assistant", ... + "content", streamedText); + end + text = string(message.content); +else + text = ""; + message = struct(); +end +end + +function parameters = buildParametersCall(model, messages, nvp) +% Builds a struct in the format that is expected by the API, combining +% MESSAGES, FUNCTIONS and parameters in NVP. + +parameters = struct(); +parameters.model = model; +parameters.messages = messages; + +parameters.stream = ~isempty(nvp.StreamFun); + +options = struct; +if ~isempty(nvp.Seed) + options.seed = nvp.Seed; +end + +dict = mapNVPToParameters; + +nvpOptions = keys(dict); +for opt = nvpOptions.' + if isfield(nvp, opt) && ~isempty(nvp.(opt)) && ~isequaln(nvp.(opt),Inf) + options.(dict(opt)) = nvp.(opt); + end +end + +parameters.options = options; +end + +function dict = mapNVPToParameters() +dict = dictionary(); +dict("Temperature") = "temperature"; +dict("TopP") = "top_p"; +dict("TopK") = "top_k"; +dict("TailFreeSamplingZ") = "tfs_z"; +dict("StopSequences") = "stop"; +dict("MaxNumTokens") = "num_predict"; +end diff --git a/+llms/+internal/callOpenAIChatAPI.m b/+llms/+internal/callOpenAIChatAPI.m index 6226653..8d58fd4 100644 --- a/+llms/+internal/callOpenAIChatAPI.m +++ b/+llms/+internal/callOpenAIChatAPI.m @@ -1,25 +1,12 @@ function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp) +% This function is undocumented and will change in a future release + %callOpenAIChatAPI Calls the openAI chat completions API. % % MESSAGES and FUNCTIONS should be structs matching the json format % required by the OpenAI Chat Completions API. % Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api % -% Currently, the supported NVP are, including the equivalent name in the API: -% - ToolChoice (tool_choice) -% - ModelName (model) -% - Temperature (temperature) -% - TopProbabilityMass (top_p) -% - NumCompletions (n) -% - StopSequences (stop) -% - MaxNumTokens (max_tokens) -% - PresencePenalty (presence_penalty) -% - FrequencyPenalty (frequence_penalty) -% - ResponseFormat (response_format) -% - Seed (seed) -% - ApiKey -% - TimeOut -% - StreamFun % More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create % % Example @@ -48,34 +35,34 @@ % apiKey = "your-api-key-here" % % % Send a request -% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey) +% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, APIKey=apiKey) % Copyright 2023-2024 The MathWorks, Inc. arguments messages functions - nvp.ToolChoice = [] - nvp.ModelName = "gpt-3.5-turbo" - nvp.Temperature = 1 - nvp.TopProbabilityMass = 1 - nvp.NumCompletions = 1 - nvp.StopSequences = [] - nvp.MaxNumTokens = inf - nvp.PresencePenalty = 0 - nvp.FrequencyPenalty = 0 - nvp.ResponseFormat = "text" - nvp.Seed = [] - nvp.ApiKey = "" - nvp.TimeOut = 10 - nvp.StreamFun = [] + nvp.ToolChoice + nvp.ModelName + nvp.Temperature + nvp.TopP + nvp.NumCompletions + nvp.StopSequences + nvp.MaxNumTokens + nvp.PresencePenalty + nvp.FrequencyPenalty + nvp.ResponseFormat + nvp.Seed + nvp.APIKey + nvp.TimeOut + nvp.StreamFun end END_POINT = "https://api.openai.com/v1/chat/completions"; parameters = buildParametersCall(messages, functions, nvp); -[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun); +[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, END_POINT, nvp.TimeOut, nvp.StreamFun); % If call errors, "choices" will not be part of response.Body.Data, instead % we get response.Body.Data.error @@ -160,7 +147,7 @@ function dict = mapNVPToParameters() dict = dictionary(); dict("Temperature") = "temperature"; -dict("TopProbabilityMass") = "top_p"; +dict("TopP") = "top_p"; dict("NumCompletions") = "n"; dict("StopSequences") = "stop"; dict("MaxNumTokens") = "max_tokens"; diff --git a/+llms/+internal/getApiKeyFromNvpOrEnv.m b/+llms/+internal/getApiKeyFromNvpOrEnv.m index 1b55a92..42cec81 100644 --- a/+llms/+internal/getApiKeyFromNvpOrEnv.m +++ b/+llms/+internal/getApiKeyFromNvpOrEnv.m @@ -1,23 +1,23 @@ -function key = getApiKeyFromNvpOrEnv(nvp) +function key = getApiKeyFromNvpOrEnv(nvp,envVarName) % This function is undocumented and will change in a future release %getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable. % -% This function takes a struct nvp containing name-value pairs and checks -% if it contains a field called "ApiKey". If the field is not found, -% the function attempts to retrieve the API key from an environment -% variable called "OPENAI_API_KEY". If both methods fail, the function -% throws an error. +% This function takes a struct nvp containing name-value pairs and checks if +% it contains a field called "APIKey". If the field is not found, the +% function attempts to retrieve the API key from an environment variable +% whose name is given as the second argument. If both methods fail, the +% function throws an error. -% Copyright 2023 The MathWorks, Inc. +% Copyright 2023-2024 The MathWorks, Inc. - if isfield(nvp, "ApiKey") - key = nvp.ApiKey; + if isfield(nvp, "APIKey") + key = nvp.APIKey; else - if isenv("OPENAI_API_KEY") - key = getenv("OPENAI_API_KEY"); + if isenv(envVarName) + key = getenv(envVarName); else - error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified")); + error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName)); end end -end \ No newline at end of file +end diff --git a/+llms/+internal/gptPenalties.m b/+llms/+internal/gptPenalties.m new file mode 100644 index 0000000..fc02cd8 --- /dev/null +++ b/+llms/+internal/gptPenalties.m @@ -0,0 +1,12 @@ +classdef (Abstract) gptPenalties + % This class is undocumented and will change in a future release + + % Copyright 2024 The MathWorks, Inc. + properties + %PRESENCEPENALTY Penalty for using a token in the response that has already been used. + PresencePenalty {llms.utils.mustBeValidPenalty} = 0 + + %FREQUENCYPENALTY Penalty for using a token that is frequent in the training data. + FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 + end +end diff --git a/+llms/+internal/hasTools.m b/+llms/+internal/hasTools.m new file mode 100644 index 0000000..d7a346c --- /dev/null +++ b/+llms/+internal/hasTools.m @@ -0,0 +1,15 @@ +classdef (Abstract) hasTools + % This class is undocumented and will change in a future release + + % Copyright 2023-2024 The MathWorks, Inc. + + properties (SetAccess=protected) + %FunctionNames Names of the functions that the model can request calls + FunctionNames + end + + properties (Access=protected) + Tools + FunctionsStruct + end +end diff --git a/+llms/+internal/needsAPIKey.m b/+llms/+internal/needsAPIKey.m new file mode 100644 index 0000000..5a54bce --- /dev/null +++ b/+llms/+internal/needsAPIKey.m @@ -0,0 +1,9 @@ +classdef (Abstract) needsAPIKey + % This class is undocumented and will change in a future release + + % Copyright 2023-2024 The MathWorks, Inc. + + properties (Access=protected) + APIKey + end +end diff --git a/+llms/+internal/sendRequest.m b/+llms/+internal/sendRequest.m index 631c2dc..2230808 100644 --- a/+llms/+internal/sendRequest.m +++ b/+llms/+internal/sendRequest.m @@ -1,9 +1,11 @@ function [response, streamedText] = sendRequest(parameters, token, endpoint, timeout, streamFun) +% This function is undocumented and will change in a future release + %sendRequest Sends a request to an ENDPOINT using PARAMETERS and -% api key TOKEN. TIMEOUT is the nubmer of seconds to wait for initial +% api key TOKEN. TIMEOUT is the number of seconds to wait for initial % server connection. STREAMFUN is an optional callback function. -% Copyright 2023 The MathWorks, Inc. +% Copyright 2023-2024 The MathWorks, Inc. arguments parameters @@ -15,17 +17,20 @@ % Define the headers for the API request -headers = [matlab.net.http.HeaderField('Content-Type', 'application/json')... - matlab.net.http.HeaderField('Authorization', "Bearer " + token)]; +headers = matlab.net.http.HeaderField('Content-Type', 'application/json'); +if ~isempty(token) + headers = [headers ... + matlab.net.http.HeaderField('Authorization', "Bearer " + token)... + matlab.net.http.HeaderField('api-key',token)]; +end % Define the request message request = matlab.net.http.RequestMessage('post',headers,parameters); -% Create a HTTPOptions object; +% set the timeout httpOpts = matlab.net.http.HTTPOptions; - -% Set the ConnectTimeout option httpOpts.ConnectTimeout = timeout; +httpOpts.ResponseTimeout = timeout; % Send the request and store the response if isempty(streamFun) diff --git a/+llms/+internal/textGenerator.m b/+llms/+internal/textGenerator.m new file mode 100644 index 0000000..f6cb167 --- /dev/null +++ b/+llms/+internal/textGenerator.m @@ -0,0 +1,31 @@ +classdef (Abstract) textGenerator + % This class is undocumented and will change in a future release + + % Copyright 2023-2024 The MathWorks, Inc. + + properties + %Temperature Temperature of generation. + Temperature {llms.utils.mustBeValidTemperature} = 1 + + %TopP Top probability mass to consider for generation. + TopP {llms.utils.mustBeValidTopP} = 1 + + %StopSequences Sequences to stop the generation of tokens. + StopSequences {llms.utils.mustBeValidStop} = {} + end + + properties (SetAccess=protected) + %TimeOut Connection timeout in seconds (default 10 secs) + TimeOut + + %SystemPrompt System prompt. + SystemPrompt = [] + + %ResponseFormat Response format, "text" or "json" + ResponseFormat + end + + properties (Access=protected) + StreamFun + end +end diff --git a/+llms/+openai/models.m b/+llms/+openai/models.m index 8bee9b2..96f9c17 100644 --- a/+llms/+openai/models.m +++ b/+llms/+openai/models.m @@ -2,7 +2,7 @@ %MODELS - supported OpenAI models % Copyright 2024 The MathWorks, Inc. - models = [... + models = [... "gpt-4o","gpt-4o-2024-05-13",... "gpt-4-turbo","gpt-4-turbo-2024-04-09",... "gpt-4","gpt-4-0613", ... diff --git a/+llms/+stream/responseStreamer.m b/+llms/+stream/responseStreamer.m index 8db9ff3..b13048d 100644 --- a/+llms/+stream/responseStreamer.m +++ b/+llms/+stream/responseStreamer.m @@ -1,4 +1,4 @@ -classdef responseStreamer < matlab.net.http.io.StringConsumer +classdef responseStreamer < matlab.net.http.io.BinaryConsumer %responseStreamer Responsible for obtaining the streaming results from the %API @@ -7,6 +7,7 @@ properties ResponseText StreamFun + Incomplete = "" end methods @@ -20,17 +21,24 @@ if this.Response.StatusCode ~= matlab.net.http.StatusCode.OK length = 0; else - length = this.start@matlab.net.http.io.StringConsumer; + length = this.start@matlab.net.http.io.BinaryConsumer; end end end - + methods function [len,stop] = putData(this, data) - [len,stop] = this.putData@matlab.net.http.io.StringConsumer(data); - + [len,stop] = this.putData@matlab.net.http.io.BinaryConsumer(data); + stop = doPutData(this, data, stop); + end + end + + methods (Access=?tresponseStreamer) + function stop = doPutData(this, data, stop) % Extract out the response text from the message str = native2unicode(data','UTF-8'); + str = this.Incomplete + string(str); + this.Incomplete = ""; str = split(str,newline); str = str(strlength(str)>0); str = erase(str,"data: "); @@ -43,35 +51,54 @@ try json = jsondecode(str{i}); catch ME - errID = 'llms:stream:responseStreamer:InvalidInput'; - msg = "Input does not have the expected json format. " + str{i}; - ME = MException(errID,msg); - throw(ME) + if i == length(str) + this.Incomplete = str{i}; + return; + end + error("llms:stream:responseStreamer:InvalidInput", ... + llms.utils.errorMessageCatalog.getMessage(... + "llms:stream:responseStreamer:InvalidInput", str{i})); end - if ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"]) - stop = true; - return - else - if isfield(json.choices.delta,"tool_calls") - if isfield(json.choices.delta.tool_calls,"id") - id = json.choices.delta.tool_calls.id; - type = json.choices.delta.tool_calls.type; - fcn = json.choices.delta.tool_calls.function; - s = struct('id',id,'type',type,'function',fcn); - txt = jsonencode(s); + if isfield(json,'choices') + if isempty(json.choices) + continue; + end + if isfield(json.choices,'finish_reason') && ... + ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"]) + stop = true; + return + else + if isfield(json.choices,"delta") && ... + isfield(json.choices.delta,"tool_calls") + if isfield(json.choices.delta.tool_calls,"id") + id = json.choices.delta.tool_calls.id; + type = json.choices.delta.tool_calls.type; + fcn = json.choices.delta.tool_calls.function; + s = struct('id',id,'type',type,'function',fcn); + txt = jsonencode(s); + else + s = jsondecode(this.ResponseText); + args = json.choices.delta.tool_calls.function.arguments; + s.function.arguments = [s.function.arguments args]; + txt = jsonencode(s); + end + this.StreamFun(''); + this.ResponseText = txt; else - s = jsondecode(this.ResponseText); - args = json.choices.delta.tool_calls.function.arguments; - s.function.arguments = [s.function.arguments args]; - txt = jsonencode(s); + txt = json.choices.delta.content; + this.StreamFun(txt); + this.ResponseText = [this.ResponseText txt]; end - this.StreamFun(''); - this.ResponseText = txt; - else - txt = json.choices.delta.content; + end + else + txt = json.message.content; + if strlength(txt) > 0 this.StreamFun(txt); this.ResponseText = [this.ResponseText txt]; end + if isfield(json,"done") + stop = json.done; + end end end end diff --git a/+llms/+utils/errorMessageCatalog.m b/+llms/+utils/errorMessageCatalog.m index f3d57c7..6915120 100644 --- a/+llms/+utils/errorMessageCatalog.m +++ b/+llms/+utils/errorMessageCatalog.m @@ -43,10 +43,12 @@ catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters."; catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1})."; catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4."; -catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter."; +catalog("llms:endpointMustBeSpecified") = "Unable to find endpoint. Either set environment variable AZURE_OPENAI_ENDPOINT or specify name-value argument ""Endpoint""."; +catalog("llms:deploymentMustBeSpecified") = "Unable to find deployment name. Either set environment variable AZURE_OPENAI_DEPLOYMENT or specify name-value argument ""Deployment""."; +catalog("llms:keyMustBeSpecified") = "Unable to find API key. Either set environment variable {1} or specify name-value argument ""APIKey""."; catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages."; catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified."; -catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects."; +catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or a messageHistory object."; catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for ModelName '{3}'"; catalog("llms:invalidOptionForModel") = "{1} is not supported for ModelName '{2}'"; catalog("llms:invalidContentTypeForModel") = "{1} is not supported for ModelName '{2}'"; @@ -54,6 +56,7 @@ catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'"; catalog("llms:pngExpected") = "Argument must be a PNG image."; catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message."; -catalog("llms:apiReturnedError") = "OpenAI API Error: {1}"; +catalog("llms:apiReturnedError") = "Server error: ""{1}"""; catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}."; +catalog("llms:stream:responseStreamer:InvalidInput") = "Input does not have the expected json format, got ""{1}""."; end diff --git a/+llms/+utils/isUnique.m b/+llms/+utils/isUnique.m index cdd7960..423c709 100644 --- a/+llms/+utils/isUnique.m +++ b/+llms/+utils/isUnique.m @@ -6,4 +6,3 @@ % Copyright 2023 The MathWorks, Inc. tf = numel(values)==numel(unique(values)); end - diff --git a/+llms/+utils/mustBeNonzeroLengthTextScalar.m b/+llms/+utils/mustBeNonzeroLengthTextScalar.m index 5fec18e..33f6cb7 100644 --- a/+llms/+utils/mustBeNonzeroLengthTextScalar.m +++ b/+llms/+utils/mustBeNonzeroLengthTextScalar.m @@ -1,4 +1,12 @@ function mustBeNonzeroLengthTextScalar(content) +% This function is undocumented and will change in a future release + +% Simple function to check if value is empty or text scalar + +% Copyright 2024 The MathWorks, Inc. mustBeNonzeroLengthText(content) +if iscellstr(content) + content = string(content); +end mustBeTextScalar(content) -end \ No newline at end of file +end diff --git a/+llms/+utils/mustBeTextOrEmpty.m b/+llms/+utils/mustBeTextOrEmpty.m index 766007a..f3e2c8a 100644 --- a/+llms/+utils/mustBeTextOrEmpty.m +++ b/+llms/+utils/mustBeTextOrEmpty.m @@ -7,4 +7,4 @@ function mustBeTextOrEmpty(value) if ~isempty(value) mustBeTextScalar(value) end -end \ No newline at end of file +end diff --git a/+llms/+utils/mustBeValidPenalty.m b/+llms/+utils/mustBeValidPenalty.m new file mode 100644 index 0000000..f18cd40 --- /dev/null +++ b/+llms/+utils/mustBeValidPenalty.m @@ -0,0 +1,6 @@ +function mustBeValidPenalty(value) +% This function is undocumented and will change in a future release + +% Copyright 2024 The MathWorks, Inc. + validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2}) +end diff --git a/+llms/+utils/mustBeValidStop.m b/+llms/+utils/mustBeValidStop.m new file mode 100644 index 0000000..f3862c7 --- /dev/null +++ b/+llms/+utils/mustBeValidStop.m @@ -0,0 +1,13 @@ +function mustBeValidStop(value) +% This function is undocumented and will change in a future release + +% Copyright 2024 The MathWorks, Inc. + if ~isempty(value) + mustBeVector(value); + mustBeNonzeroLengthText(value); + % This restriction is set by the OpenAI API + if numel(value)>4 + error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements")); + end + end +end diff --git a/+llms/+utils/mustBeValidTemperature.m b/+llms/+utils/mustBeValidTemperature.m new file mode 100644 index 0000000..976370f --- /dev/null +++ b/+llms/+utils/mustBeValidTemperature.m @@ -0,0 +1,6 @@ +function mustBeValidTemperature(value) +% This function is undocumented and will change in a future release + +% Copyright 2024 The MathWorks, Inc. + validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2}) +end diff --git a/+llms/+utils/mustBeValidTopP.m b/+llms/+utils/mustBeValidTopP.m new file mode 100644 index 0000000..ed2bbd6 --- /dev/null +++ b/+llms/+utils/mustBeValidTopP.m @@ -0,0 +1,6 @@ +function mustBeValidTopP(value) +% This function is undocumented and will change in a future release + +% Copyright 2024 The MathWorks, Inc. + validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1}) +end diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 83e4508..e5f5348 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,6 +7,26 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v4 + - name: Install Ollama + run: | + curl -fsSL https://ollama.com/install.sh | sudo -E sh + - name: Start serving + run: | + # Run the background, there is no way to daemonise at the moment + ollama serve & + + # A short pause is required before the HTTP port is opened + sleep 5 + + # This endpoint blocks until ready + time curl -i http://localhost:11434 + + # For debugging, record Ollama version + ollama --version + + - name: Pull mistral model + run: | + ollama pull mistral - name: Set up MATLAB uses: matlab-actions/setup-matlab@v2 with: @@ -15,6 +35,10 @@ jobs: - name: Run tests and generate artifacts env: OPENAI_KEY: ${{ secrets.OPENAI_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_KEY }} + AZURE_OPENAI_DEPLOYMENT: ${{ secrets.AZURE_DEPLOYMENT }} + AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} + AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_KEY }} uses: matlab-actions/run-tests@v2 with: test-results-junit: test-results/results.xml diff --git a/.gitignore b/.gitignore index 15e5229..8068a93 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ startup.m papers_to_read.csv data/* +examples/data/* +._* diff --git a/README.md b/README.md index 678961a..ec29399 100644 --- a/README.md +++ b/README.md @@ -2,18 +2,7 @@ [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/llms-with-matlab) [![View Large Language Models (LLMs) with MATLAB on File Exchange](https://www.mathworks.com/matlabcentral/images/matlab-file-exchange.svg)](https://www.mathworks.com/matlabcentral/fileexchange/163796-large-language-models-llms-with-matlab) -This repository contains example code to demonstrate how to connect MATLAB to the OpenAI™ Chat Completions API (which powers ChatGPT™) as well as OpenAI Images API (which powers DALL·E™). This allows you to leverage the natural language processing capabilities of large language models directly within your MATLAB environment. - -The functionality shown here serves as an interface to the ChatGPT and DALL·E APIs. To start using the OpenAI APIs, you first need to obtain OpenAI API keys. You are responsible for any fees OpenAI may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the OpenAI APIs. - -Some of the current LLMs supported are: -- gpt-3.5-turbo, gpt-3.5-turbo-1106, gpt-3.5-turbo-0125 -- gpt-4o, gpt-4o-2024-05-13 (GPT-4 Omni) -- gpt-4-turbo, gpt-4-turbo-2024-04-09 (GPT-4 Turbo with Vision) -- gpt-4, gpt-4-0613 -- dall-e-2, dall-e-3 - -For details on the specification of each model, check the official [OpenAI documentation](https://platform.openai.com/docs/models). +This repository contains code to connect MATLAB to the [OpenAI™ Chat Completions API](https://platform.openai.com/docs/guides/text-generation/chat-completions-api) (which powers ChatGPT™), OpenAI Images API (which powers DALL·E™), [Azure® OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/), and local [Ollama®](https://ollama.com/) models. This allows you to leverage the natural language processing capabilities of large language models directly within your MATLAB environment. ## Requirements @@ -24,10 +13,19 @@ For details on the specification of each model, check the official [OpenAI docum ### 3rd Party Products: -- An active OpenAI API subscription and API key. +- For OpenAI connections: An active OpenAI API subscription and API key. +- For Azure OpenAI Services: An active Azure subscription with OpenAI access, deployment, and API key. +- For Ollama: A local Ollama installation. Currently, only connections on `localhost` are supported, i.e., Ollama and MATLAB must run on the same machine. ## Setup +See these pages for instructions specific to the 3rd party product selected: + +* [OpenAI](doc/OpenAI.md) +* [Azure](doc/Azure.md) +* [Ollama](doc/Ollama.md) + + ### MATLAB Online To use this repository with MATLAB Online, click [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/llms-with-matlab) @@ -51,284 +49,6 @@ To use this repository with a local installation of MATLAB, first clone the repo addpath('path/to/llms-with-matlab'); ``` -### Setting up your API key - -Set up your OpenAI API key. Create a `.env` file in the project root directory with the following content. - -``` -OPENAI_API_KEY= -``` - -Then load your `.env` file as follows: - -```matlab -loadenv(".env") -``` - -## Getting Started with Chat Completion API - -To get started, you can either create an `openAIChat` object and use its methods or use it in a more complex setup, as needed. - -### Simple call without preserving chat history - -In some situations, you will want to use chat completion models without preserving chat history. For example, when you want to perform independent queries in a programmatic way. - -Here's a simple example of how to use the `openAIChat` for sentiment analysis: - -```matlab -% Initialize the OpenAI Chat object, passing a system prompt - -% The system prompt tells the assistant how to behave, in this case, as a sentiment analyzer -systemPrompt = "You are a sentiment analyser. You will look at a sentence and output"+... - " a single word that classifies that sentence as either 'positive' or 'negative'."+.... - "Examples: \n"+... - "The project was a complete failure. \n"+... - "negative \n\n"+... - "The team successfully completed the project ahead of schedule."+... - "positive \n\n"+... - "His attitude was terribly discouraging to the team. \n"+... - "negative \n\n"; - -chat = openAIChat(systemPrompt); - -% Generate a response, passing a new sentence for classification -txt = generate(chat,"The team is feeling very motivated") -% Should output "positive" -``` - -### Creating a chat system - -If you want to create a chat system, you will have to create a history of the conversation and pass that to the `generate` function. - -To start a conversation history, create a `openAIMessages` object: - -```matlab -history = openAIMessages; -``` - -Then create the chat assistant: - -```matlab -chat = openAIChat("You are a helpful AI assistant."); -``` - -Add a user message to the history and pass it to `generate` - -```matlab -history = addUserMessage(history,"What is an eigenvalue?"); -[txt, response] = generate(chat, history) -``` - -The output `txt` will contain the answer and `response` will contain the full response, which you need to include in the history as follows -```matlab -history = addResponseMessage(history, response); -``` - -You can keep interacting with the API and since we are saving the history, it will know about previous interactions. -```matlab -history = addUserMessage(history,"Generate MATLAB code that computes that"); -[txt, response] = generate(chat,history); -% Will generate code to compute the eigenvalue -``` - -### Streaming the response - -Streaming allows you to start receiving the output from the API as it is generated token by token, rather than wait for the entire completion to be generated. You can specifying the streaming function when you create the chat assistant. In this example, the streaming function will print the response to the command window. -```matlab -% streaming function -sf = @(x)fprintf("%s",x); -chat = openAIChat(StreamFun=sf); -txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?") -% Should stream the response token by token -``` - -### Calling MATLAB functions with the API - -Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. -Note that the API is not able to directly call any function, so you should call the function and pass the values to the API directly. This process can be automated as shown in [AnalyzeScientificPapersUsingFunctionCalls.mlx](/examples/AnalyzeScientificPapersUsingFunctionCalls.mlx), but it's important to consider that ChatGPT can hallucinate function names, so avoid executing any arbitrary generated functions and only allow the execution of functions that you have defined. - -For example, if you want to use the API for mathematical operations such as `sind`, instead of letting the model generate the result and risk running into hallucinations, you can give the model direct access to the function as follows: - - -```matlab -f = openAIFunction("sind","Sine of argument in degrees"); -f = addParameter(f,"x",type="number",description="Angle in degrees."); -chat = openAIChat("You are a helpful assistant.",Tools=f); -``` - -When the model identifies that it could use the defined functions to answer a query, it will return a `tool_calls` request, instead of directly generating the response: - -```matlab -messages = openAIMessages; -messages = addUserMessage(messages, "What is the sine of 30?"); -[txt, response] = generate(chat, messages); -messages = addResponseMessage(messages, response); -``` - -The variable `response` should contain a request for a function call. -```bash ->> response - -response = - - struct with fields: - - role: 'assistant' - content: [] - tool_calls: [1×1 struct] - ->> response.tool_calls - -ans = - - struct with fields: - - id: 'call_wDpCLqtLhXiuRpKFw71gXzdy' - type: 'function' - function: [1×1 struct] - ->> response.tool_calls.function - -ans = - - struct with fields: - - name: 'sind' - arguments: '{↵ "x": 30↵}' -``` - -You can then call the function `sind` with the specified argument and return the value to the API add a function message to the history: - -```matlab -% Arguments are returned as json, so you need to decode it first -id = string(response.tool_calls.id); -func = string(response.tool_calls.function.name); -if func == "sind" - args = jsondecode(response.tool_calls.function.arguments); - result = sind(args.x); - messages = addToolMessage(messages,id,func,"x="+result); - [txt, response] = generate(chat, messages); -else - % handle calls to unknown functions -end -``` - -The model then will use the function result to generate a more precise response: - -```shell ->> txt - -txt = - - "The sine of 30 degrees is approximately 0.5." -``` - -### Extracting structured information with the API - -Another useful application for defining functions is extract structured information from some text. You can just pass a function with the output format that you would like the model to output and the information you want to extract. For example, consider the following piece of text: - -```matlab -patientReport = "Patient John Doe, a 45-year-old male, presented " + ... - "with a two-week history of persistent cough and fatigue. " + ... - "Chest X-ray revealed an abnormal shadow in the right lung." + ... - " A CT scan confirmed a 3cm mass in the right upper lobe," + ... - " suggestive of lung cancer. The patient has been referred " + ... - "for biopsy to confirm the diagnosis."; -``` - -If you want to extract information from this text, you can define a function as follows: -```matlab -f = openAIFunction("extractPatientData","Extracts data about a patient from a record"); -f = addParameter(f,"patientName",type="string",description="Name of the patient"); -f = addParameter(f,"patientAge",type="number",description="Age of the patient"); -f = addParameter(f,"patientSymptoms",type="string",description="Symptoms that the patient is having."); -``` - -Note that this function does not need to exist, since it will only be used to extract the Name, Age and Symptoms of the patient and it does not need to be called: - -```matlab -chat = openAIChat("You are helpful assistant that reads patient records and extracts information", ... - Tools=f); -messages = openAIMessages; -messages = addUserMessage(messages,"Extract the information from the report:" + newline + patientReport); -[txt, response] = generate(chat, messages); -``` - -The model should return the extracted information as a function call: -```shell ->> response - -response = - - struct with fields: - - role: 'assistant' - content: [] - tool_call: [1×1 struct] - ->> response.tool_calls - -ans = - - struct with fields: - - id: 'call_4VRtN7jb3pTPosMSb4ZaLoWP' - type: 'function' - function: [1×1 struct] - ->> response.tool_calls.function - -ans = - - struct with fields: - - name: 'extractPatientData' - arguments: '{↵ "patientName": "John Doe",↵ "patientAge": 45,↵ "patientSymptoms": "persistent cough, fatigue"↵}' -``` - -You can extract the arguments and write the data to a table, for example. - -### Understand the content of an image - -You can use gpt-4-turbo to experiment with image understanding. -```matlab -chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-turbo"); -image_path = "peppers.png"; -messages = openAIMessages; -messages = addUserMessageWithImages(messages,"What is in the image?",image_path); -[txt,response] = generate(chat,messages,MaxNumTokens=4096); -% Should output the description of the image -``` - -### Obtaining embeddings - -You can extract embeddings from your text with OpenAI using the function `extractOpenAIEmbeddings` as follows: -```matlab -exampleText = "Here is an example!"; -emb = extractOpenAIEmbeddings(exampleText); -``` - -The resulting embedding is a vector that captures the semantics of your text and can be used on tasks such as retrieval augmented generation and clustering. - -```matlab ->> size(emb) - -ans = - - 1 1536 -``` -## Getting Started with Images API - -To get started, you can either create an `openAIImages` object and use its methods or use it in a more complex setup, as needed. - -```matlab -mdl = openAIImages(ModelName="dall-e-3"); -images = generate(mdl,"Create a 3D avatar of a whimsical sushi on the beach. He is decorated with various sushi elements and is playfully interacting with the beach environment."); -figure -imshow(images{1}) -% Should output an image based on the prompt -``` - ## Examples To learn how to use this in your workflows, see [Examples](/examples/). @@ -345,7 +65,7 @@ To learn how to use this in your workflows, see [Examples](/examples/). ## License -The license is available in the license.txt file in this GitHub repository. +The license is available in the [license.txt](license.txt) file in this GitHub repository. ## Community Support [MATLAB Central](https://www.mathworks.com/matlabcentral) diff --git a/azureChat.m b/azureChat.m new file mode 100644 index 0000000..fccf7da --- /dev/null +++ b/azureChat.m @@ -0,0 +1,320 @@ +classdef(Sealed) azureChat < llms.internal.textGenerator & ... + llms.internal.gptPenalties & llms.internal.hasTools & llms.internal.needsAPIKey +%azureChat Chat completion API from Azure. +% +% CHAT = azureChat creates an azureChat object, with the parameters needed +% to connect to Azure taken from the environment. +% +% CHAT = azureChat(systemPrompt) creates an azureChat object with the +% specified system prompt. +% +% CHAT = azureChat(__,Name=Value) specifies additional options +% using one or more name-value arguments: +% +% Endpoint - The endpoint as defined in the Azure OpenAI Services +% interface. Needs to be specified or stored in the +% environment variable AZURE_OPENAI_ENDPOINT. +% +% Deployment - The deployment as defined in the Azure OpenAI Services +% interface. Needs to be specified or stored in the +% environment variable AZURE_OPENAI_DEPLOYMENT. +% +% APIKey - The API key for accessing the Azure OpenAI Chat API. +% Needs to be specified or stored in the +% environment variable AZURE_OPENAI_API_KEY. +% +% Temperature - Temperature value for controlling the randomness +% of the output. Default value is 1; higher values +% increase the randomness (in some sense, +% the “creativity”) of outputs, lower values +% reduce it. Setting Temperature=0 removes +% randomness from the output altogether. +% +% TopP - Top probability mass value for controlling the +% diversity of the output. Default value is 1; +% lower values imply that only the more likely +% words can appear in any particular place. +% This is also known as top-p sampling. +% +% StopSequences - Vector of strings that when encountered, will +% stop the generation of tokens. Default +% value is empty. +% Example: ["The end.", "And that's all she wrote."] +% +% ResponseFormat - The format of response the model returns. +% "text" (default) | "json" +% +% PresencePenalty - Penalty value for using a token in the response +% that has already been used. Default value is 0. +% Higher values reduce repetition of words in the output. +% +% FrequencyPenalty - Penalty value for using a token that is frequent +% in the output. Default value is 0. +% Higher values reduce repetition of words in the output. +% +% StreamFun - Function to callback when streaming the result +% +% TimeOut - Connection Timeout in seconds. Default value is 10. +% +% Tools - A list of tools the model can call. +% +% API Version - The API version to use for this model. +% "2024-02-01" (default) | "2023-05-15" | "2024-05-01-preview" | ... +% "2024-04-01-preview" | "2024-03-01-preview" +% +% +% +% azureChat Functions: +% azureChat - Chat completion API from OpenAI. +% generate - Generate a response using the azureChat instance. +% +% azureChat Properties: +% Temperature - Temperature of generation. +% +% TopP - Top probability mass to consider for generation. +% +% StopSequences - Sequences to stop the generation of tokens. +% +% PresencePenalty - Penalty for using a token in the +% response that has already been used. +% +% FrequencyPenalty - Penalty for using a token that is +% frequent in the training data. +% +% SystemPrompt - System prompt. +% +% FunctionNames - Names of the functions that the model can +% request calls. +% +% ResponseFormat - Specifies the response format, "text" or "json". +% +% TimeOut - Connection Timeout in seconds. +% + +% Copyright 2023-2024 The MathWorks, Inc. + + properties(SetAccess=private) + Endpoint (1,1) string + DeploymentID (1,1) string + APIVersion (1,1) string + end + + methods + function this = azureChat(systemPrompt, nvp) + arguments + systemPrompt {llms.utils.mustBeTextOrEmpty} = [] + nvp.Endpoint {mustBeNonzeroLengthTextScalar} + nvp.Deployment {mustBeNonzeroLengthTextScalar} + nvp.APIKey {mustBeNonzeroLengthTextScalar} + nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty + nvp.APIVersion (1,1) {mustBeAPIVersion} = "2024-02-01" + nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 + nvp.TopP {llms.utils.mustBeValidTopP} = 1 + nvp.StopSequences {llms.utils.mustBeValidStop} = {} + nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" + nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0 + nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 + nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} + end + + if isfield(nvp,"StreamFun") + this.StreamFun = nvp.StreamFun; + else + this.StreamFun = []; + end + + if isempty(nvp.Tools) + this.Tools = []; + this.FunctionsStruct = []; + this.FunctionNames = []; + else + this.Tools = nvp.Tools; + [this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools); + end + + if ~isempty(systemPrompt) + systemPrompt = string(systemPrompt); + if ~(strlength(systemPrompt)==0) + this.SystemPrompt = {struct("role", "system", "content", systemPrompt)}; + end + end + + this.Endpoint = getEndpoint(nvp); + this.DeploymentID = getDeployment(nvp); + this.APIKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"AZURE_OPENAI_API_KEY"); + this.APIVersion = nvp.APIVersion; + this.ResponseFormat = nvp.ResponseFormat; + this.Temperature = nvp.Temperature; + this.TopP = nvp.TopP; + this.StopSequences = nvp.StopSequences; + this.PresencePenalty = nvp.PresencePenalty; + this.FrequencyPenalty = nvp.FrequencyPenalty; + this.TimeOut = nvp.TimeOut; + end + + function [text, message, response] = generate(this, messages, nvp) + %generate Generate a response using the azureChat instance. + % + % [TEXT, MESSAGE, RESPONSE] = generate(CHAT, MESSAGES) generates a response + % with the specified MESSAGES. + % + % [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options + % using one or more name-value arguments: + % + % NumCompletions - Number of completions to generate. + % Default value is 1. + % + % MaxNumTokens - Maximum number of tokens in the generated response. + % Default value is inf. + % + % ToolChoice - Function to execute. 'none', 'auto', + % or specify the function to call. + % + % Seed - An integer value to use to obtain + % reproducible responses + % + % Currently, GPT-4 Turbo with vision does not support the message.name + % parameter, functions/tools, response_format parameter, stop + % sequences, and max_tokens + + arguments + this (1,1) azureChat + messages {mustBeValidMsgs} + nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1 + nvp.MaxNumTokens (1,1) {mustBePositive} = inf + nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = [] + nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] + end + + messages = convertCharsToStrings(messages); + if isstring(messages) && isscalar(messages) + messagesStruct = {struct("role", "user", "content", messages)}; + else + messagesStruct = messages.Messages; + end + + if ~isempty(this.SystemPrompt) + messagesStruct = horzcat(this.SystemPrompt, messagesStruct); + end + + toolChoice = convertToolChoice(this, nvp.ToolChoice); + try + [text, message, response] = llms.internal.callAzureChatAPI(this.Endpoint, ... + this.DeploymentID, messagesStruct, this.FunctionsStruct, ... + ToolChoice=toolChoice, APIVersion = this.APIVersion, Temperature=this.Temperature, ... + TopP=this.TopP, NumCompletions=nvp.NumCompletions,... + StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... + PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ... + ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... + APIKey=this.APIKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); + catch ME + if ismember(ME.identifier,... + ["MATLAB:webservices:UnknownHost","MATLAB:webservices:Timeout"]) + % throw(ME)would still print a long stack trace, from + % ME.cause.stack. We cannot change ME.cause, so we + % throw a new error: + error(ME.identifier,ME.message); + end + rethrow(ME); + end + + if isfield(response.Body.Data,"error") + err = response.Body.Data.error.message; + error("llms:apiReturnedError",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err)); + end + end + end + + methods(Hidden) + function mustBeValidFunctionCall(this, functionCall) + if ~isempty(functionCall) + mustBeTextScalar(functionCall); + if isempty(this.FunctionNames) + error("llms:mustSetFunctionsForCall", llms.utils.errorMessageCatalog.getMessage("llms:mustSetFunctionsForCall")); + end + mustBeMember(functionCall, ["none","auto", this.FunctionNames]); + end + end + + function toolChoice = convertToolChoice(this, toolChoice) + % if toolChoice is empty + if isempty(toolChoice) + % if Tools is not empty, the default is 'auto'. + if ~isempty(this.Tools) + toolChoice = "auto"; + end + elseif toolChoice ~= "auto" + % if toolChoice is not empty, then it must be in the format + % {"type": "function", "function": {"name": "my_function"}} + toolChoice = struct("type","function","function",struct("name",toolChoice)); + end + + end + end +end + +function mustBeNonzeroLengthTextScalar(content) +mustBeNonzeroLengthText(content) +mustBeTextScalar(content) +end + +function [functionsStruct, functionNames] = functionAsStruct(functions) +numFunctions = numel(functions); +functionsStruct = cell(1, numFunctions); +functionNames = strings(1, numFunctions); + +for i = 1:numFunctions + functionsStruct{i} = struct('type','function', ... + 'function',encodeStruct(functions(i))) ; + functionNames(i) = functions(i).FunctionName; +end +end + +function mustBeValidMsgs(value) +if isa(value, "messageHistory") + if numel(value.Messages) == 0 + error("llms:mustHaveMessages", llms.utils.errorMessageCatalog.getMessage("llms:mustHaveMessages")); + end +else + try + llms.utils.mustBeNonzeroLengthTextScalar(value); + catch ME + error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt")); + end +end +end + +function mustBeIntegerOrEmpty(value) + if ~isempty(value) + mustBeInteger(value) + end +end + +function mustBeAPIVersion(model) + mustBeMember(model,llms.azure.apiVersions); +end + +function endpoint = getEndpoint(nvp) + if isfield(nvp, "Endpoint") + endpoint = nvp.Endpoint; + else + if isenv("AZURE_OPENAI_ENDPOINT") + endpoint = getenv("AZURE_OPENAI_ENDPOINT"); + else + error("llms:endpointMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:endpointMustBeSpecified")); + end + end +end + +function deployment = getDeployment(nvp) + if isfield(nvp, "Deployment") + deployment = nvp.Deployment; + else + if isenv("AZURE_OPENAI_DEPLOYMENT") + deployment = getenv("AZURE_OPENAI_DEPLOYMENT"); + else + error("llms:deploymentMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:deploymentMustBeSpecified")); + end + end +end diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..570b6df --- /dev/null +++ b/codecov.yml @@ -0,0 +1,3 @@ +ignore: + - "tests" + - "**/errorMessageCatalog.m" diff --git a/doc/Azure.md b/doc/Azure.md new file mode 100644 index 0000000..d5af221 --- /dev/null +++ b/doc/Azure.md @@ -0,0 +1,261 @@ +# Connecting to Azure® OpenAI Service + +This repository contains code to connect MATLAB to the [Azure® OpenAI Service](https://learn.microsoft.com/en-us/azure/ai-services/openai/). + +To use Azure OpenAI Services, you need to create a model deployment on your Azure account and obtain one of the keys for it. You are responsible for any fees Azure may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the Azure APIs. + +Some of the [current LLMs supported on Azure](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models) are: +- GPT-4o (GPT-4 Omni) +- GPT-4 Turbo +- GPT-4 +- GPT-3.5 + + +## Setting up your Azure OpenAI Services API key + +Set up your [endpoint and deployment and retrieve one of the API keys](https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart?tabs=command-line%2Cpython-new&pivots=rest-api#retrieve-key-and-endpoint). Create a `.env` file in the project root directory with the following content. + +``` +AZURE_OPENAI_ENDPOINT= +AZURE_OPENAI_DEPLOYMENT= +AZURE_OPENAI_API_KEY= +``` + +You can use either `KEY1` or `KEY2` from the Azure configuration website. + +Then load your `.env` file as follows: + +```matlab +loadenv(".env") +``` + +## Establishing a connection to Chat Completions API using Azure + +To connect MATLAB to Chat Completions API via Azure, you will have to create an `azureChat` object. See [the Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart) for details on the setup required and where to find your key, endpoint, and deployment name. As explained above, the endpoint, deployment, and key should be in the environment variables `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENYT`, and `AZURE_OPENAI_API_KEY`, or provided as `Endpoint=…`, `Deployment=…`, and `APIKey=…` in the `azureChat` call below. + +In order to create the chat assistant, use the `azureChat` function, optionally providing a system prompt: +```matlab +chat = azureChat("You are a helpful AI assistant"); +``` + +The `azureChat` object also allows to specify additional options. Call `help azureChat` for more information. +Compared to `openAIChat`, the `ModelName` option is not available due to the fact that the name of the LLM is already specified when creating the chat assistant. + +## Simple call without preserving chat history + +In some situations, you will want to use chat completion models without preserving chat history. For example, when you want to perform independent queries in a programmatic way. + +Here's a simple example of how to use the `azureChat` for sentiment analysis, initialized with a few-shot prompt: + +```matlab +% Initialize the Azure Chat object, passing a system prompt + +% The system prompt tells the assistant how to behave, in this case, as a sentiment analyzer +systemPrompt = "You are a sentiment analyser. You will look at a sentence and output"+... + " a single word that classifies that sentence as either 'positive' or 'negative'."+.... + newline + ... + "Examples:" + newline +... + "The project was a complete failure." + newline +... + "negative" + newline + newline +... + "The team successfully completed the project ahead of schedule." + newline +... + "positive" + newline + newline +... + "His attitude was terribly discouraging to the team." + newline +... + "negative" + newline + newline; + +chat = azureChat(systemPrompt); + +% Generate a response, passing a new sentence for classification +txt = generate(chat,"The team is feeling very motivated") +% Should output "positive" +``` + +## Creating a chat system + +If you want to create a chat system, you will have to create a history of the conversation and pass that to the `generate` function. + +To start a conversation history, create a `messageHistory` object: + +```matlab +history = messageHistory; +``` + +Then create the chat assistant: + +```matlab +chat = azureChat; +``` + +Add a user message to the history and pass it to `generate`: + +```matlab +history = addUserMessage(history,"What is an eigenvalue?"); +[txt, response] = generate(chat, history) +``` + +The output `txt` will contain the answer and `response` will contain the full response, which you need to include in the history as follows: +```matlab +history = addResponseMessage(history, response); +``` + +You can keep interacting with the API and since we are saving the history, it will know about previous interactions. +```matlab +history = addUserMessage(history,"Generate MATLAB code that computes that"); +[txt, response] = generate(chat,history); +% Will generate code to compute the eigenvalue +``` + +## Streaming the response + +Streaming allows you to start receiving the output from the API as it is generated token by token, rather than wait for the entire completion to be generated. You can specifying the streaming function when you create the chat assistant. In this example, the streaming function will print the response to the command window. +```matlab +% streaming function +sf = @(x) fprintf("%s",x); +chat = azureChat(StreamFun=sf); +txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?") +% Should stream the response token by token +``` + +## Calling MATLAB functions with the API + +Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. +Note that the API is not able to directly call any function, so you should call the function and pass the values to the API directly. This process can be automated as shown in [AnalyzeScientificPapersUsingFunctionCalls.mlx](/examples/AnalyzeScientificPapersUsingFunctionCalls.mlx), but it's important to consider that ChatGPT can hallucinate function names, so avoid executing any arbitrary generated functions and only allow the execution of functions that you have defined. + +For example, if you want to use the API for mathematical operations such as `sind`, instead of letting the model generate the result and risk running into hallucinations, you can give the model direct access to the function as follows: + +```matlab +f = openAIFunction("sind","Sine of argument in degrees"); +f = addParameter(f,"x",type="number",description="Angle in degrees."); +chat = azureChat("You are a helpful assistant.",Tools=f); +``` + +When the model identifies that it could use the defined functions to answer a query, it will return a `tool_calls` request, instead of directly generating the response: + +```matlab +messages = messageHistory; +messages = addUserMessage(messages, "What is the sine of 30?"); +[txt, response] = generate(chat, messages); +messages = addResponseMessage(messages, response); +``` + +The variable `response` should contain a request for a function call. +```bash +>> response + +response = + + struct with fields: + + role: 'assistant' + content: [] + tool_calls: [1×1 struct] + +>> response.tool_calls + +ans = + + struct with fields: + + id: 'call_wDpCLqtLhXiuRpKFw71gXzdy' + type: 'function' + function: [1×1 struct] + +>> response.tool_calls.function + +ans = + + struct with fields: + + name: 'sind' + arguments: '{↵ "x": 30↵}' +``` + +You can then call the function `sind` with the specified argument and return the value to the API add a function message to the history: + +```matlab +% Arguments are returned as json, so you need to decode it first +id = string(response.tool_calls.id); +func = string(response.tool_calls.function.name); +if func == "sind" + args = jsondecode(response.tool_calls.function.arguments); + result = sind(args.x); + messages = addToolMessage(messages,id,func,"x="+result); + [txt, response] = generate(chat, messages); +else + % handle calls to unknown functions +end +``` + +The model then will use the function result to generate a more precise response: + +```shell +>> txt + +txt = + + "The sine of 30 degrees is approximately 0.5." +``` + +## Extracting structured information with the API + +Another useful application for defining functions is to extract structured information from some text. You can just pass a function with the output format that you would like the model to output and the information you want to extract. For example, consider the following piece of text: + +```matlab +patientReport = "Patient John Doe, a 45-year-old male, presented " + ... + "with a two-week history of persistent cough and fatigue. " + ... + "Chest X-ray revealed an abnormal shadow in the right lung." + ... + " A CT scan confirmed a 3cm mass in the right upper lobe," + ... + " suggestive of lung cancer. The patient has been referred " + ... + "for biopsy to confirm the diagnosis."; +``` + +If you want to extract information from this text, you can define a function as follows: +```matlab +f = openAIFunction("extractPatientData","Extracts data about a patient from a record"); +f = addParameter(f,"patientName",type="string",description="Name of the patient"); +f = addParameter(f,"patientAge",type="number",description="Age of the patient"); +f = addParameter(f,"patientSymptoms",type="string",description="Symptoms that the patient is having."); +``` + +Note that this function does not need to exist, since it will only be used to extract the Name, Age and Symptoms of the patient and it does not need to be called: + +```matlab +chat = azureChat("You are helpful assistant that reads patient records and extracts information", ... + Tools=f); +messages = messageHistory; +messages = addUserMessage(messages,"Extract the information from the report:" + newline + patientReport); +[txt, response] = generate(chat, messages); +``` + +The model should return the extracted information as a function call: +```shell +>> response + +response = + + struct with fields: + + role: 'assistant' + content: [] + tool_call: [1×1 struct] + +>> response.tool_calls + +ans = + + struct with fields: + + id: 'call_4VRtN7jb3pTPosMSb4ZaLoWP' + type: 'function' + function: [1×1 struct] + +>> response.tool_calls.function + +ans = + + struct with fields: + + name: 'extractPatientData' + arguments: '{↵ "patientName": "John Doe",↵ "patientAge": 45,↵ "patientSymptoms": "persistent cough, fatigue"↵}' +``` + +You can extract the arguments and write the data to a table, for example. diff --git a/doc/Ollama.md b/doc/Ollama.md new file mode 100644 index 0000000..d6e0bc3 --- /dev/null +++ b/doc/Ollama.md @@ -0,0 +1,97 @@ +# Ollama + +This repository contains code to connect MATLAB to a local [Ollama®](https://ollama.com) server, running large language models (LLMs). + +To use local models with Ollama, you will need to install and start an Ollama server, and “pull” models into it. Please follow the Ollama documentation for details. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of any specific model. + +Some of the [LLMs currently supported out of the box on Ollama](https://ollama.com/library) are: +- llama2, llama2-uncensored, llama3, codellama +- phi3 +- aya +- mistral (v0.1, v0.2, v0.3) +- mixtral +- gemma, codegemma +- command-r + +## Establishing a connection to local LLMs using Ollama + +To create the chat assistant, call `ollamaChat` and specify the LLM you want to use: +```matlab +chat = ollamaChat("mistral"); +``` + +`ollamaChat` has additional options, please run `help ollamaChat` for details. + +## Simple call without preserving chat history + +In some situations, you will want to use chat completion models without preserving chat history. For example, when you want to perform independent queries in a programmatic way. + +Here's a simple example of how to use the `ollamaChat` for sentiment analysis, initialized with a few-shot prompt: + +```matlab +% Initialize the Ollama Chat object, passing a system prompt + +% The system prompt tells the assistant how to behave, in this case, as a sentiment analyzer +systemPrompt = "You are a sentiment analyser. You will look at a sentence and output"+... + " a single word that classifies that sentence as either 'positive' or 'negative'."+.... + newline + ... + "Examples:" + newline +... + "The project was a complete failure." + newline +... + "negative" + newline + newline +... + "The team successfully completed the project ahead of schedule." + newline +... + "positive" + newline + newline +... + "His attitude was terribly discouraging to the team." + newline +... + "negative" + newline + newline; + +chat = ollamaChat("phi3",systemPrompt); + +% Generate a response, passing a new sentence for classification +txt = generate(chat,"The team is feeling very motivated") +% Should output "positive" +``` + +## Creating a chat system + +If you want to create a chat system, you will have to create a history of the conversation and pass that to the `generate` function. + +To start a conversation history, create a `messageHistory` object: + +```matlab +history = messageHistory; +``` + +Then create the chat assistant: + +```matlab +chat = ollamaChat("mistral"); +``` + +Add a user message to the history and pass it to `generate`: + +```matlab +history = addUserMessage(history,"What is an eigenvalue?"); +[txt, response] = generate(chat, history) +``` + +The output `txt` will contain the answer and `response` will contain the full response, which you need to include in the history as follows: +```matlab +history = addResponseMessage(history, response); +``` + +You can keep interacting with the API and since we are saving the history, it will know about previous interactions. +```matlab +history = addUserMessage(history,"Generate MATLAB code that computes that"); +[txt, response] = generate(chat,history); +% Will generate code to compute the eigenvalue +``` + +## Streaming the response + +Streaming allows you to start receiving the output from the API as it is generated token by token, rather than wait for the entire completion to be generated. You can specifying the streaming function when you create the chat assistant. In this example, the streaming function will print the response to the command window. +```matlab +% streaming function +sf = @(x) fprintf("%s",x); +chat = ollamaChat("mistral", StreamFun=sf); +txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?"); +% Should stream the response token by token +``` diff --git a/doc/OpenAI.md b/doc/OpenAI.md new file mode 100644 index 0000000..76bd834 --- /dev/null +++ b/doc/OpenAI.md @@ -0,0 +1,290 @@ +# OpenAI™ + +Several functions in this repository connect MATLAB to the [OpenAI™ Chat Completions API](https://platform.openai.com/docs/guides/text-generation/chat-completions-api) (which powers ChatGPT™) and the [OpenAI Images API](https://platform.openai.com/docs/guides/images/image-generation-beta) (which powers DALL·E™). + +To start using the OpenAI APIs, you first need to obtain OpenAI API keys. You are responsible for any fees OpenAI may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the OpenAI APIs. + +Some of the current LLMs supported on OpenAI are: +- gpt-3.5-turbo, gpt-3.5-turbo-1106, gpt-3.5-turbo-0125 +- gpt-4o, gpt-4o-2024-05-13 (GPT-4 Omni) +- gpt-4-turbo, gpt-4-turbo-2024-04-09 (GPT-4 Turbo with Vision) +- gpt-4, gpt-4-0613 +- dall-e-2, dall-e-3 + +For details on the specification of each model, check the official [OpenAI documentation](https://platform.openai.com/docs/models). + + +## Setting up your OpenAI API key + +Set up your [OpenAI API key](https://platform.openai.com/account/api-keys). Create a `.env` file in the project root directory with the following content. + +``` +OPENAI_API_KEY= +``` + +Then load your `.env` file as follows: + +```matlab +loadenv(".env") +``` + +## Simple call without preserving chat history + +In some situations, you will want to use chat completion models without preserving chat history. For example, when you want to perform independent queries in a programmatic way. + +Here's a simple example of how to use the `openAIChat` for sentiment analysis, initialized with a few-shot prompt: + +```matlab +% Initialize the OpenAI Chat object, passing a system prompt + +% The system prompt tells the assistant how to behave, in this case, as a sentiment analyzer +systemPrompt = "You are a sentiment analyser. You will look at a sentence and output"+... + " a single word that classifies that sentence as either 'positive' or 'negative'."+.... + newline + ... + "Examples:" + newline +... + "The project was a complete failure." + newline +... + "negative" + newline + newline +... + "The team successfully completed the project ahead of schedule." + newline +... + "positive" + newline + newline +... + "His attitude was terribly discouraging to the team." + newline +... + "negative" + newline + newline; + +chat = openAIChat(systemPrompt); + +% Generate a response, passing a new sentence for classification +txt = generate(chat,"The team is feeling very motivated") +% Should output "positive" +``` + +## Creating a chat system + +If you want to create a chat system, you will have to create a history of the conversation and pass that to the `generate` function. + +To start a conversation history, create a `messageHistory` object: + +```matlab +history = messageHistory; +``` + +Then create the chat assistant: + +```matlab +chat = openAIChat("You are a helpful AI assistant."); +``` + +Add a user message to the history and pass it to `generate`: + +```matlab +history = addUserMessage(history,"What is an eigenvalue?"); +[txt, response] = generate(chat, history) +``` + +The output `txt` will contain the answer and `response` will contain the full response, which you need to include in the history as follows: +```matlab +history = addResponseMessage(history, response); +``` + +You can keep interacting with the API and since we are saving the history, it will know about previous interactions. +```matlab +history = addUserMessage(history,"Generate MATLAB code that computes that"); +[txt, response] = generate(chat,history); +% Will generate code to compute the eigenvalue +``` + +## Streaming the response + +Streaming allows you to start receiving the output from the API as it is generated token by token, rather than wait for the entire completion to be generated. You can specifying the streaming function when you create the chat assistant. In this example, the streaming function will print the response to the command window. +```matlab +% streaming function +sf = @(x) fprintf("%s",x); +chat = openAIChat(StreamFun=sf); +txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?") +% Should stream the response token by token +``` + +## Calling MATLAB functions with the API + +Optionally, `Tools=functions` can be used to provide function specifications to the API. The purpose of this is to enable models to generate function arguments which adhere to the provided specifications. +Note that the API is not able to directly call any function, so you should call the function and pass the values to the API directly. This process can be automated as shown in [AnalyzeScientificPapersUsingFunctionCalls.mlx](/examples/AnalyzeScientificPapersUsingFunctionCalls.mlx), but it's important to consider that ChatGPT can hallucinate function names, so avoid executing any arbitrary generated functions and only allow the execution of functions that you have defined. + +For example, if you want to use the API for mathematical operations such as `sind`, instead of letting the model generate the result and risk running into hallucinations, you can give the model direct access to the function as follows: + +```matlab +f = openAIFunction("sind","Sine of argument in degrees"); +f = addParameter(f,"x",type="number",description="Angle in degrees."); +chat = openAIChat("You are a helpful assistant.",Tools=f); +``` + +When the model identifies that it could use the defined functions to answer a query, it will return a `tool_calls` request, instead of directly generating the response: + +```matlab +messages = messageHistory; +messages = addUserMessage(messages, "What is the sine of 30?"); +[txt, response] = generate(chat, messages); +messages = addResponseMessage(messages, response); +``` + +The variable `response` should contain a request for a function call. +```bash +>> response + +response = + + struct with fields: + + role: 'assistant' + content: [] + tool_calls: [1×1 struct] + +>> response.tool_calls + +ans = + + struct with fields: + + id: 'call_wDpCLqtLhXiuRpKFw71gXzdy' + type: 'function' + function: [1×1 struct] + +>> response.tool_calls.function + +ans = + + struct with fields: + + name: 'sind' + arguments: '{↵ "x": 30↵}' +``` + +You can then call the function `sind` with the specified argument and return the value to the API add a function message to the history: + +```matlab +% Arguments are returned as json, so you need to decode it first +id = string(response.tool_calls.id); +func = string(response.tool_calls.function.name); +if func == "sind" + args = jsondecode(response.tool_calls.function.arguments); + result = sind(args.x); + messages = addToolMessage(messages,id,func,"x="+result); + [txt, response] = generate(chat, messages); +else + % handle calls to unknown functions +end +``` + +The model then will use the function result to generate a more precise response: + +```shell +>> txt + +txt = + + "The sine of 30 degrees is approximately 0.5." +``` + +## Extracting structured information with the API + +Another useful application for defining functions is to extract structured information from some text. You can just pass a function with the output format that you would like the model to output and the information you want to extract. For example, consider the following piece of text: + +```matlab +patientReport = "Patient John Doe, a 45-year-old male, presented " + ... + "with a two-week history of persistent cough and fatigue. " + ... + "Chest X-ray revealed an abnormal shadow in the right lung." + ... + " A CT scan confirmed a 3cm mass in the right upper lobe," + ... + " suggestive of lung cancer. The patient has been referred " + ... + "for biopsy to confirm the diagnosis."; +``` + +If you want to extract information from this text, you can define a function as follows: +```matlab +f = openAIFunction("extractPatientData","Extracts data about a patient from a record"); +f = addParameter(f,"patientName",type="string",description="Name of the patient"); +f = addParameter(f,"patientAge",type="number",description="Age of the patient"); +f = addParameter(f,"patientSymptoms",type="string",description="Symptoms that the patient is having."); +``` + +Note that this function does not need to exist, since it will only be used to extract the Name, Age and Symptoms of the patient and it does not need to be called: + +```matlab +chat = openAIChat("You are helpful assistant that reads patient records and extracts information", ... + Tools=f); +messages = messageHistory; +messages = addUserMessage(messages,"Extract the information from the report:" + newline + patientReport); +[txt, response] = generate(chat, messages); +``` + +The model should return the extracted information as a function call: +```shell +>> response + +response = + + struct with fields: + + role: 'assistant' + content: [] + tool_call: [1×1 struct] + +>> response.tool_calls + +ans = + + struct with fields: + + id: 'call_4VRtN7jb3pTPosMSb4ZaLoWP' + type: 'function' + function: [1×1 struct] + +>> response.tool_calls.function + +ans = + + struct with fields: + + name: 'extractPatientData' + arguments: '{↵ "patientName": "John Doe",↵ "patientAge": 45,↵ "patientSymptoms": "persistent cough, fatigue"↵}' +``` + +You can extract the arguments and write the data to a table, for example. + +## Understanding the content of an image + +You can use gpt-4-turbo to experiment with image understanding. +```matlab +chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-turbo"); +image_path = "peppers.png"; +messages = messageHistory; +messages = addUserMessageWithImages(messages,"What is in the image?",image_path); +[txt,response] = generate(chat,messages,MaxNumTokens=4096); +% Should output the description of the image +``` + +## Obtaining embeddings + +You can extract embeddings from your text with OpenAI using the function `extractOpenAIEmbeddings` as follows: +```matlab +exampleText = "Here is an example!"; +emb = extractOpenAIEmbeddings(exampleText); +``` + +The resulting embedding is a vector that captures the semantics of your text and can be used on tasks such as retrieval augmented generation and clustering. + +```matlab +>> size(emb) + +ans = + + 1 1536 +``` +## Getting Started with Images API + +To get started, you can either create an `openAIImages` object and use its methods or use it in a more complex setup, as needed. + +```matlab +mdl = openAIImages(ModelName="dall-e-3"); +images = generate(mdl,"Create a 3D avatar of a whimsical sushi on the beach. He is decorated with various sushi elements and is playfully interacting with the beach environment."); +figure +imshow(images{1}) +% Should output an image based on the prompt +``` + diff --git a/examples/AnalyzeTextDataUsingParallelFunctionCallwithChatGPT.mlx b/examples/AnalyzeTextDataUsingParallelFunctionCallwithChatGPT.mlx index 63d6d76..8cfd809 100644 Binary files a/examples/AnalyzeTextDataUsingParallelFunctionCallwithChatGPT.mlx and b/examples/AnalyzeTextDataUsingParallelFunctionCallwithChatGPT.mlx differ diff --git a/examples/CreateSimpleChatBot.mlx b/examples/CreateSimpleChatBot.mlx index b83f547..b3eb66b 100644 Binary files a/examples/CreateSimpleChatBot.mlx and b/examples/CreateSimpleChatBot.mlx differ diff --git a/examples/CreateSimpleOllamaChatBot.mlx b/examples/CreateSimpleOllamaChatBot.mlx new file mode 100644 index 0000000..64f9b31 Binary files /dev/null and b/examples/CreateSimpleOllamaChatBot.mlx differ diff --git a/examples/DescribeImagesUsingChatGPT.mlx b/examples/DescribeImagesUsingChatGPT.mlx index b55f4a2..e6434f2 100644 Binary files a/examples/DescribeImagesUsingChatGPT.mlx and b/examples/DescribeImagesUsingChatGPT.mlx differ diff --git a/examples/ProcessGeneratedTextInRealTimeByUsingOllamaInStreamingMode.mlx b/examples/ProcessGeneratedTextInRealTimeByUsingOllamaInStreamingMode.mlx new file mode 100644 index 0000000..6725f6e Binary files /dev/null and b/examples/ProcessGeneratedTextInRealTimeByUsingOllamaInStreamingMode.mlx differ diff --git a/examples/UsingDALLEToGenerateImages.mlx b/examples/UsingDALLEToGenerateImages.mlx index 5c04990..5fe2397 100644 Binary files a/examples/UsingDALLEToGenerateImages.mlx and b/examples/UsingDALLEToGenerateImages.mlx differ diff --git a/extractOpenAIEmbeddings.m b/extractOpenAIEmbeddings.m index 4be564c..6813e0a 100644 --- a/extractOpenAIEmbeddings.m +++ b/extractOpenAIEmbeddings.m @@ -9,7 +9,7 @@ % % 'ModelName' - The ID of the model to use. % -% 'ApiKey' - OpenAI API token. It can also be specified by +% 'APIKey' - OpenAI API token. It can also be specified by % setting the environment variable OPENAI_API_KEY % % 'TimeOut' - Connection Timeout in seconds (default: 10 secs) @@ -28,12 +28,12 @@ "text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002" nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 nvp.Dimensions (1,1) {mustBeInteger,mustBePositive} - nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar} + nvp.APIKey {llms.utils.mustBeNonzeroLengthTextScalar} end END_POINT = "https://api.openai.com/v1/embeddings"; -key = llms.internal.getApiKeyFromNvpOrEnv(nvp); +key = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY"); parameters = struct("input",text,"model",nvp.ModelName); diff --git a/functionSignatures.json b/functionSignatures.json index a786406..fcafc31 100644 --- a/functionSignatures.json +++ b/functionSignatures.json @@ -1,43 +1,121 @@ { - "_schemaVersion": "1.0.0", - "openAIChat.openAIChat": - { - "inputs": - [ - {"name":"systemPrompt","kind":"ordered","type":["string","scalar"]}, - {"name":"Tools","kind":"namevalue","type":"openAIFunction"}, - {"name":"ModelName","kind":"namevalue","type":"choices=llms.openai.models"}, - {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, - {"name":"TopProbabilityMass","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, - {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, - {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, - {"name":"ApiKey","kind":"namevalue","type":["string","scalar"]}, - {"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, - {"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, - {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, - {"name":"StreamFun","kind":"namevalue","type":"function_handle"} - ], - "outputs": - [ - {"name":"this","type":"openAIChat"} - ] - }, - "openAIChat.generate": - { - "inputs": - [ - {"name":"this","kind":"required","type":["openAIChat","scalar"]}, - {"name":"messages","kind":"required","type":[["openAIMessages","row"],["string","scalar"]]}, - {"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]}, - {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, - {"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"}, - {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} - ], - "outputs": - [ - {"name":"text","type":"string"}, - {"name":"message","type":"struct"}, - {"name":"response","type":"matlab.net.http.ResponseMessage"} - ] - } + "_schemaVersion": "1.0.0", + "openAIChat.openAIChat": + { + "inputs": + [ + {"name":"systemPrompt","kind":"ordered","type":["string","scalar"]}, + {"name":"Tools","kind":"namevalue","type":"openAIFunction"}, + {"name":"ModelName","kind":"namevalue","type":"choices=llms.openai.models"}, + {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, + {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, + {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, + {"name":"APIKey","kind":"namevalue","type":["string","scalar"]}, + {"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, + {"name":"StreamFun","kind":"namevalue","type":"function_handle"} + ], + "outputs": + [ + {"name":"this","type":"openAIChat"} + ] + }, + "openAIChat.generate": + { + "inputs": + [ + {"name":"this","kind":"required","type":["openAIChat","scalar"]}, + {"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]}, + {"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]}, + {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, + {"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"}, + {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} + ], + "outputs": + [ + {"name":"text","type":"string"}, + {"name":"message","type":"struct"}, + {"name":"response","type":"matlab.net.http.ResponseMessage"} + ] + }, + "azureChat.azureChat": + { + "inputs": + [ + {"name":"systemPrompt","kind":"ordered","type":["string","scalar"]}, + {"name":"Endpoint","kind":"namevalue","type":["string","scalar"]}, + {"name":"Deployment","kind":"namevalue","type":["string","scalar"]}, + {"name":"APIKey","kind":"namevalue","type":["string","scalar"]}, + {"name":"Tools","kind":"namevalue","type":"openAIFunction"}, + {"name":"APIVersion","kind":"namevalue","type":"choices=llms.azure.apiVersions"}, + {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, + {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, + {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, + {"name":"PresencePenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"FrequencyPenalty","kind":"namevalue","type":["numeric","scalar","<=2",">=-2"]}, + {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, + {"name":"StreamFun","kind":"namevalue","type":"function_handle"} + ], + "outputs": + [ + {"name":"this","type":"azureChat"} + ] + }, + "azureChat.generate": + { + "inputs": + [ + {"name":"this","kind":"required","type":["azureChat","scalar"]}, + {"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]}, + {"name":"NumCompletions","kind":"namevalue","type":["numeric","scalar","integer","positive"]}, + {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, + {"name":"ToolChoice","kind":"namevalue","type":"choices=[\"none\",\"auto\",this.FunctionNames]"}, + {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} + ], + "outputs": + [ + {"name":"text","type":"string"}, + {"name":"message","type":"struct"}, + {"name":"response","type":"matlab.net.http.ResponseMessage"} + ] + }, + "ollamaChat.ollamaChat": + { + "inputs": + [ + {"name":"model","kind":"positional","type":"choices=ollamaChat.models"}, + {"name":"systemPrompt","kind":"ordered","type":["string","scalar"]}, + {"name":"Temperature","kind":"namevalue","type":["numeric","scalar",">=0","<=2"]}, + {"name":"TopP","kind":"namevalue","type":["numeric","scalar",">=0","<=1"]}, + {"name":"TopK","kind":"namevalue","type":["numeric","scalar","integer",">=1"]}, + {"name":"StopSequences","kind":"namevalue","type":["string","vector"]}, + {"name":"ResponseFormat","kind":"namevalue","type":"choices={'text','json'}"}, + {"name":"TailFreeSamplingZ","kind":"namevalue","type":["numeric","scalar","real"]}, + {"name":"TimeOut","kind":"namevalue","type":["numeric","scalar","real","positive"]}, + {"name":"StreamFun","kind":"namevalue","type":"function_handle"} + ], + "outputs": + [ + {"name":"this","type":"ollamaChat"} + ] + }, + "ollamaChat.generate": + { + "inputs": + [ + {"name":"this","kind":"required","type":["ollamaChat","scalar"]}, + {"name":"messages","kind":"required","type":[["messageHistory","row"],["string","scalar"]]}, + {"name":"MaxNumTokens","kind":"namevalue","type":["numeric","scalar","positive"]}, + {"name":"Seed","kind":"namevalue","type":["numeric","integer","scalar"]} + ], + "outputs": + [ + {"name":"text","type":"string"}, + {"name":"message","type":"struct"}, + {"name":"response","type":"matlab.net.http.ResponseMessage"} + ] + } } diff --git a/messageHistory.m b/messageHistory.m new file mode 100644 index 0000000..74c5e63 --- /dev/null +++ b/messageHistory.m @@ -0,0 +1,323 @@ +classdef (Sealed) messageHistory + %messageHistory - Create an object to manage and store messages in a conversation. + % messages = messageHistory creates a messageHistory object. + % + % messageHistory functions: + % addSystemMessage - Add system message. + % addUserMessage - Add user message. + % addUserMessageWithImages - Add user message with images for + % GPT-4 Turbo with Vision. + % addToolMessage - Add a tool message. + % addResponseMessage - Add a response message. + % removeMessage - Remove message from history. + % + % messageHistory properties: + % Messages - Messages in the conversation history. + + % Copyright 2023-2024 The MathWorks, Inc. + + properties(SetAccess=private) + %MESSAGES - Messages in the conversation history. + Messages = {} + end + + methods + function this = addSystemMessage(this, name, content) + %addSystemMessage Add system message. + % + % MESSAGES = addSystemMessage(MESSAGES, NAME, CONTENT) adds a system + % message with the specified name and content. NAME and CONTENT + % must be text scalars. + % + % Example: + % % Create messages object + % messages = messageHistory; + % + % % Add system messages to provide examples of the conversation + % messages = addSystemMessage(messages, "example_user", "Hello, how are you?"); + % messages = addSystemMessage(messages, "example_assistant", "Olá, como vai?"); + % messages = addSystemMessage(messages, "example_user", "The sky is beautiful today"); + % messages = addSystemMessage(messages, "example_assistant", "O céu está lindo hoje."); + + arguments + this (1,1) messageHistory + name {mustBeNonzeroLengthTextScalar} + content {mustBeNonzeroLengthTextScalar} + end + + newMessage = struct("role", "system", "name", string(name), "content", string(content)); + this.Messages{end+1} = newMessage; + end + + function this = addUserMessage(this, content) + %addUserMessage Add user message. + % + % MESSAGES = addUserMessage(MESSAGES, CONTENT) adds a user message + % with the specified content to MESSAGES. CONTENT must be a text scalar. + % + % Example: + % % Create messages object + % messages = messageHistory; + % + % % Add user message + % messages = addUserMessage(messages, "Where is Natick located?"); + + arguments + this (1,1) messageHistory + content {mustBeNonzeroLengthTextScalar} + end + + newMessage = struct("role", "user", "content", string(content)); + this.Messages{end+1} = newMessage; + end + + function this = addUserMessageWithImages(this, content, images, nvp) + %addUserMessageWithImages Add user message with images + % + % MESSAGES = addUserMessageWithImages(MESSAGES, CONTENT, IMAGES) + % adds a user message with the specified content and images + % to MESSAGES. CONTENT must be a text scalar. IMAGES must be + % a string array of image URLs or file paths. + % + % messages = addUserMessageWithImages(__,Detail="low"); + % specify how the model should process the images using + % "Detail" parameter. The default is "auto". + % - When set to "low", the model scales the image to 512x512 + % - When set to "high", the model scales the image to 512x512 + % and also creates detailed 512x512 crops of the image + % - When set to "auto", the models chooses which mode to use + % depending on the input image. + % + % Example: + % + % % Create a chat with GPT-4 Turbo with Vision + % chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-vision-preview"); + % + % % Create messages object + % messages = messageHistory; + % + % % Add user message with an image + % content = "What is in this picture?" + % images = "peppers.png" + % messages = addUserMessageWithImages(messages, content, images); + % + % % Generate a response + % [text, response] = generate(chat, messages, MaxNumTokens=300); + + arguments + this (1,1) messageHistory + content {mustBeNonzeroLengthTextScalar} + images (1,:) {mustBeNonzeroLengthText} + nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto" + end + + newMessage = struct("role", "user", "content", []); + newMessage.content = {struct("type","text","text",string(content))}; + for img = images(:).' + if startsWith(img,("https://"|"http://")) + s = struct( ... + "type","image_url", ... + "image_url",struct("url",img)); + else + [~,~,ext] = fileparts(img); + MIMEType = "data:image/" + erase(ext,".") + ";base64,"; + % Base64 encode the image using the given MIME type + fid = fopen(img); + im = fread(fid,'*uint8'); + fclose(fid); + b64 = matlab.net.base64encode(im); + s = struct( ... + "type","image_url", ... + "image_url",struct("url",MIMEType + b64)); + end + + s.image_url.detail = nvp.Detail; + + newMessage.content{end+1} = s; + this.Messages{end+1} = newMessage; + end + + end + + function this = addToolMessage(this, id, name, content) + %addToolMessage Add Tool message. + % + % MESSAGES = addToolMessage(MESSAGES, ID, NAME, CONTENT) + % adds a tool message with the specified id, name and content. + % ID, NAME and CONTENT must be text scalars. + % + % Example: + % % Create messages object + % messages = messageHistory; + % + % % Add function message, containing the result of + % % calling strcat("Hello", " World") + % messages = addToolMessage(messages, "call_123", "strcat", "Hello World"); + + arguments + this (1,1) messageHistory + id {mustBeNonzeroLengthTextScalar} + name {mustBeNonzeroLengthTextScalar} + content {mustBeNonzeroLengthTextScalar} + + end + + newMessage = struct("tool_call_id", id, "role", "tool", ... + "name", string(name), "content", string(content)); + this.Messages{end+1} = newMessage; + end + + function this = addResponseMessage(this, messageStruct) + %addResponseMessage Add response message. + % + % MESSAGES = addResponseMessage(MESSAGES, messageStruct) adds a response + % message with the specified messageStruct. The input + % messageStruct should be a struct with field 'role' and + % value 'assistant' and with field 'content'. This response + % can be obtained from calling the GENERATE function. + % + % Example: + % + % % Create a chat object + % chat = openAIChat("You are a helpful AI Assistant."); + % + % % Create messages object + % messages = messageHistory; + % + % % Add user message + % messages = addUserMessage(messages, "What is the capital of England?"); + % + % % Generate a response + % [text, response] = generate(chat, messages); + % + % % Add response to history + % messages = addResponseMessage(messages, response); + + arguments + this (1,1) messageHistory + messageStruct (1,1) struct + end + + if ~isfield(messageStruct, "role")||~isequal(messageStruct.role, "assistant")||~isfield(messageStruct, "content") + error("llms:mustBeAssistantCall",llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantCall")); + end + + % Assistant is asking for function call + if isfield(messageStruct, "tool_calls") + toolCalls = messageStruct.tool_calls; + validateAssistantWithToolCalls(toolCalls) + this = addAssistantMessage(this, messageStruct.content, toolCalls); + else + % Simple assistant response + validateRegularAssistant(messageStruct.content); + this = addAssistantMessage(this,messageStruct.content); + end + end + + function this = removeMessage(this, idx) + %removeMessage Remove message. + % + % MESSAGES = removeMessage(MESSAGES, IDX) removes a message at the specified + % index from MESSAGES. IDX must be a positive integer. + % + % Example: + % + % % Create messages object + % messages = messageHistory; + % + % % Add user messages + % messages = addUserMessage(messages, "What is the capital of England?"); + % messages = addUserMessage(messages, "What is the capital of Italy?"); + % + % % Remove the first message + % messages = removeMessage(messages,1); + + arguments + this (1,1) messageHistory + idx (1,1) {mustBeInteger, mustBePositive} + end + if idx>numel(this.Messages) + error("llms:mustBeValidIndex",llms.utils.errorMessageCatalog.getMessage("llms:mustBeValidIndex", string(numel(this.Messages)))); + end + this.Messages(idx) = []; + end + end + + methods(Access=private) + + function this = addAssistantMessage(this, content, toolCalls) + arguments + this (1,1) messageHistory + content string + toolCalls struct = [] + end + + if isempty(toolCalls) + % Default assistant response + newMessage = struct("role", "assistant", "content", content); + else + % tool_calls message + toolsStruct = repmat(struct("id",[],"type",[],"function",[]),size(toolCalls)); + for i = 1:numel(toolCalls) + toolsStruct(i).id = toolCalls(i).id; + toolsStruct(i).type = toolCalls(i).type; + toolsStruct(i).function = struct( ... + "name", toolCalls(i).function.name, ... + "arguments", toolCalls(i).function.arguments); + end + if numel(toolsStruct) > 1 + newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct); + else + newMessage = struct("role", "assistant", "content", content, "tool_calls", []); + newMessage.tool_calls = {toolsStruct}; + end + end + + if isempty(this.Messages) + this.Messages = {newMessage}; + else + this.Messages{end+1} = newMessage; + end + end + end +end + +function mustBeNonzeroLengthTextScalar(content) +mustBeNonzeroLengthText(content) +mustBeTextScalar(content) +end + +function validateRegularAssistant(content) +try + mustBeNonzeroLengthText(content) + mustBeTextScalar(content) +catch ME + error("llms:mustBeAssistantWithContent",llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithContent")) +end +end + +function validateAssistantWithToolCalls(toolCallStruct) +if ~(isstruct(toolCallStruct) && isfield(toolCallStruct, "id") && isfield(toolCallStruct, "function")) + error("llms:mustBeAssistantWithIdAndFunction", ... + llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithIdAndFunction")) +else + functionCallStruct = [toolCallStruct.function]; +end + +if ~isfield(functionCallStruct, "name")||~isfield(functionCallStruct, "arguments") + error("llms:mustBeAssistantWithNameAndArguments", ... + llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithNameAndArguments")) +end + +try + for i = 1:numel(functionCallStruct) + mustBeNonzeroLengthText(functionCallStruct(i).name) + mustBeTextScalar(functionCallStruct(i).name) + mustBeNonzeroLengthText(functionCallStruct(i).arguments) + mustBeTextScalar(functionCallStruct(i).arguments) + end +catch ME + error("llms:assistantMustHaveTextNameAndArguments", ... + llms.utils.errorMessageCatalog.getMessage("llms:assistantMustHaveTextNameAndArguments")) +end +end diff --git a/ollamaChat.m b/ollamaChat.m new file mode 100644 index 0000000..2538038 --- /dev/null +++ b/ollamaChat.m @@ -0,0 +1,206 @@ +classdef (Sealed) ollamaChat < llms.internal.textGenerator +%ollamaChat Chat completion API from Ollama®. +% +% CHAT = ollamaChat(modelName) creates an ollamaChat object for the given model. +% +% CHAT = ollamaChat(__,systemPrompt) creates an ollamaChat object with the +% specified system prompt. +% +% CHAT = ollamaChat(__,Name=Value) specifies additional options +% using one or more name-value arguments: +% +% Temperature - Temperature value for controlling the randomness +% of the output. Default value depends on the model; +% if not specified in the model, defaults to 0.8. +% Higher values increase the randomness (in some +% sense, the “creativity”) of outputs, lower +% values reduce it. Setting Temperature=0 removes +% randomness from the output altogether. +% +% TopP - Top probability mass value for controlling the +% diversity of the output. Default value is 1; +% lower values imply that only the more likely +% words can appear in any particular place. +% This is also known as top-p sampling. +% +% TopK - Maximum number of most likely tokens that are +% considered for output. Default is Inf, allowing +% all tokens. Smaller values reduce diversity in +% the output. +% +% TailFreeSamplingZ - Reduce the use of less probable tokens, based on +% the second-order differences of ordered +% probabilities. Default value is 1, disabling +% tail-free sampling. Lower values reduce +% diversity, with some authors recommending +% values around 0.95. Tail-free sampling is +% slower than using TopP or TopK. +% +% StopSequences - Vector of strings that when encountered, will +% stop the generation of tokens. Default +% value is empty. +% Example: ["The end.", "And that's all she wrote."] +% +% +% ResponseFormat - The format of response the model returns. +% "text" (default) | "json" +% +% StreamFun - Function to callback when streaming the +% result. +% +% TimeOut - Connection Timeout in seconds. Default is 120. +% +% +% +% ollamaChat Functions: +% ollamaChat - Chat completion API using Ollama server. +% generate - Generate a response using the ollamaChat instance. +% +% ollamaChat Properties, in addition to the name-value pairs above: +% Model - Model name (as expected by Ollama server). +% +% SystemPrompt - System prompt. + +% Copyright 2024 The MathWorks, Inc. + + properties + Model (1,1) string + TopK (1,1) {mustBeReal,mustBePositive} = Inf + TailFreeSamplingZ (1,1) {mustBeReal} = 1 + end + + methods + function this = ollamaChat(modelName, systemPrompt, nvp) + arguments + modelName {mustBeTextScalar} + systemPrompt {llms.utils.mustBeTextOrEmpty} = [] + nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 + nvp.TopP {llms.utils.mustBeValidTopP} = 1 + nvp.TopK (1,1) {mustBeReal,mustBePositive} = Inf + nvp.StopSequences {llms.utils.mustBeValidStop} = {} + nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" + nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 120 + nvp.TailFreeSamplingZ (1,1) {mustBeReal} = 1 + nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} + end + + if isfield(nvp,"StreamFun") + this.StreamFun = nvp.StreamFun; + else + this.StreamFun = []; + end + + if ~isempty(systemPrompt) + systemPrompt = string(systemPrompt); + if ~(strlength(systemPrompt)==0) + this.SystemPrompt = {struct("role", "system", "content", systemPrompt)}; + end + end + + this.Model = modelName; + this.ResponseFormat = nvp.ResponseFormat; + this.Temperature = nvp.Temperature; + this.TopP = nvp.TopP; + this.TopK = nvp.TopK; + this.TailFreeSamplingZ = nvp.TailFreeSamplingZ; + this.StopSequences = nvp.StopSequences; + this.TimeOut = nvp.TimeOut; + end + + function [text, message, response] = generate(this, messages, nvp) + %generate Generate a response using the ollamaChat instance. + % + % [TEXT, MESSAGE, RESPONSE] = generate(CHAT, MESSAGES) generates a response + % with the specified MESSAGES. + % + % [TEXT, MESSAGE, RESPONSE] = generate(__, Name=Value) specifies additional options + % using one or more name-value arguments: + % + % MaxNumTokens - Maximum number of tokens in the generated response. + % Default value is inf. + % + % Seed - An integer value to use to obtain + % reproducible responses + + arguments + this (1,1) ollamaChat + messages {mustBeValidMsgs} + nvp.MaxNumTokens (1,1) {mustBePositive} = inf + nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = [] + end + + messages = convertCharsToStrings(messages); + if isstring(messages) && isscalar(messages) + messagesStruct = {struct("role", "user", "content", messages)}; + else + messagesStruct = messages.Messages; + end + + if ~isempty(this.SystemPrompt) + messagesStruct = horzcat(this.SystemPrompt, messagesStruct); + end + + [text, message, response] = llms.internal.callOllamaChatAPI(... + this.Model, messagesStruct, ... + Temperature=this.Temperature, ... + TopP=this.TopP, TopK=this.TopK,... + TailFreeSamplingZ=this.TailFreeSamplingZ,... + StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... + ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... + TimeOut=this.TimeOut, StreamFun=this.StreamFun); + + if isfield(response.Body.Data,"error") + err = response.Body.Data.error; + error("llms:apiReturnedError",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err)); + end + end + end + + methods(Static) + function mdls = models + %ollamaChat.models - return models available on Ollama server + % MDLS = ollamaChat.models returns a string vector MDLS + % listing the models available on the local Ollama server. + % + % These names can be used in the ollamaChat constructor. + % For names with a colon, such as "phi:latest", it is + % possible to only use the part before the colon, i.e., + % "phi". + endpoint = "http://localhost:11434/api/tags"; + response = webread(endpoint); + mdls = string({response.models.name}).'; + baseMdls = unique(extractBefore(mdls,":latest")); + % remove all those "mistral:latest", iff those are the only + % model entries pointing at some model + for base=baseMdls.' + found = startsWith(mdls,base+":"); + if nnz(found) == 1 + mdls(found) = []; + end + end + mdls = unique([mdls(:); baseMdls]); + mdls(strlength(mdls) < 1) = []; + mdls(ismissing(mdls)) = []; + end + end +end + +function mustBeValidMsgs(value) +if isa(value, "messageHistory") + if numel(value.Messages) == 0 + error("llms:mustHaveMessages", llms.utils.errorMessageCatalog.getMessage("llms:mustHaveMessages")); + end +else + try + llms.utils.mustBeNonzeroLengthTextScalar(value); + catch ME + error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt")); + end +end +end + +function mustBeIntegerOrEmpty(value) + if ~isempty(value) + mustBeInteger(value) + end +end diff --git a/openAIChat.m b/openAIChat.m index c0d045e..6a46cce 100644 --- a/openAIChat.m +++ b/openAIChat.m @@ -1,35 +1,48 @@ -classdef(Sealed) openAIChat +classdef(Sealed) openAIChat < llms.internal.textGenerator & ... + llms.internal.gptPenalties & llms.internal.hasTools & llms.internal.needsAPIKey %openAIChat Chat completion API from OpenAI. % % CHAT = openAIChat(systemPrompt) creates an openAIChat object with the % specified system prompt. % -% CHAT = openAIChat(systemPrompt,ApiKey=key) uses the specified API key +% CHAT = openAIChat(systemPrompt,APIKey=key) uses the specified API key % % CHAT = openAIChat(systemPrompt,Name=Value) specifies additional options % using one or more name-value arguments: % -% Tools - Array of openAIFunction objects representing -% custom functions to be used during chat completions. -% % ModelName - Name of the model to use for chat completions. % The default value is "gpt-3.5-turbo". % % Temperature - Temperature value for controlling the randomness -% of the output. Default value is 1. +% of the output. Default value is 1; higher values +% increase the randomness (in some sense, +% the “creativity”) of outputs, lower values +% reduce it. Setting Temperature=0 removes +% randomness from the output altogether. +% +% TopP - Top probability mass value for controlling the +% diversity of the output. Default value is 1; +% lower values imply that only the more likely +% words can appear in any particular place. +% This is also known as top-p sampling. % -% TopProbabilityMass - Top probability mass value for controlling the -% diversity of the output. Default value is 1. +% Tools - Array of openAIFunction objects representing +% custom functions to be used during chat completions. % % StopSequences - Vector of strings that when encountered, will % stop the generation of tokens. Default % value is empty. +% Example: ["The end.", "And that's all she wrote."] % % PresencePenalty - Penalty value for using a token in the response % that has already been used. Default value is 0. +% Higher values reduce repetition of words in the output. % % FrequencyPenalty - Penalty value for using a token that is frequent -% in the training data. Default value is 0. +% in the output. Default value is 0. +% Higher values reduce repetition of words in the output. +% +% TimeOut - Connection Timeout in seconds. Default value is 10. % % StreamFun - Function to callback when streaming the % result @@ -46,7 +59,7 @@ % % Temperature - Temperature of generation. % -% TopProbabilityMass - Top probability mass to consider for generation. +% TopP - Top probability mass to consider for generation. % % StopSequences - Sequences to stop the generation of tokens. % @@ -61,67 +74,32 @@ % FunctionNames - Names of the functions that the model can % request calls. % -% ResponseFormat - Specifies the response format, text or json +% ResponseFormat - Specifies the response format, "text" or "json". % -% TimeOut - Connection Timeout in seconds (default: 10 secs) +% TimeOut - Connection Timeout in seconds. % % Copyright 2023-2024 The MathWorks, Inc. - properties - %TEMPERATURE Temperature of generation. - Temperature {mustBeValidTemperature} = 1 - - %TOPPROBABILITYMASS Top probability mass to consider for generation. - TopProbabilityMass {mustBeValidTopP} = 1 - - %STOPSEQUENCES Sequences to stop the generation of tokens. - StopSequences {mustBeValidStop} = {} - - %PRESENCEPENALTY Penalty for using a token in the response that has already been used. - PresencePenalty {mustBeValidPenalty} = 0 - - %FREQUENCYPENALTY Penalty for using a token that is frequent in the training data. - FrequencyPenalty {mustBeValidPenalty} = 0 - end - - properties(SetAccess=private) - %TIMEOUT Connection timeout in seconds (default 10 secs) - TimeOut - - %FUNCTIONNAMES Names of the functions that the model can request calls - FunctionNames - + properties(SetAccess=private) %MODELNAME Model name. ModelName - - %SYSTEMPROMPT System prompt. - SystemPrompt = [] - - %RESPONSEFORMAT Response format, "text" or "json" - ResponseFormat end - properties(Access=private) - Tools - FunctionsStruct - ApiKey - StreamFun - end methods - function this = openAIChat(systemPrompt, nvp) + function this = openAIChat(systemPrompt, nvp) arguments systemPrompt {llms.utils.mustBeTextOrEmpty} = [] nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty nvp.ModelName (1,1) string {mustBeModel} = "gpt-3.5-turbo" - nvp.Temperature {mustBeValidTemperature} = 1 - nvp.TopProbabilityMass {mustBeValidTopP} = 1 - nvp.StopSequences {mustBeValidStop} = {} + nvp.Temperature {llms.utils.mustBeValidTemperature} = 1 + nvp.TopP {llms.utils.mustBeValidTopP} = 1 + nvp.StopSequences {llms.utils.mustBeValidStop} = {} nvp.ResponseFormat (1,1) string {mustBeMember(nvp.ResponseFormat,["text","json"])} = "text" - nvp.ApiKey {mustBeNonzeroLengthTextScalar} - nvp.PresencePenalty {mustBeValidPenalty} = 0 - nvp.FrequencyPenalty {mustBeValidPenalty} = 0 + nvp.APIKey {mustBeNonzeroLengthTextScalar} + nvp.PresencePenalty {llms.utils.mustBeValidPenalty} = 0 + nvp.FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0 nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')} end @@ -140,7 +118,7 @@ this.Tools = nvp.Tools; [this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools); end - + if ~isempty(systemPrompt) systemPrompt = string(systemPrompt); if systemPrompt ~= "" @@ -150,7 +128,7 @@ this.ModelName = nvp.ModelName; this.Temperature = nvp.Temperature; - this.TopProbabilityMass = nvp.TopProbabilityMass; + this.TopP = nvp.TopP; this.StopSequences = nvp.StopSequences; % ResponseFormat is only supported in the latest models only @@ -159,7 +137,7 @@ this.PresencePenalty = nvp.PresencePenalty; this.FrequencyPenalty = nvp.FrequencyPenalty; - this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp); + this.APIKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY"); this.TimeOut = nvp.TimeOut; end @@ -178,19 +156,19 @@ % MaxNumTokens - Maximum number of tokens in the generated response. % Default value is inf. % - % ToolChoice - Function to execute. 'none', 'auto', + % ToolChoice - Function to execute. 'none', 'auto', % or specify the function to call. % % Seed - An integer value to use to obtain % reproducible responses - % - % Currently, GPT-4 Turbo with vision does not support the message.name + % + % Currently, GPT-4 Turbo with vision does not support the message.name % parameter, functions/tools, response_format parameter, and stop % sequences. It also has a low MaxNumTokens default, which can be overridden. arguments this (1,1) openAIChat - messages (1,1) {mustBeValidMsgs} + messages {mustBeValidMsgs} nvp.NumCompletions (1,1) {mustBePositive, mustBeInteger} = 1 nvp.MaxNumTokens (1,1) {mustBePositive} = inf nvp.ToolChoice {mustBeValidFunctionCall(this, nvp.ToolChoice)} = [] @@ -199,8 +177,9 @@ toolChoice = convertToolChoice(this, nvp.ToolChoice); + messages = convertCharsToStrings(messages); if isstring(messages) && isscalar(messages) - messagesStruct = {struct("role", "user", "content", messages)}; + messagesStruct = {struct("role", "user", "content", messages)}; else messagesStruct = messages.Messages; end @@ -210,21 +189,19 @@ if ~isempty(this.SystemPrompt) messagesStruct = horzcat(this.SystemPrompt, messagesStruct); end - + [text, message, response] = llms.internal.callOpenAIChatAPI(messagesStruct, this.FunctionsStruct,... ModelName=this.ModelName, ToolChoice=toolChoice, Temperature=this.Temperature, ... - TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,... + TopP=this.TopP, NumCompletions=nvp.NumCompletions,... StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ... PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ... ResponseFormat=this.ResponseFormat,Seed=nvp.Seed, ... - ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); + APIKey=this.APIKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun); if isfield(response.Body.Data,"error") err = response.Body.Data.error.message; - text = llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err); - message = struct("role","assistant","content",text); + error("llms:apiReturnedError",llms.utils.errorMessageCatalog.getMessage("llms:apiReturnedError",err)); end - end end @@ -243,7 +220,7 @@ function mustBeValidFunctionCall(this, functionCall) % if toolChoice is empty if isempty(toolChoice) % if Tools is not empty, the default is 'auto'. - if ~isempty(this.Tools) + if ~isempty(this.Tools) toolChoice = "auto"; end elseif ~ismember(toolChoice,["auto","none"]) @@ -274,12 +251,12 @@ function mustBeNonzeroLengthTextScalar(content) end function mustBeValidMsgs(value) -if isa(value, "openAIMessages") - if numel(value.Messages) == 0 +if isa(value, "messageHistory") + if numel(value.Messages) == 0 error("llms:mustHaveMessages", llms.utils.errorMessageCatalog.getMessage("llms:mustHaveMessages")); end else - try + try llms.utils.mustBeNonzeroLengthTextScalar(value); catch ME error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt")); @@ -287,29 +264,6 @@ function mustBeValidMsgs(value) end end -function mustBeValidPenalty(value) -validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonsparse', '<=', 2, '>=', -2}) -end - -function mustBeValidTopP(value) -validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 1}) -end - -function mustBeValidTemperature(value) -validateattributes(value, {'numeric'}, {'real', 'scalar', 'nonnegative', 'nonsparse', '<=', 2}) -end - -function mustBeValidStop(value) -if ~isempty(value) - mustBeVector(value); - mustBeNonzeroLengthText(value); - % This restriction is set by the OpenAI API - if numel(value)>4 - error("llms:stopSequencesMustHaveMax4Elements", llms.utils.errorMessageCatalog.getMessage("llms:stopSequencesMustHaveMax4Elements")); - end -end -end - function mustBeIntegerOrEmpty(value) if ~isempty(value) mustBeInteger(value) diff --git a/openAIImages.m b/openAIImages.m index 91229db..c963c38 100644 --- a/openAIImages.m +++ b/openAIImages.m @@ -1,4 +1,4 @@ -classdef openAIImages +classdef openAIImages < llms.internal.needsAPIKey %openAIImages Connect to Images API from OpenAI. % % MDL = openAIImages creates an openAIImages object with dall-e-2 @@ -8,7 +8,7 @@ % ModelName - Name of the model to use for image generation. % "dall-e-2" (default) or "dall-e-3". % -% MDL = openAIImages(ModelName, ApiKey=key) uses the specified API key +% MDL = openAIImages(ModelName, APIKey=key) uses the specified API key % % MDL = openAIImages(__, Name=Value) specifies additional options % using one or more name-value arguments: @@ -30,7 +30,7 @@ % Copyright 2024 The MathWorks, Inc. - properties(SetAccess=private) + properties(SetAccess=private) %ModelName Model name. ModelName @@ -38,50 +38,46 @@ TimeOut end - properties (Access=private) - ApiKey - end - methods function this = openAIImages(nvp) arguments nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["dall-e-2", "dall-e-3"])} = "dall-e-2" - nvp.ApiKey {mustBeNonzeroLengthTextScalar} + nvp.APIKey {mustBeNonzeroLengthTextScalar} nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 end this.ModelName = nvp.ModelName; - this.ApiKey = llms.internal.getApiKeyFromNvpOrEnv(nvp); + this.APIKey = llms.internal.getApiKeyFromNvpOrEnv(nvp,"OPENAI_API_KEY"); this.TimeOut = nvp.TimeOut; end function [images, response] = generate(this,prompt,nvp) %generate Generate images using the openAIImages instance - % + % % [IMAGES, RESPONSE] = generate(MDL, PROMPT) generates images % with the specified prompt. The PROMPT should be a text description - % of the desired image(s). + % of the desired image(s). % % [IMAGES, RESPONSE] = generate(__, Name=Value) specifies % additional options. - % - % NumImages - Number of images to generate. - % Default value is 1. - % For "dall-e-3" only 1 output is supported. % - % Size - Size of the generated images. + % NumImages - Number of images to generate. + % Default value is 1. + % For "dall-e-3" only 1 output is supported. + % + % Size - Size of the generated images. % Defaults to 1024x1024 - % "dall-e-2" supports 256x256, + % "dall-e-2" supports 256x256, % 512x512, or 1024x1024. - % "dall-e-3" supports 1024x1024, + % "dall-e-3" supports 1024x1024, % 1792x1024, or 1024x1792 % - % Quality - Quality of the images to generate. - % "standard" (default) or "hd". + % Quality - Quality of the images to generate. + % "standard" (default) or "hd". % Only "dall-e-3" supports this parameter. % - % Style - The style of the generated images. - % "vivid" (default) or "natural". + % Style - The style of the generated images. + % "vivid" (default) or "natural". % Only "dall-e-3" supports this parameter. arguments @@ -108,12 +104,12 @@ if this.ModelName=="dall-e-2" % dall-e-3 only params - if isfield(nvp, "Quality") + if isfield(nvp, "Quality") error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", ... "Quality", this.ModelName)); end - if isfield(nvp, "Style") + if isfield(nvp, "Style") error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", ... "Style", this.ModelName)); @@ -143,17 +139,17 @@ function [images, response] = edit(this,imagePath,prompt,nvp) %edit Generate an edited or extended image from a given image and prompt - % - % [IMAGES, RESPONSE] = edit(MDL, IMAGEPATH, PROMPT) + % + % [IMAGES, RESPONSE] = edit(MDL, IMAGEPATH, PROMPT) % generates new images from an original image and prompt % - % imagePath - The path to the source image file. - % Must be a valid PNG file, less than 4MB, - % and square. If mask is not provided, - % image must have transparency, which + % imagePath - The path to the source image file. + % Must be a valid PNG file, less than 4MB, + % and square. If mask is not provided, + % image must have transparency, which % will be used as the mask. % - % prompt - A text description of the desired image(s). + % prompt - A text description of the desired image(s). % The maximum length: 1000 characters % % [IMAGES, RESPONSE] = edit(__, Name=Value) specifies @@ -161,16 +157,16 @@ % % MaskImagePath - The path to the image file whose % fully transparent area indicates - % where the source image should be edited. - % Must be a valid PNG file, less than 4MB, + % where the source image should be edited. + % Must be a valid PNG file, less than 4MB, % and have the same dimensions as - % source image. - % - % NumImages - Number of images to generate. - % Default value is 1. The max is 10. + % source image. + % + % NumImages - Number of images to generate. + % Default value is 1. The max is 10. % - % Size - Size of the generated images. - % Must be one of 256x256, 512x512, or + % Size - Size of the generated images. + % Must be one of 256x256, 512x512, or % 1024x1024 (default) arguments @@ -184,7 +180,7 @@ ["256x256", "512x512","1024x1024"]), ... mustBeValidSize(this,nvp.Size)} = "1024x1024" end - + % For now, this is only supported for "dall-e-2" if this.ModelName~="dall-e-2" error("llms:functionNotAvailableForModel", ... @@ -218,22 +214,22 @@ function [images, response] = createVariation(this,imagePath,nvp) %createVariation Generate variations from a given image - % + % % [IMAGES, RESPONSE] = createVariation(MDL, IMAGEPATH) generates new images % from an original image % - % imagePath - The path to the source image file. - % Must be a valid PNG file, less than 4MB, + % imagePath - The path to the source image file. + % Must be a valid PNG file, less than 4MB, % and square. % % [IMAGES, RESPONSE] = createVariation(__, Name=Value) specifies % additional options. - % - % NumImages - Number of images to generate. - % Default value is 1. The max is 10. % - % Size - Size of the generated images. - % Must be one of "256x256", "512x512", or + % NumImages - Number of images to generate. + % Default value is 1. The max is 10. + % + % Size - Size of the generated images. + % Must be one of "256x256", "512x512", or % "1024x1024" (default) arguments @@ -269,7 +265,7 @@ function response = sendRequest(this, endpoint, body) %sendRequest send request to the given endpoint, return response - headers = matlab.net.http.HeaderField('Authorization', "Bearer " + this.ApiKey); + headers = matlab.net.http.HeaderField('Authorization', "Bearer " + this.APIKey); if isa(body,'struct') headers(2) = matlab.net.http.HeaderField('Content-Type', 'application/json'); end diff --git a/openAIMessages.m b/openAIMessages.m index 4aff1af..4008045 100644 --- a/openAIMessages.m +++ b/openAIMessages.m @@ -1,323 +1,10 @@ -classdef (Sealed) openAIMessages - %openAIMessages - Create an object to manage and store messages in a conversation. - % messages = openAIMessages creates an openAIMessages object. - % - % openAIMessages functions: - % addSystemMessage - Add system message. - % addUserMessage - Add user message. - % addUserMessageWithImages - Add user message with images for - % GPT-4 Turbo with Vision. - % addToolMessage - Add a tool message. - % addResponseMessage - Add a response message. - % removeMessage - Remove message from history. - % - % openAIMessages properties: - % Messages - Messages in the conversation history. +function msgs = openAIMessages +%openAIMessages - backward compatibility function +% +% This function only exists for backward compatibility and will be removed +% at some time in the future. Please use messageHistory instead. - % Copyright 2023-2024 The MathWorks, Inc. +% Copyright 2024 The MathWorks, Inc. - properties(SetAccess=private) - %MESSAGES - Messages in the conversation history. - Messages = {} - end - - methods - function this = addSystemMessage(this, name, content) - %addSystemMessage Add system message. - % - % MESSAGES = addSystemMessage(MESSAGES, NAME, CONTENT) adds a system - % message with the specified name and content. NAME and CONTENT - % must be text scalars. - % - % Example: - % % Create messages object - % messages = openAIMessages; - % - % % Add system messages to provide examples of the conversation - % messages = addSystemMessage(messages, "example_user", "Hello, how are you?"); - % messages = addSystemMessage(messages, "example_assistant", "Olá, como vai?"); - % messages = addSystemMessage(messages, "example_user", "The sky is beautiful today"); - % messages = addSystemMessage(messages, "example_assistant", "O céu está lindo hoje."); - - arguments - this (1,1) openAIMessages - name {mustBeNonzeroLengthTextScalar} - content {mustBeNonzeroLengthTextScalar} - end - - newMessage = struct("role", "system", "name", string(name), "content", string(content)); - this.Messages{end+1} = newMessage; - end - - function this = addUserMessage(this, content) - %addUserMessage Add user message. - % - % MESSAGES = addUserMessage(MESSAGES, CONTENT) adds a user message - % with the specified content to MESSAGES. CONTENT must be a text scalar. - % - % Example: - % % Create messages object - % messages = openAIMessages; - % - % % Add user message - % messages = addUserMessage(messages, "Where is Natick located?"); - - arguments - this (1,1) openAIMessages - content {mustBeNonzeroLengthTextScalar} - end - - newMessage = struct("role", "user", "content", string(content)); - this.Messages{end+1} = newMessage; - end - - function this = addUserMessageWithImages(this, content, images, nvp) - %addUserMessageWithImages Add user message with images - % - % MESSAGES = addUserMessageWithImages(MESSAGES, CONTENT, IMAGES) - % adds a user message with the specified content and images - % to MESSAGES. CONTENT must be a text scalar. IMAGES must be - % a string array of image URLs or file paths. - % - % messages = addUserMessageWithImages(__,Detail="low"); - % specify how the model should process the images using - % "Detail" parameter. The default is "auto". - % - When set to "low", the model scales the image to 512x512 - % - When set to "high", the model scales the image to 512x512 - % and also creates detailed 512x512 crops of the image - % - When set to "auto", the models chooses which mode to use - % depending on the input image. - % - % Example: - % - % % Create a chat with GPT-4 Turbo with Vision - % chat = openAIChat("You are an AI assistant.", ModelName="gpt-4-vision-preview"); - % - % % Create messages object - % messages = openAIMessages; - % - % % Add user message with an image - % content = "What is in this picture?" - % images = "peppers.png" - % messages = addUserMessageWithImages(messages, content, images); - % - % % Generate a response - % [text, response] = generate(chat, messages, MaxNumTokens=300); - - arguments - this (1,1) openAIMessages - content {mustBeNonzeroLengthTextScalar} - images (1,:) {mustBeNonzeroLengthText} - nvp.Detail string {mustBeMember(nvp.Detail,["low","high","auto"])} = "auto" - end - - newMessage = struct("role", "user", "content", []); - newMessage.content = {struct("type","text","text",string(content))}; - for img = images(:).' - if startsWith(img,("https://"|"http://")) - s = struct( ... - "type","image_url", ... - "image_url",struct("url",img)); - else - [~,~,ext] = fileparts(img); - MIMEType = "data:image/" + erase(ext,".") + ";base64,"; - % Base64 encode the image using the given MIME type - fid = fopen(img); - im = fread(fid,'*uint8'); - fclose(fid); - b64 = matlab.net.base64encode(im); - s = struct( ... - "type","image_url", ... - "image_url",struct("url",MIMEType + b64)); - end - - s.image_url.detail = nvp.Detail; - - newMessage.content{end+1} = s; - this.Messages{end+1} = newMessage; - end - - end - - function this = addToolMessage(this, id, name, content) - %addToolMessage Add Tool message. - % - % MESSAGES = addToolMessage(MESSAGES, ID, NAME, CONTENT) - % adds a tool message with the specified id, name and content. - % ID, NAME and CONTENT must be text scalars. - % - % Example: - % % Create messages object - % messages = openAIMessages; - % - % % Add function message, containing the result of - % % calling strcat("Hello", " World") - % messages = addToolMessage(messages, "call_123", "strcat", "Hello World"); - - arguments - this (1,1) openAIMessages - id {mustBeNonzeroLengthTextScalar} - name {mustBeNonzeroLengthTextScalar} - content {mustBeNonzeroLengthTextScalar} - - end - - newMessage = struct("tool_call_id", id, "role", "tool", ... - "name", string(name), "content", string(content)); - this.Messages{end+1} = newMessage; - end - - function this = addResponseMessage(this, messageStruct) - %addResponseMessage Add response message. - % - % MESSAGES = addResponseMessage(MESSAGES, messageStruct) adds a response - % message with the specified messageStruct. The input - % messageStruct should be a struct with field 'role' and - % value 'assistant' and with field 'content'. This response - % can be obtained from calling the GENERATE function. - % - % Example: - % - % % Create a chat object - % chat = openAIChat("You are a helpful AI Assistant."); - % - % % Create messages object - % messages = openAIMessages; - % - % % Add user message - % messages = addUserMessage(messages, "What is the capital of England?"); - % - % % Generate a response - % [text, response] = generate(chat, messages); - % - % % Add response to history - % messages = addResponseMessage(messages, response); - - arguments - this (1,1) openAIMessages - messageStruct (1,1) struct - end - - if ~isfield(messageStruct, "role")||~isequal(messageStruct.role, "assistant")||~isfield(messageStruct, "content") - error("llms:mustBeAssistantCall",llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantCall")); - end - - % Assistant is asking for function call - if isfield(messageStruct, "tool_calls") - toolCalls = messageStruct.tool_calls; - validateAssistantWithToolCalls(toolCalls) - this = addAssistantMessage(this, messageStruct.content, toolCalls); - else - % Simple assistant response - validateRegularAssistant(messageStruct.content); - this = addAssistantMessage(this,messageStruct.content); - end - end - - function this = removeMessage(this, idx) - %removeMessage Remove message. - % - % MESSAGES = removeMessage(MESSAGES, IDX) removes a message at the specified - % index from MESSAGES. IDX must be a positive integer. - % - % Example: - % - % % Create messages object - % messages = openAIMessages; - % - % % Add user messages - % messages = addUserMessage(messages, "What is the capital of England?"); - % messages = addUserMessage(messages, "What is the capital of Italy?"); - % - % % Remove the first message - % messages = removeMessage(messages,1); - - arguments - this (1,1) openAIMessages - idx (1,1) {mustBeInteger, mustBePositive} - end - if idx>numel(this.Messages) - error("llms:mustBeValidIndex",llms.utils.errorMessageCatalog.getMessage("llms:mustBeValidIndex", string(numel(this.Messages)))); - end - this.Messages(idx) = []; - end - end - - methods(Access=private) - - function this = addAssistantMessage(this, content, toolCalls) - arguments - this (1,1) openAIMessages - content string - toolCalls struct = [] - end - - if isempty(toolCalls) - % Default assistant response - newMessage = struct("role", "assistant", "content", content); - else - % tool_calls message - toolsStruct = repmat(struct("id",[],"type",[],"function",[]),size(toolCalls)); - for i = 1:numel(toolCalls) - toolsStruct(i).id = toolCalls(i).id; - toolsStruct(i).type = toolCalls(i).type; - toolsStruct(i).function = struct( ... - "name", toolCalls(i).function.name, ... - "arguments", toolCalls(i).function.arguments); - end - if numel(toolsStruct) > 1 - newMessage = struct("role", "assistant", "content", content, "tool_calls", toolsStruct); - else - newMessage = struct("role", "assistant", "content", content, "tool_calls", []); - newMessage.tool_calls = {toolsStruct}; - end - end - - if isempty(this.Messages) - this.Messages = {newMessage}; - else - this.Messages{end+1} = newMessage; - end - end - end -end - -function mustBeNonzeroLengthTextScalar(content) -mustBeNonzeroLengthText(content) -mustBeTextScalar(content) -end - -function validateRegularAssistant(content) -try - mustBeNonzeroLengthText(content) - mustBeTextScalar(content) -catch ME - error("llms:mustBeAssistantWithContent",llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithContent")) -end -end - -function validateAssistantWithToolCalls(toolCallStruct) -if ~(isstruct(toolCallStruct) && isfield(toolCallStruct, "id") && isfield(toolCallStruct, "function")) - error("llms:mustBeAssistantWithIdAndFunction", ... - llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithIdAndFunction")) -else - functionCallStruct = [toolCallStruct.function]; -end - -if ~isfield(functionCallStruct, "name")||~isfield(functionCallStruct, "arguments") - error("llms:mustBeAssistantWithNameAndArguments", ... - llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithNameAndArguments")) -end - -try - for i = 1:numel(functionCallStruct) - mustBeNonzeroLengthText(functionCallStruct(i).name) - mustBeTextScalar(functionCallStruct(i).name) - mustBeNonzeroLengthText(functionCallStruct(i).arguments) - mustBeTextScalar(functionCallStruct(i).arguments) - end -catch ME - error("llms:assistantMustHaveTextNameAndArguments", ... - llms.utils.errorMessageCatalog.getMessage("llms:assistantMustHaveTextNameAndArguments")) -end +msgs = messageHistory; end diff --git a/tests/tazureChat.m b/tests/tazureChat.m new file mode 100644 index 0000000..333b1c7 --- /dev/null +++ b/tests/tazureChat.m @@ -0,0 +1,448 @@ +classdef tazureChat < matlab.unittest.TestCase +% Tests for azureChat + +% Copyright 2024 The MathWorks, Inc. + + properties(TestParameter) + InvalidConstructorInput = iGetInvalidConstructorInput; + InvalidGenerateInput = iGetInvalidGenerateInput; + InvalidValuesSetters = iGetInvalidValuesSetters; + StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}}); + end + + methods(Test) + function constructChatWithAllNVP(testCase) + deploymentID = "hello"; + functions = openAIFunction("funName"); + temperature = 0; + topP = 1; + stop = ["[END]", "."]; + apiKey = "this-is-not-a-real-key"; + presenceP = -2; + frequenceP = 2; + systemPrompt = "This is a system prompt"; + timeout = 3; + chat = azureChat(systemPrompt, Deployment=deploymentID, Tools=functions, ... + Temperature=temperature, TopP=topP, StopSequences=stop, APIKey=apiKey,... + FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout); + testCase.verifyEqual(chat.Temperature, temperature); + testCase.verifyEqual(chat.TopP, topP); + testCase.verifyEqual(chat.StopSequences, stop); + testCase.verifyEqual(chat.FrequencyPenalty, frequenceP); + testCase.verifyEqual(chat.PresencePenalty, presenceP); + end + + function doGenerate(testCase,StringInputs) + testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); + chat = azureChat; + response = testCase.verifyWarningFree(@() generate(chat,StringInputs)); + testCase.verifyClass(response,'string'); + testCase.verifyGreaterThan(strlength(response),0); + end + + function generateMultipleResponses(testCase) + chat = azureChat; + [~,~,response] = generate(chat,"What is a cat?",NumCompletions=3); + testCase.verifySize(response.Body.Data.choices,[3,1]); + end + + + function doReturnErrors(testCase) + testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); + chat = azureChat; + % This input is considerably longer than accepted as input for + % GPT-3.5 (16385 tokens) + wayTooLong = string(repmat('a ',1,20000)); + testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError"); + end + + function seedFixesResult(testCase) + testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); + chat = azureChat; + response1 = generate(chat,"hi",Seed=1234); + response2 = generate(chat,"hi",Seed=1234); + testCase.verifyEqual(response1,response2); + end + + function createAzureChatWithStreamFunc(testCase) + testCase.assumeTrue(isenv("AZURE_OPENAI_API_KEY"),"end-to-end test requires environment variables AZURE_OPENAI_API_KEY, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_DEPLOYMENT."); + function seen = sf(str) + persistent data; + if isempty(data) + data = strings(1, 0); + end + % Append streamed text to an empty string array of length 1 + data = [data, str]; + seen = data; + end + chat = azureChat(StreamFun=@sf); + + testCase.verifyWarningFree(@()generate(chat, "Hello world.")); + % Checking that persistent data, which is still stored in + % memory, is greater than 1. This would mean that the stream + % function has been called and streamed some text. + testCase.verifyGreaterThan(numel(sf("")), 1); + end + + function generateWithTools(testCase) + import matlab.unittest.constraints.HasField + + f = openAIFunction("getCurrentWeather", "Get the current weather in a given location"); + f = addParameter(f, "location", type="string", description="The city and country, optionally state. E.g., San Francisco, CA, USA"); + f = addParameter(f, "unit", type="string", enum=["Kelvin","Celsius"], RequiredParameter=false); + + chat = azureChat(Tools=f); + + prompt = "What's the weather like in San Francisco, Tokyo, and Paris?"; + [~, response] = generate(chat, prompt, ToolChoice="getCurrentWeather"); + + testCase.assertThat(response, HasField("tool_calls")); + testCase.assertEqual(response.tool_calls.type,'function'); + testCase.assertEqual(response.tool_calls.function.name,'getCurrentWeather'); + data = testCase.verifyWarningFree( ... + @() jsondecode(response.tool_calls.function.arguments)); + testCase.verifyThat(data,HasField("location")); + end + + function errorsWhenPassingToolChoiceWithEmptyTools(testCase) + chat = azureChat(APIKey="this-is-not-a-real-key"); + testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall"); + end + + function shortErrorForBadEndpoint(testCase) + chat = azureChat(Endpoint="https://nobodyhere.whatever/"); + caught = false; + try + generate(chat,"input"); + catch ME + caught = ME; + end + testCase.assertClass(caught,"MException"); + testCase.verifyEqual(caught.identifier,'MATLAB:webservices:UnknownHost'); + testCase.verifyEmpty(caught.cause); + end + + function invalidInputsConstructor(testCase, InvalidConstructorInput) + testCase.verifyError(@()azureChat(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); + end + + function invalidInputsGenerate(testCase, InvalidGenerateInput) + f = openAIFunction("validfunction"); + chat = azureChat(Tools=f, APIKey="this-is-not-a-real-key"); + testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error); + end + + function invalidSetters(testCase, InvalidValuesSetters) + chat = azureChat(APIKey="this-is-not-a-real-key"); + function assignValueToProperty(property, value) + chat.(property) = value; + end + + testCase.verifyError(@()assignValueToProperty(InvalidValuesSetters.Property,InvalidValuesSetters.Value), InvalidValuesSetters.Error); + end + + function keyNotFound(testCase) + % to verify the error, we need to unset the environment variable + % AZURE_OPENAI_API_KEY, if given. Use a fixture to restore the + % value on leaving the test point: + import matlab.unittest.fixtures.EnvironmentVariableFixture + testCase.applyFixture(EnvironmentVariableFixture("AZURE_OPENAI_API_KEY","dummy")); + unsetenv("AZURE_OPENAI_API_KEY"); + testCase.verifyError(@()azureChat, "llms:keyMustBeSpecified"); + end + end +end + +function invalidValuesSetters = iGetInvalidValuesSetters + +invalidValuesSetters = struct( ... + "InvalidTemperatureType", struct( ... + "Property", "Temperature", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidTemperatureSize", struct( ... + "Property", "Temperature", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "TemperatureTooLarge", struct( ... + "Property", "Temperature", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "TemperatureTooSmall", struct( ... + "Property", "Temperature", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... + "InvalidTopPType", struct( ... + "Property", "TopP", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidTopPSize", struct( ... + "Property", "TopP", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "TopPTooLarge", struct( ... + "Property", "TopP", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "TopPTooSmall", struct( ... + "Property", "TopP", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... + "WrongTypeStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", 123, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "WrongSizeStopNonVector", struct( ... + "Property", "StopSequences", ... + "Value", repmat("stop", 4), ... + "Error", "MATLAB:validators:mustBeVector"), ... + ... + "EmptyStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", "", ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "WrongSizeStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", ["1" "2" "3" "4" "5"], ... + "Error", "llms:stopSequencesMustHaveMax4Elements"), ... + ... + "InvalidPresencePenalty", struct( ... + "Property", "PresencePenalty", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidPresencePenaltySize", struct( ... + "Property", "PresencePenalty", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "PresencePenaltyTooLarge", struct( ... + "Property", "PresencePenalty", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "PresencePenaltyTooSmall", struct( ... + "Property", "PresencePenalty", ... + "Value", -20, ... + "Error", "MATLAB:notGreaterEqual"), ... + ... + "InvalidFrequencyPenalty", struct( ... + "Property", "FrequencyPenalty", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidFrequencyPenaltySize", struct( ... + "Property", "FrequencyPenalty", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "FrequencyPenaltyTooLarge", struct( ... + "Property", "FrequencyPenalty", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "FrequencyPenaltyTooSmall", struct( ... + "Property", "FrequencyPenalty", ... + "Value", -20, ... + "Error", "MATLAB:notGreaterEqual")); +end + +function invalidConstructorInput = iGetInvalidConstructorInput +validFunction = openAIFunction("funName"); +invalidConstructorInput = struct( ... + "InvalidResponseFormatValue", struct( ... + "Input",{{"ResponseFormat", "foo" }},... + "Error", "MATLAB:validators:mustBeMember"), ... + ... + "InvalidResponseFormatSize", struct( ... + "Input",{{"ResponseFormat", ["text" "text"] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "InvalidStreamFunType", struct( ... + "Input",{{"StreamFun", "2" }},... + "Error", "MATLAB:validators:mustBeA"), ... + ... + "InvalidStreamFunSize", struct( ... + "Input",{{"StreamFun", [1 1 1] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "InvalidTimeOutType", struct( ... + "Input",{{"TimeOut", "2" }},... + "Error", "MATLAB:validators:mustBeReal"), ... + ... + "InvalidTimeOutSize", struct( ... + "Input",{{"TimeOut", [1 1 1] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "WrongTypeSystemPrompt",struct( ... + "Input",{{ 123 }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "WrongSizeSystemPrompt",struct( ... + "Input",{{ ["test"; "test"] }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "InvalidToolsType",struct( ... + "Input",{{"Tools", "a" }},... + "Error","MATLAB:validators:mustBeA"),... + ... + "InvalidToolsSize",struct( ... + "Input",{{"Tools", repmat(validFunction, 2, 2) }},... + "Error","MATLAB:validation:IncompatibleSize"),... + ... + "InvalidAPIVersionType",struct( ... + "Input",{{"APIVersion", 0}},... + "Error","MATLAB:validators:mustBeMember"),... + ... + "InvalidAPIVersionSize",struct( ... + "Input",{{"APIVersion", ["2023-05-15", "2023-05-15"]}},... + "Error","MATLAB:validation:IncompatibleSize"),... + ... + "InvalidAPIVersionOption",struct( ... + "Input",{{ "APIVersion", "gpt" }},... + "Error","MATLAB:validators:mustBeMember"),... + ... + "InvalidTemperatureType",struct( ... + "Input",{{ "Temperature" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidTemperatureSize",struct( ... + "Input",{{ "Temperature" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "TemperatureTooLarge",struct( ... + "Input",{{ "Temperature" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "TemperatureTooSmall",struct( ... + "Input",{{ "Temperature" -20 }},... + "Error","MATLAB:expectedNonnegative"),... + ... + "InvalidTopPType",struct( ... + "Input",{{ "TopP" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidTopPSize",struct( ... + "Input",{{ "TopP" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "TopPTooLarge",struct( ... + "Input",{{ "TopP" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "TopPTooSmall",struct( ... + "Input",{{ "TopP" -20 }},... + "Error","MATLAB:expectedNonnegative"),... + ... + "WrongTypeStopSequences",struct( ... + "Input",{{ "StopSequences" 123}},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "WrongSizeStopNonVector",struct( ... + "Input",{{ "StopSequences" repmat("stop", 4) }},... + "Error","MATLAB:validators:mustBeVector"),... + ... + "EmptyStopSequences",struct( ... + "Input",{{ "StopSequences" ""}},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "WrongSizeStopSequences",struct( ... + "Input",{{ "StopSequences" ["1" "2" "3" "4" "5"]}},... + "Error","llms:stopSequencesMustHaveMax4Elements"),... + ... + "InvalidPresencePenalty",struct( ... + "Input",{{ "PresencePenalty" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidPresencePenaltySize",struct( ... + "Input",{{ "PresencePenalty" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "PresencePenaltyTooLarge",struct( ... + "Input",{{ "PresencePenalty" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "PresencePenaltyTooSmall",struct( ... + "Input",{{ "PresencePenalty" -20 }},... + "Error","MATLAB:notGreaterEqual"),... + ... + "InvalidFrequencyPenalty",struct( ... + "Input",{{ "FrequencyPenalty" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidFrequencyPenaltySize",struct( ... + "Input",{{ "FrequencyPenalty" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "FrequencyPenaltyTooLarge",struct( ... + "Input",{{ "FrequencyPenalty" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "FrequencyPenaltyTooSmall",struct( ... + "Input",{{ "FrequencyPenalty" -20 }},... + "Error","MATLAB:notGreaterEqual"),... + ... + "InvalidApiKeyType",struct( ... + "Input",{{ "APIKey" 123 }},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "InvalidApiKeySize",struct( ... + "Input",{{ "APIKey" ["abc" "abc"] }},... + "Error","MATLAB:validators:mustBeTextScalar")); +end + +function invalidGenerateInput = iGetInvalidGenerateInput +emptyMessages = messageHistory; +validMessages = addUserMessage(emptyMessages,"Who invented the telephone?"); + +invalidGenerateInput = struct( ... + "EmptyInput",struct( ... + "Input",{{ [] }},... + "Error","llms:mustBeMessagesOrTxt"),... + ... + "InvalidInputType",struct( ... + "Input",{{ 123 }},... + "Error","llms:mustBeMessagesOrTxt"),... + ... + "EmptyMessages",struct( ... + "Input",{{ emptyMessages }},... + "Error","llms:mustHaveMessages"),... + ... + "InvalidMaxNumTokensType",struct( ... + "Input",{{ validMessages "MaxNumTokens" "2" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "InvalidMaxNumTokensValue",struct( ... + "Input",{{ validMessages "MaxNumTokens" 0 }},... + "Error","MATLAB:validators:mustBePositive"),... + ... + "InvalidNumCompletionsType",struct( ... + "Input",{{ validMessages "NumCompletions" "2" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "InvalidNumCompletionsValue",struct( ... + "Input",{{ validMessages "NumCompletions" 0 }},... + "Error","MATLAB:validators:mustBePositive"), ... + ... + "InvalidToolChoiceValue",struct( ... + "Input",{{ validMessages "ToolChoice" "functionDoesNotExist" }},... + "Error","MATLAB:validators:mustBeMember"),... + ... + "InvalidToolChoiceType",struct( ... + "Input",{{ validMessages "ToolChoice" 0 }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "InvalidToolChoiceSize",struct( ... + "Input",{{ validMessages "ToolChoice" ["validfunction", "validfunction"] }},... + "Error","MATLAB:validators:mustBeTextScalar")); +end diff --git a/tests/texampleTests.m b/tests/texampleTests.m index 79b0420..de184d3 100644 --- a/tests/texampleTests.m +++ b/tests/texampleTests.m @@ -3,6 +3,12 @@ % Copyright 2024 The MathWorks, Inc. + + properties(TestParameter) + ChatBotExample = {"CreateSimpleChatBot", "CreateSimpleOllamaChatBot"}; + end + + methods (TestClassSetup) function setUpAndTearDowns(testCase) import matlab.unittest.fixtures.CurrentFolderFixture @@ -29,44 +35,16 @@ function testAnalyzeScientificPapersUsingFunctionCalls(~) AnalyzeScientificPapersUsingFunctionCalls; end - function testProcessGeneratedTextinRealTimebyUsingChatGPTinStreamingMode(~) - ProcessGeneratedTextinRealTimebyUsingChatGPTinStreamingMode; - end - - function testUsingDALLEToGenerateImages(~) - UsingDALLEToGenerateImages; - end - - function testInformationRetrievalUsingOpenAIDocumentEmbedding(~) - InformationRetrievalUsingOpenAIDocumentEmbedding; - end - - function testDescribeImagesUsingChatGPT(~) - DescribeImagesUsingChatGPT; - end - - function testSummarizeLargeDocumentsUsingChatGPTandMATLAB(~) - SummarizeLargeDocumentsUsingChatGPTandMATLAB; + function testAnalyzeSentimentinTextUsingChatGPTinJSONMode(testCase) + testCase.verifyWarning(@AnalyzeSentimentinTextUsingChatGPTinJSONMode,... + "llms:warningJsonInstruction"); end function testAnalyzeTextDataUsingParallelFunctionCallwithChatGPT(~) AnalyzeTextDataUsingParallelFunctionCallwithChatGPT; end - function testRetrievalAugmentedGenerationUsingChatGPTandMATLAB(~) - RetrievalAugmentedGenerationUsingChatGPTandMATLAB; - end - - function testUsingDALLEToEditImages(~) - UsingDALLEToEditImages; - end - - function testAnalyzeSentimentinTextUsingChatGPTinJSONMode(testCase) - testCase.verifyWarning(@AnalyzeSentimentinTextUsingChatGPTinJSONMode,... - "llms:warningJsonInstruction"); - end - - function testCreateSimpleChatBot(testCase) + function testCreateSimpleChatBot(testCase,ChatBotExample) % set up a fake input command, returning canned user prompts count = 0; prompts = [ @@ -101,11 +79,43 @@ function testCreateSimpleChatBot(testCase) numWordsResponse = []; %#ok % Run the example - CreateSimpleChatBot; + eval(ChatBotExample); testCase.verifyEqual(count,find(prompts=="end",1)); testCase.verifySize(messages.Messages,[1 2*(count-1)]); end + + function testDescribeImagesUsingChatGPT(~) + DescribeImagesUsingChatGPT; + end + + function testInformationRetrievalUsingOpenAIDocumentEmbedding(~) + InformationRetrievalUsingOpenAIDocumentEmbedding; + end + + function testProcessGeneratedTextinRealTimebyUsingChatGPTinStreamingMode(~) + ProcessGeneratedTextinRealTimebyUsingChatGPTinStreamingMode; + end + + function testProcessGeneratedTextInRealTimeByUsingOllamaInStreamingMode(~) + ProcessGeneratedTextInRealTimeByUsingOllamaInStreamingMode; + end + + function testRetrievalAugmentedGenerationUsingChatGPTandMATLAB(~) + RetrievalAugmentedGenerationUsingChatGPTandMATLAB; + end + + function testSummarizeLargeDocumentsUsingChatGPTandMATLAB(~) + SummarizeLargeDocumentsUsingChatGPTandMATLAB; + end + + function testUsingDALLEToEditImages(~) + UsingDALLEToEditImages; + end + + function testUsingDALLEToGenerateImages(~) + UsingDALLEToGenerateImages; + end end end diff --git a/tests/textractOpenAIEmbeddings.m b/tests/textractOpenAIEmbeddings.m index fca5bc8..f5352aa 100644 --- a/tests/textractOpenAIEmbeddings.m +++ b/tests/textractOpenAIEmbeddings.m @@ -23,9 +23,9 @@ function saveEnvVar(testCase) methods(Test) % Test methods function embedsDifferentStringTypes(testCase) - testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ApiKey="this-is-not-a-real-key")); - testCase.verifyWarningFree(@()extractOpenAIEmbeddings('bla', ApiKey="this-is-not-a-real-key")); - testCase.verifyWarningFree(@()extractOpenAIEmbeddings({'bla'}, ApiKey="this-is-not-a-real-key")); + testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", APIKey="this-is-not-a-real-key")); + testCase.verifyWarningFree(@()extractOpenAIEmbeddings('bla', APIKey="this-is-not-a-real-key")); + testCase.verifyWarningFree(@()extractOpenAIEmbeddings({'bla'}, APIKey="this-is-not-a-real-key")); end function keyNotFound(testCase) @@ -36,29 +36,25 @@ function validCombinationOfModelAndDimension(testCase, ValidDimensionsModelCombi testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ... Dimensions=ValidDimensionsModelCombinations.Dimensions,... ModelName=ValidDimensionsModelCombinations.ModelName, ... - ApiKey="not-real")); + APIKey="not-real")); end function embedStringWithSuccessfulOpenAICall(testCase) testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ... - ApiKey=getenv("OPENAI_KEY"))); + APIKey=getenv("OPENAI_KEY"))); end function invalidCombinationOfModelAndDimension(testCase) testCase.verifyError(@()extractOpenAIEmbeddings("bla", ... Dimensions=10,... ModelName="text-embedding-ada-002", ... - ApiKey="not-real"), ... + APIKey="not-real"), ... "llms:invalidOptionForModel") end function useAllNVP(testCase) testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ModelName="text-embedding-ada-002", ... - ApiKey="this-is-not-a-real-key", TimeOut=10)); - end - - function verySmallTimeOutErrors(testCase) - testCase.verifyError(@()extractOpenAIEmbeddings("bla", TimeOut=0.0001, ApiKey="false-key"), "MATLAB:webservices:Timeout") + APIKey="this-is-not-a-real-key", TimeOut=10)); end function testInvalidInputs(testCase, InvalidInput) @@ -111,12 +107,12 @@ function testInvalidInputs(testCase, InvalidInput) ... "LargeDimensionValueForModelLarge",struct( ... "Input",{{"bla", "ModelName", "text-embedding-3-large", ... - "Dimensions", 3073, "ApiKey", "fake-key" }},... + "Dimensions", 3073, "APIKey", "fake-key" }},... "Error","llms:dimensionsMustBeSmallerThan"),... ... "LargeDimensionValueForModelSmall",struct( ... "Input",{{"bla", "ModelName", "text-embedding-3-small", ... - "Dimensions", 1537, "ApiKey", "fake-key" }},... + "Dimensions", 1537, "APIKey", "fake-key" }},... "Error","llms:dimensionsMustBeSmallerThan"),... ... "InvalidDimensionSize",struct( ... @@ -124,11 +120,11 @@ function testInvalidInputs(testCase, InvalidInput) "Error","MATLAB:validation:IncompatibleSize"),... ... "InvalidApiKeyType",struct( ... - "Input",{{"bla", "ApiKey" 123 }},... + "Input",{{"bla", "APIKey" 123 }},... "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidApiKeySize",struct( ... - "Input",{{"bla", "ApiKey" ["abc" "abc"] }},... + "Input",{{"bla", "APIKey" ["abc" "abc"] }},... "Error","MATLAB:validators:mustBeTextScalar")); end diff --git a/tests/tmessageHistory.m b/tests/tmessageHistory.m new file mode 100644 index 0000000..da00d93 --- /dev/null +++ b/tests/tmessageHistory.m @@ -0,0 +1,359 @@ +classdef tmessageHistory < matlab.unittest.TestCase +% Tests for messageHistory + +% Copyright 2023-2024 The MathWorks, Inc. + + properties(TestParameter) + InvalidInputsUserPrompt = iGetInvalidInputsUserPrompt(); + InvalidInputsUserImagesPrompt = iGetInvalidInputsUserImagesPrompt(); + InvalidInputsFunctionPrompt = iGetInvalidFunctionPrompt(); + InvalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt(); + InvalidInputsResponseMessage = iGetInvalidInputsResponseMessage(); + InvalidRemoveMessage = iGetInvalidRemoveMessage(); + InvalidFuncCallsCases = iGetInvalidFuncCallsCases() + ValidTextInput = {"This is okay"; 'this is ok'}; + end + + methods(Test) + function constructorStartsWithEmptyMessages(testCase) + msgs = messageHistory; + testCase.verifyTrue(isempty(msgs.Messages)); + end + + function differentInputTextAccepted(testCase, ValidTextInput) + msgs = messageHistory; + testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); + testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); + testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput)); + testCase.verifyWarningFree(@()addToolMessage(msgs, ValidTextInput, ValidTextInput, ValidTextInput)); + end + + function systemMessageIsAdded(testCase) + prompt = "Here is a system prompt"; + name = "example"; + msgs = messageHistory; + systemPrompt = struct("role", "system", "name", name, "content", prompt); + msgs = addSystemMessage(msgs, name, prompt); + testCase.verifyEqual(msgs.Messages{1}, systemPrompt); + end + + function userMessageIsAdded(testCase) + prompt = "Here is a user prompt"; + msgs = messageHistory; + userPrompt = struct("role", "user", "content", prompt); + msgs = addUserMessage(msgs, prompt); + testCase.verifyEqual(msgs.Messages{1}, userPrompt); + end + + function userImageMessageIsAddedWithLocalImg(testCase) + prompt = "Here is a user prompt"; + msgs = messageHistory; + img = "peppers.png"; + testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img)); + end + + function userImageMessageIsAddedWithRemoteImg(testCase) + prompt = "Here is a user prompt"; + msgs = messageHistory; + img = "https://www.mathworks.com/help/examples/matlab/win64/DisplayGrayscaleRGBIndexedOrBinaryImageExample_04.png"; + testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img)); + end + + function toolMessageIsAdded(testCase) + prompt = "20"; + name = "sin"; + id = "123"; + msgs = messageHistory; + systemPrompt = struct("tool_call_id", id, "role", "tool", "name", name, "content", prompt); + msgs = addToolMessage(msgs, id, name, prompt); + testCase.verifyEqual(msgs.Messages{1}, systemPrompt); + end + + function assistantMessageIsAdded(testCase) + prompt = "Here is an assistant prompt"; + msgs = messageHistory; + assistantPrompt = struct("role", "assistant", "content", prompt); + msgs = addResponseMessage(msgs, assistantPrompt); + testCase.verifyEqual(msgs.Messages{1}, assistantPrompt); + end + + function assistantToolCallMessageIsAdded(testCase) + msgs = messageHistory; + functionName = "functionName"; + args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; + funCall = struct("name", functionName, "arguments", args); + toolCall = struct("id", "123", "type", "function", "function", funCall); + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); + % tool_calls is an array of struct in API response + toolCallPrompt.tool_calls = toolCall; + msgs = addResponseMessage(msgs, toolCallPrompt); + % to include in msgs, tool_calls must be a cell + testCase.verifyEqual(fieldnames(msgs.Messages{1}), fieldnames(toolCallPrompt)); + testCase.verifyEqual(msgs.Messages{1}.tool_calls{1}, toolCallPrompt.tool_calls); + end + + function errorsAssistantWithWithoutToolCallId(testCase) + msgs = messageHistory; + functionName = "functionName"; + args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; + funCall = struct("name", functionName, "arguments", args); + toolCall = struct("type", "function", "function", funCall); + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); + % tool_calls is an array of struct in API response + toolCallPrompt.tool_calls = toolCall; + + testCase.verifyError(@()addResponseMessage(msgs, toolCallPrompt), "llms:mustBeAssistantWithIdAndFunction"); + end + + function errorsAssistantWithToolCallsWithoutNameOrArgs(testCase, InvalidFuncCallsCases) + msgs = messageHistory; + funCall = InvalidFuncCallsCases.FunCallStruct; + toolCall = struct("id", "123", "type", "function", "function", funCall); + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); + % tool_calls is an array of struct in API response + toolCallPrompt.tool_calls = toolCall; + + testCase.verifyError(@()addResponseMessage(msgs, toolCallPrompt), InvalidFuncCallsCases.Error); + end + + function errorsAssistantWithWithNonTextNameAndArguments(testCase) + msgs = messageHistory; + funCall = struct("name", 1, "arguments", 2); + toolCall = struct("id", "123", "type", "function", "function", funCall); + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); + % tool_calls is an array of struct in API response + toolCallPrompt.tool_calls = toolCall; + + testCase.verifyError(@()addResponseMessage(msgs, toolCallPrompt), "llms:assistantMustHaveTextNameAndArguments"); + end + + function assistantToolCallMessageWithoutArgsIsAdded(testCase) + msgs = messageHistory; + functionName = "functionName"; + funCall = struct("name", functionName, "arguments", "{}"); + toolCall = struct("id", "123", "type", "function", "function", funCall); + % tool_calls is an array of struct in API response + toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall); + msgs = addResponseMessage(msgs, toolCallPrompt); + % to include in msgs, tool_calls must be a cell + testCase.verifyEqual(fieldnames(msgs.Messages{1}), fieldnames(toolCallPrompt)); + testCase.verifyEqual(msgs.Messages{1}.tool_calls{1}, toolCallPrompt.tool_calls); + end + + function assistantParallelToolCallMessageIsAdded(testCase) + msgs = messageHistory; + functionName = "functionName"; + args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; + funCall = struct("name", functionName, "arguments", args); + toolCall = struct("id", "123", "type", "function", "function", funCall); + % tool_calls is an array of struct in API response + toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", toolCall); + toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall]; + msgs = addResponseMessage(msgs, toolCallPrompt); + testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt); + end + + function messageGetsRemoved(testCase) + msgs = messageHistory; + idx = 2; + + msgs = addSystemMessage(msgs, "name", "content"); + msgs = addUserMessage(msgs, "content"); + msgs = addToolMessage(msgs, "123", "name", "content"); + sizeMsgs = length(msgs.Messages); + % Message exists before removal + msgToBeRemoved = msgs.Messages{idx}; + testCase.verifyTrue(any(cellfun(@(c) isequal(c, msgToBeRemoved), msgs.Messages))); + + msgs = removeMessage(msgs, idx); + testCase.verifyFalse(any(cellfun(@(c) isequal(c, msgToBeRemoved), msgs.Messages))); + testCase.verifyEqual(length(msgs.Messages), sizeMsgs-1); + end + + function removalIdxCantBeLargerThanNumElements(testCase) + msgs = messageHistory; + + msgs = addSystemMessage(msgs, "name", "content"); + msgs = addUserMessage(msgs, "content"); + msgs = addToolMessage(msgs, "123", "name", "content"); + sizeMsgs = length(msgs.Messages); + + testCase.verifyError(@()removeMessage(msgs, sizeMsgs+1), "llms:mustBeValidIndex"); + end + + function invalidInputsSystemPrompt(testCase, InvalidInputsSystemPrompt) + msgs = messageHistory; + testCase.verifyError(@()addSystemMessage(msgs,InvalidInputsSystemPrompt.Input{:}), InvalidInputsSystemPrompt.Error); + end + + function invalidInputsUserPrompt(testCase, InvalidInputsUserPrompt) + msgs = messageHistory; + testCase.verifyError(@()addUserMessage(msgs,InvalidInputsUserPrompt.Input{:}), InvalidInputsUserPrompt.Error); + end + + function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt) + msgs = messageHistory; + testCase.verifyError(@()addUserMessageWithImages(msgs,InvalidInputsUserImagesPrompt.Input{:}), InvalidInputsUserImagesPrompt.Error); + end + + function invalidInputsFunctionPrompt(testCase, InvalidInputsFunctionPrompt) + msgs = messageHistory; + testCase.verifyError(@()addToolMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error); + end + + function invalidInputsRemove(testCase, InvalidRemoveMessage) + msgs = messageHistory; + testCase.verifyError(@()removeMessage(msgs,InvalidRemoveMessage.Input{:}), InvalidRemoveMessage.Error); + end + + function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage) + msgs = messageHistory; + testCase.verifyError(@()addResponseMessage(msgs,InvalidInputsResponseMessage.Input{:}), InvalidInputsResponseMessage.Error); + end + end +end + +function invalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt() +invalidInputsSystemPrompt = struct( ... + "NonStringInputName", ... + struct("Input", {{123, "content"}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonStringInputContent", ... + struct("Input", {{"name", 123}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "EmptytName", ... + struct("Input", {{"", "content"}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "EmptytContent", ... + struct("Input", {{"name", ""}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonScalarInputName", ... + struct("Input", {{["name1" "name2"], "content"}}, ... + "Error", "MATLAB:validators:mustBeTextScalar"),... + ... + "NonScalarInputContent", ... + struct("Input", {{"name", ["content1", "content2"]}}, ... + "Error", "MATLAB:validators:mustBeTextScalar")); +end + +function invalidInputsUserPrompt = iGetInvalidInputsUserPrompt() +invalidInputsUserPrompt = struct( ... + "NonStringInput", ... + struct("Input", {{123}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonScalarInput", ... + struct("Input", {{["prompt1" "prompt2"]}}, ... + "Error", "MATLAB:validators:mustBeTextScalar"), ... + ... + "EmptyInput", ... + struct("Input", {{""}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText")); +end + +function invalidInputsUserImagesPrompt = iGetInvalidInputsUserImagesPrompt() +invalidInputsUserImagesPrompt = struct( ... + "NonStringInput", ... + struct("Input", {{123, "peppers.png"}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonScalarInput", ... + struct("Input", {{["prompt1" "prompt2"], "peppers.png"}}, ... + "Error", "MATLAB:validators:mustBeTextScalar"), ... + ... + "EmptyInput", ... + struct("Input", {{"", "peppers.png"}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonTextImage", ... + struct("Input", {{"prompt", 123}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "EmptyImageName", ... + struct("Input", {{"prompt", 123}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "InvalidDetail", ... + struct("Input", {{"prompt", "peppers.png", "Detail", "invalid"}}, ... + "Error", "MATLAB:validators:mustBeMember")); +end + +function invalidFunctionPrompt = iGetInvalidFunctionPrompt() +invalidFunctionPrompt = struct( ... + "NonStringInputName", ... + struct("Input", {{"123", 123, "content"}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonStringInputContent", ... + struct("Input", {{"123", "name", 123}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "EmptytName", ... + struct("Input", {{"123", "", "content"}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "EmptytContent", ... + struct("Input", {{"123", "name", ""}}, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "NonScalarInputName", ... + struct("Input", {{"123", ["name1" "name2"], "content"}}, ... + "Error", "MATLAB:validators:mustBeTextScalar"),... + ... + "NonScalarInputContent", ... + struct("Input", {{"123","name", ["content1", "content2"]}}, ... + "Error", "MATLAB:validators:mustBeTextScalar")); +end + +function invalidRemoveMessage = iGetInvalidRemoveMessage() +invalidRemoveMessage = struct( ... + "NonInteger", ... + struct("Input", {{0.5}}, ... + "Error", "MATLAB:validators:mustBeInteger"), ... + ... + "NonPositive", ... + struct("Input", {{0}}, ... + "Error", "MATLAB:validators:mustBePositive"), ... + ... + "NonScalarInput", ... + struct("Input", {{[1 2]}}, ... + "Error", "MATLAB:validation:IncompatibleSize")); +end + +function invalidInputsResponseMessage = iGetInvalidInputsResponseMessage() +invalidInputsResponseMessage = struct( ... + "NonStructInput", ... + struct("Input", {{123}},... + "Error", "MATLAB:validation:UnableToConvert"),... + ... + "NonExistentRole", ... + struct("Input", {{struct("role", "123", "content", "123")}},... + "Error", "llms:mustBeAssistantCall"),... + ... + "NonExistentContent", ... + struct("Input", {{struct("role", "assistant")}},... + "Error", "llms:mustBeAssistantCall"),... + ... + "EmptyContent", ... + struct("Input", {{struct("role", "assistant", "content", "")}},... + "Error", "llms:mustBeAssistantWithContent"),... + ... + "NonScalarContent", ... + struct("Input", {{struct("role", "assistant", "content", ["a", "b"])}},... + "Error", "llms:mustBeAssistantWithContent")); +end + +function invalidFuncCallsCases = iGetInvalidFuncCallsCases() +invalidFuncCallsCases = struct( ... + "NoArguments", ... + struct("FunCallStruct", struct("name", "functionName"),... + "Error", "llms:mustBeAssistantWithNameAndArguments"),... + ... + "NoName", ... + struct("FunCallStruct", struct("arguments", "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"), ... + "Error", "llms:mustBeAssistantWithNameAndArguments")); +end \ No newline at end of file diff --git a/tests/tollamaChat.m b/tests/tollamaChat.m new file mode 100644 index 0000000..bdbadff --- /dev/null +++ b/tests/tollamaChat.m @@ -0,0 +1,329 @@ +classdef tollamaChat < matlab.unittest.TestCase +% Tests for ollamaChat + +% Copyright 2024 The MathWorks, Inc. + + properties(TestParameter) + InvalidConstructorInput = iGetInvalidConstructorInput; + InvalidGenerateInput = iGetInvalidGenerateInput; + InvalidValuesSetters = iGetInvalidValuesSetters; + ValidValuesSetters = iGetValidValuesSetters; + StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}}); + end + + methods(Test) + function simpleConstruction(testCase) + bot = ollamaChat("mistral"); + testCase.verifyClass(bot,"ollamaChat"); + end + + function constructChatWithAllNVP(testCase) + temperature = 0; + topP = 1; + stop = ["[END]", "."]; + systemPrompt = "This is a system prompt"; + timeout = 3; + model = "mistral"; + chat = ollamaChat(model, systemPrompt, ... + Temperature=temperature, TopP=topP, StopSequences=stop,... + TimeOut=timeout); + testCase.verifyEqual(chat.Temperature, temperature); + testCase.verifyEqual(chat.TopP, topP); + testCase.verifyEqual(chat.StopSequences, stop); + end + + function doGenerate(testCase,StringInputs) + chat = ollamaChat("mistral"); + response = testCase.verifyWarningFree(@() generate(chat,StringInputs)); + testCase.verifyClass(response,'string'); + testCase.verifyGreaterThan(strlength(response),0); + end + + function extremeTopK(testCase) + % setting top-k to k=1 leaves no random choice, + % so we expect to get a fixed response. + chat = ollamaChat("mistral",TopK=1); + prompt = "Top-k sampling with k=1 returns a definite answer."; + response1 = generate(chat,prompt); + response2 = generate(chat,prompt); + testCase.verifyEqual(response1,response2); + end + + function extremeTfsZ(testCase) + %% This should work, and it does on some computers. On others, Ollama + %% receives the parameter, but either Ollama or llama.cpp fails to + %% honor it correctly. + testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + + % setting tfs_z to z=0 leaves no random choice, but degrades to + % greedy sampling, so we expect to get a fixed response. + chat = ollamaChat("mistral",TailFreeSamplingZ=0); + prompt = "Sampling with tfs_z=0 returns a definite answer."; + response1 = generate(chat,prompt); + response2 = generate(chat,prompt); + testCase.verifyEqual(response1,response2); + end + + function stopSequences(testCase) + chat = ollamaChat("mistral",TopK=1); + prompt = "Top-k sampling with k=1 returns a definite answer."; + response1 = generate(chat,prompt); + chat.StopSequences = "1"; + response2 = generate(chat,prompt); + + testCase.verifyEqual(response2, extractBefore(response1,"1")); + end + + function seedFixesResult(testCase) + %% This should work, and it does on some computers. On others, Ollama + %% receives the parameter, but either Ollama or llama.cpp fails to + %% honor it correctly. + testCase.assumeTrue(false,"disabled due to Ollama/llama.cpp not honoring parameter reliably"); + + chat = ollamaChat("mistral"); + response1 = generate(chat,"hi",Seed=1234); + response2 = generate(chat,"hi",Seed=1234); + testCase.verifyEqual(response1,response2); + end + + + function streamFunc(testCase) + function seen = sf(str) + persistent data; + if isempty(data) + data = strings(1, 0); + end + % Append streamed text to an empty string array of length 1 + data = [data, str]; + seen = data; + end + chat = ollamaChat("mistral", StreamFun=@sf); + + testCase.verifyWarningFree(@()generate(chat, "Hello world.")); + % Checking that persistent data, which is still stored in + % memory, is greater than 1. This would mean that the stream + % function has been called and streamed some text. + testCase.verifyGreaterThan(numel(sf("")), 1); + end + + function doReturnErrors(testCase) + testCase.assumeFalse( ... + any(startsWith(ollamaChat.models,"abcdefghijklmnop")), ... + "We want a model name that does not exist on this server"); + chat = ollamaChat("abcdefghijklmnop"); + testCase.verifyError(@() generate(chat,"hi!"), "llms:apiReturnedError"); + end + + + function invalidInputsConstructor(testCase, InvalidConstructorInput) + testCase.verifyError(@() ollamaChat("mistral", InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); + end + + function invalidInputsGenerate(testCase, InvalidGenerateInput) + chat = ollamaChat("mistral"); + testCase.verifyError(@() generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error); + end + + function invalidSetters(testCase, InvalidValuesSetters) + chat = ollamaChat("mistral"); + function assignValueToProperty(property, value) + chat.(property) = value; + end + + testCase.verifyError(@() assignValueToProperty(InvalidValuesSetters.Property,InvalidValuesSetters.Value), InvalidValuesSetters.Error); + end + + function validSetters(testCase, ValidValuesSetters) + chat = ollamaChat("mistral"); + function assignValueToProperty(property, value) + chat.(property) = value; + end + + testCase.verifyWarningFree(@() assignValueToProperty(ValidValuesSetters.Property,ValidValuesSetters.Value)); + end + + function queryModels(testCase) + % our test setup has at least mistral loaded + models = ollamaChat.models; + testCase.verifyClass(models,"string"); + testCase.verifyThat(models, ... + matlab.unittest.constraints.IsSupersetOf("mistral")); + end + end +end + +function invalidValuesSetters = iGetInvalidValuesSetters + +invalidValuesSetters = struct( ... + "InvalidTemperatureType", struct( ... + "Property", "Temperature", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidTemperatureSize", struct( ... + "Property", "Temperature", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "TemperatureTooLarge", struct( ... + "Property", "Temperature", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "TemperatureTooSmall", struct( ... + "Property", "Temperature", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... + "InvalidTopPType", struct( ... + "Property", "TopP", ... + "Value", "2", ... + "Error", "MATLAB:invalidType"), ... + ... + "InvalidTopPSize", struct( ... + "Property", "TopP", ... + "Value", [1 1 1], ... + "Error", "MATLAB:expectedScalar"), ... + ... + "TopPTooLarge", struct( ... + "Property", "TopP", ... + "Value", 20, ... + "Error", "MATLAB:notLessEqual"), ... + ... + "TopPTooSmall", struct( ... + "Property", "TopP", ... + "Value", -20, ... + "Error", "MATLAB:expectedNonnegative"), ... + ... + "WrongTypeStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", 123, ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "WrongSizeStopNonVector", struct( ... + "Property", "StopSequences", ... + "Value", repmat("stop", 4), ... + "Error", "MATLAB:validators:mustBeVector"), ... + ... + "EmptyStopSequences", struct( ... + "Property", "StopSequences", ... + "Value", "", ... + "Error", "MATLAB:validators:mustBeNonzeroLengthText")); +end + +function validSetters = iGetValidValuesSetters +validSetters = struct(... + "SmallTopNum", struct( ... + "Property", "TopK", ... + "Value", 2)); + % Currently disabled because it requires some code reorganization + % and we have higher priorities ... + % "ManyStopSequences", struct( ... + % "Property", "StopSequences", ... + % "Value", ["1" "2" "3" "4" "5"])); +end + +function invalidConstructorInput = iGetInvalidConstructorInput +invalidConstructorInput = struct( ... + "InvalidResponseFormatValue", struct( ... + "Input",{{"ResponseFormat", "foo" }},... + "Error", "MATLAB:validators:mustBeMember"), ... + ... + "InvalidResponseFormatSize", struct( ... + "Input",{{"ResponseFormat", ["text" "text"] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "InvalidStreamFunType", struct( ... + "Input",{{"StreamFun", "2" }},... + "Error", "MATLAB:validators:mustBeA"), ... + ... + "InvalidStreamFunSize", struct( ... + "Input",{{"StreamFun", [1 1 1] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "InvalidTimeOutType", struct( ... + "Input",{{"TimeOut", "2" }},... + "Error", "MATLAB:validators:mustBeReal"), ... + ... + "InvalidTimeOutSize", struct( ... + "Input",{{"TimeOut", [1 1 1] }},... + "Error", "MATLAB:validation:IncompatibleSize"), ... + ... + "WrongTypeSystemPrompt",struct( ... + "Input",{{ 123 }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "WrongSizeSystemPrompt",struct( ... + "Input",{{ ["test"; "test"] }},... + "Error","MATLAB:validators:mustBeTextScalar"),... + ... + "InvalidTemperatureType",struct( ... + "Input",{{ "Temperature" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidTemperatureSize",struct( ... + "Input",{{ "Temperature" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "TemperatureTooLarge",struct( ... + "Input",{{ "Temperature" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "TemperatureTooSmall",struct( ... + "Input",{{ "Temperature" -20 }},... + "Error","MATLAB:expectedNonnegative"),... + ... + "InvalidTopPType",struct( ... + "Input",{{ "TopP" "2" }},... + "Error","MATLAB:invalidType"),... + ... + "InvalidTopPSize",struct( ... + "Input",{{ "TopP" [1 1 1] }},... + "Error","MATLAB:expectedScalar"),... + ... + "TopPTooLarge",struct( ... + "Input",{{ "TopP" 20 }},... + "Error","MATLAB:notLessEqual"),... + ... + "TopPTooSmall",struct( ... + "Input",{{ "TopP" -20 }},... + "Error","MATLAB:expectedNonnegative"),...I + ... + "WrongTypeStopSequences",struct( ... + "Input",{{ "StopSequences" 123}},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "WrongSizeStopNonVector",struct( ... + "Input",{{ "StopSequences" repmat("stop", 4) }},... + "Error","MATLAB:validators:mustBeVector"),... + ... + "EmptyStopSequences",struct( ... + "Input",{{ "StopSequences" ""}},... + "Error","MATLAB:validators:mustBeNonzeroLengthText")); +end + +function invalidGenerateInput = iGetInvalidGenerateInput +emptyMessages = messageHistory; +validMessages = addUserMessage(emptyMessages,"Who invented the telephone?"); + +invalidGenerateInput = struct( ... + "EmptyInput",struct( ... + "Input",{{ [] }},... + "Error","llms:mustBeMessagesOrTxt"),... + ... + "InvalidInputType",struct( ... + "Input",{{ 123 }},... + "Error","llms:mustBeMessagesOrTxt"),... + ... + "EmptyMessages",struct( ... + "Input",{{ emptyMessages }},... + "Error","llms:mustHaveMessages"),... + ... + "InvalidMaxNumTokensType",struct( ... + "Input",{{ validMessages "MaxNumTokens" "2" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "InvalidMaxNumTokensValue",struct( ... + "Input",{{ validMessages "MaxNumTokens" 0 }},... + "Error","MATLAB:validators:mustBePositive")); +end diff --git a/tests/topenAIChat.m b/tests/topenAIChat.m index 22fa9bf..7112b50 100644 --- a/tests/topenAIChat.m +++ b/tests/topenAIChat.m @@ -3,46 +3,35 @@ % Copyright 2023-2024 The MathWorks, Inc. - methods (TestClassSetup) - function saveEnvVar(testCase) - % Ensures key is not in environment variable for tests - openAIEnvVar = "OPENAI_API_KEY"; - if isenv(openAIEnvVar) - key = getenv(openAIEnvVar); - unsetenv(openAIEnvVar); - testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key); - end - end - end - properties(TestParameter) ValidConstructorInput = iGetValidConstructorInput(); InvalidConstructorInput = iGetInvalidConstructorInput(); - InvalidGenerateInput = iGetInvalidGenerateInput(); - InvalidValuesSetters = iGetInvalidValuesSetters(); + InvalidGenerateInput = iGetInvalidGenerateInput(); + InvalidValuesSetters = iGetInvalidValuesSetters(); + StringInputs = struct('string',{"hi"},'char',{'hi'},'cellstr',{{'hi'}}); end methods(Test) % Test methods - function generateAcceptsSingleStringAsInput(testCase) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); - testCase.verifyWarningFree(@()generate(chat,"This is okay")); - chat = openAIChat(ApiKey='this-is-not-a-real-key'); - testCase.verifyWarningFree(@()generate(chat,"This is okay")); + function generateAcceptsSingleStringAsInput(testCase,StringInputs) + chat = openAIChat; + testCase.verifyWarningFree(@()generate(chat,StringInputs)); + end + + function generateMultipleResponses(testCase) + chat = openAIChat; + [~,~,response] = generate(chat,"What is a cat?",NumCompletions=3); + testCase.verifySize(response.Body.Data.choices,[3,1]); end function generateAcceptsMessagesAsInput(testCase) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); - messages = openAIMessages; + chat = openAIChat; + messages = messageHistory; messages = addUserMessage(messages, "This should be okay."); testCase.verifyWarningFree(@()generate(chat,messages)); end - function keyNotFound(testCase) - testCase.verifyError(@()openAIChat, "llms:keyMustBeSpecified"); - end - function constructChatWithAllNVP(testCase) functions = openAIFunction("funName"); modelName = "gpt-3.5-turbo"; @@ -55,12 +44,12 @@ function constructChatWithAllNVP(testCase) systemPrompt = "This is a system prompt"; timeout = 3; chat = openAIChat(systemPrompt, Tools=functions, ModelName=modelName, ... - Temperature=temperature, TopProbabilityMass=topP, StopSequences=stop, ApiKey=apiKey,... + Temperature=temperature, TopP=topP, StopSequences=stop, APIKey=apiKey,... FrequencyPenalty=frequenceP, PresencePenalty=presenceP, TimeOut=timeout); testCase.verifyEqual(chat.ModelName, modelName); testCase.verifyEqual(chat.Temperature, temperature); - testCase.verifyEqual(chat.TopProbabilityMass, topP); + testCase.verifyEqual(chat.TopP, topP); testCase.verifyEqual(chat.StopSequences, stop); testCase.verifyEqual(chat.FrequencyPenalty, frequenceP); testCase.verifyEqual(chat.PresencePenalty, presenceP); @@ -81,44 +70,116 @@ function validConstructorCalls(testCase,ValidConstructorInput) end end - function verySmallTimeOutErrors(testCase) - chat = openAIChat(TimeOut=0.0001, ApiKey="false-key"); - - testCase.verifyError(@()generate(chat, "hi"), "MATLAB:webservices:Timeout") - end - function errorsWhenPassingToolChoiceWithEmptyTools(testCase) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); + chat = openAIChat(APIKey="this-is-not-a-real-key"); testCase.verifyError(@()generate(chat,"input", ToolChoice="bla"), "llms:mustSetFunctionsForCall"); end function settingToolChoiceWithNone(testCase) functions = openAIFunction("funName"); - chat = openAIChat(ApiKey="this-is-not-a-real-key",Tools=functions); + chat = openAIChat(Tools=functions); testCase.verifyWarningFree(@()generate(chat,"This is okay","ToolChoice","none")); end - function settingSeedToInteger(testCase) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); + function fixedSeedFixesResult(testCase) + chat = openAIChat; - testCase.verifyWarningFree(@()generate(chat,"This is okay", "Seed", 2)); + result1 = generate(chat,"This is okay", "Seed", 2); + result2 = generate(chat,"This is okay", "Seed", 2); + testCase.verifyEqual(result1,result2); end function invalidInputsConstructor(testCase, InvalidConstructorInput) testCase.verifyError(@()openAIChat(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); end + function generateWithToolsAndStreamFunc(testCase) + import matlab.unittest.constraints.HasField + + f = openAIFunction("writePaperDetails", "Function to write paper details to a table."); + f = addParameter(f, "name", type="string", description="Name of the paper."); + f = addParameter(f, "url", type="string", description="URL containing the paper."); + f = addParameter(f, "explanation", type="string", description="Explanation on why the paper is related to the given topic."); + + paperExtractor = openAIChat( ... + "You are an expert in extracting information from a paper.", ... + APIKey=getenv("OPENAI_KEY"), Tools=f, StreamFun=@(s) s); + + input = join([ + " http://arxiv.org/abs/2406.04344v1" + " 2024-06-06T17:59:56Z" + " 2024-06-06T17:59:56Z" + " Verbalized Machine Learning: Revisiting Machine Learning with Language" + " Models" + " Motivated by the large progress made by large language models (LLMs), we" + "introduce the framework of verbalized machine learning (VML). In contrast to" + "conventional machine learning models that are typically optimized over a" + "continuous parameter space, VML constrains the parameter space to be" + "human-interpretable natural language. Such a constraint leads to a new" + "perspective of function approximation, where an LLM with a text prompt can be" + "viewed as a function parameterized by the text prompt. Guided by this" + "perspective, we revisit classical machine learning problems, such as regression" + "and classification, and find that these problems can be solved by an" + "LLM-parameterized learner and optimizer. The major advantages of VML include" + "(1) easy encoding of inductive bias: prior knowledge about the problem and" + "hypothesis class can be encoded in natural language and fed into the" + "LLM-parameterized learner; (2) automatic model class selection: the optimizer" + "can automatically select a concrete model class based on data and verbalized" + "prior knowledge, and it can update the model class during training; and (3)" + "interpretable learner updates: the LLM-parameterized optimizer can provide" + "explanations for why each learner update is performed. We conduct several" + "studies to empirically evaluate the effectiveness of VML, and hope that VML can" + "serve as a stepping stone to stronger interpretability and trustworthiness in" + "ML." + "" + " " + " Tim Z. Xiao" + " " + " " + " Robert Bamler" + " " + " " + " Bernhard Schölkopf" + " " + " " + " Weiyang Liu" + " " + " Technical Report v1 (92 pages, 15 figures)" + " " + " " + " " + " " + " " + " " + ], newline); + + topic = "Large Language Models"; + + prompt = "Given the following paper:" + newline + string(input)+ newline +... + "Given the topic: "+ topic + newline + "Write the details to a table."; + [~, response] = generate(paperExtractor, prompt); + + testCase.assertThat(response, HasField("tool_calls")); + testCase.verifyEqual(response.tool_calls.type,'function'); + testCase.verifyEqual(response.tool_calls.function.name,'writePaperDetails'); + data = testCase.verifyWarningFree( ... + @() jsondecode(response.tool_calls.function.arguments)); + testCase.verifyThat(data,HasField("name")); + testCase.verifyThat(data,HasField("url")); + testCase.verifyThat(data,HasField("explanation")); + end + function invalidInputsGenerate(testCase, InvalidGenerateInput) f = openAIFunction("validfunction"); - chat = openAIChat(Tools=f, ApiKey="this-is-not-a-real-key"); + chat = openAIChat(Tools=f, APIKey="this-is-not-a-real-key"); testCase.verifyError(@()generate(chat,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error); end function invalidSetters(testCase, InvalidValuesSetters) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); + chat = openAIChat(APIKey="this-is-not-a-real-key"); function assignValueToProperty(property, value) chat.(property) = value; end @@ -127,18 +188,20 @@ function assignValueToProperty(property, value) end function invalidGenerateInputforModel(testCase) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); + chat = openAIChat(APIKey="this-is-not-a-real-key"); image_path = "peppers.png"; - emptyMessages = openAIMessages; + emptyMessages = messageHistory; inValidMessages = addUserMessageWithImages(emptyMessages,"What is in the image?",image_path); testCase.verifyError(@()generate(chat,inValidMessages), "llms:invalidContentTypeForModel") end - function noStopSequencesNoMaxNumTokens(testCase) - chat = openAIChat(ApiKey="this-is-not-a-real-key"); - - testCase.verifyWarningFree(@()generate(chat,"This is okay")); + function doReturnErrors(testCase) + chat = openAIChat; + % This input is considerably longer than accepted as input for + % GPT-3.5 (16385 tokens) + wayTooLong = string(repmat('a ',1,20000)); + testCase.verifyError(@() generate(chat,wayTooLong), "llms:apiReturnedError"); end function createOpenAIChatWithStreamFunc(testCase) @@ -151,7 +214,7 @@ function createOpenAIChatWithStreamFunc(testCase) data = [data, str]; seen = data; end - chat = openAIChat(ApiKey=getenv("OPENAI_KEY"), StreamFun=@sf); + chat = openAIChat(APIKey=getenv("OPENAI_KEY"), StreamFun=@sf); testCase.verifyWarningFree(@()generate(chat, "Hello world.")); % Checking that persistent data, which is still stored in @@ -162,7 +225,7 @@ function createOpenAIChatWithStreamFunc(testCase) function warningJSONResponseFormatGPT35(testCase) chat = @() openAIChat("You are a useful assistant", ... - ApiKey="this-is-not-a-real-key", ... + APIKey="this-is-not-a-real-key", ... ResponseFormat="json", ... ModelName="gpt-3.5-turbo"); @@ -171,18 +234,27 @@ function warningJSONResponseFormatGPT35(testCase) function createOpenAIChatWithOpenAIKey(testCase) chat = openAIChat("You are a useful assistant", ... - ApiKey=getenv("OPENAI_KEY")); + APIKey=getenv("OPENAI_KEY")); testCase.verifyWarningFree(@()generate(chat, "Hello world.")); end function createOpenAIChatWithOpenAIKeyLatestModel(testCase) chat = openAIChat("You are a useful assistant", ... - ApiKey=getenv("OPENAI_KEY"), ModelName="gpt-4o"); + APIKey=getenv("OPENAI_KEY"), ModelName="gpt-4o"); testCase.verifyWarningFree(@()generate(chat, "Hello world.")); end + function keyNotFound(testCase) + % to verify the error, we need to unset the environment variable + % OPENAI_API_KEY, if given. Use a fixture to restore the + % value on leaving the test point: + import matlab.unittest.fixtures.EnvironmentVariableFixture + testCase.applyFixture(EnvironmentVariableFixture("OPENAI_API_KEY","dummy")); + unsetenv("OPENAI_API_KEY"); + testCase.verifyError(@()openAIChat, "llms:keyMustBeSpecified"); + end end end @@ -209,23 +281,23 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) "Value", -20, ... "Error", "MATLAB:expectedNonnegative"), ... ... - "InvalidTopProbabilityMassType", struct( ... - "Property", "TopProbabilityMass", ... + "InvalidTopPType", struct( ... + "Property", "TopP", ... "Value", "2", ... "Error", "MATLAB:invalidType"), ... ... - "InvalidTopProbabilityMassSize", struct( ... - "Property", "TopProbabilityMass", ... + "InvalidTopPSize", struct( ... + "Property", "TopP", ... "Value", [1 1 1], ... "Error", "MATLAB:expectedScalar"), ... ... - "TopProbabilityMassTooLarge", struct( ... - "Property", "TopProbabilityMass", ... + "TopPTooLarge", struct( ... + "Property", "TopP", ... "Value", 20, ... "Error", "MATLAB:notLessEqual"), ... ... - "TopProbabilityMassTooSmall", struct( ... - "Property", "TopProbabilityMass", ... + "TopPTooSmall", struct( ... + "Property", "TopP", ... "Value", -20, ... "Error", "MATLAB:expectedNonnegative"), ... ... @@ -296,11 +368,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) validFunction = openAIFunction("funName"); validConstructorInput = struct( ... "JustKey", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key"}}, ... + "Input",{{"APIKey","this-is-not-a-real-key"}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -312,11 +384,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "SystemPrompt", struct( ... - "Input",{{"system prompt","ApiKey","this-is-not-a-real-key"}}, ... + "Input",{{"system prompt","APIKey","this-is-not-a-real-key"}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -328,11 +400,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "Temperature", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","Temperature",2}}, ... + "Input",{{"APIKey","this-is-not-a-real-key","Temperature",2}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {2}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -343,12 +415,12 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) "ResponseFormat", {"text"} ... ) ... ), ... - "TopProbabilityMass", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","TopProbabilityMass",0.2}}, ... + "TopP", struct( ... + "Input",{{"APIKey","this-is-not-a-real-key","TopP",0.2}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {0.2}, ... + "TopP", {0.2}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -360,11 +432,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "StopSequences", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","StopSequences",["foo","bar"]}}, ... + "Input",{{"APIKey","this-is-not-a-real-key","StopSequences",["foo","bar"]}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {["foo","bar"]}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -376,11 +448,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "PresencePenalty", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","PresencePenalty",0.1}}, ... + "Input",{{"APIKey","this-is-not-a-real-key","PresencePenalty",0.1}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0.1}, ... "FrequencyPenalty", {0}, ... @@ -392,11 +464,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "FrequencyPenalty", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","FrequencyPenalty",0.1}}, ... + "Input",{{"APIKey","this-is-not-a-real-key","FrequencyPenalty",0.1}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0.1}, ... @@ -408,11 +480,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "TimeOut", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","TimeOut",0.1}}, ... + "Input",{{"APIKey","this-is-not-a-real-key","TimeOut",0.1}}, ... "ExpectedWarning", '', ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -424,11 +496,11 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) ) ... ), ... "ResponseFormat", struct( ... - "Input",{{"ApiKey","this-is-not-a-real-key","ResponseFormat","json"}}, ... + "Input",{{"APIKey","this-is-not-a-real-key","ResponseFormat","json"}}, ... "ExpectedWarning", "llms:warningJsonInstruction", ... "VerifyProperties", struct( ... "Temperature", {1}, ... - "TopProbabilityMass", {1}, ... + "TopP", {1}, ... "StopSequences", {{}}, ... "PresencePenalty", {0}, ... "FrequencyPenalty", {0}, ... @@ -454,7 +526,7 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) "Error", "MATLAB:validation:IncompatibleSize"), ... ... "InvalidResponseFormatModelCombination", struct( ... - "Input", {{"ApiKey", "this-is-not-a-real-key", "ModelName", "gpt-4", "ResponseFormat", "json"}}, ... + "Input", {{"APIKey", "this-is-not-a-real-key", "ModelName", "gpt-4", "ResponseFormat", "json"}}, ... "Error", "llms:invalidOptionAndValueForModel"), ... ... "InvalidStreamFunType", struct( ... @@ -517,20 +589,20 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) "Input",{{ "Temperature" -20 }},... "Error","MATLAB:expectedNonnegative"),... ... - "InvalidTopProbabilityMassType",struct( ... - "Input",{{ "TopProbabilityMass" "2" }},... + "InvalidTopPType",struct( ... + "Input",{{ "TopP" "2" }},... "Error","MATLAB:invalidType"),... ... - "InvalidTopProbabilityMassSize",struct( ... - "Input",{{ "TopProbabilityMass" [1 1 1] }},... + "InvalidTopPSize",struct( ... + "Input",{{ "TopP" [1 1 1] }},... "Error","MATLAB:expectedScalar"),... ... - "TopProbabilityMassTooLarge",struct( ... - "Input",{{ "TopProbabilityMass" 20 }},... + "TopPTooLarge",struct( ... + "Input",{{ "TopP" 20 }},... "Error","MATLAB:notLessEqual"),... ... - "TopProbabilityMassTooSmall",struct( ... - "Input",{{ "TopProbabilityMass" -20 }},... + "TopPTooSmall",struct( ... + "Input",{{ "TopP" -20 }},... "Error","MATLAB:expectedNonnegative"),... ... "WrongTypeStopSequences",struct( ... @@ -582,22 +654,22 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase) "Error","MATLAB:notGreaterEqual"),... ... "InvalidApiKeyType",struct( ... - "Input",{{ "ApiKey" 123 }},... + "Input",{{ "APIKey" 123 }},... "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidApiKeySize",struct( ... - "Input",{{ "ApiKey" ["abc" "abc"] }},... + "Input",{{ "APIKey" ["abc" "abc"] }},... "Error","MATLAB:validators:mustBeTextScalar")); end function invalidGenerateInput = iGetInvalidGenerateInput() -emptyMessages = openAIMessages; +emptyMessages = messageHistory; validMessages = addUserMessage(emptyMessages,"Who invented the telephone?"); invalidGenerateInput = struct( ... "EmptyInput",struct( ... "Input",{{ [] }},... - "Error","MATLAB:validation:IncompatibleSize"),... + "Error","llms:mustBeMessagesOrTxt"),... ... "InvalidInputType",struct( ... "Input",{{ 123 }},... diff --git a/tests/topenAIImages.m b/tests/topenAIImages.m index eb7fdb5..13f9466 100644 --- a/tests/topenAIImages.m +++ b/tests/topenAIImages.m @@ -25,9 +25,9 @@ function saveEnvVar(testCase) methods(Test) % Test methods function generateAcceptsSingleStringAsInput(testCase) - mdl = openAIImages(ApiKey="this-is-not-a-real-key"); + mdl = openAIImages(APIKey="this-is-not-a-real-key"); testCase.verifyWarningFree(@()generate(mdl,"This is okay")); - mdl = openAIImages(ApiKey='this-is-not-a-real-key'); + mdl = openAIImages(APIKey='this-is-not-a-real-key'); testCase.verifyWarningFree(@()generate(mdl,'This is okay')); end @@ -36,16 +36,16 @@ function keyNotFound(testCase) end function promptSizeLimit(testCase) - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-2"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-2"); testCase.verifyError(@()generate(mdl, repmat('c', 1, 1001)), "llms:promptLimitCharacter") testCase.verifyError(@()edit(mdl, which("peppers.png"), repmat('c', 1, 1001)), "llms:promptLimitCharacter") - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-3"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-3"); testCase.verifyError(@()generate(mdl, repmat('c', 1, 4001)), "llms:promptLimitCharacter") end function invalidOptionsGenerate(testCase) - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-2"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-2"); testCase.verifyError(@()generate(mdl, "cat", Quality="hd"), "llms:invalidOptionForModel") testCase.verifyError(@()generate(mdl, "cat", Style="natural"), "llms:invalidOptionForModel") testCase.verifyError(@()generate(mdl, "cat", Size="1024x1792"), "MATLAB:validators:mustBeMember") @@ -56,19 +56,19 @@ function invalidOptionsGenerate(testCase) function invalidModelEdit(testCase) validImage = string(which("peppers.png")); - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-3"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-3"); testCase.verifyError(@()edit(mdl, validImage, "cat"), "llms:functionNotAvailableForModel") end function invalidModelVariation(testCase) validImage = string(which("peppers.png")); - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-3"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-3"); testCase.verifyError(@()createVariation(mdl, validImage), ... "llms:functionNotAvailableForModel") end function generateWithAllNVP(testCase) - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-3"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-3"); testCase.verifyWarningFree(@()generate(mdl, ... "prompt", ... Quality="hd", ... @@ -79,7 +79,7 @@ function generateWithAllNVP(testCase) function editWithAllNVP(testCase) validImage = string(which("peppers.png")); - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-2"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-2"); testCase.verifyWarningFree(@()edit(mdl, ... validImage,... "prompt", ... @@ -90,7 +90,7 @@ function editWithAllNVP(testCase) function variationWithAllNVP(testCase) validImage = string(which("peppers.png")); - mdl = openAIImages(ApiKey="this-is-not-a-real-key", Model="dall-e-2"); + mdl = openAIImages(APIKey="this-is-not-a-real-key", Model="dall-e-2"); testCase.verifyWarningFree(@()createVariation(mdl, ... validImage,... Size="512x512",... @@ -101,12 +101,12 @@ function constructModelWithAllNVP(testCase) modelName = "dall-e-2"; apiKey = "this-is-not-a-real-key"; timeout = 3; - mdl = openAIImages(ModelName=modelName, ApiKey=apiKey, TimeOut=timeout); + mdl = openAIImages(ModelName=modelName, APIKey=apiKey, TimeOut=timeout); testCase.verifyEqual(mdl.ModelName, modelName); end function fakePNGImage(testCase) - mdl = openAIImages(ApiKey="this-is-not-a-real-key"); + mdl = openAIImages(APIKey="this-is-not-a-real-key"); fakePng = fullfile("test_files", "solar.png"); testCase.verifyError(@()edit(mdl,fakePng, "bla"), "llms:pngExpected"); end @@ -116,22 +116,22 @@ function invalidInputsConstructor(testCase, InvalidConstructorInput) end function invalidInputsGenerate(testCase, InvalidGenerateInput) - mdl = openAIImages(ApiKey="this-is-not-a-real-key"); + mdl = openAIImages(APIKey="this-is-not-a-real-key"); testCase.verifyError(@()generate(mdl,InvalidGenerateInput.Input{:}), InvalidGenerateInput.Error); end function invalidInputsEdit(testCase, InvalidEditInput) - mdl = openAIImages(ApiKey="this-is-not-a-real-key"); + mdl = openAIImages(APIKey="this-is-not-a-real-key"); testCase.verifyError(@()edit(mdl,InvalidEditInput.Input{:}), InvalidEditInput.Error); end function invalidInputsVariation(testCase, InvalidVariationInput) - mdl = openAIImages(ApiKey="this-is-not-a-real-key"); + mdl = openAIImages(APIKey="this-is-not-a-real-key"); testCase.verifyError(@()createVariation(mdl,InvalidVariationInput.Input{:}), InvalidVariationInput.Error); end function testThatImageIsReturned(testCase) - mdl = openAIImages(ApiKey=getenv("OPENAI_KEY")); + mdl = openAIImages(APIKey=getenv("OPENAI_KEY")); [images, response] = generate(mdl, ... "Create a 3D avatar of a whimsical sushi on the beach. " + ... @@ -163,11 +163,11 @@ function testThatImageIsReturned(testCase) "Error","MATLAB:validators:mustBeMember"),... ... "InvalidApiKeyType",struct( ... - "Input",{{ "ApiKey" 123 }},... + "Input",{{ "APIKey" 123 }},... "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidApiKeySize",struct( ... - "Input",{{ "ApiKey" ["abc" "abc"] }},... + "Input",{{ "APIKey" ["abc" "abc"] }},... "Error","MATLAB:validators:mustBeTextScalar")); end diff --git a/tests/topenAIMessages.m b/tests/topenAIMessages.m index 2aea3e4..57ac83d 100644 --- a/tests/topenAIMessages.m +++ b/tests/topenAIMessages.m @@ -1,360 +1,11 @@ classdef topenAIMessages < matlab.unittest.TestCase -% Tests for openAIMessages +% Tests for openAIMessages backward compatibility function % Copyright 2023-2024 The MathWorks, Inc. - properties(TestParameter) - InvalidInputsUserPrompt = iGetInvalidInputsUserPrompt(); - InvalidInputsUserImagesPrompt = iGetInvalidInputsUserImagesPrompt(); - InvalidInputsFunctionPrompt = iGetInvalidFunctionPrompt(); - InvalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt(); - InvalidInputsResponseMessage = iGetInvalidInputsResponseMessage(); - InvalidRemoveMessage = iGetInvalidRemoveMessage(); - InvalidFuncCallsCases = iGetInvalidFuncCallsCases() - ValidTextInput = {"This is okay"; 'this is ok'}; +methods(Test) + function returnsMessageHistory(testCase) + testCase.verifyClass(openAIMessages,"messageHistory"); end - - methods(Test) - function constructorStartsWithEmptyMessages(testCase) - msgs = openAIMessages; - testCase.verifyTrue(isempty(msgs.Messages)); - end - - function differentInputTextAccepted(testCase, ValidTextInput) - msgs = openAIMessages; - testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); - testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput)); - testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput)); - testCase.verifyWarningFree(@()addToolMessage(msgs, ValidTextInput, ValidTextInput, ValidTextInput)); - end - - - function systemMessageIsAdded(testCase) - prompt = "Here is a system prompt"; - name = "example"; - msgs = openAIMessages; - systemPrompt = struct("role", "system", "name", name, "content", prompt); - msgs = addSystemMessage(msgs, name, prompt); - testCase.verifyEqual(msgs.Messages{1}, systemPrompt); - end - - function userMessageIsAdded(testCase) - prompt = "Here is a user prompt"; - msgs = openAIMessages; - userPrompt = struct("role", "user", "content", prompt); - msgs = addUserMessage(msgs, prompt); - testCase.verifyEqual(msgs.Messages{1}, userPrompt); - end - - function userImageMessageIsAddedWithLocalImg(testCase) - prompt = "Here is a user prompt"; - msgs = openAIMessages; - img = "peppers.png"; - testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img)); - end - - function userImageMessageIsAddedWithRemoteImg(testCase) - prompt = "Here is a user prompt"; - msgs = openAIMessages; - img = "https://www.mathworks.com/help/examples/matlab/win64/DisplayGrayscaleRGBIndexedOrBinaryImageExample_04.png"; - testCase.verifyWarningFree(@()addUserMessageWithImages(msgs, prompt, img)); - end - - function toolMessageIsAdded(testCase) - prompt = "20"; - name = "sin"; - id = "123"; - msgs = openAIMessages; - systemPrompt = struct("tool_call_id", id, "role", "tool", "name", name, "content", prompt); - msgs = addToolMessage(msgs, id, name, prompt); - testCase.verifyEqual(msgs.Messages{1}, systemPrompt); - end - - function assistantMessageIsAdded(testCase) - prompt = "Here is an assistant prompt"; - msgs = openAIMessages; - assistantPrompt = struct("role", "assistant", "content", prompt); - msgs = addResponseMessage(msgs, assistantPrompt); - testCase.verifyEqual(msgs.Messages{1}, assistantPrompt); - end - - function assistantToolCallMessageIsAdded(testCase) - msgs = openAIMessages; - functionName = "functionName"; - args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; - funCall = struct("name", functionName, "arguments", args); - toolCall = struct("id", "123", "type", "function", "function", funCall); - toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); - % tool_calls is an array of struct in API response - toolCallPrompt.tool_calls = toolCall; - msgs = addResponseMessage(msgs, toolCallPrompt); - % to include in msgs, tool_calls must be a cell - testCase.verifyEqual(fieldnames(msgs.Messages{1}), fieldnames(toolCallPrompt)); - testCase.verifyEqual(msgs.Messages{1}.tool_calls{1}, toolCallPrompt.tool_calls); - end - - function errorsAssistantWithWithoutToolCallId(testCase) - msgs = openAIMessages; - functionName = "functionName"; - args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; - funCall = struct("name", functionName, "arguments", args); - toolCall = struct("type", "function", "function", funCall); - toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); - % tool_calls is an array of struct in API response - toolCallPrompt.tool_calls = toolCall; - - testCase.verifyError(@()addResponseMessage(msgs, toolCallPrompt), "llms:mustBeAssistantWithIdAndFunction"); - end - - function errorsAssistantWithToolCallsWithoutNameOrArgs(testCase, InvalidFuncCallsCases) - msgs = openAIMessages; - funCall = InvalidFuncCallsCases.FunCallStruct; - toolCall = struct("id", "123", "type", "function", "function", funCall); - toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); - % tool_calls is an array of struct in API response - toolCallPrompt.tool_calls = toolCall; - - testCase.verifyError(@()addResponseMessage(msgs, toolCallPrompt), InvalidFuncCallsCases.Error); - end - - function errorsAssistantWithWithNonTextNameAndArguments(testCase) - msgs = openAIMessages; - funCall = struct("name", 1, "arguments", 2); - toolCall = struct("id", "123", "type", "function", "function", funCall); - toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []); - % tool_calls is an array of struct in API response - toolCallPrompt.tool_calls = toolCall; - - testCase.verifyError(@()addResponseMessage(msgs, toolCallPrompt), "llms:assistantMustHaveTextNameAndArguments"); - end - - function assistantToolCallMessageWithoutArgsIsAdded(testCase) - msgs = openAIMessages; - functionName = "functionName"; - funCall = struct("name", functionName, "arguments", "{}"); - toolCall = struct("id", "123", "type", "function", "function", funCall); - % tool_calls is an array of struct in API response - toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall); - msgs = addResponseMessage(msgs, toolCallPrompt); - % to include in msgs, tool_calls must be a cell - testCase.verifyEqual(fieldnames(msgs.Messages{1}), fieldnames(toolCallPrompt)); - testCase.verifyEqual(msgs.Messages{1}.tool_calls{1}, toolCallPrompt.tool_calls); - end - - function assistantParallelToolCallMessageIsAdded(testCase) - msgs = openAIMessages; - functionName = "functionName"; - args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"; - funCall = struct("name", functionName, "arguments", args); - toolCall = struct("id", "123", "type", "function", "function", funCall); - % tool_calls is an array of struct in API response - toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", toolCall); - toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall]; - msgs = addResponseMessage(msgs, toolCallPrompt); - testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt); - end - - function messageGetsRemoved(testCase) - msgs = openAIMessages; - idx = 2; - - msgs = addSystemMessage(msgs, "name", "content"); - msgs = addUserMessage(msgs, "content"); - msgs = addToolMessage(msgs, "123", "name", "content"); - sizeMsgs = length(msgs.Messages); - % Message exists before removal - msgToBeRemoved = msgs.Messages{idx}; - testCase.verifyTrue(any(cellfun(@(c) isequal(c, msgToBeRemoved), msgs.Messages))); - - msgs = removeMessage(msgs, idx); - testCase.verifyFalse(any(cellfun(@(c) isequal(c, msgToBeRemoved), msgs.Messages))); - testCase.verifyEqual(length(msgs.Messages), sizeMsgs-1); - end - - function removalIdxCantBeLargerThanNumElements(testCase) - msgs = openAIMessages; - - msgs = addSystemMessage(msgs, "name", "content"); - msgs = addUserMessage(msgs, "content"); - msgs = addToolMessage(msgs, "123", "name", "content"); - sizeMsgs = length(msgs.Messages); - - testCase.verifyError(@()removeMessage(msgs, sizeMsgs+1), "llms:mustBeValidIndex"); - end - - function invalidInputsSystemPrompt(testCase, InvalidInputsSystemPrompt) - msgs = openAIMessages; - testCase.verifyError(@()addSystemMessage(msgs,InvalidInputsSystemPrompt.Input{:}), InvalidInputsSystemPrompt.Error); - end - - function invalidInputsUserPrompt(testCase, InvalidInputsUserPrompt) - msgs = openAIMessages; - testCase.verifyError(@()addUserMessage(msgs,InvalidInputsUserPrompt.Input{:}), InvalidInputsUserPrompt.Error); - end - - function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt) - msgs = openAIMessages; - testCase.verifyError(@()addUserMessageWithImages(msgs,InvalidInputsUserImagesPrompt.Input{:}), InvalidInputsUserImagesPrompt.Error); - end - - function invalidInputsFunctionPrompt(testCase, InvalidInputsFunctionPrompt) - msgs = openAIMessages; - testCase.verifyError(@()addToolMessage(msgs,InvalidInputsFunctionPrompt.Input{:}), InvalidInputsFunctionPrompt.Error); - end - - function invalidInputsRemove(testCase, InvalidRemoveMessage) - msgs = openAIMessages; - testCase.verifyError(@()removeMessage(msgs,InvalidRemoveMessage.Input{:}), InvalidRemoveMessage.Error); - end - - function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage) - msgs = openAIMessages; - testCase.verifyError(@()addResponseMessage(msgs,InvalidInputsResponseMessage.Input{:}), InvalidInputsResponseMessage.Error); - end - end end - -function invalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt() -invalidInputsSystemPrompt = struct( ... - "NonStringInputName", ... - struct("Input", {{123, "content"}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonStringInputContent", ... - struct("Input", {{"name", 123}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "EmptytName", ... - struct("Input", {{"", "content"}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "EmptytContent", ... - struct("Input", {{"name", ""}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonScalarInputName", ... - struct("Input", {{["name1" "name2"], "content"}}, ... - "Error", "MATLAB:validators:mustBeTextScalar"),... - ... - "NonScalarInputContent", ... - struct("Input", {{"name", ["content1", "content2"]}}, ... - "Error", "MATLAB:validators:mustBeTextScalar")); -end - -function invalidInputsUserPrompt = iGetInvalidInputsUserPrompt() -invalidInputsUserPrompt = struct( ... - "NonStringInput", ... - struct("Input", {{123}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonScalarInput", ... - struct("Input", {{["prompt1" "prompt2"]}}, ... - "Error", "MATLAB:validators:mustBeTextScalar"), ... - ... - "EmptyInput", ... - struct("Input", {{""}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText")); end - -function invalidInputsUserImagesPrompt = iGetInvalidInputsUserImagesPrompt() -invalidInputsUserImagesPrompt = struct( ... - "NonStringInput", ... - struct("Input", {{123, "peppers.png"}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonScalarInput", ... - struct("Input", {{["prompt1" "prompt2"], "peppers.png"}}, ... - "Error", "MATLAB:validators:mustBeTextScalar"), ... - ... - "EmptyInput", ... - struct("Input", {{"", "peppers.png"}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonTextImage", ... - struct("Input", {{"prompt", 123}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"),... - ... - "EmptyImageName", ... - struct("Input", {{"prompt", 123}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"),... - ... - "InvalidDetail", ... - struct("Input", {{"prompt", "peppers.png", "Detail", "invalid"}}, ... - "Error", "MATLAB:validators:mustBeMember")); -end - -function invalidFunctionPrompt = iGetInvalidFunctionPrompt() -invalidFunctionPrompt = struct( ... - "NonStringInputName", ... - struct("Input", {{"123", 123, "content"}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonStringInputContent", ... - struct("Input", {{"123", "name", 123}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "EmptytName", ... - struct("Input", {{"123", "", "content"}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "EmptytContent", ... - struct("Input", {{"123", "name", ""}}, ... - "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... - ... - "NonScalarInputName", ... - struct("Input", {{"123", ["name1" "name2"], "content"}}, ... - "Error", "MATLAB:validators:mustBeTextScalar"),... - ... - "NonScalarInputContent", ... - struct("Input", {{"123","name", ["content1", "content2"]}}, ... - "Error", "MATLAB:validators:mustBeTextScalar")); -end - -function invalidRemoveMessage = iGetInvalidRemoveMessage() -invalidRemoveMessage = struct( ... - "NonInteger", ... - struct("Input", {{0.5}}, ... - "Error", "MATLAB:validators:mustBeInteger"), ... - ... - "NonPositive", ... - struct("Input", {{0}}, ... - "Error", "MATLAB:validators:mustBePositive"), ... - ... - "NonScalarInput", ... - struct("Input", {{[1 2]}}, ... - "Error", "MATLAB:validation:IncompatibleSize")); -end - -function invalidInputsResponseMessage = iGetInvalidInputsResponseMessage() -invalidInputsResponseMessage = struct( ... - "NonStructInput", ... - struct("Input", {{123}},... - "Error", "MATLAB:validation:UnableToConvert"),... - ... - "NonExistentRole", ... - struct("Input", {{struct("role", "123", "content", "123")}},... - "Error", "llms:mustBeAssistantCall"),... - ... - "NonExistentContent", ... - struct("Input", {{struct("role", "assistant")}},... - "Error", "llms:mustBeAssistantCall"),... - ... - "EmptyContent", ... - struct("Input", {{struct("role", "assistant", "content", "")}},... - "Error", "llms:mustBeAssistantWithContent"),... - ... - "NonScalarContent", ... - struct("Input", {{struct("role", "assistant", "content", ["a", "b"])}},... - "Error", "llms:mustBeAssistantWithContent")); -end - -function invalidFuncCallsCases = iGetInvalidFuncCallsCases() -invalidFuncCallsCases = struct( ... - "NoArguments", ... - struct("FunCallStruct", struct("name", "functionName"),... - "Error", "llms:mustBeAssistantWithNameAndArguments"),... - ... - "NoName", ... - struct("FunCallStruct", struct("arguments", "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}"), ... - "Error", "llms:mustBeAssistantWithNameAndArguments")); -end \ No newline at end of file diff --git a/tests/tresponseStreamer.m b/tests/tresponseStreamer.m new file mode 100644 index 0000000..fc8450a --- /dev/null +++ b/tests/tresponseStreamer.m @@ -0,0 +1,73 @@ +classdef tresponseStreamer < matlab.unittest.TestCase +% Tests for llms.stream.reponseStreamer +% +% This test file contains unit tests, with a specific focus on edge cases that +% are hard to trigger in end-to-end tests. + +% Copyright 2024 The MathWorks, Inc. + + methods (Test) + function singleResponse(testCase) + s = tracingStreamer; + inp = 'data: {"choices":[{"content_filter_results":{},"delta":{"content":"foo","role":"assistant"}}]}'; + inp = [inp newline 'data: [DONE]']; + inp = unicode2native(inp,"UTF-8").'; + testCase.verifyTrue(s.doPutData(inp,false)); + testCase.verifyEqual(s.StreamFun(),"foo"); + end + + function skipEmpty(testCase) + s = tracingStreamer; + inp = [... + 'data: {"choices":[{"content_filter_results":{},"delta":{"content":"foo","role":"assistant"}}]}' newline ... + 'data: {"choices":[]}' newline ... + 'data: [DONE]']; + inp = unicode2native(inp,"UTF-8").'; + testCase.verifyTrue(s.doPutData(inp,false)); + testCase.verifyEqual(s.StreamFun(),"foo"); + end + + function splitResponse(testCase) + % it can happen that the server sends packets split in the + % middle of a JSON object. Hard to trigger on purpose. + s = tracingStreamer; + inp = 'data: {"choices":[{"content_filter_results":{},"delta":{"content":"foo","role":"assistant"}}]}'; + inp = unicode2native(inp,"UTF-8").'; + testCase.verifyFalse(s.doPutData(inp(1:42),false)); + testCase.verifyFalse(s.doPutData(inp(43:end),false)); + testCase.verifyEqual(s.StreamFun(),"foo"); + end + + function ollamaFormat(testCase) + s = tracingStreamer; + inp = '{"model":"mistral","created_at":"2024-06-07T07:43:30.658793Z","message":{"role":"assistant","content":" Hello"},"done":false}'; + inp = unicode2native(inp,"UTF-8").'; + testCase.verifyFalse(s.doPutData(inp,false)); + inp = '{"model":"mistral","created_at":"2024-06-07T07:43:30.658793Z","message":{"role":"assistant","content":" World"},"done":true}'; + inp = unicode2native(inp,"UTF-8").'; + testCase.verifyTrue(s.doPutData(inp,false)); + testCase.verifyEqual(s.StreamFun(),[" Hello"," World"]); + end + + function badJSON(testCase) + s = tracingStreamer; + inp = 'data: {"choices":[{"content_filter_results":{};"delta":{"content":"foo","role":"assistant"}}]}'; + inp = [inp newline inp]; + inp = unicode2native(inp,"UTF-8").'; + testCase.verifyError(@() s.doPutData(inp,false),'llms:stream:responseStreamer:InvalidInput'); + testCase.verifyEmpty(s.StreamFun()); + end + end +end + +function s = tracingStreamer + data = strings(1, 0); + function seen = sf(str) + % Append streamed text to an empty string array of length 1 + if nargin > 0 + data = [data, str]; + end + seen = data; + end + s = llms.stream.responseStreamer(@sf); +end