Skip to content

Commit

Permalink
Add Stream support when create Chat completions (#36)
Browse files Browse the repository at this point in the history
* Add support to stream chats

* remove dead API key
  • Loading branch information
dylanshine committed Apr 13, 2023
1 parent a25c282 commit a7ef198
Show file tree
Hide file tree
Showing 19 changed files with 212 additions and 68 deletions.
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Audio/CreateTranscriptionRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import AsyncHTTPClient
struct CreateTranscriptionRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/audio/transcriptions"
let body: HTTPClient.Body?
let body: Data?
private let boundary = UUID().uuidString

var headers: HTTPHeaders {
Expand Down Expand Up @@ -58,7 +58,7 @@ struct CreateTranscriptionRequest: Request {
builder.addTextField(named: "language", value: language.rawValue)
}

self.body = .data(builder.build())
self.body = builder.build()
}
}

4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Audio/CreateTranslationRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import AsyncHTTPClient
struct CreateTranslationRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/audio/translations"
let body: HTTPClient.Body?
let body: Data?
private let boundary = UUID().uuidString

var headers: HTTPHeaders {
Expand Down Expand Up @@ -53,6 +53,6 @@ struct CreateTranslationRequest: Request {
builder.addTextField(named: "temperature", value: String(temperature))
}

self.body = .data(builder.build())
self.body = builder.build()
}
}
19 changes: 1 addition & 18 deletions Sources/OpenAIKit/Chat/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,11 @@ extension Chat {
public let index: Int
public let message: Message
public let finishReason: FinishReason?

public enum FinishReason: String {
/// API returned complete model output
case stop

/// Incomplete model output due to max_tokens parameter or token limit
case length

/// Omitted content due to a flag from our content filters
case contentFilter = "content_filter"
}
}
}

extension Chat.Choice: Codable {}

extension Chat.Choice.FinishReason: Codable {}

