Skip to content

Commit

Permalink
RunAsync in C# (#16890)
Browse files Browse the repository at this point in the history
Implement c# binding for RunAsync.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Aug 8, 2023
1 parent 249917a commit 063e905
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 4 deletions.
130 changes: 128 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.NetworkInformation;
using System.Runtime.InteropServices;
using System.Threading.Tasks;

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -604,7 +606,7 @@ public IDisposableReadOnlyCollection<OrtValue> Run(RunOptions runOptions, IReadO
throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count}).");
}

var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata);
var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata);
var inputHandlesArray = inputValues.Select(v => v.Handle).ToArray();

var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata);
Expand Down Expand Up @@ -636,7 +638,7 @@ public IDisposableReadOnlyCollection<OrtValue> Run(RunOptions runOptions, IReadO
IntPtr[] inputHandlesArray = new IntPtr[inputs.Count];

int count = 0;
foreach(var input in inputs)
foreach (var input in inputs)
{
inputNamesArray[count] = LookupInputMetadata(input.Key).ZeroTerminatedName;
inputHandlesArray[count] = input.Value.Handle;
Expand Down Expand Up @@ -1044,6 +1046,130 @@ public ulong ProfilingStartTimeNs
}
}

private static void OrtCallback(IntPtr userData, IntPtr[] ouputs, uint numOutputs, IntPtr status)
{
var hostHdl = GCHandle.FromIntPtr(userData);
CallbackHost host = (CallbackHost)hostHdl.Target;
try
{
host.callback(host.outputValues, status);
}
finally
{
hostHdl.Free();
}
}

private delegate void OrtCallbackDelegate(IntPtr userData, IntPtr[] outputs, uint numOutputs, IntPtr status);

private static OrtCallbackDelegate ortCallback = new OrtCallbackDelegate(OrtCallback);

private delegate void UserCallbackDelegate(IReadOnlyCollection<OrtValue> outputs, IntPtr status);

private class CallbackHost
{
public IReadOnlyCollection<string> inputNames { get; }
public IReadOnlyCollection<OrtValue> inputValues { get; }
public IReadOnlyCollection<string> outputNames { get; }
public IReadOnlyCollection<OrtValue> outputValues { get; }
public UserCallbackDelegate callback { get; }

public IntPtr[] rawInputNames { get; }
public IntPtr[] rawInputValues { get; }
public IntPtr[] rawOutputNames { get; }
public IntPtr[] rawOutputValues { get; }

public CallbackHost(InferenceSession session,
IReadOnlyCollection<string> cbInputNames,
IReadOnlyCollection<OrtValue> cbinputValues,
IReadOnlyCollection<string> cbOutputNames,
IReadOnlyCollection<OrtValue> cbOutputValues,
UserCallbackDelegate userCallback)
{

inputNames = cbInputNames;
inputValues = cbinputValues;
outputNames = cbOutputNames;
outputValues = cbOutputValues;
callback = userCallback;

rawInputNames = LookupUtf8Names(inputNames, n => n, session.LookupInputMetadata);
rawInputValues = inputValues.Select(v => v.Handle).ToArray();

rawOutputNames = LookupUtf8Names(outputNames, n => n, session.LookupOutputMetadata);
rawOutputValues = outputValues.Select(v => v.Handle).ToArray();
}
}

private void RunAsyncInternal(RunOptions options,
IReadOnlyCollection<string> inputNames,
IReadOnlyCollection<OrtValue> inputValues,
IReadOnlyCollection<string> outputNames,
IReadOnlyCollection<OrtValue> outputValues,
UserCallbackDelegate callback)
{
CallbackHost host = new CallbackHost(this, inputNames, inputValues, outputNames, outputValues, callback);
var host_hdl = GCHandle.Alloc(host, GCHandleType.Normal);

try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunAsync(
_nativeHandle,
options == null ? (IntPtr)null : options.Handle,
host.rawInputNames,
host.rawInputValues,
(UIntPtr)host.rawInputNames.Length,
host.rawOutputNames,
(UIntPtr)host.rawOutputNames.Length,
host.rawOutputValues,
Marshal.GetFunctionPointerForDelegate(ortCallback),
GCHandle.ToIntPtr(host_hdl)
));
}
catch (OnnxRuntimeException)
{
host_hdl.Free();
throw;
}
}

/// <summary>
/// Run inference asynchronous in a thread of intra-op thread pool
/// </summary>
/// <param name="options">run option, can be null</param>
/// <param name="inputNames">name of inputs</param>
/// <param name="inputValues">input ort values</param>
/// <param name="outputNames">name of outputs</param>
/// <param name="outputValues">output of ort values</param>
/// <returns>task to be awaited</returns>
/// <exception cref="OnnxRuntimeException"></exception>
public async Task<IReadOnlyCollection<OrtValue>> RunAsync(RunOptions options,
IReadOnlyCollection<string> inputNames,
IReadOnlyCollection<OrtValue> inputValues,
IReadOnlyCollection<string> outputNames,
IReadOnlyCollection<OrtValue> outputValues)
{
var promise = new TaskCompletionSource<IReadOnlyCollection<OrtValue>>();
RunAsyncInternal(options,
inputNames,
inputValues,
outputNames,
outputValues,
(IReadOnlyCollection<OrtValue> outputs, IntPtr status) =>
{
try
{
NativeApiStatus.VerifySuccess(status);
promise.SetResult(outputs);
}
catch (Exception ex)
{
promise.SetException(ex);
}
});
return await promise.Task;
}

