diff --git a/+llms/+utils/errorMessageCatalog.m b/+llms/+utils/errorMessageCatalog.m index caf4c5e..fe13f0d 100644 --- a/+llms/+utils/errorMessageCatalog.m +++ b/+llms/+utils/errorMessageCatalog.m @@ -55,4 +55,5 @@ catalog("llms:pngExpected") = "Argument must be a PNG image."; catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message."; catalog("llms:apiReturnedError") = "OpenAI API Error: {1}"; +catalog("llms:dimensionsMustBeSmallerThan") = "Dimensions must be less than or equal to {1}."; end \ No newline at end of file diff --git a/extractOpenAIEmbeddings.m b/extractOpenAIEmbeddings.m index 9660052..4be564c 100644 --- a/extractOpenAIEmbeddings.m +++ b/extractOpenAIEmbeddings.m @@ -20,14 +20,14 @@ % [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full % response from the OpenAI API call. % -% Copyright 2023 The MathWorks, Inc. +% Copyright 2023-2024 The MathWorks, Inc. arguments - text (1,:) {mustBeText} + text (1,:) {mustBeNonzeroLengthText} nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["text-embedding-ada-002", ... "text-embedding-3-large", "text-embedding-3-small"])} = "text-embedding-ada-002" nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10 - nvp.Dimensions (1,1) {mustBeInteger} + nvp.Dimensions (1,1) {mustBeInteger,mustBePositive} nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar} end @@ -42,6 +42,7 @@ error("llms:invalidOptionForModel", ... llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Dimensions", nvp.ModelName)); end + mustBeCorrectDimensions(nvp.Dimensions,nvp.ModelName); parameters.dimensions = nvp.Dimensions; end @@ -53,4 +54,17 @@ emb = emb'; else emb = []; +end +end + +function mustBeCorrectDimensions(dimensions,modelName) + model2dim = .... + dictionary(["text-embedding-3-large", "text-embedding-3-small"], ... + [3072,1536]); + + if dimensions>model2dim(modelName) + error("llms:dimensionsMustBeSmallerThan", ... + llms.utils.errorMessageCatalog.getMessage("llms:dimensionsMustBeSmallerThan", ... + string(model2dim(modelName)))); + end end \ No newline at end of file diff --git a/tests/textractOpenAIEmbeddings.m b/tests/textractOpenAIEmbeddings.m index a58c1f6..1bf33a6 100644 --- a/tests/textractOpenAIEmbeddings.m +++ b/tests/textractOpenAIEmbeddings.m @@ -1,7 +1,7 @@ classdef textractOpenAIEmbeddings < matlab.unittest.TestCase % Tests for extractOpenAIEmbeddings -% Copyright 2023 The MathWorks, Inc. +% Copyright 2023-2024 The MathWorks, Inc. methods (TestClassSetup) function saveEnvVar(testCase) @@ -56,6 +56,14 @@ function testInvalidInputs(testCase, InvalidInput) function invalidInput = iGetInvalidInput invalidInput = struct( ... + "InvalidEmptyText", struct( ... + "Input",{{ "" }},... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... + "InvalidEmptyTextArray", struct( ... + "Input",{{ ["", ""] }},... + "Error", "MATLAB:validators:mustBeNonzeroLengthText"), ... + ... "InvalidTimeOutType", struct( ... "Input",{{ "bla", "TimeOut", "2" }},... "Error", "MATLAB:validators:mustBeReal"), ... @@ -66,7 +74,7 @@ function testInvalidInputs(testCase, InvalidInput) ... "WrongTypeText",struct( ... "Input",{{ 123 }},... - "Error","MATLAB:validators:mustBeText"),... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidModelNameType",struct( ... "Input",{{"bla", "ModelName", 0 }},... @@ -84,6 +92,20 @@ function testInvalidInputs(testCase, InvalidInput) "Input",{{"bla", "Dimensions", "123" }},... "Error","MATLAB:validators:mustBeNumericOrLogical"),... ... + "InvalidDimensionValue",struct( ... + "Input",{{"bla", "Dimensions", "-11" }},... + "Error","MATLAB:validators:mustBeNumericOrLogical"),... + ... + "LargeDimensionValueForModelLarge",struct( ... + "Input",{{"bla", "ModelName", "text-embedding-3-large", ... + "Dimensions", 3073, "ApiKey", "fake-key" }},... + "Error","llms:dimensionsMustBeSmallerThan"),... + ... + "LargeDimensionValueForModelSmall",struct( ... + "Input",{{"bla", "ModelName", "text-embedding-3-small", ... + "Dimensions", 1537, "ApiKey", "fake-key" }},... + "Error","llms:dimensionsMustBeSmallerThan"),... + ... "InvalidDimensionSize",struct( ... "Input",{{"bla", "Dimensions", [123, 123] }},... "Error","MATLAB:validation:IncompatibleSize"),...