diff --git a/openAIImages.m b/openAIImages.m index e60aa16..479c60a 100644 --- a/openAIImages.m +++ b/openAIImages.m @@ -85,8 +85,8 @@ % Only "dall-e-3" supports this parameter. arguments - this (1,1) openAIImages - prompt {mustBeTextScalar} + this (1,1) openAIImages + prompt {mustBeNonzeroLengthTextScalar} nvp.NumImages (1,1) {mustBePositive, mustBeInteger,... mustBeLessThanOrEqual(nvp.NumImages,10)} = 1 nvp.Size (1,1) string {mustBeMember(nvp.Size, ["256x256", "512x512", ... @@ -176,7 +176,7 @@ arguments this (1,1) openAIImages imagePath {mustBeValidFileType(imagePath)} - prompt {mustBeTextScalar} + prompt {mustBeNonzeroLengthTextScalar} nvp.MaskImagePath {mustBeValidFileType(nvp.MaskImagePath)} nvp.NumImages (1,1) {mustBePositive, mustBeInteger,... mustBeLessThanOrEqual(nvp.NumImages,10)} = 1 @@ -345,7 +345,9 @@ function validatePromptSize(model, prompt) function mustBeValidFileType(filePath) mustBeFile(filePath); s = dir(filePath); - if ~endsWith(s.name, ".png") + imgDetails = imfinfo(filePath); + imgFormat = imgDetails.Format; + if ~(imgFormat=="png") error("llms:pngExpected", ... llms.utils.errorMessageCatalog.getMessage("llms:pngExpected")); end diff --git a/tests/test_files/solar.png b/tests/test_files/solar.png new file mode 100644 index 0000000..fc37277 Binary files /dev/null and b/tests/test_files/solar.png differ diff --git a/tests/topenAIImages.m b/tests/topenAIImages.m index 4e70f9a..2b625e1 100644 --- a/tests/topenAIImages.m +++ b/tests/topenAIImages.m @@ -105,6 +105,12 @@ function constructModelWithAllNVP(testCase) testCase.verifyEqual(mdl.ModelName, modelName); end + function fakePNGImage(testCase) + mdl = openAIImages(ApiKey="this-is-not-a-real-key"); + fakePng = fullfile("test_files", "solar.png"); + testCase.verifyError(@()edit(mdl,fakePng, "bla"), "llms:pngExpected"); + end + function invalidInputsConstructor(testCase, InvalidConstructorInput) testCase.verifyError(@()openAIImages(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error); end @@ -157,11 +163,15 @@ function invalidInputsVariation(testCase, InvalidVariationInput) invalidGenerateInput = struct( ... "EmptyInput",struct( ... "Input",{{ [] }},... - "Error","MATLAB:validators:mustBeTextScalar"),... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidInputType",struct( ... "Input",{{ 123 }},... - "Error","MATLAB:validators:mustBeTextScalar"),... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... + "InvalidPromptLen",struct( ... + "Input",{{ "" }},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidNumImagesType",struct( ... "Input",{{ "prompt" "NumImages" "2" }},... @@ -233,17 +243,21 @@ function invalidInputsVariation(testCase, InvalidVariationInput) "Input",{{ 123, "prompt" }},... "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... + "InvalidPromptLen",struct( ... + "Input",{{ validImage, "" }},... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... + ... "InvalidImageExtension",struct( ... "Input",{{ nonPNGImage, "prompt" }},... "Error","llms:pngExpected"),... ... "EmptyPrompt",struct( ... "Input",{{ validImage, [] }},... - "Error","MATLAB:validators:mustBeTextScalar"),... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidPromptType",struct( ... "Input",{{ validImage, 123 }},... - "Error","MATLAB:validators:mustBeTextScalar"),... + "Error","MATLAB:validators:mustBeNonzeroLengthText"),... ... "InvalidMaskImage",struct( ... "Input",{{ validImage, "foo", "MaskImagePath", 123 }},...