#endregion

#region private methods
Expand Down
30 changes: 30 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ public struct OrtApi
public IntPtr UpdateROCMProviderOptions;
public IntPtr GetROCMProviderOptionsAsString;
public IntPtr ReleaseROCMProviderOptions;
public IntPtr CreateAndRegisterAllocatorV2;
public IntPtr RunAsync;
}

internal static class NativeMethods
Expand Down Expand Up @@ -510,6 +512,8 @@ static NativeMethods()
OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions));
OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString));
OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions));
OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2));
OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync));
}

internal class NativeLib
Expand Down Expand Up @@ -916,6 +920,32 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
out UIntPtr /*(ulong* out)*/ startTime);
public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2(
IntPtr /* (OrtEnv*) */ environment,
IntPtr /*(char*)*/ provider_type,
IntPtr /*(OrtMemoryInfo*)*/ mem_info,
IntPtr /*(OrtArenaCfg*)*/ arena_cfg,
IntPtr /*(char**)*/ provider_options_keys,
IntPtr /*(char**)*/ provider_options_values,
UIntPtr /*(size_t)*/num_keys);
public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync(
IntPtr /*(OrtSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
IntPtr[] /*(char**)*/ inputNames,
IntPtr[] /*(OrtValue*[])*/ inputValues,
UIntPtr /*(size_t)*/ inputCount,
IntPtr[] /*(char**)*/ outputNames,
UIntPtr /*(size_t)*/ outputCount,
IntPtr[] /*(OrtValue*[])*/ outputValues,
IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function
IntPtr /*(void*)*/ user_data
);
public static DOrtRunAsync OrtRunAsync;

#endregion InferenceSession API

#region SessionOptions API
Expand Down
79 changes: 77 additions & 2 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -476,7 +479,7 @@ public void RunInferenceUsingPreAllocatedOutputsAndDictionary()
TensorElementType.Float, expectedShape))
{
// Run inference
var inputValues = new List<OrtValue>{ inputOrtValue }.AsReadOnly();
var inputValues = new List<OrtValue> { inputOrtValue }.AsReadOnly();
var outputValues = new List<OrtValue> { outputOrtValue }.AsReadOnly();
session.Run(runOptions, inputNames, inputValues,
expectedOutputNames, outputValues);
Expand Down Expand Up @@ -1342,7 +1345,7 @@ private void TestModelInputFLOAT16()
[Fact(DisplayName = "TestModelInputBFLOAT16")]
private void TestModelInputBFLOAT16()
{
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384),
new BFloat16(16448), new BFloat16(16512), new BFloat16(16544) };
int[] inputShape = { 1, 5 };
// model takes 1x5 input of fixed type, echoes back
Expand Down Expand Up @@ -2025,6 +2028,78 @@ public SkipNonPackageTests()
}
}
}

[Fact(DisplayName = "TestModelRunAsyncTask")]
private async void TestModelRunAsyncTask()
{
Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
long[] shape = { 1, 5 };

var inputNames = new List<string> { "input" };
var inputValues = new List<OrtValue> { OrtValue.CreateTensorValueFromMemory(inputData, shape) };

var outputNames = new List<string> { "output" };
var outputValues = new List<OrtValue> { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance,
TensorElementType.Float16, shape) };

var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
using (SessionOptions opt = new SessionOptions())
{
opt.IntraOpNumThreads = 2;
using (var session = new InferenceSession(model, opt))
{
try
{
var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues);
var outputs = await task;
var valueOut = outputs.ElementAt<OrtValue>(0);
var float16s = valueOut.GetTensorDataAsSpan<Float16>().ToArray();
Assert.Equal(new Float16(16896), float16s[2]);
}
catch
{
Assert.True(false);
}
}
}
}

[Fact(DisplayName = "TestModelRunAsyncTaskFail")]
private async void TestModelRunAsyncTaskFail()
{
Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) };
long[] shape = { 1, 5 };

var inputNames = new List<string> { "input" };
var inputValues = new List<OrtValue> { OrtValue.CreateTensorValueFromMemory(inputData, shape) };

var outputNames = new List<string> { "output" };
var outputValues = new List<OrtValue> { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance,
TensorElementType.Float16, shape) };

var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx");
using (SessionOptions opt = new SessionOptions())
{
opt.IntraOpNumThreads = 1; // this will make RunAsync fail
string err = "";
using (var session = new InferenceSession(model, opt))
{
try
{
var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues);
var outputs = await task;
}
catch (Exception ex)
{
err = ex.Message;
}
finally
{
Assert.Contains("intra op thread pool must have at least one thread for RunAsync", err);
}
}
}
}
}

}

0 comments on commit 063e905

Please sign in to comment.