Skip to content

Commit

Permalink
more progress
Browse files Browse the repository at this point in the history
  • Loading branch information
bartekpacia committed Aug 9, 2024
1 parent 20c8fc7 commit 89ca7ca
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 4 deletions.
5 changes: 4 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ kotlin-result = { module = "com.michael-bull.kotlin-result:kotlin-result", versi
ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" }
ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-serial-gson = { module = "io.ktor:ktor-serialization-gson", version.ref = "ktor" }
ktor-serial-json = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" }
ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor" }
ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor" }
ktor-server-content-negotiation = { module = "io.ktor:ktor-server-server-negotiation", version.ref = "ktor" }
ktor-client-content-negotiation = { module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" }
ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor" }
ktor-server-cors = { module = "io.ktor:ktor-server-cors", version.ref = "ktor" }
ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor" }
Expand All @@ -116,6 +118,7 @@ detekt = { id = "io.gitlab.arturbosch.detekt", version.ref = "detekt" }
protobuf = { id = "com.google.protobuf", version.ref = "googleProtobufPlugin" }
kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" }
kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
mavenPublish = { id = "com.vanniktech.maven.publish", version = "0.19.0" }
jreleaser = { id = "org.jreleaser", version = "1.13.1" }
shadow = { id = "com.github.johnrengelman.shadow", version = "7.1.2" }
5 changes: 5 additions & 0 deletions maestro-client/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
plugins {
id("maven-publish")
alias(libs.plugins.kotlin.jvm)
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.mavenPublish)
alias(libs.plugins.protobuf)
}
Expand Down Expand Up @@ -70,6 +71,10 @@ dependencies {
api(libs.jackson.dataformat.xml)
api(libs.apk.parser)

implementation(libs.ktor.client.core)
implementation(libs.ktor.serial.json)
implementation(libs.ktor.client.content.negotiation)

implementation project(':maestro-ios')
implementation(libs.google.findbugs)
implementation(libs.axml)
Expand Down
17 changes: 16 additions & 1 deletion maestro-client/src/main/java/maestro/Maestro.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package maestro
import com.github.romankh3.image.comparison.ImageComparison
import maestro.Filters.asFilter
import maestro.UiElement.Companion.toUiElementOrNull
import maestro.ai.AI
import maestro.ai.Prediction
import maestro.drivers.WebDriver
import maestro.utils.MaestroTimer
import maestro.utils.ScreenshotUtils
Expand All @@ -37,7 +39,10 @@ import java.util.*
import kotlin.system.measureTimeMillis

@Suppress("unused", "MemberVisibilityCanBePrivate")
class Maestro(private val driver: Driver) : AutoCloseable {
class Maestro(
private val driver: Driver,
private val ai: AI,
) : AutoCloseable {

private val sessionId = UUID.randomUUID()

Expand Down Expand Up @@ -606,6 +611,16 @@ class Maestro(private val driver: Driver) : AutoCloseable {
driver.setAirplaneMode(enabled)
}

fun assertVisualAI() {
Prediction.findDefects(
client = ai,
)

ai.chatCompletion(
prompt =
)
}

companion object {

private val LOGGER = LoggerFactory.getLogger(Maestro::class.java)
Expand Down
26 changes: 26 additions & 0 deletions maestro-client/src/main/java/maestro/ai/AI.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package maestro.ai

import java.io.Closeable

data class CompletionData(
val prompt: String,
val model: String,
val temperature: Float,
val maxTokens: Int,
val images: List<String>,
val response: String,
)

abstract class AI : Closeable {

abstract suspend fun chatCompletion(
prompt: String,
images: List<ByteArray> = listOf(),
temperature: Double? = null,
model: String? = null,
maxTokens: Int? = null,
imageDetail: String? = null,
identifier: String? = null,
): CompletionData

}
43 changes: 43 additions & 0 deletions maestro-client/src/main/java/maestro/ai/Prediction.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package maestro.ai

object Prediction {
suspend fun findDefects(openAIClient: AI, screen: ByteArray, previousFalsePositives: List<String>): CompletionData {

// List of failed attempts to not make up false positives:
// |* If you don't see any defect, return "No defects found".
// |* If you are sure there are no defects, return "No defects found".
// |* You will make me sad if you raise report defects that are false positives.
// |* Do not make up defects that are not present in the screenshot. It's fine if you don't find any defects.

val prompt = """
|You are a QA engineer performing quality assurance for a mobile application. Identify any defects in the provided screenshot.
|
|RULES:
|* All defects you find must belong to one of the following categories:
|${categories.joinToString { "\n * ${it.first}: ${it.second}" }}
|
|* If you see defects, your response MUST only include defect name and reasoning for each defect.
|* Provide response in the format: <defect name>:<reasoning>
|* Do not raise false positives. Some example responses that have a high chance of being a false positive:
|
| * button is partially cropped at the bottom
| * button is not aligned horizontall/vertically within its container
|
|${if (previousFalsePositives.isNotEmpty()) "Additionally, the following defects are false positives:" else ""}
|${if (previousFalsePositives.isNotEmpty()) previousFalsePositives.joinToString("\n") { " * $it" } else ""}
""".trimMargin("|")

// println("Prompt:")
// println(prompt)

return openAIClient.chatCompletion(
prompt,
// model = "gpt-4o-2024-08-03",
model = "gpt-4o",
maxTokens = 4096,
identifier = "find-defects",
imageDetail = "high",
images = listOf(screen),
)
}
}
107 changes: 107 additions & 0 deletions maestro-client/src/main/java/maestro/ai/openai/Client.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package maestro.ai.openai


import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.util.encodeBase64
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import maestro.ai.AI
import maestro.ai.CompletionData
import org.slf4j.LoggerFactory

private const val API_URL = "https://api.openai.com/v1/chat/completions"

private val logger = LoggerFactory.getLogger(OpenAI::class.java)

class OpenAI(
private val apiKey: String,
private val defaultModel: String = "gpt-4o",
private val defaultTemperature: Double = 0.2,
private val defaultMaxTokens: Int = 2048,
private val defaultImageDetail: String = "low",
) : AI() {
private val client = HttpClient {
install(ContentNegotiation) {
Json {
ignoreUnknownKeys = true
}
}

install(HttpTimeout) {
connectTimeoutMillis = 10000
socketTimeoutMillis = 60000
requestTimeoutMillis = 60000
}
}

private val json = Json { ignoreUnknownKeys = true }

override suspend fun chatCompletion(
prompt: String,
images: List<ByteArray>,
temperature: Double?,
model: String?,
maxTokens: Int?,
imageDetail: String?,
identifier: String?,
): CompletionData {
val imagesBase64 = images.map { it.encodeBase64() }

// Fallback to OpenAI defaults
val actualTemperature = temperature ?: defaultTemperature
val actualModel = model ?: defaultModel
val actualMaxTokens = maxTokens ?: defaultMaxTokens
val actualImageDetail = imageDetail ?: defaultImageDetail

val imagesContent = imagesBase64.map { image ->
ContentDetail(
type = "image_url",
imageUrl = Base64Image(url = "data:image/png;base64,$image", detail = actualImageDetail),
)
}
val textContent = ContentDetail(type = "text", text = prompt)

val messages = listOf(
MessageContent(
role = "user",
content = imagesContent + textContent,
)
)

val chatCompletionRequest = ChatCompletionRequest(
model = actualModel,
temperature = actualTemperature,
messages = messages,
maxTokens = actualMaxTokens,
seed = 1566,
responseFormat = null,
)

var openAiResponse = client.post(API_URL) {
contentType(ContentType.Application.Json)
headers["Authorization"] = "Bearer $apiKey"
setBody(
Json.encodeToString(
OpenAIChatCompletionRequest(
model = mod,
temperature = temp,
messages = msgs,
max_tokens = mt,
seed = 1566,
response_format = if (model == "gpt-4-1106-preview" && prompt.contains("JSON")) ResponseFormat(
"json_object"
) else null
)
)
)
}
}

override fun close() = client.close()
}
38 changes: 38 additions & 0 deletions maestro-client/src/main/java/maestro/ai/openai/Request.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package maestro.ai.openai

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class ChatCompletionRequest(
val model: String,
val messages: List<MessageContent>,
val temperature: Double,
@SerialName("max_tokens") val maxTokens: Int,
@SerialName("response_format") val responseFormat: ResponseFormat?,
val seed: Int,
)

@Serializable
data class ResponseFormat(
val type: String,
)

@Serializable
data class MessageContent(
val role: String,
val content: List<ContentDetail>,
)

@Serializable
data class ContentDetail(
val type: String,
val text: String? = null,
@SerialName("image_url") val imageUrl: Base64Image? = null,
)

@Serializable
data class Base64Image(
val url: String,
val detail: String,
)
42 changes: 42 additions & 0 deletions maestro-client/src/main/java/maestro/ai/openai/Response.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package maestro.ai.openai

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class ChatCompletionResponse(
val id: String,
val `object`: String,
val created: Long,
val model: String,
@SerialName("system_fingerprint") val systemFingerprint: String? = null,
val usage: Usage? = null,
val choices: List<Choice>,
)

@Serializable
data class Usage(
@SerialName("prompt_tokens") val promptTokens: Int,
@SerialName("completion_tokens") val completionTokens: Int? = null,
@SerialName("total_tokens") val totalTokens: Int,
)

@Serializable
data class Choice(
val message: Message,
@SerialName("finish_details") val finishDetails: FinishDetails? = null,
val index: Int,
@SerialName("finish_reason") val finishReason: String? = null,
)

@Serializable
data class Message(
val role: String,
val content: String,
)

@Serializable
data class FinishDetails(
val type: String,
val stop: String? = null,
)
10 changes: 9 additions & 1 deletion maestro-orchestra/src/main/java/maestro/orchestra/Orchestra.kt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import okio.sink
import java.io.File
import java.lang.Long.max
import java.nio.file.Files
import java.time.LocalDateTime

class Orchestra(
private val maestro: Maestro,
Expand Down Expand Up @@ -341,7 +342,14 @@ class Orchestra(

private fun assertVisualAICommand(command: AssertVisualAICommand): Boolean {
val imageData = Buffer()
val screenshot = maestro.takeScreenshot(imageData, compressed = false)
maestro.takeScreenshot(imageData, compressed = false)

File("${LocalDateTime.now()}.png").apply {
createNewFile()
writeBytes(imageData.readByteArray())
}

val response = maestro.

// Make call async and add to "post-flow analysis store"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package maestro.orchestra.yaml
private const val DEFAULT_DIFF_THRESHOLD = 95

data class YamlAssertVisualAI(
val assertion: String?,
val assertion: String? = null,
val optional: Boolean = false,
val label: String? = null,
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,10 @@ data class YamlFluentCommand(
toggleAirplaneMode = YamlToggleAirplaneMode()
)

"assertVisualAI" -> YamlFluentCommand(
assertVisualAI = YamlAssertVisualAI()
)

else -> throw SyntaxError("Invalid command: \"$stringCommand\"")
}
}
Expand Down

0 comments on commit 89ca7ca

Please sign in to comment.