Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix IAsyncEnumerable controller methods to allow setting headers #57924

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.Http.Json;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;

Expand Down Expand Up @@ -91,7 +92,9 @@ public static Task WriteAsJsonAsync<TValue>(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(typeof(TValue)))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -132,7 +135,9 @@ public static Task WriteAsJsonAsync<TValue>(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(typeof(TValue)))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -185,7 +190,9 @@ public static Task WriteAsJsonAsync(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && value is not null && !AsyncEnumerableHelper.IsIAsyncEnumerable(value.GetType()))
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -305,7 +312,9 @@ public static Task WriteAsJsonAsync(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(type))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down Expand Up @@ -368,7 +377,9 @@ public static Task WriteAsJsonAsync(
response.ContentType = contentType ?? ContentTypeConstants.JsonContentTypeWithCharset;

var startTask = Task.CompletedTask;
if (!response.HasStarted)
// Don't call StartAsync for IAsyncEnumerable. Headers might be set at the beginning of the generator which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!response.HasStarted && !AsyncEnumerableHelper.IsIAsyncEnumerable(type))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
startTask = response.StartAsync(cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
<Compile Remove="$(RepoRoot)src\Components\Endpoints\src\FormMapping\HttpContextFormDataProvider.cs" LinkBase="SharedFormMapping" />
<Compile Remove="$(RepoRoot)src\Components\Endpoints\src\FormMapping\BrowserFileFromFormFile.cs" LinkBase="SharedFormMapping" />
<Compile Include="$(SharedSourceRoot)ContentTypeConstants.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Reflection\AsyncEnumerableHelper.cs" LinkBase="Shared" />
</ItemGroup>

<ItemGroup>
Expand Down
117 changes: 116 additions & 1 deletion src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.AspNetCore.Http.Features;

#nullable enable

Expand Down Expand Up @@ -481,6 +482,83 @@ public async Task WriteAsJsonAsync_NullValue_WithJsonTypeInfo_JsonResponse()
Assert.Equal("null", data);
}

[Fact]
public async Task WriteAsJsonAsyncGeneric_AsyncEnumerableStartAsyncNotCalled()
{
// Arrange
var body = new MemoryStream();
var context = new DefaultHttpContext();
context.Response.Body = body;
var responseBodyFeature = new TestHttpResponseBodyFeature(context.Features.GetRequiredFeature<IHttpResponseBodyFeature>());
context.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

// Act
await context.Response.WriteAsJsonAsync(AsyncEnumerable());

// Assert
Assert.Equal(ContentTypeConstants.JsonContentTypeWithCharset, context.Response.ContentType);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);

Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));

async IAsyncEnumerable<int> AsyncEnumerable()
{
Assert.False(responseBodyFeature.StartCalled);
await Task.Yield();
BrennanConroy marked this conversation as resolved.
Show resolved Hide resolved
yield return 1;
yield return 2;
}
}

[Fact]
public async Task WriteAsJsonAsync_AsyncEnumerableStartAsyncNotCalled()
{
// Arrange
var body = new MemoryStream();
var context = new DefaultHttpContext();
context.Response.Body = body;
var responseBodyFeature = new TestHttpResponseBodyFeature(context.Features.GetRequiredFeature<IHttpResponseBodyFeature>());
context.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

// Act
await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>));

// Assert
Assert.Equal(ContentTypeConstants.JsonContentTypeWithCharset, context.Response.ContentType);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);

Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));

async IAsyncEnumerable<int> AsyncEnumerable()
{
Assert.False(responseBodyFeature.StartCalled);
await Task.Yield();
yield return 1;
yield return 2;
}
}

[Fact]
public async Task WriteAsJsonAsync_StartAsyncCalled()
{
// Arrange
var body = new MemoryStream();
var context = new DefaultHttpContext();
context.Response.Body = body;
var responseBodyFeature = new TestHttpResponseBodyFeature(context.Features.GetRequiredFeature<IHttpResponseBodyFeature>());
context.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

// Act
await context.Response.WriteAsJsonAsync(new int[] {1, 2}, typeof(int[]));

// Assert
Assert.Equal(ContentTypeConstants.JsonContentTypeWithCharset, context.Response.ContentType);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);

Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));
Assert.True(responseBodyFeature.StartCalled);
}

public class TestObject
{
public string? StringProperty { get; set; }
Expand Down Expand Up @@ -530,4 +608,41 @@ public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationTo
return new ValueTask(tcs.Task);
}
}

public class TestHttpResponseBodyFeature : IHttpResponseBodyFeature
{
private readonly IHttpResponseBodyFeature _inner;

public bool StartCalled;

public TestHttpResponseBodyFeature(IHttpResponseBodyFeature inner)
{
_inner = inner;
}

public Stream Stream => _inner.Stream;

public PipeWriter Writer => _inner.Writer;

public Task CompleteAsync()
{
return _inner.CompleteAsync();
}

public void DisableBuffering()
{
_inner.DisableBuffering();
}

public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
{
return _inner.SendFileAsync(path, offset, count, cancellationToken);
}

public Task StartAsync(CancellationToken cancellationToken = default)
{
StartCalled = true;
return _inner.StartAsync(cancellationToken);
}
}
}
12 changes: 10 additions & 2 deletions src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Internal;

namespace Microsoft.AspNetCore.Mvc.Formatters;

