Skip to content

Commit

Permalink
Add Generator to wrap some async calls in Task.Run().Result (#896)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnaCoda authored Jul 13, 2023
1 parent 549b42c commit f73f16c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ public AsyncToSyncMethodTransformer(
CancellationToken token
) : base( model, token ) { }

// Need to disable D2L0018 in the method if we add Task.Run() to syncify something
private bool m_disableTaskRunWarningFlag;

public TransformResult<MethodDeclarationSyntax> Transform( MethodDeclarationSyntax decl ) {
// TODO: remove CancellationToken parameters
decl = decl.WithAttributeLists( ReplaceGenerateSyncAttribute( decl.AttributeLists ) )
Expand All @@ -19,6 +22,18 @@ public TransformResult<MethodDeclarationSyntax> Transform( MethodDeclarationSynt
.WithReturnType( TransformType( decl.ReturnType, isReturnType: true ) )
.WithExpressionBody( MaybeTransform( decl.ExpressionBody, Transform ) )
.WithBody( MaybeTransform( decl.Body, Transform ) );

if ( m_disableTaskRunWarningFlag ) {
PragmaWarningDirectiveTriviaSyntax restorePragma = SyntaxFactory.PragmaWarningDirectiveTrivia( SyntaxFactory.Token( SyntaxKind.RestoreKeyword ), true )
.AddErrorCodes( SyntaxFactory.IdentifierName( "D2L0018" ) ).NormalizeWhitespace().WithLeadingTrivia( SyntaxFactory.SyntaxTrivia( SyntaxKind.EndOfLineTrivia, "\n" ) ); ;
PragmaWarningDirectiveTriviaSyntax disablePragma = SyntaxFactory.PragmaWarningDirectiveTrivia( SyntaxFactory.Token( SyntaxKind.DisableKeyword ), true )
.AddErrorCodes( SyntaxFactory.IdentifierName( "D2L0018" ) ).NormalizeWhitespace();
decl = decl
.WithLeadingTrivia( decl.GetLeadingTrivia().Add( SyntaxFactory.Trivia( disablePragma ) ) )
.WithTrailingTrivia( decl.GetTrailingTrivia().Insert( 0, SyntaxFactory.Trivia( restorePragma ) ) );
m_disableTaskRunWarningFlag = false;
}

return GetResult( decl );
}

Expand Down Expand Up @@ -272,19 +287,30 @@ private ExpressionSyntax Transform( InvocationExpressionSyntax invocationExpr) {
newExpr = memberAccess.Expression;
return Transform( newExpr );
}
} else if( memberAccess is not null && ShouldWrapMemberAccessInTaskRun( memberAccess ) ) {
m_disableTaskRunWarningFlag = true;
return SyntaxFactory.ParseExpression( $"Task.Run(() => {invocationExpr}).Result" );
}

return invocationExpr
.WithExpression( Transform( invocationExpr.Expression ) )
.WithArgumentList( TransformAll( invocationExpr.ArgumentList, Transform ) );
}

// TODO: These two methods may need future modification for more specificity (make sure it's Task.FromResult or Content.ReadAsStringAsync)
bool ShouldRemoveReturnedMemberAccess( MemberAccessExpressionSyntax memberAccessExpr )
=> memberAccessExpr.Name.Identifier.ValueText switch {
"FromResult" => true,
"CompletedTask" => true,
_ => false
};

bool ShouldWrapMemberAccessInTaskRun( MemberAccessExpressionSyntax memberAccessExpr )
=> memberAccessExpr.Name.Identifier.ValueText switch {
"ReadAsStringAsync" => true,
_ => false
};

private ExpressionSyntax Transform( MemberAccessExpressionSyntax memberAccessExpr ) {
if( memberAccessExpr.IsKind( SyntaxKind.SimpleMemberAccessExpression ) ) {
if( ShouldRemoveReturnedMemberAccess( memberAccessExpr ) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ public void SimpleLambda() {
Assert.AreEqual( "[Blocking] void Bar() { Func<int, Task> baz = quux => fred.Delay( 2*y ); }", actual.Value.ToFullString() );
}

[Test]
public void WrapInTaskRun() {
var actual = Transform( @"[GenerateSync] async Task BarAsync() { string baz = await response.Content.ReadAsStringAsync(); }" );

Assert.IsTrue( actual.Success );
Assert.IsEmpty( actual.Diagnostics );
Assert.AreEqual( "#pragma warning disable D2L0018\r\n[Blocking] void Bar() { string baz = Task.Run(() => response.Content.ReadAsStringAsync()).Result; }\n#pragma warning restore D2L0018\r\n", actual.Value.ToFullString() );
}

[Test]
public void TryCatch() {
var actual = Transform( @"[GenerateSync]
Expand Down

0 comments on commit f73f16c

Please sign in to comment.