Skip to content

Commit

Permalink
Merge pull request #21 from matlab-deep-learning/fix-image-gen-bugs
Browse files Browse the repository at this point in the history
Fixing checks for images and empty prompts.
  • Loading branch information
debymf committed Apr 23, 2024
2 parents 69ca8b6 + 2d27fec commit 83ea7d0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
10 changes: 6 additions & 4 deletions openAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
% Only "dall-e-3" supports this parameter.

arguments
this (1,1) openAIImages
prompt {mustBeTextScalar}
this (1,1) openAIImages
prompt {mustBeNonzeroLengthTextScalar}
nvp.NumImages (1,1) {mustBePositive, mustBeInteger,...
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
nvp.Size (1,1) string {mustBeMember(nvp.Size, ["256x256", "512x512", ...
Expand Down Expand Up @@ -176,7 +176,7 @@
arguments
this (1,1) openAIImages
imagePath {mustBeValidFileType(imagePath)}
prompt {mustBeTextScalar}
prompt {mustBeNonzeroLengthTextScalar}
nvp.MaskImagePath {mustBeValidFileType(nvp.MaskImagePath)}
nvp.NumImages (1,1) {mustBePositive, mustBeInteger,...
mustBeLessThanOrEqual(nvp.NumImages,10)} = 1
Expand Down Expand Up @@ -345,7 +345,9 @@ function validatePromptSize(model, prompt)
function mustBeValidFileType(filePath)
mustBeFile(filePath);
s = dir(filePath);
if ~endsWith(s.name, ".png")
imgDetails = imfinfo(filePath);
imgFormat = imgDetails.Format;
if ~(imgFormat=="png")
error("llms:pngExpected", ...
llms.utils.errorMessageCatalog.getMessage("llms:pngExpected"));
end
Expand Down
Binary file added tests/test_files/solar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 18 additions & 4 deletions tests/topenAIImages.m
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ function constructModelWithAllNVP(testCase)
testCase.verifyEqual(mdl.ModelName, modelName);
end

function fakePNGImage(testCase)
mdl = openAIImages(ApiKey="this-is-not-a-real-key");
fakePng = fullfile("test_files", "solar.png");
testCase.verifyError(@()edit(mdl,fakePng, "bla"), "llms:pngExpected");
end

function invalidInputsConstructor(testCase, InvalidConstructorInput)
testCase.verifyError(@()openAIImages(InvalidConstructorInput.Input{:}), InvalidConstructorInput.Error);
end
Expand Down Expand Up @@ -157,11 +163,15 @@ function invalidInputsVariation(testCase, InvalidVariationInput)
invalidGenerateInput = struct( ...
"EmptyInput",struct( ...
"Input",{{ [] }},...
"Error","MATLAB:validators:mustBeTextScalar"),...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidInputType",struct( ...
"Input",{{ 123 }},...
"Error","MATLAB:validators:mustBeTextScalar"),...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidPromptLen",struct( ...
"Input",{{ "" }},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidNumImagesType",struct( ...
"Input",{{ "prompt" "NumImages" "2" }},...
Expand Down Expand Up @@ -233,17 +243,21 @@ function invalidInputsVariation(testCase, InvalidVariationInput)
"Input",{{ 123, "prompt" }},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidPromptLen",struct( ...
"Input",{{ validImage, "" }},...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidImageExtension",struct( ...
"Input",{{ nonPNGImage, "prompt" }},...
"Error","llms:pngExpected"),...
...
"EmptyPrompt",struct( ...
"Input",{{ validImage, [] }},...
"Error","MATLAB:validators:mustBeTextScalar"),...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidPromptType",struct( ...
"Input",{{ validImage, 123 }},...
"Error","MATLAB:validators:mustBeTextScalar"),...
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
...
"InvalidMaskImage",struct( ...
"Input",{{ validImage, "foo", "MaskImagePath", 123 }},...
Expand Down

0 comments on commit 83ea7d0

Please sign in to comment.