Skip to content

Commit 60ab8be

Browse files
authored
Merge pull request #600 from betalgo/feature/Improve-Assistant-streaming-methods
Improved Assistant streaming methods
2 parents b25961b + 64e376b commit 60ab8be

File tree

7 files changed

+216
-28
lines changed

7 files changed

+216
-28
lines changed

OpenAI.Playground/TestHelpers/AssistantHelpers/RunTestHelper.cs

Lines changed: 127 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -165,19 +165,58 @@ public static async Task CreateRunAsStreamTest(IOpenAIService openAI)
165165
var result = openAI.Beta.Runs.RunCreateAsStream(CreatedThreadId, new()
166166
{
167167
AssistantId = assistantResult.Id
168-
});
168+
},justDataMode:false);
169169

170170
await foreach (var run in result)
171171
{
172172
if (run.Successful)
173173
{
174-
if (string.IsNullOrEmpty(run.Status))
174+
Console.WriteLine($"Event:{run.StreamEvent}");
175+
if (run is RunResponse runResponse)
175176
{
176-
Console.Write(".");
177+
if (string.IsNullOrEmpty(runResponse.Status))
178+
{
179+
Console.Write(".");
180+
}
181+
else
182+
{
183+
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
184+
}
185+
}
186+
187+
else if (run is RunStepResponse runStepResponse)
188+
{
189+
if (string.IsNullOrEmpty(runStepResponse.Status))
190+
{
191+
Console.Write(".");
192+
}
193+
else
194+
{
195+
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
196+
}
197+
}
198+
199+
else if (run is MessageResponse messageResponse)
200+
{
201+
if (string.IsNullOrEmpty(messageResponse.Id))
202+
{
203+
Console.Write(".");
204+
}
205+
else
206+
{
207+
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
208+
}
177209
}
178210
else
179211
{
180-
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
212+
if (run.StreamEvent!=null)
213+
{
214+
Console.WriteLine(run.StreamEvent);
215+
}
216+
else
217+
{
218+
Console.Write(".");
219+
}
181220
}
182221
}
183222
else
@@ -450,13 +489,52 @@ public static async Task SubmitToolOutputsAsStreamToRunTest(IOpenAIService openA
450489
{
451490
if (run.Successful)
452491
{
453-
if (string.IsNullOrEmpty(run.Status))
492+
Console.WriteLine($"Event:{run.StreamEvent}");
493+
if (run is RunResponse runResponse)
454494
{
455-
Console.Write(".");
495+
if (string.IsNullOrEmpty(runResponse.Status))
496+
{
497+
Console.Write(".");
498+
}
499+
else
500+
{
501+
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
502+
}
503+
}
504+
505+
else if (run is RunStepResponse runStepResponse)
506+
{
507+
if (string.IsNullOrEmpty(runStepResponse.Status))
508+
{
509+
Console.Write(".");
510+
}
511+
else
512+
{
513+
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
514+
}
515+
}
516+
517+
else if (run is MessageResponse messageResponse)
518+
{
519+
if (string.IsNullOrEmpty(messageResponse.Id))
520+
{
521+
Console.Write(".");
522+
}
523+
else
524+
{
525+
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
526+
}
456527
}
457528
else
458529
{
459-
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
530+
if (run.StreamEvent != null)
531+
{
532+
Console.WriteLine(run.StreamEvent);
533+
}
534+
else
535+
{
536+
Console.Write(".");
537+
}
460538
}
461539
}
462540
else
@@ -642,13 +720,52 @@ public static async Task CreateThreadAndRunAsStream(IOpenAIService sdk)
642720
{
643721
if (run.Successful)
644722
{
645-
if (string.IsNullOrEmpty(run.Status))
723+
Console.WriteLine($"Event:{run.StreamEvent}");
724+
if (run is RunResponse runResponse)
646725
{
647-
Console.Write(".");
726+
if (string.IsNullOrEmpty(runResponse.Status))
727+
{
728+
Console.Write(".");
729+
}
730+
else
731+
{
732+
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
733+
}
734+
}
735+
736+
else if (run is RunStepResponse runStepResponse)
737+
{
738+
if (string.IsNullOrEmpty(runStepResponse.Status))
739+
{
740+
Console.Write(".");
741+
}
742+
else
743+
{
744+
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
745+
}
746+
}
747+
748+
else if (run is MessageResponse messageResponse)
749+
{
750+
if (string.IsNullOrEmpty(messageResponse.Id))
751+
{
752+
Console.Write(".");
753+
}
754+
else
755+
{
756+
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
757+
}
648758
}
649759
else
650760
{
651-
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
761+
if (run.StreamEvent != null)
762+
{
763+
Console.WriteLine(run.StreamEvent);
764+
}
765+
else
766+
{
767+
Console.Write(".");
768+
}
652769
}
653770
}
654771
else
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System.Text.Json;
2+
using OpenAI.ObjectModels.ResponseModels;
3+
using OpenAI.ObjectModels.SharedModels;
4+
5+
namespace OpenAI.Extensions;
6+
7+
public static class JsonToObjectRouterExtension
8+
{
9+
public static Type Route(string json)
10+
{
11+
var apiResponse = JsonSerializer.Deserialize<ObjectBaseResponse>(json);
12+
13+
return apiResponse?.ObjectTypeName switch
14+
{
15+
"thread.run.step" => typeof(RunStepResponse),
16+
"thread.run" => typeof(RunResponse),
17+
"thread.message" => typeof(MessageResponse),
18+
"thread.message.delta" => typeof(MessageResponse),
19+
_ => typeof(BaseResponse)
20+
};
21+
}
22+
}

OpenAI.SDK/Extensions/StreamHandleExtension.cs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Runtime.CompilerServices;
1+
using System.Collections.Generic;
2+
using System.Runtime.CompilerServices;
23
using System.Text.Json;
34
using OpenAI.ObjectModels;
45
using OpenAI.ObjectModels.RequestModels;
@@ -8,6 +9,10 @@ namespace OpenAI.Extensions;
89

910
public static class StreamHandleExtension
1011
{
12+
public static async IAsyncEnumerable<BaseResponse> AsStream(this HttpResponseMessage response, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
13+
{
14+
await foreach (var baseResponse in AsStream<BaseResponse>(response, justDataMode, cancellationToken)) yield return baseResponse;
15+
}
1116
public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpResponseMessage response, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
1217
where TResponse : BaseResponse, new()
1318
{
@@ -20,13 +25,15 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
2025

2126
await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
2227
using var reader = new StreamReader(stream);
23-
28+
string? tempStreamEvent = null;
29+
bool isEventDelta;
2430
// Continuously read the stream until the end of it
2531
while (true)
2632
{
2733
cancellationToken.ThrowIfCancellationRequested();
2834

2935
var line = await reader.ReadLineAsync();
36+
// Console.WriteLine("---" + line);
3037
// Break the loop if we have reached the end of the stream
3138
if (line == null)
3239
{
@@ -39,11 +46,28 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
3946
continue;
4047
}
4148

49+
if (line.StartsWith("event: "))
50+
{
51+
line = line.RemoveIfStartWith("event: ");
52+
tempStreamEvent = line;
53+
isEventDelta = true;
54+
}
55+
else
56+
{
57+
isEventDelta = false;
58+
}
59+
4260
if (justDataMode && !line.StartsWith("data: "))
4361
{
4462
continue;
4563
}
4664

65+
if (!justDataMode && isEventDelta )
66+
{
67+
yield return new(){ObjectTypeName = "base.stream.event",StreamEvent = tempStreamEvent};
68+
continue;
69+
}
70+
4771
line = line.RemoveIfStartWith("data: ");
4872

4973
// Exit the loop if the stream is done
@@ -56,7 +80,14 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
5680
try
5781
{
5882
// When the response is good, each line is a serializable CompletionCreateRequest
59-
block = JsonSerializer.Deserialize<TResponse>(line);
83+
if (typeof(TResponse) == typeof(BaseResponse))
84+
{
85+
block =JsonSerializer.Deserialize(line, JsonToObjectRouterExtension.Route(line), new JsonSerializerOptions()) as TResponse;
86+
}
87+
else
88+
{
89+
block = JsonSerializer.Deserialize<TResponse>(line);
90+
}
6091
}
6192
catch (Exception)
6293
{
@@ -78,6 +109,8 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
78109
{
79110
block.HttpStatusCode = httpStatusCode;
80111
block.HeaderValues = headerValues;
112+
block.StreamEvent = tempStreamEvent;
113+
tempStreamEvent = null;
81114
yield return block;
82115
}
83116
}

OpenAI.SDK/Interfaces/IRunService.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System.Runtime.CompilerServices;
22
using OpenAI.ObjectModels.RequestModels;
3+
using OpenAI.ObjectModels.ResponseModels;
34
using OpenAI.ObjectModels.SharedModels;
45

56
namespace OpenAI.Interfaces;
@@ -24,8 +25,8 @@ public interface IRunService
2425
/// <param name="modelId"></param>
2526
/// <param name="justDataMode"></param>
2627
/// <param name="cancellationToken"></param>
27-
/// <returns></returns>
28-
IAsyncEnumerable<RunResponse> RunCreateAsStream(string threadId, RunCreateRequest request, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
28+
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
29+
IAsyncEnumerable<BaseResponse> RunCreateAsStream(string threadId, RunCreateRequest request, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
2930

3031
/// <summary>
3132
/// Retrieves a run.
@@ -71,9 +72,10 @@ public interface IRunService
7172
/// <param name="threadId"></param>
7273
/// <param name="runId"></param>
7374
/// <param name="request"></param>
75+
/// <param name="justDataMode"></param>
7476
/// <param name="cancellationToken"></param>
75-
/// <returns></returns>
76-
IAsyncEnumerable<RunResponse> RunSubmitToolOutputsAsStream(string threadId, string runId, SubmitToolOutputsToRunRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);
77+
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
78+
IAsyncEnumerable<BaseResponse> RunSubmitToolOutputsAsStream(string threadId, string runId, SubmitToolOutputsToRunRequest request, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
7779

7880
/// <summary>
7981
/// Modifies a run.
@@ -93,7 +95,8 @@ public interface IRunService
9395
/// <summary>
9496
/// Create a thread and run it in one request as Stream.
9597
/// </summary>
96-
IAsyncEnumerable<RunResponse> CreateThreadAndRunAsStream(CreateThreadAndRunRequest createThreadAndRunRequest, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
98+
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
99+
IAsyncEnumerable<BaseResponse> CreateThreadAndRunAsStream(CreateThreadAndRunRequest createThreadAndRunRequest, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
97100

98101
/// <summary>
99102
/// Returns a list of runs belonging to a thread.

0 commit comments

Comments
 (0)