Expand Down Expand Up @@ -88,10 +89,17 @@ public sealed override async Task WriteResponseBodyAsync(OutputFormatterWriteCon
try
{
var responseWriter = httpContext.Response.BodyWriter;

if (!httpContext.Response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await httpContext.Response.StartAsync();
var typeToCheck = context.ObjectType ?? context.Object?.GetType();
// Don't call StartAsync for IAsyncEnumerable methods. Headers might be set in the controller method which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (typeToCheck is not null && !AsyncEnumerableHelper.IsIAsyncEnumerable(typeToCheck))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await httpContext.Response.StartAsync();
}
}

if (jsonTypeInfo is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,13 @@ public async Task ExecuteAsync(ActionContext context, JsonResult result)
var responseWriter = response.BodyWriter;
if (!response.HasStarted)
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await response.StartAsync();
// Don't call StartAsync for IAsyncEnumerable methods. Headers might be set in the controller method which isn't invoked until
// JsonSerializer starts iterating over the IAsyncEnumerable.
if (!AsyncEnumerableHelper.IsIAsyncEnumerable(objectType))
{
// Flush headers before starting Json serialization. This avoids an extra layer of buffering before the first flush.
await response.StartAsync();
}
}

await JsonSerializer.SerializeAsync(responseWriter, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted);
Expand Down
1 change: 1 addition & 0 deletions src/Mvc/Mvc.Core/src/Microsoft.AspNetCore.Mvc.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Microsoft.AspNetCore.Mvc.RouteAttribute</Description>
<Compile Include="$(SharedSourceRoot)HttpParseResult.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)HttpRuleParser.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Json\JsonSerializerExtensions.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)Reflection\AsyncEnumerableHelper.cs" LinkBase="Shared" />
</ItemGroup>

<ItemGroup>
Expand Down
117 changes: 117 additions & 0 deletions src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.InternalTesting;
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.Extensions.Primitives;
Expand Down Expand Up @@ -113,6 +115,121 @@ public async Task WriteResponseBodyAsync_ForLargeAsyncEnumerable()
Assert.Equal(expected.ToArray(), body.ToArray());
}

// Regression test: https://github.com/dotnet/aspnetcore/issues/57895
[Fact]
public async Task WriteResponseBodyAsync_AsyncEnumerableStartAsyncNotCalled()
{
// Arrange
TestHttpResponseBodyFeature responseBodyFeature = null;
var expected = new MemoryStream();
await JsonSerializer.SerializeAsync(expected, AsyncEnumerable(), new JsonSerializerOptions(JsonSerializerDefaults.Web));
var formatter = GetOutputFormatter();
var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
var encoding = CreateOrGetSupportedEncoding(formatter, "utf-8", isDefaultEncoding: true);

var body = new MemoryStream();

var actionContext = GetActionContext(mediaType, body);
responseBodyFeature = new TestHttpResponseBodyFeature(actionContext.HttpContext.Features.Get<IHttpResponseBodyFeature>());
actionContext.HttpContext.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

var asyncEnumerable = AsyncEnumerable();
var outputFormatterContext = new OutputFormatterWriteContext(
actionContext.HttpContext,
new TestHttpResponseStreamWriterFactory().CreateWriter,
asyncEnumerable.GetType(),
asyncEnumerable)
{
ContentType = new StringSegment(mediaType.ToString()),
};

// Act
await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));

// Assert
Assert.Equal(expected.ToArray(), body.ToArray());

async IAsyncEnumerable<int> AsyncEnumerable()
{
// StartAsync shouldn't be called by SystemTestJsonOutputFormatter when using IAsyncEnumerable
// This allows Controller methods to set Headers, etc.
Assert.False(responseBodyFeature?.StartCalled ?? false);
await Task.Yield();
yield return 1;
}
}

[Fact]
public async Task WriteResponseBodyAsync_StartAsyncCalled()
{
// Arrange
TestHttpResponseBodyFeature responseBodyFeature = null;
var expected = new MemoryStream();
await JsonSerializer.SerializeAsync(expected, 1, new JsonSerializerOptions(JsonSerializerDefaults.Web));
var formatter = GetOutputFormatter();
var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
var encoding = CreateOrGetSupportedEncoding(formatter, "utf-8", isDefaultEncoding: true);

var body = new MemoryStream();

var actionContext = GetActionContext(mediaType, body);
responseBodyFeature = new TestHttpResponseBodyFeature(actionContext.HttpContext.Features.Get<IHttpResponseBodyFeature>());
actionContext.HttpContext.Features.Set<IHttpResponseBodyFeature>(responseBodyFeature);

var outputFormatterContext = new OutputFormatterWriteContext(
actionContext.HttpContext,
new TestHttpResponseStreamWriterFactory().CreateWriter,
typeof(int),
1)
{
ContentType = new StringSegment(mediaType.ToString()),
};

// Act
await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));

// Assert
Assert.Equal(expected.ToArray(), body.ToArray());
Assert.True(responseBodyFeature.StartCalled);
}

public class TestHttpResponseBodyFeature : IHttpResponseBodyFeature
{
private readonly IHttpResponseBodyFeature _inner;

public bool StartCalled;

public TestHttpResponseBodyFeature(IHttpResponseBodyFeature inner)
{
_inner = inner;
}

public Stream Stream => _inner.Stream;

public PipeWriter Writer => _inner.Writer;

public Task CompleteAsync()
{
return _inner.CompleteAsync();
}

public void DisableBuffering()
{
_inner.DisableBuffering();
}

public Task SendFileAsync(string path, long offset, long? count, CancellationToken cancellationToken = default)
{
return _inner.SendFileAsync(path, offset, count, cancellationToken);
}

public Task StartAsync(CancellationToken cancellationToken = default)
{
StartCalled = true;
return _inner.StartAsync(cancellationToken);
}
}

[Fact]
public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses()
{
Expand Down
Loading
Loading