Skip to content

Commit

Permalink
Merge pull request #8 from matlab-deep-learning/AzureAPI
Browse files Browse the repository at this point in the history
Adding support to Azure API
  • Loading branch information
ccreutzi committed Jun 24, 2024
2 parents 38edd99 + 1ac24ff commit 05c861b
Show file tree
Hide file tree
Showing 50 changed files with 3,671 additions and 1,397 deletions.
12 changes: 12 additions & 0 deletions +llms/+azure/apiVersions.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function versions = apiVersions
%VERSIONS - supported azure API versions

% Copyright 2024 The MathWorks, Inc.
versions = [...
"2024-05-01-preview", ...
"2024-04-01-preview", ...
"2024-03-01-preview", ...
"2024-02-01", ...
"2023-05-15", ...
];
end
130 changes: 130 additions & 0 deletions +llms/+internal/callAzureChatAPI.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
function [text, message, response] = callAzureChatAPI(endpoint, deploymentID, messages, functions, nvp)
% This function is undocumented and will change in a future release

%callAzureChatAPI Calls the openAI chat completions API on Azure.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% More details on the parameters: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt
%
% Example
%
% % Create messages struct
% messages = {struct("role", "system",...
% "content", "You are a helpful assistant");
% struct("role", "user", ...
% "content", "What is the edit distance between hi and hello?")};
%
% % Create functions struct
% functions = {struct("name", "editDistance", ...
% "description", "Find edit distance between two strings or documents.", ...
% "parameters", struct( ...
% "type", "object", ...
% "properties", struct(...
% "str1", struct(...
% "description", "Source string.", ...
% "type", "string"),...
% "str2", struct(...
% "description", "Target string.", ...
% "type", "string")),...
% "required", ["str1", "str2"]))};
%
% % Define your API key
% apiKey = "your-api-key-here"
%
% % Send a request
% [text, message] = llms.internal.callAzureChatAPI(messages, functions, APIKey=apiKey)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
endpoint
deploymentID
messages
functions
nvp.ToolChoice
nvp.APIVersion
nvp.Temperature
nvp.TopP
nvp.NumCompletions
nvp.StopSequences
nvp.MaxNumTokens
nvp.PresencePenalty
nvp.FrequencyPenalty
nvp.ResponseFormat
nvp.Seed
nvp.APIKey
nvp.TimeOut
nvp.StreamFun
end

URL = endpoint + "openai/deployments/" + deploymentID + "/chat/completions?api-version=" + nvp.APIVersion;

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, URL, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
if response.StatusCode=="OK"
% Outputs the first generation
if isempty(nvp.StreamFun)
message = response.Body.Data.choices(1).message;
else
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "tool_choice")
text = "";
else
text = string(message.content);
end
else
text = "";
message = struct();
end
end

function parameters = buildParametersCall(messages, functions, nvp)
% Builds a struct in the format that is expected by the API, combining
% MESSAGES, FUNCTIONS and parameters in NVP.

parameters = struct();
parameters.messages = messages;

parameters.stream = ~isempty(nvp.StreamFun);

if ~isempty(functions)
parameters.tools = functions;
end

if ~isempty(nvp.ToolChoice)
parameters.tool_choice = nvp.ToolChoice;
end

if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
end

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for opt = nvpOptions.'
if isfield(nvp, opt)
parameters.(dict(opt)) = nvp.(opt);
end
end
end

function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("NumCompletions") = "n";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "max_tokens";
dict("PresencePenalty") = "presence_penalty";
dict("FrequencyPenalty") = "frequency_penalty";
end
106 changes: 106 additions & 0 deletions +llms/+internal/callOllamaChatAPI.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
function [text, message, response] = callOllamaChatAPI(model, messages, nvp)
% This function is undocumented and will change in a future release

%callOllamaChatAPI Calls the Ollama® chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the Ollama Chat Completions API.
% Ref: https://github.com/ollama/ollama/blob/main/docs/api.md
%
% More details on the parameters: https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
%
% Example
%
% model = "mistral";
%
% % Create messages struct
% messages = {struct("role", "system",...
% "content", "You are a helpful assistant");
% struct("role", "user", ...
% "content", "What is the edit distance between hi and hello?")};
%
% % Send a request
% [text, message] = llms.internal.callOllamaChatAPI(model, messages)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
model
messages
nvp.Temperature
nvp.TopP
nvp.TopK
nvp.TailFreeSamplingZ
nvp.StopSequences
nvp.MaxNumTokens
nvp.ResponseFormat
nvp.Seed
nvp.TimeOut
nvp.StreamFun
end

URL = "http://localhost:11434/api/chat";

% The JSON for StopSequences must have an array, and cannot say "stop": "foo".
% The easiest way to ensure that is to never pass in a scalar …
if isscalar(nvp.StopSequences)
nvp.StopSequences = [nvp.StopSequences, nvp.StopSequences];
end