extension Chat {
public enum Message {
case system(content: String)
Expand Down Expand Up @@ -86,11 +73,7 @@ extension Chat.Message: Codable {
extension Chat.Message {
public var content: String {
switch self {
case .system(let content):
return content
case .user(let content):
return content
case .assistant(let content):
case .system(let content), .user(let content), .assistant(let content):
return content
}
}
Expand Down
54 changes: 49 additions & 5 deletions Sources/OpenAIKit/Chat/ChatProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ public struct ChatProvider {
temperature: Double = 1.0,
topP: Double = 1.0,
n: Int = 1,
stream: Bool = false,
stops: [String] = [],
maxTokens: Int? = nil,
presencePenalty: Double = 0.0,
Expand All @@ -35,17 +34,62 @@ public struct ChatProvider {
temperature: temperature,
topP: topP,
n: n,
stream: stream,
stream: false,
stops: stops,
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
user: user
)
let chat: Chat = try await requestHandler.perform(request: request)

return chat

return try await requestHandler.perform(request: request)

}

/**
Create chat completion
POST
https://api.openai.com/v1/chat/completions
Creates a chat completion for the provided prompt and parameters
stream If set, partial message deltas will be sent, like in ChatGPT.
Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
*/
public func stream(
model: ModelID,
messages: [Chat.Message] = [],
temperature: Double = 1.0,
topP: Double = 1.0,
n: Int = 1,
stops: [String] = [],
maxTokens: Int? = nil,
presencePenalty: Double = 0.0,
frequencyPenalty: Double = 0.0,
logitBias: [String : Int] = [:],
user: String? = nil
) async throws -> AsyncThrowingStream<ChatStream, Error> {

let request = try CreateChatRequest(
model: model.id,
messages: messages,
temperature: temperature,
topP: topP,
n: n,
stream: true,
stops: stops,
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
user: user
)

return try await requestHandler.stream(request: request)

}
}
32 changes: 32 additions & 0 deletions Sources/OpenAIKit/Chat/ChatStream.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import Foundation

public struct ChatStream {
public let id: String
public let object: String
public let created: Date
public let model: String
public let choices: [ChatStream.Choice]
}

extension ChatStream: Codable {}

extension ChatStream {
public struct Choice {
let index: Int
let finishReason: FinishReason?
let delta: ChatStream.Choice.Message
}
}

extension ChatStream.Choice: Codable {}

extension ChatStream.Choice {
public struct Message {
let content: String?
let role: String?
}
}

extension ChatStream.Choice.Message: Codable {}


4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Chat/CreateChatRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateChatRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/chat/completions"
let body: HTTPClient.Body?
let body: Data?

init(
model: String,
Expand Down Expand Up @@ -37,7 +37,7 @@ struct CreateChatRequest: Request {
user: user
)

self.body = .data(try Self.encoder.encode(body))
self.body = try Self.encoder.encode(body)
}
}

Expand Down
14 changes: 14 additions & 0 deletions Sources/OpenAIKit/Chat/FinishReason.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import Foundation

public enum FinishReason: String {
/// API returned complete model output
case stop

/// Incomplete model output due to max_tokens parameter or token limit
case length

/// Omitted content due to a flag from our content filters
case contentFilter = "content_filter"
}

extension FinishReason: Codable {}
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Completion/CreateCompletionRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateCompletionRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/completions"
let body: HTTPClient.Body?
let body: Data?

init(
model: String,
Expand Down Expand Up @@ -45,7 +45,7 @@ struct CreateCompletionRequest: Request {
user: user
)

self.body = .data(try Self.encoder.encode(body))
self.body = try Self.encoder.encode(body)
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Edit/CreateEditRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateEditRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/edits"
let body: HTTPClient.Body?
let body: Data?

init(
model: String,
Expand All @@ -25,7 +25,7 @@ struct CreateEditRequest: Request {
topP: topP
)

self.body = .data(try Self.encoder.encode(body))
self.body = try Self.encoder.encode(body)
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Embedding/CreateEmbeddingRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateEmbeddingRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/embeddings"
let body: HTTPClient.Body?
let body: Data?

init(
model: String,
Expand All @@ -19,7 +19,7 @@ struct CreateEmbeddingRequest: Request {
user: user
)

self.body = .data(try Self.encoder.encode(body))
self.body = try Self.encoder.encode(body)
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/File/Request/UploadFileRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import AsyncHTTPClient
struct UploadFileRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/files"
let body: HTTPClient.Body?
let body: Data?
private let boundary = UUID().uuidString

var headers: HTTPHeaders {
Expand All @@ -30,7 +30,7 @@ struct UploadFileRequest: Request {

builder.addTextField(named: "purpose", value: purpose.rawValue)

self.body = .data(builder.build())
self.body = builder.build()
}
}

4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Image/Requests/CreateImageEditRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateImageEditRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/images/edits"
let body: HTTPClient.Body?
let body: Data?
private let boundary = UUID().uuidString

var headers: HTTPHeaders {
Expand Down Expand Up @@ -49,6 +49,6 @@ struct CreateImageEditRequest: Request {
builder.addTextField(named: "user", value: user)
}

self.body = .data(builder.build())
self.body = builder.build()
}
}
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Image/Requests/CreateImageRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateImageRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/images/generations"
let body: HTTPClient.Body?
let body: Data?

init(
prompt: String,
Expand All @@ -21,7 +21,7 @@ struct CreateImageRequest: Request {
user: user
)

self.body = .data(try Self.encoder.encode(body))
self.body = try Self.encoder.encode(body)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateImageVariationRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/images/variations"
let body: HTTPClient.Body?
let body: Data?
private let boundary = UUID().uuidString

var headers: HTTPHeaders {
Expand Down Expand Up @@ -37,6 +37,6 @@ struct CreateImageVariationRequest: Request {
builder.addTextField(named: "user", value: user)
}

self.body = .data(builder.build())
self.body = builder.build()
}
}
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Moderation/CreateModerationRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Foundation
struct CreateModerationRequest: Request {
let method: HTTPMethod = .POST
let path = "/v1/moderations"
let body: HTTPClient.Body?
let body: Data?

init(
input: String,
Expand All @@ -17,7 +17,7 @@ struct CreateModerationRequest: Request {
model: model
)

self.body = .data(try Self.encoder.encode(body))
self.body = try Self.encoder.encode(body)
}
}

Expand Down
4 changes: 2 additions & 2 deletions Sources/OpenAIKit/Request Handler/Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ protocol Request {
var scheme: String { get }
var host: String { get }
var path: String { get }
var body: HTTPClient.Body? { get }
var body: Data? { get }
var headers: HTTPHeaders { get }
var keyDecodingStrategy: JSONDecoder.KeyDecodingStrategy { get }
var dateDecodingStrategy: JSONDecoder.DateDecodingStrategy { get }
Expand All @@ -18,7 +18,7 @@ extension Request {

var scheme: String { "https" }
var host: String { "api.openai.com" }
var body: HTTPClient.Body? { nil }
var body: Data? { nil }

var headers: HTTPHeaders {
var headers = HTTPHeaders()
Expand Down
Loading

0 comments on commit a7ef198

Please sign in to comment.