diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 93560f64530b1..c35c5f9111a4a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -6,7 +6,9 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Net.NetworkInformation; using System.Runtime.InteropServices; +using System.Threading.Tasks; namespace Microsoft.ML.OnnxRuntime { @@ -604,7 +606,7 @@ public IDisposableReadOnlyCollection Run(RunOptions runOptions, IReadO throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } - var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); + var inputNamesArray = LookupUtf8Names(inputNames, n => n, LookupInputMetadata); var inputHandlesArray = inputValues.Select(v => v.Handle).ToArray(); var outputNamesArray = LookupUtf8Names(outputNames, n => n, LookupOutputMetadata); @@ -636,7 +638,7 @@ public IDisposableReadOnlyCollection Run(RunOptions runOptions, IReadO IntPtr[] inputHandlesArray = new IntPtr[inputs.Count]; int count = 0; - foreach(var input in inputs) + foreach (var input in inputs) { inputNamesArray[count] = LookupInputMetadata(input.Key).ZeroTerminatedName; inputHandlesArray[count] = input.Value.Handle; @@ -1044,6 +1046,130 @@ public ulong ProfilingStartTimeNs } } + private static void OrtCallback(IntPtr userData, IntPtr[] ouputs, uint numOutputs, IntPtr status) + { + var hostHdl = GCHandle.FromIntPtr(userData); + CallbackHost host = (CallbackHost)hostHdl.Target; + try + { + host.callback(host.outputValues, status); + } + finally + { + hostHdl.Free(); + } + } + + private delegate void OrtCallbackDelegate(IntPtr userData, IntPtr[] outputs, uint numOutputs, IntPtr status); + + private static OrtCallbackDelegate ortCallback = new OrtCallbackDelegate(OrtCallback); + + private delegate void UserCallbackDelegate(IReadOnlyCollection outputs, IntPtr status); + + private class CallbackHost + { + public IReadOnlyCollection inputNames { get; } + public IReadOnlyCollection inputValues { get; } + public IReadOnlyCollection outputNames { get; } + public IReadOnlyCollection outputValues { get; } + public UserCallbackDelegate callback { get; } + + public IntPtr[] rawInputNames { get; } + public IntPtr[] rawInputValues { get; } + public IntPtr[] rawOutputNames { get; } + public IntPtr[] rawOutputValues { get; } + + public CallbackHost(InferenceSession session, + IReadOnlyCollection cbInputNames, + IReadOnlyCollection cbinputValues, + IReadOnlyCollection cbOutputNames, + IReadOnlyCollection cbOutputValues, + UserCallbackDelegate userCallback) + { + + inputNames = cbInputNames; + inputValues = cbinputValues; + outputNames = cbOutputNames; + outputValues = cbOutputValues; + callback = userCallback; + + rawInputNames = LookupUtf8Names(inputNames, n => n, session.LookupInputMetadata); + rawInputValues = inputValues.Select(v => v.Handle).ToArray(); + + rawOutputNames = LookupUtf8Names(outputNames, n => n, session.LookupOutputMetadata); + rawOutputValues = outputValues.Select(v => v.Handle).ToArray(); + } + } + + private void RunAsyncInternal(RunOptions options, + IReadOnlyCollection inputNames, + IReadOnlyCollection inputValues, + IReadOnlyCollection outputNames, + IReadOnlyCollection outputValues, + UserCallbackDelegate callback) + { + CallbackHost host = new CallbackHost(this, inputNames, inputValues, outputNames, outputValues, callback); + var host_hdl = GCHandle.Alloc(host, GCHandleType.Normal); + + try + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtRunAsync( + _nativeHandle, + options == null ? (IntPtr)null : options.Handle, + host.rawInputNames, + host.rawInputValues, + (UIntPtr)host.rawInputNames.Length, + host.rawOutputNames, + (UIntPtr)host.rawOutputNames.Length, + host.rawOutputValues, + Marshal.GetFunctionPointerForDelegate(ortCallback), + GCHandle.ToIntPtr(host_hdl) + )); + } + catch (OnnxRuntimeException) + { + host_hdl.Free(); + throw; + } + } + + /// + /// Run inference asynchronous in a thread of intra-op thread pool + /// + /// run option, can be null + /// name of inputs + /// input ort values + /// name of outputs + /// output of ort values + /// task to be awaited + /// + public async Task> RunAsync(RunOptions options, + IReadOnlyCollection inputNames, + IReadOnlyCollection inputValues, + IReadOnlyCollection outputNames, + IReadOnlyCollection outputValues) + { + var promise = new TaskCompletionSource>(); + RunAsyncInternal(options, + inputNames, + inputValues, + outputNames, + outputValues, + (IReadOnlyCollection outputs, IntPtr status) => + { + try + { + NativeApiStatus.VerifySuccess(status); + promise.SetResult(outputs); + } + catch (Exception ex) + { + promise.SetException(ex); + } + }); + return await promise.Task; + } + #endregion #region private methods diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 678be914ea4bc..2ba837be22041 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -292,6 +292,8 @@ public struct OrtApi public IntPtr UpdateROCMProviderOptions; public IntPtr GetROCMProviderOptionsAsString; public IntPtr ReleaseROCMProviderOptions; + public IntPtr CreateAndRegisterAllocatorV2; + public IntPtr RunAsync; } internal static class NativeMethods @@ -510,6 +512,8 @@ static NativeMethods() OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions)); OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString)); OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); + OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); + OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); } internal class NativeLib @@ -916,6 +920,32 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca out UIntPtr /*(ulong* out)*/ startTime); public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2( + IntPtr /* (OrtEnv*) */ environment, + IntPtr /*(char*)*/ provider_type, + IntPtr /*(OrtMemoryInfo*)*/ mem_info, + IntPtr /*(OrtArenaCfg*)*/ arena_cfg, + IntPtr /*(char**)*/ provider_options_keys, + IntPtr /*(char**)*/ provider_options_values, + UIntPtr /*(size_t)*/num_keys); + public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync( + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options + IntPtr[] /*(char**)*/ inputNames, + IntPtr[] /*(OrtValue*[])*/ inputValues, + UIntPtr /*(size_t)*/ inputCount, + IntPtr[] /*(char**)*/ outputNames, + UIntPtr /*(size_t)*/ outputCount, + IntPtr[] /*(OrtValue*[])*/ outputValues, + IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function + IntPtr /*(void*)*/ user_data + ); + public static DOrtRunAsync OrtRunAsync; + #endregion InferenceSession API #region SessionOptions API diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 5afd5638e9744..486fcd27d2e20 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -5,7 +5,10 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -476,7 +479,7 @@ public void RunInferenceUsingPreAllocatedOutputsAndDictionary() TensorElementType.Float, expectedShape)) { // Run inference - var inputValues = new List{ inputOrtValue }.AsReadOnly(); + var inputValues = new List { inputOrtValue }.AsReadOnly(); var outputValues = new List { outputOrtValue }.AsReadOnly(); session.Run(runOptions, inputNames, inputValues, expectedOutputNames, outputValues); @@ -1342,7 +1345,7 @@ private void TestModelInputFLOAT16() [Fact(DisplayName = "TestModelInputBFLOAT16")] private void TestModelInputBFLOAT16() { - BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384), + BFloat16[] modelInput = { new BFloat16(16256), new BFloat16(16384), new BFloat16(16448), new BFloat16(16512), new BFloat16(16544) }; int[] inputShape = { 1, 5 }; // model takes 1x5 input of fixed type, echoes back @@ -2025,6 +2028,78 @@ public SkipNonPackageTests() } } } + + [Fact(DisplayName = "TestModelRunAsyncTask")] + private async void TestModelRunAsyncTask() + { + Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) }; + long[] shape = { 1, 5 }; + + var inputNames = new List { "input" }; + var inputValues = new List { OrtValue.CreateTensorValueFromMemory(inputData, shape) }; + + var outputNames = new List { "output" }; + var outputValues = new List { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, + TensorElementType.Float16, shape) }; + + var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx"); + using (SessionOptions opt = new SessionOptions()) + { + opt.IntraOpNumThreads = 2; + using (var session = new InferenceSession(model, opt)) + { + try + { + var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues); + var outputs = await task; + var valueOut = outputs.ElementAt(0); + var float16s = valueOut.GetTensorDataAsSpan().ToArray(); + Assert.Equal(new Float16(16896), float16s[2]); + } + catch + { + Assert.True(false); + } + } + } + } + + [Fact(DisplayName = "TestModelRunAsyncTaskFail")] + private async void TestModelRunAsyncTaskFail() + { + Float16[] inputData = { new Float16(15360), new Float16(16384), new Float16(16896), new Float16(17408), new Float16(17664) }; + long[] shape = { 1, 5 }; + + var inputNames = new List { "input" }; + var inputValues = new List { OrtValue.CreateTensorValueFromMemory(inputData, shape) }; + + var outputNames = new List { "output" }; + var outputValues = new List { OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, + TensorElementType.Float16, shape) }; + + var model = TestDataLoader.LoadModelFromEmbeddedResource("test_types_FLOAT16.onnx"); + using (SessionOptions opt = new SessionOptions()) + { + opt.IntraOpNumThreads = 1; // this will make RunAsync fail + string err = ""; + using (var session = new InferenceSession(model, opt)) + { + try + { + var task = session.RunAsync(null, inputNames, inputValues, outputNames, outputValues); + var outputs = await task; + } + catch (Exception ex) + { + err = ex.Message; + } + finally + { + Assert.Contains("intra op thread pool must have at least one thread for RunAsync", err); + } + } + } + } } }