Skip to content

Commit

Permalink
Send GenerateContentRequest in CountTokensRequest (google-gemini#175
Browse files Browse the repository at this point in the history
)
  • Loading branch information
andrewheard authored May 29, 2024
1 parent d8b1fbb commit 97a81a2
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Sources/GoogleAI/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
struct CountTokensRequest {
let model: String
let contents: [ModelContent]
let generateContentRequest: GenerateContentRequest
let options: RequestOptions
}

Expand All @@ -42,7 +42,7 @@ public struct CountTokensResponse {
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CountTokensRequest: Encodable {
enum CodingKeys: CodingKey {
case contents
case generateContentRequest
}
}

Expand Down
1 change: 1 addition & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct GenerateContentRequest {
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentRequest: Encodable {
enum CodingKeys: String, CodingKey {
case model
case contents
case generationConfig
case safetySettings
Expand Down
13 changes: 11 additions & 2 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,18 @@ public final class GenerativeModel {
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> CountTokensResponse {
do {
let countTokensRequest = try CountTokensRequest(
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
contents: content(),
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
isStreaming: false,
options: requestOptions)
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
contents: content(),
generateContentRequest: generateContentRequest,
options: requestOptions
)
return try await generativeAIService.loadRequest(request: countTokensRequest)
Expand Down
144 changes: 144 additions & 0 deletions Tests/GoogleAITests/GenerateContentRequestTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import XCTest

@testable import GoogleGenerativeAI

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
final class GenerateContentRequestTests: XCTestCase {
let encoder = JSONEncoder()
let role = "test-role"
let prompt = "test-prompt"
let modelName = "test-model-name"

override func setUp() {
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
}

// MARK: GenerateContentRequest Encoding

func testEncodeRequest_allFieldsIncluded() throws {
let content = [ModelContent(role: role, parts: prompt)]
let request = GenerateContentRequest(
model: modelName,
contents: content,
generationConfig: GenerationConfig(temperature: 0.5),
safetySettings: [SafetySetting(
harmCategory: .dangerousContent,
threshold: .blockLowAndAbove
)],
tools: [Tool(functionDeclarations: [FunctionDeclaration(
name: "test-function-name",
description: "test-function-description",
parameters: nil
)])],
toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)),
systemInstruction: ModelContent(role: "system", parts: "test-system-instruction"),
isStreaming: false,
options: RequestOptions()
)

let jsonData = try encoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"contents" : [
{
"parts" : [
{
"text" : "\(prompt)"
}
],
"role" : "\(role)"
}
],
"generationConfig" : {
"temperature" : 0.5
},
"model" : "\(modelName)",
"safetySettings" : [
{
"category" : "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold" : "BLOCK_LOW_AND_ABOVE"
}
],
"systemInstruction" : {
"parts" : [
{
"text" : "test-system-instruction"
}
],
"role" : "system"
},
"toolConfig" : {
"functionCallingConfig" : {
"mode" : "AUTO"
}
},
"tools" : [
{
"functionDeclarations" : [
{
"description" : "test-function-description",
"name" : "test-function-name",
"parameters" : {
"type" : "OBJECT"
}
}
]
}
]
}
""")
}

func testEncodeRequest_optionalFieldsOmitted() throws {
let content = [ModelContent(role: role, parts: prompt)]
let request = GenerateContentRequest(
model: modelName,
contents: content,
generationConfig: nil,
safetySettings: nil,
tools: nil,
toolConfig: nil,
systemInstruction: nil,
isStreaming: false,
options: RequestOptions()
)

let jsonData = try encoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"contents" : [
{
"parts" : [
{
"text" : "\(prompt)"
}
],
"role" : "\(role)"
}
],
"model" : "\(modelName)"
}
""")
}
}

0 comments on commit 97a81a2

Please sign in to comment.