Skip to content

Commit

Permalink
add Message & FinishReason enums
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanshine committed Mar 17, 2023
1 parent 6652520 commit 504969a
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 23 deletions.
64 changes: 55 additions & 9 deletions Sources/OpenAIKit/Chat/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,67 @@ extension Chat {
public struct Choice {
public let index: Int
public let message: Message
public let finishReason: String?
public let finishReason: FinishReason?

public enum FinishReason: String {
/// API returned complete model output
case stop

/// Incomplete model output due to max_tokens parameter or token limit
case length

/// Omitted content due to a flag from our content filters
case contentFilter = "content_filter"
}
}
}

extension Chat.Choice: Codable {}

extension Chat.Choice.FinishReason: Codable {}

extension Chat {
public struct Message {
public let role: String
public let content: String
public init(role: String, content: String) {
self.role = role
self.content = content
}
public enum Message {
case system(content: String)
case user(content: String)
case assistant(content: String)
}
}

extension Chat.Message: Codable {}
extension Chat.Message: Codable {
private enum CodingKeys: String, CodingKey {
case role
case content
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let role = try container.decode(String.self, forKey: .role)
let content = try container.decode(String.self, forKey: .content)
switch role {
case "system":
self = .system(content: content)
case "user":
self = .user(content: content)
case "assistant":
self = .assistant(content: content)
default:
throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid type")
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case .system(let content):
try container.encode("system", forKey: .role)
try container.encode(content, forKey: .content)
case .user(let content):
try container.encode("user", forKey: .role)
try container.encode(content, forKey: .content)
case .assistant(let content):
try container.encode("assistant", forKey: .role)
try container.encode(content, forKey: .content)
}
}
}
2 changes: 1 addition & 1 deletion Sources/OpenAIKit/Request Handler/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ extension Request {
}

var keyDecodingStrategy: JSONDecoder.KeyDecodingStrategy { .convertFromSnakeCase }
var dateDecodingStrategy: JSONDecoder.DateDecodingStrategy { .millisecondsSince1970 }
var dateDecodingStrategy: JSONDecoder.DateDecodingStrategy { .secondsSince1970 }
}

extension JSONEncoder {
Expand Down
11 changes: 2 additions & 9 deletions Sources/OpenAIKit/Usage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@ import Foundation

public struct Usage {
public let promptTokens: Int
public let completionTokens: Int
public let completionTokens: Int?
public let totalTokens: Int
}

extension Usage: Codable {
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.promptTokens = try container.decodeIfPresent(Int.self, forKey: .promptTokens) ?? 0
self.completionTokens = try container.decodeIfPresent(Int.self, forKey: .completionTokens) ?? 0
self.totalTokens = try container.decodeIfPresent(Int.self, forKey: .totalTokens) ?? 0
}
}
extension Usage: Codable {}
153 changes: 153 additions & 0 deletions Tests/OpenAIKitTests/MessageTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//
// MessageTests.swift
//
//
// Created by Ronald Mannak on 3/6/23.
//
import XCTest
@testable import OpenAIKit

