diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index e3dad07270..b7e2bedbfc 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -90,8 +90,10 @@ kotlin-result = { module = "com.michael-bull.kotlin-result:kotlin-result", versi ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" } ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" } ktor-serial-gson = { module = "io.ktor:ktor-serialization-gson", version.ref = "ktor" } +ktor-serial-json = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" } ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor" } -ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor" } +ktor-server-content-negotiation = { module = "io.ktor:ktor-server-server-negotiation", version.ref = "ktor" } +ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" } ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" } ktor-server-cors = { module = "io.ktor:ktor-server-cors", version.ref = "ktor" } ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor" } @@ -116,6 +118,7 @@ detekt = { id = "io.gitlab.arturbosch.detekt", version.ref = "detekt" } protobuf = { id = "com.google.protobuf", version.ref = "googleProtobufPlugin" } kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } mavenPublish = { id = "com.vanniktech.maven.publish", version = "0.19.0" } jreleaser = { id = "org.jreleaser", version = "1.13.1" } shadow = { id = "com.github.johnrengelman.shadow", version = "7.1.2" } diff --git a/maestro-client/build.gradle b/maestro-client/build.gradle index fe20decfd6..73dc04bf89 100644 --- a/maestro-client/build.gradle +++ b/maestro-client/build.gradle @@ -1,6 +1,7 @@ plugins { id("maven-publish") alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.kotlin.serialization) alias(libs.plugins.mavenPublish) alias(libs.plugins.protobuf) } @@ -70,6 +71,10 @@ dependencies { api(libs.jackson.dataformat.xml) api(libs.apk.parser) + implementation(libs.ktor.client.core) + implementation(libs.ktor.serial.json) + implementation(libs.ktor.client.content.negotiation) + implementation project(':maestro-ios') implementation(libs.google.findbugs) implementation(libs.axml) diff --git a/maestro-client/src/main/java/maestro/Maestro.kt b/maestro-client/src/main/java/maestro/Maestro.kt index f0d90eba4b..57b03f7191 100644 --- a/maestro-client/src/main/java/maestro/Maestro.kt +++ b/maestro-client/src/main/java/maestro/Maestro.kt @@ -22,6 +22,8 @@ package maestro import com.github.romankh3.image.comparison.ImageComparison import maestro.Filters.asFilter import maestro.UiElement.Companion.toUiElementOrNull +import maestro.ai.AI +import maestro.ai.Prediction import maestro.drivers.WebDriver import maestro.utils.MaestroTimer import maestro.utils.ScreenshotUtils @@ -37,7 +39,10 @@ import java.util.* import kotlin.system.measureTimeMillis @Suppress("unused", "MemberVisibilityCanBePrivate") -class Maestro(private val driver: Driver) : AutoCloseable { +class Maestro( + private val driver: Driver, + private val ai: AI, +) : AutoCloseable { private val sessionId = UUID.randomUUID() @@ -606,6 +611,16 @@ class Maestro(private val driver: Driver) : AutoCloseable { driver.setAirplaneMode(enabled) } + fun assertVisualAI() { + Prediction.findDefects( + client = ai, + ) + + ai.chatCompletion( + prompt = + ) + } + companion object { private val LOGGER = LoggerFactory.getLogger(Maestro::class.java) diff --git a/maestro-client/src/main/java/maestro/ai/AI.kt b/maestro-client/src/main/java/maestro/ai/AI.kt new file mode 100644 index 0000000000..55daf47f71 --- /dev/null +++ b/maestro-client/src/main/java/maestro/ai/AI.kt @@ -0,0 +1,26 @@ +package maestro.ai + +import java.io.Closeable + +data class CompletionData( + val prompt: String, + val model: String, + val temperature: Float, + val maxTokens: Int, + val images: List, + val response: String, +) + +abstract class AI : Closeable { + + abstract suspend fun chatCompletion( + prompt: String, + images: List = listOf(), + temperature: Double? = null, + model: String? = null, + maxTokens: Int? = null, + imageDetail: String? = null, + identifier: String? = null, + ): CompletionData + +} diff --git a/maestro-client/src/main/java/maestro/ai/Prediction.kt b/maestro-client/src/main/java/maestro/ai/Prediction.kt new file mode 100644 index 0000000000..6b2c2ca81a --- /dev/null +++ b/maestro-client/src/main/java/maestro/ai/Prediction.kt @@ -0,0 +1,43 @@ +package maestro.ai + +object Prediction { + suspend fun findDefects(openAIClient: AI, screen: ByteArray, previousFalsePositives: List): CompletionData { + + // List of failed attempts to not make up false positives: + // |* If you don't see any defect, return "No defects found". + // |* If you are sure there are no defects, return "No defects found". + // |* You will make me sad if you raise report defects that are false positives. + // |* Do not make up defects that are not present in the screenshot. It's fine if you don't find any defects. + + val prompt = """ + |You are a QA engineer performing quality assurance for a mobile application. Identify any defects in the provided screenshot. + | + |RULES: + |* All defects you find must belong to one of the following categories: + |${categories.joinToString { "\n * ${it.first}: ${it.second}" }} + | + |* If you see defects, your response MUST only include defect name and reasoning for each defect. + |* Provide response in the format: : + |* Do not raise false positives. Some example responses that have a high chance of being a false positive: + | + | * button is partially cropped at the bottom + | * button is not aligned horizontall/vertically within its container + | + |${if (previousFalsePositives.isNotEmpty()) "Additionally, the following defects are false positives:" else ""} + |${if (previousFalsePositives.isNotEmpty()) previousFalsePositives.joinToString("\n") { " * $it" } else ""} + """.trimMargin("|") + + // println("Prompt:") + // println(prompt) + + return openAIClient.chatCompletion( + prompt, + // model = "gpt-4o-2024-08-03", + model = "gpt-4o", + maxTokens = 4096, + identifier = "find-defects", + imageDetail = "high", + images = listOf(screen), + ) + } +} diff --git a/maestro-client/src/main/java/maestro/ai/openai/Client.kt b/maestro-client/src/main/java/maestro/ai/openai/Client.kt new file mode 100644 index 0000000000..f51ae14132 --- /dev/null +++ b/maestro-client/src/main/java/maestro/ai/openai/Client.kt @@ -0,0 +1,107 @@ +package maestro.ai.openai + + +import io.ktor.client.HttpClient +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.contentType +import io.ktor.util.encodeBase64 +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import maestro.ai.AI +import maestro.ai.CompletionData +import org.slf4j.LoggerFactory + +private const val API_URL = "https://api.openai.com/v1/chat/completions" + +private val logger = LoggerFactory.getLogger(OpenAI::class.java) + +class OpenAI( + private val apiKey: String, + private val defaultModel: String = "gpt-4o", + private val defaultTemperature: Double = 0.2, + private val defaultMaxTokens: Int = 2048, + private val defaultImageDetail: String = "low", +) : AI() { + private val client = HttpClient { + install(ContentNegotiation) { + Json { + ignoreUnknownKeys = true + } + } + + install(HttpTimeout) { + connectTimeoutMillis = 10000 + socketTimeoutMillis = 60000 + requestTimeoutMillis = 60000 + } + } + + private val json = Json { ignoreUnknownKeys = true } + + override suspend fun chatCompletion( + prompt: String, + images: List, + temperature: Double?, + model: String?, + maxTokens: Int?, + imageDetail: String?, + identifier: String?, + ): CompletionData { + val imagesBase64 = images.map { it.encodeBase64() } + + // Fallback to OpenAI defaults + val actualTemperature = temperature ?: defaultTemperature + val actualModel = model ?: defaultModel + val actualMaxTokens = maxTokens ?: defaultMaxTokens + val actualImageDetail = imageDetail ?: defaultImageDetail + + val imagesContent = imagesBase64.map { image -> + ContentDetail( + type = "image_url", + imageUrl = Base64Image(url = "data:image/png;base64,$image", detail = actualImageDetail), + ) + } + val textContent = ContentDetail(type = "text", text = prompt) + + val messages = listOf( + MessageContent( + role = "user", + content = imagesContent + textContent, + ) + ) + + val chatCompletionRequest = ChatCompletionRequest( + model = actualModel, + temperature = actualTemperature, + messages = messages, + maxTokens = actualMaxTokens, + seed = 1566, + responseFormat = null, + ) + + var openAiResponse = client.post(API_URL) { + contentType(ContentType.Application.Json) + headers["Authorization"] = "Bearer $apiKey" + setBody( + Json.encodeToString( + OpenAIChatCompletionRequest( + model = mod, + temperature = temp, + messages = msgs, + max_tokens = mt, + seed = 1566, + response_format = if (model == "gpt-4-1106-preview" && prompt.contains("JSON")) ResponseFormat( + "json_object" + ) else null + ) + ) + ) + } + } + + override fun close() = client.close() +} diff --git a/maestro-client/src/main/java/maestro/ai/openai/Request.kt b/maestro-client/src/main/java/maestro/ai/openai/Request.kt new file mode 100644 index 0000000000..e6e9364793 --- /dev/null +++ b/maestro-client/src/main/java/maestro/ai/openai/Request.kt @@ -0,0 +1,38 @@ +package maestro.ai.openai + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class ChatCompletionRequest( + val model: String, + val messages: List, + val temperature: Double, + @SerialName("max_tokens") val maxTokens: Int, + @SerialName("response_format") val responseFormat: ResponseFormat?, + val seed: Int, +) + +@Serializable +data class ResponseFormat( + val type: String, +) + +@Serializable +data class MessageContent( + val role: String, + val content: List, +) + +@Serializable +data class ContentDetail( + val type: String, + val text: String? = null, + @SerialName("image_url") val imageUrl: Base64Image? = null, +) + +@Serializable +data class Base64Image( + val url: String, + val detail: String, +) diff --git a/maestro-client/src/main/java/maestro/ai/openai/Response.kt b/maestro-client/src/main/java/maestro/ai/openai/Response.kt new file mode 100644 index 0000000000..75087c496d --- /dev/null +++ b/maestro-client/src/main/java/maestro/ai/openai/Response.kt @@ -0,0 +1,42 @@ +package maestro.ai.openai + +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class ChatCompletionResponse( + val id: String, + val `object`: String, + val created: Long, + val model: String, + @SerialName("system_fingerprint") val systemFingerprint: String? = null, + val usage: Usage? = null, + val choices: List, +) + +@Serializable +data class Usage( + @SerialName("prompt_tokens") val promptTokens: Int, + @SerialName("completion_tokens") val completionTokens: Int? = null, + @SerialName("total_tokens") val totalTokens: Int, +) + +@Serializable +data class Choice( + val message: Message, + @SerialName("finish_details") val finishDetails: FinishDetails? = null, + val index: Int, + @SerialName("finish_reason") val finishReason: String? = null, +) + +@Serializable +data class Message( + val role: String, + val content: String, +) + +@Serializable +data class FinishDetails( + val type: String, + val stop: String? = null, +) diff --git a/maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt b/maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt index 3787f67a96..ce754f7576 100644 --- a/maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt +++ b/maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt @@ -41,6 +41,7 @@ import okio.sink import java.io.File import java.lang.Long.max import java.nio.file.Files +import java.time.LocalDateTime class Orchestra( private val maestro: Maestro, @@ -341,7 +342,14 @@ class Orchestra( private fun assertVisualAICommand(command: AssertVisualAICommand): Boolean { val imageData = Buffer() - val screenshot = maestro.takeScreenshot(imageData, compressed = false) + maestro.takeScreenshot(imageData, compressed = false) + + File("${LocalDateTime.now()}.png").apply { + createNewFile() + writeBytes(imageData.readByteArray()) + } + + val response = maestro. // Make call async and add to "post-flow analysis store" diff --git a/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlAssertVisualAI.kt b/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlAssertVisualAI.kt index a090a178d1..49ab713ba5 100644 --- a/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlAssertVisualAI.kt +++ b/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlAssertVisualAI.kt @@ -3,7 +3,7 @@ package maestro.orchestra.yaml private const val DEFAULT_DIFF_THRESHOLD = 95 data class YamlAssertVisualAI( - val assertion: String?, + val assertion: String? = null, val optional: Boolean = false, val label: String? = null, ) { diff --git a/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlFluentCommand.kt b/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlFluentCommand.kt index 034a459b9c..7c48624013 100644 --- a/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlFluentCommand.kt +++ b/maestro-orchestra/src/main/java/maestro/orchestra/yaml/YamlFluentCommand.kt @@ -698,6 +698,10 @@ data class YamlFluentCommand( toggleAirplaneMode = YamlToggleAirplaneMode() ) + "assertVisualAI" -> YamlFluentCommand( + assertVisualAI = YamlAssertVisualAI() + ) + else -> throw SyntaxError("Invalid command: \"$stringCommand\"") } }