parameters = buildParametersCall(model, messages, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,[],URL,nvp.TimeOut,nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
if response.StatusCode=="OK"
% Outputs the first generation
if isempty(nvp.StreamFun)
message = response.Body.Data.message;
else
message = struct("role", "assistant", ...
"content", streamedText);
end
text = string(message.content);
else
text = "";
message = struct();
end
end

function parameters = buildParametersCall(model, messages, nvp)
% Builds a struct in the format that is expected by the API, combining
% MESSAGES, FUNCTIONS and parameters in NVP.

parameters = struct();
parameters.model = model;
parameters.messages = messages;

parameters.stream = ~isempty(nvp.StreamFun);

options = struct;
if ~isempty(nvp.Seed)
options.seed = nvp.Seed;
end

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for opt = nvpOptions.'
if isfield(nvp, opt) && ~isempty(nvp.(opt)) && ~isequaln(nvp.(opt),Inf)
options.(dict(opt)) = nvp.(opt);
end
end

parameters.options = options;
end

function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopP") = "top_p";
dict("TopK") = "top_k";
dict("TailFreeSamplingZ") = "tfs_z";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "num_predict";
end
51 changes: 19 additions & 32 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp)
% This function is undocumented and will change in a future release

%callOpenAIChatAPI Calls the openAI chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% Currently, the supported NVP are, including the equivalent name in the API:
% - ToolChoice (tool_choice)
% - ModelName (model)
% - Temperature (temperature)
% - TopProbabilityMass (top_p)
% - NumCompletions (n)
% - StopSequences (stop)
% - MaxNumTokens (max_tokens)
% - PresencePenalty (presence_penalty)
% - FrequencyPenalty (frequence_penalty)
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - TimeOut
% - StreamFun
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
%
% Example
Expand Down Expand Up @@ -48,34 +35,34 @@
% apiKey = "your-api-key-here"
%
% % Send a request
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, APIKey=apiKey)

% Copyright 2023-2024 The MathWorks, Inc.

arguments
messages
functions
nvp.ToolChoice = []
nvp.ModelName = "gpt-3.5-turbo"
nvp.Temperature = 1
nvp.TopProbabilityMass = 1
nvp.NumCompletions = 1
nvp.StopSequences = []
nvp.MaxNumTokens = inf
nvp.PresencePenalty = 0
nvp.FrequencyPenalty = 0
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
nvp.ToolChoice
nvp.ModelName
nvp.Temperature
nvp.TopP
nvp.NumCompletions
nvp.StopSequences
nvp.MaxNumTokens
nvp.PresencePenalty
nvp.FrequencyPenalty
nvp.ResponseFormat
nvp.Seed
nvp.APIKey
nvp.TimeOut
nvp.StreamFun
end

END_POINT = "https://api.openai.com/v1/chat/completions";

parameters = buildParametersCall(messages, functions, nvp);

[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.APIKey, END_POINT, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
Expand Down Expand Up @@ -160,7 +147,7 @@
function dict = mapNVPToParameters()
dict = dictionary();
dict("Temperature") = "temperature";
dict("TopProbabilityMass") = "top_p";
dict("TopP") = "top_p";
dict("NumCompletions") = "n";
dict("StopSequences") = "stop";
dict("MaxNumTokens") = "max_tokens";
Expand Down
26 changes: 13 additions & 13 deletions +llms/+internal/getApiKeyFromNvpOrEnv.m
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
function key = getApiKeyFromNvpOrEnv(nvp)
function key = getApiKeyFromNvpOrEnv(nvp,envVarName)
% This function is undocumented and will change in a future release

%getApiKeyFromNvpOrEnv Retrieves an API key from a Name-Value Pair struct or environment variable.
%
% This function takes a struct nvp containing name-value pairs and checks
% if it contains a field called "ApiKey". If the field is not found,
% the function attempts to retrieve the API key from an environment
% variable called "OPENAI_API_KEY". If both methods fail, the function
% throws an error.
% This function takes a struct nvp containing name-value pairs and checks if
% it contains a field called "APIKey". If the field is not found, the
% function attempts to retrieve the API key from an environment variable
% whose name is given as the second argument. If both methods fail, the
% function throws an error.

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

if isfield(nvp, "ApiKey")
key = nvp.ApiKey;
if isfield(nvp, "APIKey")
key = nvp.APIKey;
else
if isenv("OPENAI_API_KEY")
key = getenv("OPENAI_API_KEY");
if isenv(envVarName)
key = getenv(envVarName);
else
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified"));
error("llms:keyMustBeSpecified", llms.utils.errorMessageCatalog.getMessage("llms:keyMustBeSpecified", envVarName));
end
end
end
end
12 changes: 12 additions & 0 deletions +llms/+internal/gptPenalties.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
classdef (Abstract) gptPenalties
% This class is undocumented and will change in a future release

% Copyright 2024 The MathWorks, Inc.
properties
%PRESENCEPENALTY Penalty for using a token in the response that has already been used.
PresencePenalty {llms.utils.mustBeValidPenalty} = 0

%FREQUENCYPENALTY Penalty for using a token that is frequent in the training data.
FrequencyPenalty {llms.utils.mustBeValidPenalty} = 0
end
end
Loading

0 comments on commit 05c861b

Please sign in to comment.