Skip to content

Commit

Permalink
Adds support for ChatGPT (gpt-3.5-turbo and gpt-3.5-turbo-0301) (#10)
Browse files Browse the repository at this point in the history
* Adds support for gpt-3.5-turbo and gpt-3.5-turbo-0301

* Adds a test for the chat API.
  • Loading branch information
arthurgarzajr authored Mar 2, 2023
1 parent 0ebaf8b commit 95e2a12
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 0 deletions.
38 changes: 38 additions & 0 deletions Sources/OpenAIKit/Chat/Chat.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import Foundation

/**
Given a prompt, the model will return one or more predicted chat completions, and can also return the probabilities of alternative tokens at each position.
*/
public struct Chat {
public let id: String
public let object: String
public let created: Date
public let model: String
public let choices: [Choice]
public let usage: Usage
}

extension Chat: Codable {}

extension Chat {
public struct Choice {
public let index: Int
public let message: Message
public let finishReason: String?
}
}

extension Chat.Choice: Codable {}

extension Chat {
public struct Message {
public let role: String
public let content: String
public init(role: String, content: String) {
self.role = role
self.content = content
}
}
}

extension Chat.Message: Codable {}
49 changes: 49 additions & 0 deletions Sources/OpenAIKit/Chat/ChatProvider.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
public struct ChatProvider {

private let requestHandler: RequestHandler

init(requestHandler: RequestHandler) {
self.requestHandler = requestHandler
}

/**
Create chat completion
POST
https://api.openai.com/v1/chat/completions
Creates a chat completion for the provided prompt and parameters
*/
public func create(
model: ModelID,
messages: [Chat.Message] = [],
temperature: Double = 1.0,
topP: Double = 1.0,
n: Int = 1,
stream: Bool = false,
stops: [String] = [],
presencePenalty: Double = 0.0,
frequencyPenalty: Double = 0.0,
logitBias: [String : Int] = [:],
user: String? = nil
) async throws -> Chat {

let request = try CreateChatRequest(
model: model.id,
messages: messages,
temperature: temperature,
topP: topP,
n: n,
stream: stream,
stops: stops,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
user: user
)
let chat: Chat = try await requestHandler.perform(request: request)

return chat

}
}
97 changes: 97 additions & 0 deletions Sources/OpenAIKit/Chat/CreateChatRequest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import AsyncHTTPClient
import NIOHTTP1
import Foundation

struct CreateChatRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/chat/completions"
let body: HTTPClient.Body?

init(
model: String,
messages: [Chat.Message],
temperature: Double,
topP: Double,
n: Int,
stream: Bool,
stops: [String],
presencePenalty: Double,
frequencyPenalty: Double,
logitBias: [String: Int],
user: String?
) throws {

let body = Body(
model: model,
messages: messages,
temperature: temperature,
topP: topP,
n: n,
stream: stream,
stops: stops,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
user: user
)

self.body = .data(try Self.encoder.encode(body))
}
}

extension CreateChatRequest {
struct Body: Encodable {
let model: String
let messages: [Chat.Message]
let temperature: Double
let topP: Double
let n: Int
let stream: Bool
let stops: [String]
let presencePenalty: Double
let frequencyPenalty: Double
let logitBias: [String: Int]
let user: String?

enum CodingKeys: CodingKey {
case model
case messages
case temperature
case topP
case n
case stream
case stop
case presencePenalty
case frequencyPenalty
case logitBias
case user
}

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(model, forKey: .model)

if !messages.isEmpty {
try container.encode(messages, forKey: .messages)
}

try container.encode(temperature, forKey: .temperature)
try container.encode(topP, forKey: .topP)
try container.encode(n, forKey: .n)
try container.encode(stream, forKey: .stream)

if !stops.isEmpty {
try container.encode(stops, forKey: .stop)
}

try container.encode(presencePenalty, forKey: .presencePenalty)
try container.encode(frequencyPenalty, forKey: .frequencyPenalty)

if !logitBias.isEmpty {
try container.encode(logitBias, forKey: .logitBias)
}

try container.encodeIfPresent(user, forKey: .user)
}
}
}
2 changes: 2 additions & 0 deletions Sources/OpenAIKit/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ public struct Client {

public let models: ModelProvider
public let completions: CompletionProvider
public let chats: ChatProvider
public let edits: EditProvider
public let images: ImageProvider
public let embeddings: EmbeddingProvider
Expand All @@ -25,6 +26,7 @@ public struct Client {

self.models = ModelProvider(requestHandler: requestHandler)
self.completions = CompletionProvider(requestHandler: requestHandler)
self.chats = ChatProvider(requestHandler: requestHandler)
self.edits = EditProvider(requestHandler: requestHandler)
self.images = ImageProvider(requestHandler: requestHandler)
self.embeddings = EmbeddingProvider(requestHandler: requestHandler)
Expand Down
2 changes: 2 additions & 0 deletions Sources/OpenAIKit/Model/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public protocol ModelID {

extension Model {
public enum GPT3: String, ModelID {
case gpt3_5Turbo = "gpt-3.5-turbo"
case gpt3_5Turbo0301 = "gpt-3.5-turbo-0301"
case textDavinci003 = "text-davinci-003"
case textDavinci002 = "text-davinci-002"
case textCurie001 = "text-curie-001"
Expand Down
14 changes: 14 additions & 0 deletions Tests/OpenAIKitTests/OpenAIKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ final class OpenAIKitTests: XCTestCase {
print(completion)
}

func test_createChat() async throws {
let completion = try await client.chats.create(
model: Model.GPT3.gpt3_5Turbo,
messages: [
Chat.Message(
role: "user",
content: "Write a haiku"
)
]
)

print(completion)
}

func test_createEdit() async throws {
let edit = try await client.edits.create(
input: "Whay day of the wek is it?",
Expand Down

0 comments on commit 95e2a12

Please sign in to comment.