final class MessageTests: XCTestCase {

let decoder = JSONDecoder()
let encoder = JSONEncoder()

override func setUpWithError() throws {
decoder.dateDecodingStrategy = .secondsSince1970
decoder.keyDecodingStrategy = .convertFromSnakeCase
encoder.keyEncodingStrategy = .convertToSnakeCase
}

override func tearDownWithError() throws {
// Put teardown code here. This method is called after the invocation of each test method in the class.
}

func testFinishReasonContentFilterCoding() throws {
let filter: Chat.Choice.FinishReason = .contentFilter
let encoded = try encoder.encode(filter)
XCTAssertEqual("\"content_filter\"", String(data: encoded, encoding: .utf8)!)

let decoded = try decoder.decode(Chat.Choice.FinishReason.self, from: encoded)
XCTAssertEqual(decoded, Chat.Choice.FinishReason.contentFilter)
XCTAssertNotEqual(decoded, Chat.Choice.FinishReason.length)
XCTAssertNotEqual(decoded, Chat.Choice.FinishReason.stop)
}

func testFinishReasonLengthCoding() throws {
let filter: Chat.Choice.FinishReason = .length
let encoded = try encoder.encode(filter)
XCTAssertEqual("\"length\"", String(data: encoded, encoding: .utf8)!)

let decoded = try decoder.decode(Chat.Choice.FinishReason.self, from: encoded)
XCTAssertEqual(decoded, Chat.Choice.FinishReason.length)
XCTAssertNotEqual(decoded, Chat.Choice.FinishReason.contentFilter)
XCTAssertNotEqual(decoded, Chat.Choice.FinishReason.stop)
}

func testFinishStopCoding() throws {
let filter: Chat.Choice.FinishReason = .stop
let encoded = try encoder.encode(filter)
XCTAssertEqual("\"stop\"", String(data: encoded, encoding: .utf8)!)

let decoded = try decoder.decode(Chat.Choice.FinishReason.self, from: encoded)
XCTAssertEqual(decoded, Chat.Choice.FinishReason.stop)
XCTAssertNotEqual(decoded, Chat.Choice.FinishReason.contentFilter)
XCTAssertNotEqual(decoded, Chat.Choice.FinishReason.length)
}


func testMessageCoding() throws {
let messageData = """
{"role": "user", "content": "Translate the following English text to French: "}
""".data(using: .utf8)!
let message = try decoder.decode(Chat.Message.self, from: messageData)
switch message {
case .system(_):
XCTFail("incorrect role")
case .user(let content):
XCTAssertEqual(content, "Translate the following English text to French: ")
case .assistant(_):
XCTFail("incorrect role")
}
}

func testMessageRoundtrip() throws {
let message = Chat.Message.system(content: "You are a helpful assistant that translates English to French.")
let encoded = try encoder.encode(message)
let decoded = try decoder.decode(Chat.Message.self, from: encoded)
print(String(data: encoded, encoding: .utf8)!)
switch decoded {
case .system(let content):
guard case let .system(content: original) = message else {
XCTFail()
return
}
XCTAssertEqual(content, original)
case .user(_):
XCTFail("incorrect role")
case .assistant(_):
XCTFail("incorrect role")
}
}

func testChatDecoding() throws {
let exampleResponse = """
{
"id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve",
"object": "chat.completion",
"created": 1677649420,
"model": "gpt-3.5-turbo",
"usage": {"prompt_tokens": 56, "completion_tokens": 31, "total_tokens": 87},
"choices": [
{
"message": {
"role": "assistant",
"content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers."},
"finish_reason": "stop",
"index": 0
}
]
}
""".data(using: .utf8)!
let chat = try decoder.decode(Chat.self, from: exampleResponse)
XCTAssertEqual(chat.id, "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve")
XCTAssertEqual(chat.created.timeIntervalSince1970, 1677649420)
XCTAssertEqual(chat.usage.promptTokens, 56)
XCTAssertEqual(chat.usage.completionTokens, 31)
XCTAssertEqual(chat.usage.totalTokens, 87)

XCTAssertEqual(chat.choices.count, 1)
let firstChoice = chat.choices.first!
XCTAssertEqual(firstChoice.index, 0)
switch firstChoice.message {
case .system(_):
XCTFail()
case .assistant(let content):
XCTAssertEqual(content, "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers.")
case .user(_):
XCTFail()
}
}

func testChatRequest() throws {
let request = try CreateChatRequest(
model: "gpt-3.5-turbo", //.gpt3_5Turbo,
messages: [
.system(content: "You are Malcolm Tucker from The Thick of It, an unfriendly assistant for writing mail and explaining science and history. You write text in your voice for me."),
.user(content: "tell me a joke"),
],
temperature: 1.0,
topP: 1.0,
n: 1,
stream: false,
stops: [],
maxTokens: nil,
presencePenalty: 0.0,
frequencyPenalty: 0.0,
logitBias: [:],
user: nil
)

print(request.body)
}
}
5 changes: 1 addition & 4 deletions Tests/OpenAIKitTests/OpenAIKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ final class OpenAIKitTests: XCTestCase {
let completion = try await client.chats.create(
model: Model.GPT3.gpt3_5Turbo,
messages: [
Chat.Message(
role: "user",
content: "Write a haiku"
)
.user(content: "Write a haiki")
]
)

Expand Down

0 comments on commit 504969a

Please sign in to comment.