Skip to content

Commit

Permalink
Add code execution support (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Aug 1, 2024
1 parent 672fdb7 commit 2bf9fe5
Show file tree
Hide file tree
Showing 11 changed files with 644 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class FunctionCallingViewModel: ObservableObject {
case let .functionCall(functionCall):
messages.insert(functionCall.chatMessage(), at: messages.count - 1)
functionCalls.append(functionCall)
case .data, .fileData, .functionResponse:
case .data, .fileData, .functionResponse, .executableCode, .codeExecutionResult:
fatalError("Unsupported response content.")
}
}
Expand Down
3 changes: 2 additions & 1 deletion Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ public class Chat {
case let .text(str):
combinedText += str

case .data, .fileData, .functionCall, .functionResponse:
case .data, .fileData, .functionCall, .functionResponse, .executableCode,
.codeExecutionResult:
// Don't combine it, just add to the content. If there's any text pending, add that as
// a part.
if !combinedText.isEmpty {
Expand Down
85 changes: 84 additions & 1 deletion Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ public struct Tool {
/// A list of `FunctionDeclarations` available to the model.
let functionDeclarations: [FunctionDeclaration]?

/// Enables the model to execute code as part of generation.
let codeExecution: CodeExecution?

/// Constructs a new `Tool`.
///
/// - Parameters:
Expand All @@ -172,8 +175,11 @@ public struct Tool {
/// populating ``FunctionCall`` in the response. The next conversation turn may contain a
/// ``FunctionResponse`` in ``ModelContent/Part/functionResponse(_:)`` with the
/// ``ModelContent/role`` "function", providing generation context for the next model turn.
public init(functionDeclarations: [FunctionDeclaration]?) {
/// - codeExecution: Enables the model to execute code as part of generation, if provided.
public init(functionDeclarations: [FunctionDeclaration]? = nil,
codeExecution: CodeExecution? = nil) {
self.functionDeclarations = functionDeclarations
self.codeExecution = codeExecution
}
}

Expand Down Expand Up @@ -244,6 +250,55 @@ public struct FunctionResponse: Equatable {
}
}

/// Tool that executes code generated by the model, automatically returning the result to the model.
///
/// This type has no fields. See ``ExecutableCode`` and ``CodeExecutionResult``, which are only
/// generated when using this tool.
public struct CodeExecution {
/// Constructs a new `CodeExecution` tool.
public init() {}
}

/// Code generated by the model that is meant to be executed, and the result returned to the model.
///
/// Only generated when using the ``CodeExecution`` tool, in which case the code will automatically
/// be executed, and a corresponding ``CodeExecutionResult`` will also be generated.
public struct ExecutableCode: Equatable {
/// The programming language of the ``code``.
public let language: String

/// The code to be executed.
public let code: String
}

/// Result of executing the ``ExecutableCode``.
///
/// Only generated when using the ``CodeExecution`` tool, and always follows a part containing the
/// ``ExecutableCode``.
public struct CodeExecutionResult: Equatable {
/// Possible outcomes of the code execution.
public enum Outcome: String {
/// An unrecognized code execution outcome was provided.
case unknown = "OUTCOME_UNKNOWN"
/// Unspecified status; this value should not be used.
case unspecified = "OUTCOME_UNSPECIFIED"
/// Code execution completed successfully.
case ok = "OUTCOME_OK"
/// Code execution finished but with a failure; ``CodeExecutionResult/output`` should contain
/// the failure details from `stderr`.
case failed = "OUTCOME_FAILED"
/// Code execution ran for too long, and was cancelled. There may or may not be a partial
/// ``CodeExecutionResult/output`` present.
case deadlineExceeded = "OUTCOME_DEADLINE_EXCEEDED"
}

/// Outcome of the code execution.
public let outcome: Outcome

/// Contains `stdout` when code execution is successful, `stderr` or other description otherwise.
public let output: String
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
Expand Down Expand Up @@ -293,3 +348,31 @@ extension FunctionCallingConfig.Mode: Encodable {}
extension ToolConfig: Encodable {}

extension FunctionResponse: Encodable {}

extension CodeExecution: Encodable {}

extension ExecutableCode: Codable {}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CodeExecutionResult.Outcome: Codable {
public init(from decoder: any Decoder) throws {
let value = try decoder.singleValueContainer().decode(String.self)
guard let decodedOutcome = CodeExecutionResult.Outcome(rawValue: value) else {
Logging.default
.error("[GoogleGenerativeAI] Unrecognized Outcome with value \"\(value)\".")
self = .unknown
return
}

self = decodedOutcome
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CodeExecutionResult: Codable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
outcome = try container.decode(Outcome.self, forKey: .outcome)
output = try container.decodeIfPresent(String.self, forKey: .output) ?? ""
}
}
21 changes: 18 additions & 3 deletions Sources/GoogleAI/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,31 @@ public struct GenerateContentResponse {
return nil
}
let textValues: [String] = candidate.content.parts.compactMap { part in
guard case let .text(text) = part else {
switch part {
case let .text(text):
return text
case let .executableCode(executableCode):
let codeBlockLanguage: String
if executableCode.language == "LANGUAGE_UNSPECIFIED" {
codeBlockLanguage = ""
} else {
codeBlockLanguage = executableCode.language.lowercased()
}
return "```\(codeBlockLanguage)\n\(executableCode.code)\n```"
case let .codeExecutionResult(codeExecutionResult):
if codeExecutionResult.output.isEmpty {
return nil
}
return "```\n\(codeExecutionResult.output)\n```"
case .data, .fileData, .functionCall, .functionResponse:
return nil
}
return text
}
guard textValues.count > 0 else {
Logging.default.error("Could not get a text part from the first candidate.")
return nil
}
return textValues.joined(separator: " ")
return textValues.joined(separator: "\n")
}

/// Returns function calls found in any `Part`s of the first candidate of the response, if any.
Expand Down
19 changes: 19 additions & 0 deletions Sources/GoogleAI/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ public struct ModelContent: Equatable {
/// A response to a function call.
case functionResponse(FunctionResponse)

/// Code generated by the model that is meant to be executed.
case executableCode(ExecutableCode)

/// Result of executing the ``ExecutableCode``.
case codeExecutionResult(CodeExecutionResult)

// MARK: Convenience Initializers

/// Convenience function for populating a Part with JPEG data.
Expand Down Expand Up @@ -129,6 +135,8 @@ extension ModelContent.Part: Codable {
case fileData
case functionCall
case functionResponse
case executableCode
case codeExecutionResult
}

enum InlineDataKeys: String, CodingKey {
Expand Down Expand Up @@ -164,6 +172,10 @@ extension ModelContent.Part: Codable {
try container.encode(functionCall, forKey: .functionCall)
case let .functionResponse(functionResponse):
try container.encode(functionResponse, forKey: .functionResponse)
case let .executableCode(executableCode):
try container.encode(executableCode, forKey: .executableCode)
case let .codeExecutionResult(codeExecutionResult):
try container.encode(codeExecutionResult, forKey: .codeExecutionResult)
}
}

Expand All @@ -181,6 +193,13 @@ extension ModelContent.Part: Codable {
self = .data(mimetype: mimetype, bytes)
} else if values.contains(.functionCall) {
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
} else if values.contains(.executableCode) {
self = try .executableCode(values.decode(ExecutableCode.self, forKey: .executableCode))
} else if values.contains(.codeExecutionResult) {
self = try .codeExecutionResult(values.decode(
CodeExecutionResult.self,
forKey: .codeExecutionResult
))
} else {
throw DecodingError.dataCorrupted(.init(
codingPath: [CodingKeys.text, CodingKeys.inlineData],
Expand Down
154 changes: 154 additions & 0 deletions Tests/GoogleAITests/CodeExecutionTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest

@testable import GoogleGenerativeAI

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
final class CodeExecutionTests: XCTestCase {
let decoder = JSONDecoder()
let encoder = JSONEncoder()

let languageKey = "language"
let languageValue = "PYTHON"
let codeKey = "code"
let codeValue = "print('Hello, world!')"
let outcomeKey = "outcome"
let outcomeValue = "OUTCOME_OK"
let outputKey = "output"
let outputValue = "Hello, world!"

override func setUp() {
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
}

func testEncodeCodeExecution() throws {
let jsonData = try encoder.encode(CodeExecution())

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
}
""")
}

func testDecodeExecutableCode() throws {
let expectedExecutableCode = ExecutableCode(language: languageValue, code: codeValue)
let json = """
{
"\(languageKey)": "\(languageValue)",
"\(codeKey)": "\(codeValue)"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let executableCode = try XCTUnwrap(decoder.decode(ExecutableCode.self, from: jsonData))

XCTAssertEqual(executableCode, expectedExecutableCode)
}

func testEncodeExecutableCode() throws {
let executableCode = ExecutableCode(language: languageValue, code: codeValue)

let jsonData = try encoder.encode(executableCode)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"\(codeKey)" : "\(codeValue)",
"\(languageKey)" : "\(languageValue)"
}
""")
}

func testDecodeCodeExecutionResultOutcome_ok() throws {
let expectedOutcome = CodeExecutionResult.Outcome.ok
let json = "\"\(outcomeValue)\""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let outcome = try XCTUnwrap(decoder.decode(CodeExecutionResult.Outcome.self, from: jsonData))

XCTAssertEqual(outcome, expectedOutcome)
}

func testDecodeCodeExecutionResultOutcome_unknown() throws {
let expectedOutcome = CodeExecutionResult.Outcome.unknown
let json = "\"OUTCOME_NEW_VALUE\""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let outcome = try XCTUnwrap(decoder.decode(CodeExecutionResult.Outcome.self, from: jsonData))

XCTAssertEqual(outcome, expectedOutcome)
}

func testEncodeCodeExecutionResultOutcome() throws {
let jsonData = try encoder.encode(CodeExecutionResult.Outcome.ok)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, "\"\(outcomeValue)\"")
}

func testDecodeCodeExecutionResult() throws {
let expectedCodeExecutionResult = CodeExecutionResult(outcome: .ok, output: "Hello, world!")
let json = """
{
"\(outcomeKey)": "\(outcomeValue)",
"\(outputKey)": "\(outputValue)"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let codeExecutionResult = try XCTUnwrap(decoder.decode(
CodeExecutionResult.self,
from: jsonData
))

XCTAssertEqual(codeExecutionResult, expectedCodeExecutionResult)
}

func testDecodeCodeExecutionResult_missingOutput() throws {
let expectedCodeExecutionResult = CodeExecutionResult(outcome: .deadlineExceeded, output: "")
let json = """
{
"\(outcomeKey)": "OUTCOME_DEADLINE_EXCEEDED"
}
"""
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let codeExecutionResult = try XCTUnwrap(decoder.decode(
CodeExecutionResult.self,
from: jsonData
))

XCTAssertEqual(codeExecutionResult, expectedCodeExecutionResult)
}

func testEncodeCodeExecutionResult() throws {
let codeExecutionResult = CodeExecutionResult(outcome: .ok, output: outputValue)

let jsonData = try encoder.encode(codeExecutionResult)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"\(outcomeKey)" : "\(outcomeValue)",
"\(outputKey)" : "\(outputValue)"
}
""")
}
}
Loading

0 comments on commit 2bf9fe5

Please sign in to comment.