diff --git a/maestro-ai/README.md b/maestro-ai/README.md new file mode 100644 index 0000000000..81118c7fd1 --- /dev/null +++ b/maestro-ai/README.md @@ -0,0 +1,19 @@ +# maestro-ai + +This project implements AI support for use in Maestro. + +It's both a library and a demo-app executable. + +### Demo app + +Build it: + +```console +./gradlew :maestro-ai:installDist +``` + +then learn how to use it: + +```console +./maestro-ai/build/install/maestro-ai-demo/bin/maestro-ai-demo --help +``` diff --git a/maestro-ai/src/main/java/maestro/ai/DemoApp.kt b/maestro-ai/src/main/java/maestro/ai/DemoApp.kt index 25f42db65f..0e00781c84 100644 --- a/maestro-ai/src/main/java/maestro/ai/DemoApp.kt +++ b/maestro-ai/src/main/java/maestro/ai/DemoApp.kt @@ -8,6 +8,9 @@ import com.github.ajalt.clikt.parameters.options.flag import com.github.ajalt.clikt.parameters.options.option import com.github.ajalt.clikt.parameters.types.float import com.github.ajalt.clikt.parameters.types.path +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import maestro.ai.antrophic.Claude @@ -52,9 +55,7 @@ fun main(args: Array) = DemoApp().main(args) * - bar_good_2.png */ class DemoApp : CliktCommand() { - private val inputFiles: List by argument(help = "screenshots to use") - .path(mustExist = true) - .multiple() + private val inputFiles: List by argument(help = "screenshots to use").path(mustExist = true).multiple() private val model: String by option(help = "LLM to use").default("gpt-4o-2024-08-06") @@ -64,19 +65,18 @@ class DemoApp : CliktCommand() { private val temperature: Float by option(help = "Temperature for LLM").float().default(0.2f) + private val parallel: Boolean by option(help = "Run in parallel. May get rate limited").flag() + override fun run() = runBlocking { val apiKey = System.getenv("MAESTRO_CLI_AI_KEY") require(apiKey != null) { "OpenAI API key is not provided" } - val testCases = inputFiles - .map { it.toFile() } - .map { file -> + val testCases = inputFiles.map { it.toFile() }.map { file -> require(!file.isDirectory) { "Provided file is a directory, not a file" } require(file.exists()) { "Provided file does not exist" } require(file.extension == "png") { "Provided file is not a PNG file" } file - } - .map { file -> + }.map { file -> val filename = file.nameWithoutExtension val parts = filename.split("_") require(parts.size == 3) { "Screenshot name is invalid: ${file.name}" } @@ -86,15 +86,7 @@ class DemoApp : CliktCommand() { val promptFile = "${file.parent}/${appName}_${parts[1]}_$index.txt" println("Prompt file: $promptFile") - val prompt = File(promptFile) - .run { - if (exists()) { - readText() - } else { - println("There is no prompt for ${file.path}") - null - } - } + val prompt = File(promptFile).run { if (exists()) readText() else null } TestCase( screenshot = file, @@ -121,12 +113,10 @@ class DemoApp : CliktCommand() { else -> throw IllegalArgumentException("Unknown model: $model") } - // println("---\nRunning ${testCases.size} test cases\n---") - testCases.forEach { testCase -> val bytes = testCase.screenshot.readBytes() - launch { + val job = async { val defects = Prediction.findDefects( aiClient = aiClient, screen = bytes, @@ -138,7 +128,11 @@ class DemoApp : CliktCommand() { verify(testCase, defects) } + + if (parallel) job.await() } + + println("Exited, bye!") } private fun verify(testCase: TestCase, defects: List) { @@ -147,9 +141,9 @@ class DemoApp : CliktCommand() { if (defects.isNotEmpty()) { println( """ - PASS ${testCase.screenshot.name}: ${defects.size} defects found (as expected) - ${defects.joinToString("\n") { "\t* ${it.category}: ${it.reasoning}" }} - """.trimIndent() + PASS ${testCase.screenshot.name}: ${defects.size} defects found (as expected) + ${defects.joinToString("\n") { "\t* ${it.category}: ${it.reasoning}" }} + """.trimIndent() ) } else { println("FAIL ${testCase.screenshot.name} false-negative: No defects found but some were expected") @@ -160,15 +154,15 @@ class DemoApp : CliktCommand() { if (defects.isEmpty()) { println( """ - PASS ${testCase.screenshot.name}: No defects found (as expected) - """.trimIndent() + PASS ${testCase.screenshot.name}: No defects found (as expected) + """.trimIndent() ) } else { println( """ - FAIL ${testCase.screenshot.name} false-positive: ${defects.size} defects found but none were expected - ${defects.joinToString("\n") { "\t* ${it.category}: ${it.reasoning}" }} - """.trimIndent() + FAIL ${testCase.screenshot.name} false-positive: ${defects.size} defects found but none were expected + ${defects.joinToString("\n") { "\t* ${it.category}: ${it.reasoning}" }} + """.trimIndent() ) } } diff --git a/maestro-ai/src/main/java/maestro/ai/Prediction.kt b/maestro-ai/src/main/java/maestro/ai/Prediction.kt index 2eba4b1fa3..e84b3326e2 100644 --- a/maestro-ai/src/main/java/maestro/ai/Prediction.kt +++ b/maestro-ai/src/main/java/maestro/ai/Prediction.kt @@ -49,23 +49,33 @@ object Prediction { """.trimIndent() ) - if (assertion != null) { - appendLine("Additionally, you must ensure the following assertion is true: $assertion") - } - append( """ + | |RULES: |* All defects you find must belong to one of the following categories: |${categories.joinToString(separator = "\n") { " * ${it.first}: ${it.second}" }} |* If you see defects, your response MUST only include defect name and detailed reasoning for each defect. |* Provide response as a list of JSON objects, each representing : |* Do not raise false positives. Some example responses that have a high chance of being a false positive: - |* button is partially cropped at the bottom - |* button is not aligned horizontally/vertically within its container + | * button is partially cropped at the bottom + | * button is not aligned horizontally/vertically within its container """.trimMargin("|") ) + if (assertion != null) { + append( + """ + | + | + |Additionally, if the following assertion isn't true, consider it as a defect with category "assertion": + | + | "${assertion.removeSuffix("\n")}" + | + |""".trimMargin("|") + ) + } + // Claude doesn't have a JSON mode as of 21-08-2024 // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency if (aiClient is Claude) { @@ -98,6 +108,8 @@ object Prediction { appendLine(" * $falsePositive") } } + + appendLine("Be brief.") } if (printPrompt) { diff --git a/maestro-ai/src/main/java/maestro/ai/openai/Client.kt b/maestro-ai/src/main/java/maestro/ai/openai/Client.kt index 1c00206326..a5e3cceca0 100644 --- a/maestro-ai/src/main/java/maestro/ai/openai/Client.kt +++ b/maestro-ai/src/main/java/maestro/ai/openai/Client.kt @@ -10,6 +10,7 @@ import io.ktor.http.ContentType import io.ktor.http.contentType import io.ktor.http.isSuccess import io.ktor.util.encodeBase64 +import kotlinx.serialization.SerializationException import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject @@ -103,8 +104,15 @@ class OpenAI( throw Exception("Failed to complete request to OpenAI: ${httpResponse.status}, $body") } + print(body) + json.decodeFromString(body) - } catch (e: Exception) { + } catch (e: SerializationException) { + logger.error("Failed to parse response from OpenAI", e) + logger.error("Response body: ${e.message}") + throw e + } + catch (e: Exception) { logger.error("Failed to complete request to OpenAI", e) throw e }