Skip to content

Commit

Permalink
finish adding the 2 new commands
Browse files Browse the repository at this point in the history
  • Loading branch information
bartekpacia committed Aug 27, 2024
1 parent c0b69c5 commit 0aed452
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 63 deletions.
7 changes: 0 additions & 7 deletions maestro-ai/src/main/java/maestro/ai/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ abstract class AI(
// * OpenAI: https://platform.openai.com/docs/guides/structured-outputs
// * Gemini: https://ai.google.dev/gemini-api/docs/json-mode

val checkAssertion: String = run {
val resourceStream = this::class.java.getResourceAsStream("/checkAssertion_schema.json")
?: throw IllegalStateException("Could not find checkAssertion_schema.json in resources")

resourceStream.bufferedReader().use { it.readText() }
}

val askForDefectsSchema: String = run {
val resourceStream = this::class.java.getResourceAsStream("/askForDefects_schema.json")
?: throw IllegalStateException("Could not find askForDefects_schema.json in resources")
Expand Down
16 changes: 14 additions & 2 deletions maestro-ai/src/main/java/maestro/ai/DemoApp.kt
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,25 @@ class DemoApp : CliktCommand() {
val bytes = testCase.screenshot.readBytes()

val job = async {
val defects = Prediction.findDefects(
val defects = if (testCase.prompt == null) Prediction.findDefects(
aiClient = aiClient,
screen = bytes,
previousFalsePositives = listOf(),
printPrompt = showPrompts,
printRawResponse = showRawResponse,
)
) else {
val result = Prediction.performAssertion(
aiClient = aiClient,
screen = bytes,
assertion = testCase.prompt,
printPrompt = showPrompts,
printRawResponse = showRawResponse,
)

if (result == null) emptyList()
else listOf(result)
}

verify(testCase, defects)
}

Expand Down
45 changes: 28 additions & 17 deletions maestro-ai/src/main/java/maestro/ai/Prediction.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,10 @@ data class Defect(
)

@Serializable
private data class FindDefectsResponse(
private data class ModelResponse(
val defects: List<Defect>,
)

@Serializable
data class PerformAssertionResult(
val passed: Boolean,
val reasoning: String,
)

object Prediction {
private val json = Json { ignoreUnknownKeys = true }

Expand All @@ -30,6 +24,8 @@ object Prediction {
"layout" to "Some UI elements are overlapping or are cropped",
)

private val allDefectCategories = defectCategories + listOf("assertion" to "The assertion is not true")

suspend fun findDefects(
aiClient: AI,
screen: ByteArray,
Expand Down Expand Up @@ -126,7 +122,7 @@ object Prediction {
println("--- RAW RESPONSE END ---")
}

val defects = json.decodeFromString<FindDefectsResponse>(aiResponse.response)
val defects = json.decodeFromString<ModelResponse>(aiResponse.response)
return defects.defects
}

Expand All @@ -136,7 +132,7 @@ object Prediction {
assertion: String,
printPrompt: Boolean = false,
printRawResponse: Boolean = false,
): PerformAssertionResult {
): Defect? {
val prompt = buildString {

appendLine(
Expand All @@ -150,22 +146,37 @@ object Prediction {
""".trimMargin("|")
)

append(
"""
|
|RULES:
|* Provide response as a valid JSON, with structure described below.
|* If the assertion is false, the list in the JSON output MUST be empty.
|* If assertion is false:
| * Your response MUST only include a single defect with category "assertion".
| * Provide detailed reasoning to explain why you think the assertion is false.
""".trimMargin("|")
)

// Claude doesn't have a JSON mode as of 21-08-2024
// https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency
// We could do "if (aiClient is Claude)", but actually, this also helps with gpt-4o sometimes
// generating never-ending stream of output.
// generatig never-ending stream of output.
append(
"""
|
|* You must provide result as a valid JSON object, matching this structure:
|
| {
| "result": {
| "passed": "<boolean>",
| "reasoning": "<string>"
| },
| "defect": [
| {
| "category": "assertion",
| "reasoning": "<reasoning, string>"
| },
| ]
| }
|
|The "defects" array MUST contain at most a single JSON object.
|DO NOT output any other information in the JSON object.
""".trimMargin("|")
)
Expand All @@ -184,7 +195,7 @@ object Prediction {
identifier = "perform-assertion",
imageDetail = "high",
images = listOf(screen),
jsonSchema = if (aiClient is OpenAI) json.parseToJsonElement(AI.checkAssertion).jsonObject else null,
jsonSchema = if (aiClient is OpenAI) json.parseToJsonElement(AI.askForDefectsSchema).jsonObject else null,
)

if (printRawResponse) {
Expand All @@ -193,7 +204,7 @@ object Prediction {
println("--- RAW RESPONSE END ---")
}

val result = json.decodeFromString<PerformAssertionResult>(aiResponse.response)
return result
val response = json.decodeFromString<ModelResponse>(aiResponse.response)
return response.defects.firstOrNull()
}
}
25 changes: 0 additions & 25 deletions maestro-ai/src/main/resources/checkAssertion_schema.json

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ class AnsiResultView(
CommandStatus.COMPLETED -> ""
CommandStatus.FAILED -> ""
CommandStatus.RUNNING -> ""
CommandStatus.PENDING -> "\uD83D\uDD32"
CommandStatus.SKIPPED -> "️️"
CommandStatus.PENDING -> "\uD83D\uDD32 " // 🔲
CommandStatus.SKIPPED -> ""
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ data class AssertWithAICommand(
override fun description(): String {
if (label != null) return label

return "Assert no defects with AI: $assertion"
return "Assert with AI: $assertion"
}

override fun evaluateScripts(jsEngine: JsEngine): Command {
Expand Down
15 changes: 6 additions & 9 deletions maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ class Orchestra(

val defects = Prediction.findDefects(
aiClient = ai,
assertion = null,
screen = imageData.copy().readByteArray(),
previousFalsePositives = listOf(), // TODO(bartekpacia): take it from WorkspaceConfig (or MaestroConfig?)
)
Expand All @@ -363,7 +362,7 @@ class Orchestra(

val word = if (defects.size == 1) "defect" else "defects"
throw MaestroException.AssertionFailure(
"Ffound ${defects.size} possible $word. See the report after the test completes to learn more.",
"Found ${defects.size} possible $word. See the report after the test completes to learn more.",
maestro.viewHierarchy().root,
)
}
Expand All @@ -381,21 +380,19 @@ class Orchestra(
val imageData = Buffer()
maestro.takeScreenshot(imageData, compressed = false)

val defects = Prediction.findDefects(
val defect = Prediction.performAssertion(
aiClient = ai,
assertion = command.assertion,
screen = imageData.copy().readByteArray(),
previousFalsePositives = listOf(), // TODO(bartekpacia): take it from WorkspaceConfig (or MaestroConfig?)
assertion = command.assertion,
)

if (defects.isNotEmpty()) {
onCommandGeneratedOutput(command, defects, imageData)
if (defect != null) {
onCommandGeneratedOutput(command, listOf(defect), imageData)

if (command.optional) throw CommandSkipped

val word = if (defects.size == 1) "defect" else "defects"
throw MaestroException.AssertionFailure(
"Visual AI found ${defects.size} possible $word. See the report to learn more.",
"Assertion failed. See the report to learn more.",
maestro.viewHierarchy().root,
)
}
Expand Down

0 comments on commit 0aed452

Please sign in to comment.