Skip to content

Commit

Permalink
Merge pull request #19 from matlab-deep-learning/fix-embedding-bugs
Browse files Browse the repository at this point in the history
Fix argument validation for extractOpenAIEmbeddings
  • Loading branch information
debymf committed Apr 20, 2024
2 parents d3e7389 + 48a7073 commit 69ca8b6
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
1 change: 1 addition & 0 deletions +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@
catalog("llms:pngExpected") = "Argument must be a PNG image.";
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
catalog("llms:apiReturnedError") = "OpenAI API Error: {1}";
catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}.";
end
20 changes: 17 additions & 3 deletions extractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
% [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full
% response from the OpenAI API call.
%
% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

arguments
text (1,:) {mustBeText}
text (1,:) {mustBeNonzeroLengthText}
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ...
"text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002"
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
nvp.Dimensions (1,1) {mustBeInteger}
nvp.Dimensions (1,1) {mustBeInteger,mustBePositive}
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
end

Expand All @@ -42,6 +42,7 @@
error("llms:invalidOptionForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Dimensions", nvp.ModelName));
end
mustBeCorrectDimensions(nvp.Dimensions,nvp.ModelName);
parameters.dimensions = nvp.Dimensions;
end

Expand All @@ -53,4 +54,17 @@
emb = emb';
else
emb = [];
end
end

function mustBeCorrectDimensions(dimensions,modelName)
model2dim = ....
dictionary(["text-embedding-3-large", "text-embedding-3-small"], ...
[3072,1536]);

if dimensions>model2dim(modelName)
error("llms:dimensionsMustBeSmallerThan", ...
llms.utils.errorMessageCatalog.getMessage("llms:dimensionsMustBeSmallerThan", ...
string(model2dim(modelName))));
end
end
26 changes: 24 additions & 2 deletions tests/textractOpenAIEmbeddings.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
classdef textractOpenAIEmbeddings < matlab.unittest.TestCase
% Tests for extractOpenAIEmbeddings

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

methods (TestClassSetup)
function saveEnvVar(testCase)
Expand Down Expand Up @@ -56,6 +56,14 @@ function testInvalidInputs(testCase, InvalidInput)

function invalidInput = iGetInvalidInput
invalidInput = struct( ...
"InvalidEmptyText", struct( ...
"Input",{{ "" }},...
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
...
"InvalidEmptyTextArray", struct( ...
"Input",{{ ["", ""] }},...
"Error", "MATLAB:validators:mustBeNonzeroLengthText"), ...
...
"InvalidTimeOutType", struct( ...
"Input",{{ "bla", "TimeOut", "2" }},...
"Error", "MATLAB:validators:mustBeReal"), ...
Expand All @@ -66,7 +74,7 @@ function testInvalidInputs(testCase, InvalidInput)
...
"WrongTypeText",struct( ...
"Input",{{ 123 }},...
"Error","MATLAB:validators:mustBeText"),...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidModelNameType",struct( ...
"Input",{{"bla", "ModelName", 0 }},...
Expand All @@ -84,6 +92,20 @@ function testInvalidInputs(testCase, InvalidInput)
"Input",{{"bla", "Dimensions", "123" }},...
"Error","MATLAB:validators:mustBeNumericOrLogical"),...
...
"InvalidDimensionValue",struct( ...
"Input",{{"bla", "Dimensions", "-11" }},...
"Error","MATLAB:validators:mustBeNumericOrLogical"),...
...
"LargeDimensionValueForModelLarge",struct( ...
"Input",{{"bla", "ModelName", "text-embedding-3-large", ...
"Dimensions", 3073, "ApiKey", "fake-key" }},...
"Error","llms:dimensionsMustBeSmallerThan"),...
...
"LargeDimensionValueForModelSmall",struct( ...
"Input",{{"bla", "ModelName", "text-embedding-3-small", ...
"Dimensions", 1537, "ApiKey", "fake-key" }},...
"Error","llms:dimensionsMustBeSmallerThan"),...
...
"InvalidDimensionSize",struct( ...
"Input",{{"bla", "Dimensions", [123, 123] }},...
"Error","MATLAB:validation:IncompatibleSize"),...
Expand Down

0 comments on commit 69ca8b6

Please sign in to comment.