-
Notifications
You must be signed in to change notification settings - Fork 273
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
20c8fc7
commit 89ca7ca
Showing
11 changed files
with
295 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package maestro.ai | ||
|
||
import java.io.Closeable | ||
|
||
data class CompletionData( | ||
val prompt: String, | ||
val model: String, | ||
val temperature: Float, | ||
val maxTokens: Int, | ||
val images: List<String>, | ||
val response: String, | ||
) | ||
|
||
abstract class AI : Closeable { | ||
|
||
abstract suspend fun chatCompletion( | ||
prompt: String, | ||
images: List<ByteArray> = listOf(), | ||
temperature: Double? = null, | ||
model: String? = null, | ||
maxTokens: Int? = null, | ||
imageDetail: String? = null, | ||
identifier: String? = null, | ||
): CompletionData | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package maestro.ai | ||
|
||
object Prediction { | ||
suspend fun findDefects(openAIClient: AI, screen: ByteArray, previousFalsePositives: List<String>): CompletionData { | ||
|
||
// List of failed attempts to not make up false positives: | ||
// |* If you don't see any defect, return "No defects found". | ||
// |* If you are sure there are no defects, return "No defects found". | ||
// |* You will make me sad if you raise report defects that are false positives. | ||
// |* Do not make up defects that are not present in the screenshot. It's fine if you don't find any defects. | ||
|
||
val prompt = """ | ||
|You are a QA engineer performing quality assurance for a mobile application. Identify any defects in the provided screenshot. | ||
| | ||
|RULES: | ||
|* All defects you find must belong to one of the following categories: | ||
|${categories.joinToString { "\n * ${it.first}: ${it.second}" }} | ||
| | ||
|* If you see defects, your response MUST only include defect name and reasoning for each defect. | ||
|* Provide response in the format: <defect name>:<reasoning> | ||
|* Do not raise false positives. Some example responses that have a high chance of being a false positive: | ||
| | ||
| * button is partially cropped at the bottom | ||
| * button is not aligned horizontall/vertically within its container | ||
| | ||
|${if (previousFalsePositives.isNotEmpty()) "Additionally, the following defects are false positives:" else ""} | ||
|${if (previousFalsePositives.isNotEmpty()) previousFalsePositives.joinToString("\n") { " * $it" } else ""} | ||
""".trimMargin("|") | ||
|
||
// println("Prompt:") | ||
// println(prompt) | ||
|
||
return openAIClient.chatCompletion( | ||
prompt, | ||
// model = "gpt-4o-2024-08-03", | ||
model = "gpt-4o", | ||
maxTokens = 4096, | ||
identifier = "find-defects", | ||
imageDetail = "high", | ||
images = listOf(screen), | ||
) | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
maestro-client/src/main/java/maestro/ai/openai/Client.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
package maestro.ai.openai | ||
|
||
|
||
import io.ktor.client.HttpClient | ||
import io.ktor.client.plugins.HttpTimeout | ||
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation | ||
import io.ktor.client.request.post | ||
import io.ktor.client.request.setBody | ||
import io.ktor.http.ContentType | ||
import io.ktor.http.contentType | ||
import io.ktor.util.encodeBase64 | ||
import kotlinx.serialization.encodeToString | ||
import kotlinx.serialization.json.Json | ||
import maestro.ai.AI | ||
import maestro.ai.CompletionData | ||
import org.slf4j.LoggerFactory | ||
|
||
private const val API_URL = "https://api.openai.com/v1/chat/completions" | ||
|
||
private val logger = LoggerFactory.getLogger(OpenAI::class.java) | ||
|
||
class OpenAI( | ||
private val apiKey: String, | ||
private val defaultModel: String = "gpt-4o", | ||
private val defaultTemperature: Double = 0.2, | ||
private val defaultMaxTokens: Int = 2048, | ||
private val defaultImageDetail: String = "low", | ||
) : AI() { | ||
private val client = HttpClient { | ||
install(ContentNegotiation) { | ||
Json { | ||
ignoreUnknownKeys = true | ||
} | ||
} | ||
|
||
install(HttpTimeout) { | ||
connectTimeoutMillis = 10000 | ||
socketTimeoutMillis = 60000 | ||
requestTimeoutMillis = 60000 | ||
} | ||
} | ||
|
||
private val json = Json { ignoreUnknownKeys = true } | ||
|
||
override suspend fun chatCompletion( | ||
prompt: String, | ||
images: List<ByteArray>, | ||
temperature: Double?, | ||
model: String?, | ||
maxTokens: Int?, | ||
imageDetail: String?, | ||
identifier: String?, | ||
): CompletionData { | ||
val imagesBase64 = images.map { it.encodeBase64() } | ||
|
||
// Fallback to OpenAI defaults | ||
val actualTemperature = temperature ?: defaultTemperature | ||
val actualModel = model ?: defaultModel | ||
val actualMaxTokens = maxTokens ?: defaultMaxTokens | ||
val actualImageDetail = imageDetail ?: defaultImageDetail | ||
|
||
val imagesContent = imagesBase64.map { image -> | ||
ContentDetail( | ||
type = "image_url", | ||
imageUrl = Base64Image(url = "data:image/png;base64,$image", detail = actualImageDetail), | ||
) | ||
} | ||
val textContent = ContentDetail(type = "text", text = prompt) | ||
|
||
val messages = listOf( | ||
MessageContent( | ||
role = "user", | ||
content = imagesContent + textContent, | ||
) | ||
) | ||
|
||
val chatCompletionRequest = ChatCompletionRequest( | ||
model = actualModel, | ||
temperature = actualTemperature, | ||
messages = messages, | ||
maxTokens = actualMaxTokens, | ||
seed = 1566, | ||
responseFormat = null, | ||
) | ||
|
||
var openAiResponse = client.post(API_URL) { | ||
contentType(ContentType.Application.Json) | ||
headers["Authorization"] = "Bearer $apiKey" | ||
setBody( | ||
Json.encodeToString( | ||
OpenAIChatCompletionRequest( | ||
model = mod, | ||
temperature = temp, | ||
messages = msgs, | ||
max_tokens = mt, | ||
seed = 1566, | ||
response_format = if (model == "gpt-4-1106-preview" && prompt.contains("JSON")) ResponseFormat( | ||
"json_object" | ||
) else null | ||
) | ||
) | ||
) | ||
} | ||
} | ||
|
||
override fun close() = client.close() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package maestro.ai.openai | ||
|
||
import kotlinx.serialization.SerialName | ||
import kotlinx.serialization.Serializable | ||
|
||
@Serializable | ||
data class ChatCompletionRequest( | ||
val model: String, | ||
val messages: List<MessageContent>, | ||
val temperature: Double, | ||
@SerialName("max_tokens") val maxTokens: Int, | ||
@SerialName("response_format") val responseFormat: ResponseFormat?, | ||
val seed: Int, | ||
) | ||
|
||
@Serializable | ||
data class ResponseFormat( | ||
val type: String, | ||
) | ||
|
||
@Serializable | ||
data class MessageContent( | ||
val role: String, | ||
val content: List<ContentDetail>, | ||
) | ||
|
||
@Serializable | ||
data class ContentDetail( | ||
val type: String, | ||
val text: String? = null, | ||
@SerialName("image_url") val imageUrl: Base64Image? = null, | ||
) | ||
|
||
@Serializable | ||
data class Base64Image( | ||
val url: String, | ||
val detail: String, | ||
) |
42 changes: 42 additions & 0 deletions
42
maestro-client/src/main/java/maestro/ai/openai/Response.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package maestro.ai.openai | ||
|
||
import kotlinx.serialization.SerialName | ||
import kotlinx.serialization.Serializable | ||
|
||
@Serializable | ||
data class ChatCompletionResponse( | ||
val id: String, | ||
val `object`: String, | ||
val created: Long, | ||
val model: String, | ||
@SerialName("system_fingerprint") val systemFingerprint: String? = null, | ||
val usage: Usage? = null, | ||
val choices: List<Choice>, | ||
) | ||
|
||
@Serializable | ||
data class Usage( | ||
@SerialName("prompt_tokens") val promptTokens: Int, | ||
@SerialName("completion_tokens") val completionTokens: Int? = null, | ||
@SerialName("total_tokens") val totalTokens: Int, | ||
) | ||
|
||
@Serializable | ||
data class Choice( | ||
val message: Message, | ||
@SerialName("finish_details") val finishDetails: FinishDetails? = null, | ||
val index: Int, | ||
@SerialName("finish_reason") val finishReason: String? = null, | ||
) | ||
|
||
@Serializable | ||
data class Message( | ||
val role: String, | ||
val content: String, | ||
) | ||
|
||
@Serializable | ||
data class FinishDetails( | ||
val type: String, | ||
val stop: String? = null, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters