Skip to content

Commit

Permalink
feat: add new model in gen image
Browse files Browse the repository at this point in the history
  • Loading branch information
redevrx committed Jan 18, 2024
1 parent 1f711f1 commit f359c13
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 31 deletions.
2 changes: 1 addition & 1 deletion example/lib/generate_img_screen.dart
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class _GenImgScreenState extends State<GenImgScreen> {
const prompt = "Snake eat cat.";

final request = GenerateImage(prompt, 1,
size: ImageSize.size256, responseFormat: Format.url);
model: DallE3(), size: ImageSize.size256, responseFormat: Format.url);
final response = await openAI.generateImage(request);
setState(() {
img = "${response?.data?.last?.url}";
Expand Down
2 changes: 1 addition & 1 deletion example_app/openai_app/lib/bloc/openai/openai_bloc.dart
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class OpenAIBloc extends Cubit<OpenAIState> {
///[generateImage]
void generateImage() async {
final request = GenerateImage(_txtInput.value.text, 1,
size: ImageSize.size1024, responseFormat: Format.url);
model: DallE3(), size: ImageSize.size1024, responseFormat: Format.url);

///update user chat message
list.add(Message(isBot: false, message: getTextInput().value.text));
Expand Down
1 change: 1 addition & 0 deletions lib/chat_gpt_sdk.dart
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ export 'src/model/chat_complete/enum/function_call.dart';
export 'src/model/chat_complete/request/messages.dart';
export 'src/model/chat_complete/request/function_data.dart';
export 'src/model/chat_complete/request/response_format.dart';
export 'src/model/gen_image/enum/generate_image_model.dart';
10 changes: 0 additions & 10 deletions lib/src/model/complete_text/request/complete_text.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@ class CompleteText {
final String prompt;

/// ## completion models
/// - TextDavinci3Model();
/// - TextDavinci2Model();
/// - CodeDavinci2Model();
/// - TextCurie001Model();
/// - TextBabbage001Model();
/// - TextAda001Model();
/// - DavinciModel();
/// - CurieModel();
/// - BabbageModel();
/// - AdaModel();
/// - ModelFromValue(model: 'your-model-name');
final Model model;
final double temperature;
Expand Down
15 changes: 15 additions & 0 deletions lib/src/model/gen_image/enum/generate_image_model.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import 'package:chat_gpt_sdk/chat_gpt_sdk.dart';

sealed class GenerateImageModel {
final String model;

GenerateImageModel({required this.model});
}

class DallE2 extends GenerateImageModel {
DallE2() : super(model: kDallE2);
}

class DallE3 extends GenerateImageModel {
DallE3() : super(model: kDallE3);
}
5 changes: 5 additions & 0 deletions lib/src/model/gen_image/request/generate_image.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'package:chat_gpt_sdk/src/model/gen_image/enum/format.dart';
import 'package:chat_gpt_sdk/src/model/gen_image/enum/generate_image_model.dart';
import 'package:chat_gpt_sdk/src/model/gen_image/enum/image_size.dart';

class GenerateImage {
Expand All @@ -17,16 +18,20 @@ class GenerateImage {
///A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
final String user;

final GenerateImageModel model;

GenerateImage(
this.prompt,
this.n, {
required this.model,
this.size = ImageSize.size1024,
this.responseFormat = Format.url,
this.user = "",
}) : assert(1 <= n && n <= 10, 'n must be between 1 and 10.');

Map<String, dynamic> toJson() => Map.of({
"prompt": prompt,
"model": model.model,
"n": n,
"size": size?.size,
"response_format": responseFormat?.getName(),
Expand Down
4 changes: 4 additions & 0 deletions lib/src/utils/constants.dart
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ const kDavinci002Model = 'davinci-002';
const kTextMStable = 'text-moderation-stable';
const kTextMLast = 'text-moderation-latest';

///generate image model
const kDallE2 = 'dall-e-2';
const kDallE3 = 'dall-e-3';

///default header
Map<String, String> kHeader(
String? token,
Expand Down
38 changes: 28 additions & 10 deletions test/model/gen_image/request/generate_image_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ main() {
final target = GenerateImage(
'test',
2,
model: DallE3(),
size: ImageSize.size1024,
responseFormat: Format.url,
);
Expand All @@ -19,6 +20,7 @@ main() {

test('set value with enum', () {
final target1 = GenerateImage(
model: DallE3(),
'test',
1,
size: ImageSize.size256,
Expand All @@ -33,6 +35,7 @@ main() {
expect(target1.user, 'user');

final target2 = GenerateImage(
model: DallE3(),
'test',
2,
size: ImageSize.size512,
Expand All @@ -42,41 +45,56 @@ main() {
expect(target2.size?.size, '512x512');
expect(target2.responseFormat?.name, 'url');

final target3 =
GenerateImage('test', 1, size: ImageSize.size1024, user: 'user');
final target3 = GenerateImage('test', 1,
model: DallE3(), size: ImageSize.size1024, user: 'user');
expect(target3.size?.size, '1024x1024');
});

group('GeneratedImageSize', () {
test('normal', () {
expect(GenerateImage('test', 2).size?.size, '1024x1024');
expect(
GenerateImage('test', 2, size: ImageSize.size256).size?.size,
GenerateImage(
'test',
2,
model: DallE3(),
).size?.size,
'1024x1024');
expect(
GenerateImage('test', 2, model: DallE3(), size: ImageSize.size256)
.size
?.size,
'256x256',
);
expect(
GenerateImage('test', 2, size: ImageSize.size512).size?.size,
GenerateImage('test', 2, model: DallE3(), size: ImageSize.size512)
.size
?.size,
'512x512',
);
expect(
GenerateImage('test', 2, size: ImageSize.size1024).size?.size,
GenerateImage('test', 2, model: DallE3(), size: ImageSize.size1024)
.size
?.size,
'1024x1024',
);
});
});

test('n must be between 1 and 10', () {
expect(() => GenerateImage('test', 0), throwsA(isA<AssertionError>()));
expect(() => GenerateImage('test', 11), throwsA(isA<AssertionError>()));
expect(() => GenerateImage('test', model: DallE3(), 0),
throwsA(isA<AssertionError>()));
expect(() => GenerateImage('test', model: DallE3(), 11),
throwsA(isA<AssertionError>()));

expect(GenerateImage('test', 1).n, 1);
expect(GenerateImage('test', 10).n, 10);
expect(GenerateImage('test', model: DallE3(), 1).n, 1);
expect(GenerateImage('test', model: DallE3(), 10).n, 10);
});

group('toJson', () {
test('example', () {
final json = GenerateImage(
'test',
model: DallE3(),
1,
size: ImageSize.size256,
responseFormat: Format.b64Json,
Expand Down
22 changes: 13 additions & 9 deletions test/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void main() async {
expect(
() => ai.generateImage(GenerateImage(
'prompt',
model: DallE3(),
1,
size: ImageSize.size256,
responseFormat: Format.url,
Expand Down Expand Up @@ -631,7 +632,8 @@ void main() async {

group('chatGPT Image Generate With Prompt test case', () {
test('chatGPT Image Generate With Prompt success case test', () async {
final request = GenerateImage('snake red eating cat.', 2);
final request =
GenerateImage('snake red eating cat.', model: DallE3(), 2);

when(openAI.generateImage(request))
.thenAnswer((_) async => GenImgResponse());
Expand All @@ -645,7 +647,8 @@ void main() async {
test(
'chatGPT Image Generate With Prompt success case return value test',
() async {
final request = GenerateImage('snake red eating cat.', 2);
final request =
GenerateImage('snake red eating cat.', model: DallE3(), 2);

when(openAI.generateImage(request))
.thenAnswer((_) async => GenImgResponse(created: 1221120));
Expand All @@ -661,7 +664,8 @@ void main() async {
test(
'chatGPT Image Generate With Prompt error case with n is 0 test',
() async {
final request = GenerateImage('snake red eating cat.', 1);
final request =
GenerateImage('snake red eating cat.', model: DallE3(), 1);

when(openAI.generateImage(request))
.thenAnswer((_) async => GenImgResponse(created: 1221120));
Expand All @@ -673,7 +677,7 @@ void main() async {
expect(response?.created, 1221120);
expect(response?.data, null);
expect(
() => GenerateImage('snake red eating cat.', 0),
() => GenerateImage('snake red eating cat.', model: DallE3(), 0),
throwsA(isA<AssertionError>()),
);
},
Expand Down Expand Up @@ -760,7 +764,7 @@ void main() async {
test(
'chatGPT Image Generate with success case',
() async {
final request = GenerateImage('cat eating snake', 1);
final request = GenerateImage('cat eating snake', model: DallE3(), 1);

when(openAI.generateImage(request))
.thenAnswer((realInvocation) async => GenImgResponse(
Expand All @@ -779,7 +783,7 @@ void main() async {
test(
'chatGPT Image Generate with cancel gen success case',
() {
final request = GenerateImage('cat eating snake', 1);
final request = GenerateImage('cat eating snake', model: DallE3(), 1);

when(openAI.generateImage(request))
.thenAnswer((realInvocation) async => GenImgResponse(
Expand All @@ -796,7 +800,7 @@ void main() async {
test(
'chatGPT Image Generate with return two image success case',
() async {
final request = GenerateImage('cat eating snake', 2);
final request = GenerateImage('cat eating snake', model: DallE3(), 2);

when(openAI.generateImage(request)).thenAnswer(
(realInvocation) async => GenImgResponse(created: 912312, data: [
Expand All @@ -816,7 +820,7 @@ void main() async {
test(
'chatGPT Image Generate with error case',
() async {
final request = GenerateImage('', 1);
final request = GenerateImage('', model: DallE3(), 1);

when(openAI.generateImage(request))
.thenAnswer((realInvocation) async => GenImgResponse(
Expand All @@ -833,7 +837,7 @@ void main() async {
test(
'chatGPT Image Generate with openai auth error case',
() async {
final request = GenerateImage('snake', 1);
final request = GenerateImage('snake', model: DallE3(), 1);

when(openAI.generateImage(request))
.thenThrow(OpenAIAuthError(code: 404));
Expand Down

0 comments on commit f359c13

Please sign in to comment.