Skip to content

Commit 12ba7f4

Browse files
authored
Merge pull request #57 from MacPaw/streaming
Add streaming session and ability to use streaming
2 parents eefb14b + 6bb1456 commit 12ba7f4

21 files changed

+474
-45
lines changed

Demo/DemoChat/Sources/ChatStore.swift

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,53 +53,71 @@ public final class ChatStore: ObservableObject {
5353
}
5454

5555
@MainActor
56-
func sendMessage(_ message: Message, conversationId: Conversation.ID) async {
56+
func sendMessage(
57+
_ message: Message,
58+
conversationId: Conversation.ID,
59+
model: Model
60+
) async {
5761
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
5862
return
5963
}
6064
conversations[conversationIndex].messages.append(message)
6165

62-
await completeChat(conversationId: conversationId)
66+
await completeChat(
67+
conversationId: conversationId,
68+
model: model
69+
)
6370
}
6471

6572
@MainActor
66-
func completeChat(conversationId: Conversation.ID) async {
73+
func completeChat(
74+
conversationId: Conversation.ID,
75+
model: Model
76+
) async {
6777
guard let conversation = conversations.first(where: { $0.id == conversationId }) else {
6878
return
6979
}
7080

7181
conversationErrors[conversationId] = nil
7282

7383
do {
74-
let response = try await openAIClient.chats(
84+
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
85+
return
86+
}
87+
88+
let chatsStream: AsyncThrowingStream<ChatStreamResult, Error> = openAIClient.chatsStream(
7589
query: ChatQuery(
76-
model: .gpt3_5Turbo,
90+
model: model,
7791
messages: conversation.messages.map { message in
7892
Chat(role: message.role, content: message.content)
7993
}
8094
)
8195
)
82-
83-
guard let conversationIndex = conversations.firstIndex(where: { $0.id == conversationId }) else {
84-
return
85-
}
86-
87-
let existingMessages = conversations[conversationIndex].messages
88-
89-
for completionMessage in response.choices.map(\.message) {
90-
let message = Message(
91-
id: response.id,
92-
role: completionMessage.role,
93-
content: completionMessage.content,
94-
createdAt: Date(timeIntervalSince1970: TimeInterval(response.created))
95-
)
96-
97-
if existingMessages.contains(message) {
98-
continue
96+
97+
for try await partialChatResult in chatsStream {
98+
for choice in partialChatResult.choices {
99+
let existingMessages = conversations[conversationIndex].messages
100+
let message = Message(
101+
id: partialChatResult.id,
102+
role: choice.delta.role ?? .assistant,
103+
content: choice.delta.content ?? "",
104+
createdAt: Date(timeIntervalSince1970: TimeInterval(partialChatResult.created))
105+
)
106+
if let existingMessageIndex = existingMessages.firstIndex(where: { $0.id == partialChatResult.id }) {
107+
// Meld into previous message
108+
let previousMessage = existingMessages[existingMessageIndex]
109+
let combinedMessage = Message(
110+
id: message.id, // id stays the same for different deltas
111+
role: message.role,
112+
content: previousMessage.content + message.content,
113+
createdAt: message.createdAt
114+
)
115+
conversations[conversationIndex].messages[existingMessageIndex] = combinedMessage
116+
} else {
117+
conversations[conversationIndex].messages.append(message)
118+
}
99119
}
100-
conversations[conversationIndex].messages.append(message)
101120
}
102-
103121
} catch {
104122
conversationErrors[conversationId] = error
105123
}

Demo/DemoChat/Sources/UI/ChatView.swift

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public struct ChatView: View {
4646
DetailView(
4747
conversation: conversation,
4848
error: store.conversationErrors[conversation.id],
49-
sendMessage: { message in
49+
sendMessage: { message, selectedModel in
5050
Task {
5151
await store.sendMessage(
5252
Message(
@@ -55,7 +55,8 @@ public struct ChatView: View {
5555
content: message,
5656
createdAt: dateProvider()
5757
),
58-
conversationId: conversation.id
58+
conversationId: conversation.id,
59+
model: selectedModel
5960
)
6061
}
6162
}

Demo/DemoChat/Sources/UI/DetailView.swift

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,20 @@ import UIKit
1010
#elseif os(macOS)
1111
import AppKit
1212
#endif
13+
import OpenAI
1314
import SwiftUI
1415

1516
struct DetailView: View {
1617
@State var inputText: String = ""
1718
@FocusState private var isFocused: Bool
19+
@State private var showsModelSelectionSheet = false
20+
@State private var selectedChatModel: Model = .gpt3_5Turbo
21+
22+
private let availableChatModels: [Model] = [.gpt3_5Turbo, .gpt4]
1823

1924
let conversation: Conversation
2025
let error: Error?
21-
let sendMessage: (String) -> Void
26+
let sendMessage: (String, Model) -> Void
2227

2328
private var fillColor: Color {
2429
#if os(iOS)
@@ -61,6 +66,51 @@ struct DetailView: View {
6166
inputBar(scrollViewProxy: scrollViewProxy)
6267
}
6368
.navigationTitle("Chat")
69+
.safeAreaInset(edge: .top) {
70+
HStack {
71+
Text(
72+
"Model: \(selectedChatModel)"
73+
)
74+
.font(.caption)
75+
.foregroundColor(.secondary)
76+
Spacer()
77+
}
78+
.padding(.horizontal, 16)
79+
.padding(.vertical, 8)
80+
}
81+
.toolbar {
82+
ToolbarItem(placement: .navigationBarTrailing) {
83+
Button(action: {
84+
showsModelSelectionSheet.toggle()
85+
}) {
86+
Image(systemName: "cpu")
87+
}
88+
}
89+
}
90+
.confirmationDialog(
91+
"Select model",
92+
isPresented: $showsModelSelectionSheet,
93+
titleVisibility: .visible,
94+
actions: {
95+
ForEach(availableChatModels, id: \.self) { model in
96+
Button {
97+
selectedChatModel = model
98+
} label: {
99+
Text(model)
100+
}
101+
}
102+
103+
Button("Cancel", role: .cancel) {
104+
showsModelSelectionSheet = false
105+
}
106+
},
107+
message: {
108+
Text(
109+
"View https://platform.openai.com/docs/models/overview for details"
110+
)
111+
.font(.caption)
112+
}
113+
)
64114
}
65115
}
66116
}
@@ -133,7 +183,7 @@ struct DetailView: View {
133183
private func tapSendMessage(
134184
scrollViewProxy: ScrollViewProxy
135185
) {
136-
sendMessage(inputText)
186+
sendMessage(inputText, selectedChatModel)
137187
inputText = ""
138188

139189
// if let lastMessage = conversation.messages.last {
@@ -206,7 +256,7 @@ struct DetailView_Previews: PreviewProvider {
206256
]
207257
),
208258
error: nil,
209-
sendMessage: { _ in }
259+
sendMessage: { _, _ in }
210260
)
211261
}
212262
}

Demo/DemoChat/Sources/UI/ModerationChatView.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public struct ModerationChatView: View {
2121
DetailView(
2222
conversation: store.moderationConversation,
2323
error: store.moderationConversationError,
24-
sendMessage: { message in
24+
sendMessage: { message, _ in
2525
Task {
2626
await store.sendModerationMessage(
2727
Message(

README.md

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ This repository contains Swift community-maintained implementation over [OpenAI]
1616
- [Usage](#usage)
1717
- [Initialization](#initialization)
1818
- [Completions](#completions)
19+
- [Completions Streaming](#completions-streaming)
1920
- [Chats](#chats)
21+
- [Chats Streaming](#chats-streaming)
2022
- [Images](#images)
2123
- [Audio](#audio)
2224
- [Audio Transcriptions](#audio-transcriptions)
@@ -146,6 +148,43 @@ let result = try await openAI.completions(query: query)
146148
- index : 0
147149
```
148150

151+
#### Completions Streaming
152+
153+
Completions streaming is available by using `completionsStream` function. Tokens will be sent one-by-one.
154+
155+
**Closures**
156+
```swift
157+
openAI.completionsStream(query: query) { partialResult in
158+
switch partialResult {
159+
case .success(let result):
160+
print(result.choices)
161+
case .failure(let error):
162+
//Handle chunk error here
163+
}
164+
} completion: { error in
165+
//Handle streaming error here
166+
}
167+
```
168+
169+
**Combine**
170+
171+
```swift
172+
openAI
173+
.completionsStream(query: query)
174+
.sink { completion in
175+
//Handle completion result here
176+
} receiveValue: { result in
177+
//Handle chunk here
178+
}.store(in: &cancellables)
179+
```
180+
181+
**Structured concurrency**
182+
```swift
183+
for try await result in openAI.completionsStream(query: query) {
184+
//Handle result here
185+
}
186+
```
187+
149188
Review [Completions Documentation](https://platform.openai.com/docs/api-reference/completions) for more info.
150189

151190
### Chats
@@ -175,8 +214,6 @@ Using the OpenAI Chat API, you can build your own applications with `gpt-3.5-tur
175214
public let topP: Double?
176215
/// How many chat completion choices to generate for each input message.
177216
public let n: Int?
178-
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only `server-sent events` as they become available, with the stream terminated by a data: [DONE] message.
179-
public let stream: Bool?
180217
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
181218
public let stop: [String]?
182219
/// The maximum number of tokens to generate in the completion.
@@ -244,6 +281,43 @@ let result = try await openAI.chats(query: query)
244281
- total_tokens : 49
245282
```
246283

284+
#### Chats Streaming
285+
286+
Chats streaming is available by using `chatStream` function. Tokens will be sent one-by-one.
287+
288+
**Closures**
289+
```swift
290+
openAI.chatsStream(query: query) { partialResult in
291+
switch partialResult {
292+
case .success(let result):
293+
print(result.choices)
294+
case .failure(let error):
295+
//Handle chunk error here
296+
}
297+
} completion: { error in
298+
//Handle streaming error here
299+
}
300+
```
301+
302+
**Combine**
303+
304+
```swift
305+
openAI
306+
.chatsStream(query: query)
307+
.sink { completion in
308+
//Handle completion result here
309+
} receiveValue: { result in
310+
//Handle chunk here
311+
}.store(in: &cancellables)
312+
```
313+
314+
**Structured concurrency**
315+
```swift
316+
for try await result in openAI.chatsStream(query: query) {
317+
//Handle result here
318+
}
319+
```
320+
247321
Review [Chat Documentation](https://platform.openai.com/docs/guides/chat) for more info.
248322

249323
### Images

Sources/OpenAI/OpenAI.swift

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ final public class OpenAI: OpenAIProtocol {
3535
}
3636

3737
private let session: URLSessionProtocol
38+
private var streamingSessions: [NSObject] = []
3839

3940
public let configuration: Configuration
4041

@@ -59,6 +60,10 @@ final public class OpenAI: OpenAIProtocol {
5960
performRequest(request: JSONRequest<CompletionsResult>(body: query, url: buildURL(path: .completions)), completion: completion)
6061
}
6162

63+
public func completionsStream(query: CompletionsQuery, onResult: @escaping (Result<CompletionsResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
64+
performSteamingRequest(request: JSONRequest<CompletionsResult>(body: query.makeStreamable(), url: buildURL(path: .completions)), onResult: onResult, completion: completion)
65+
}
66+
6267
public func images(query: ImagesQuery, completion: @escaping (Result<ImagesResult, Error>) -> Void) {
6368
performRequest(request: JSONRequest<ImagesResult>(body: query, url: buildURL(path: .images)), completion: completion)
6469
}
@@ -71,6 +76,10 @@ final public class OpenAI: OpenAIProtocol {
7176
performRequest(request: JSONRequest<ChatResult>(body: query, url: buildURL(path: .chats)), completion: completion)
7277
}
7378

79+
public func chatsStream(query: ChatQuery, onResult: @escaping (Result<ChatStreamResult, Error>) -> Void, completion: ((Error?) -> Void)?) {
80+
performSteamingRequest(request: JSONRequest<ChatResult>(body: query.makeStreamable(), url: buildURL(path: .chats)), onResult: onResult, completion: completion)
81+
}
82+
7483
public func edits(query: EditsQuery, completion: @escaping (Result<EditsResult, Error>) -> Void) {
7584
performRequest(request: JSONRequest<EditsResult>(body: query, url: buildURL(path: .edits)), completion: completion)
7685
}
@@ -131,7 +140,27 @@ extension OpenAI {
131140
task.resume()
132141
} catch {
133142
completion(.failure(error))
134-
return
143+
}
144+
}
145+
146+
func performSteamingRequest<ResultType: Codable>(request: any URLRequestBuildable, onResult: @escaping (Result<ResultType, Error>) -> Void, completion: ((Error?) -> Void)?) {
147+
do {
148+
let request = try request.build(token: configuration.token, organizationIdentifier: configuration.organizationIdentifier, timeoutInterval: configuration.timeoutInterval)
149+
let session = StreamingSession<ResultType>(urlRequest: request)
150+
session.onReceiveContent = {_, object in
151+
onResult(.success(object))
152+
}
153+
session.onProcessingError = {_, error in
154+
onResult(.failure(error))
155+
}
156+
session.onComplete = { [weak self] object, error in
157+
self?.streamingSessions.removeAll(where: { $0 == object })
158+
completion?(error)
159+
}
160+
session.perform()
161+
streamingSessions.append(session)
162+
} catch {
163+
completion?(error)
135164
}
136165
}
137166
}

0 commit comments

Comments
 (0)