Skip to content

Commit

Permalink
Merge branch 'streaming' into 'main'
Browse files Browse the repository at this point in the history
Streaming

See merge request dferreir/llms-with-matlab!12
  • Loading branch information
debymf committed Dec 22, 2023
2 parents 744e646 + b7d1a73 commit c1815ad
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 9 deletions.
14 changes: 12 additions & 2 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
% - FrequencyPenalty (frequence_penalty)
% - ApiKey
% - TimeOut
% - StreamFun
% More details on the parameters: https://platform.openai.com/docs/api-reference/chat/create
%
% Example
Expand Down Expand Up @@ -65,19 +66,25 @@
nvp.FrequencyPenalty = 0
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
end

END_POINT = "https://api.openai.com/v1/chat/completions";

parameters = buildParametersCall(messages, functions, nvp);

response = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut);
[response, streamedText] = llms.internal.sendRequest(parameters,nvp.ApiKey, END_POINT, nvp.TimeOut, nvp.StreamFun);

% If call errors, "choices" will not be part of response.Body.Data, instead
% we get response.Body.Data.error
if response.StatusCode=="OK"
% Outputs the first generation
message = response.Body.Data.choices(1).message;
if isempty(nvp.StreamFun)
message = response.Body.Data.choices(1).message;
else
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "function_call")
text = "";
else
Expand All @@ -95,6 +102,9 @@

parameters = struct();
parameters.messages = messages;

parameters.stream = ~isempty(nvp.StreamFun);

if ~isempty(functions)
parameters.functions = functions;
end
Expand Down
20 changes: 14 additions & 6 deletions +llms/+internal/sendRequest.m
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
function response = sendRequest(parameters, token, endpoint, timeout)
% This function is undocumented and will change in a future release

function [response, streamedText] = sendRequest(parameters, token, endpoint, timeout, streamFun)
%sendRequest Sends a request to an ENDPOINT using PARAMETERS and
% api key TOKEN. TIMEOUT is the nubmer of seconds to wait for initial
% server connection.
% server connection. STREAMFUN is an optional callback function.

% Copyright 2023 The MathWorks, Inc.

Expand All @@ -12,6 +10,7 @@
token
endpoint
timeout
streamFun
end

% Define the headers for the API request
Expand All @@ -24,9 +23,18 @@

% Create a HTTPOptions object;
httpOpts = matlab.net.http.HTTPOptions;
% Set the ConnectTimeout option

% Set the ConnectTimeout option
httpOpts.ConnectTimeout = timeout;

% Send the request and store the response
response = send(request, matlab.net.URI(endpoint),httpOpts);
if isempty(streamFun)
response = send(request, matlab.net.URI(endpoint),httpOpts);
streamedText = "";
else
% User defined a stream callback function
consumer = llms.stream.responseStreamer(streamFun);
response = send(request, matlab.net.URI(endpoint),httpOpts,consumer);
streamedText = consumer.ResponseText;
end
end
51 changes: 51 additions & 0 deletions +llms/+stream/responseStreamer.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
classdef responseStreamer < matlab.net.http.io.StringConsumer
%responseStreamer Responsible for obtaining the streaming results from the
%API

% Copyright 2023 The MathWorks, Inc.

properties
ResponseText
StreamFun
end

methods
function this = responseStreamer(streamFun)
this.StreamFun = streamFun;
end
end

methods (Access=protected)
function length = start(this)
if this.Response.StatusCode ~= matlab.net.http.StatusCode.OK
length = 0;
else
length = this.start@matlab.net.http.io.StringConsumer;
end
end
end

methods
function [len,stop] = putData(this, data)
[len,stop] = this.putData@matlab.net.http.io.StringConsumer(data);

% Extract out the response text from the message
str = native2unicode(data','UTF-8');
str = split(str,newline);
str = str(strlength(str)>0);
str = erase(str,"data: ");

for i = 1:length(str)
json = jsondecode(str{i});
if strcmp(json.choices.finish_reason,'stop')
stop = true;
return
else
txt = json.choices.delta.content;
this.StreamFun(txt);
this.ResponseText = [this.ResponseText txt];
end
end
end
end
end
13 changes: 12 additions & 1 deletion openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
% FrequencyPenalty - Penalty value for using a token that is frequent
% in the training data. Default value is 0.
%
% StreamFun - Function to callback when streaming the
% result
%
% openAIChat Functions:
% openAIChat - Chat completion API from OpenAI.
% generate - Generate a response using the openAIChat instance.
Expand Down Expand Up @@ -95,6 +98,7 @@
Functions
FunctionsStruct
ApiKey
StreamFun
end

methods
Expand All @@ -112,6 +116,13 @@
nvp.PresencePenalty {mustBeValidPenalty} = 0
nvp.FrequencyPenalty {mustBeValidPenalty} = 0
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
nvp.StreamFun (1,1) {mustBeA(nvp.StreamFun,'function_handle')}
end

if isfield(nvp,"StreamFun")
this.StreamFun = nvp.StreamFun;
else
this.StreamFun = [];
end

if ~isempty(nvp.Functions)
Expand Down Expand Up @@ -182,7 +193,7 @@
TopProbabilityMass=this.TopProbabilityMass, NumCompletions=nvp.NumCompletions,...
StopSequences=this.StopSequences, MaxNumTokens=nvp.MaxNumTokens, ...
PresencePenalty=this.PresencePenalty, FrequencyPenalty=this.FrequencyPenalty, ...
ApiKey=this.ApiKey,TimeOut=this.TimeOut);
ApiKey=this.ApiKey,TimeOut=this.TimeOut, StreamFun=this.StreamFun);
end

function this = set.Temperature(this, temperature)
Expand Down
8 changes: 8 additions & 0 deletions tests/topenAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ function assignValueToProperty(property, value)
function invalidConstructorInput = iGetInvalidConstructorInput
validFunction = openAIFunction("funName");
invalidConstructorInput = struct( ...
"InvalidStreamFunType", struct( ...
"Input",{{"StreamFun", "2" }},...
"Error", "MATLAB:validators:mustBeA"), ...
...
"InvalidStreamFunSize", struct( ...
"Input",{{"StreamFun", [1 1 1] }},...
"Error", "MATLAB:validation:IncompatibleSize"), ...
...
"InvalidTimeOutType", struct( ...
"Input",{{"TimeOut", "2" }},...
"Error", "MATLAB:validators:mustBeReal"), ...
Expand Down

0 comments on commit c1815ad

Please sign in to comment.