From d63c664ca0021fbac31cee57ff1eaa8bce3d1903 Mon Sep 17 00:00:00 2001 From: rui-ren Date: Thu, 15 Feb 2024 00:02:08 -0800 Subject: [PATCH 001/237] fix rocm ci pipeline (#19525) ### Description ROCm CI pipeline issue. ``` Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.91 MiB, post-processed: Unknown size, total: 17.41 MiB) to /home/onnxruntimedev/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20... main() File "/stage/huggingface-transformers/examples/pytorch/language-modeling/run_mlm.py", line 242, in main datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/load.py", line 856, in load_dataset builder_instance.download_and_prepare( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/builder.py", line 583, in download_and_prepare self._download_and_prepare( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/builder.py", line 639, in _download_and_prepare split_generators = self._split_generators(dl_manager, **split_generators_kwargs) File "/home/onnxruntimedev/.cache/huggingface/modules/datasets_modules/datasets/wikitext/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/wikitext.py", line 138, in _split_generators data_file = dl_manager.download_and_extract(self.config.data_url) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/download_manager.py", line 289, in download_and_extract return self.extract(self.download(url_or_urls)) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/download_manager.py", line 197, in download downloaded_path_or_paths = map_nested( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/py_utils.py", line 195, in map_nested return function(data_struct) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/download_manager.py", line 220, in _download return cached_path(url_or_filename, download_config=download_config) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/file_utils.py", line 281, in cached_path output_path = get_from_cache( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/file_utils.py", line 634, in get_from_cache raise ConnectionError("Couldn't reach {}".format(url)) ConnectionError: Couldn't reach https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip ``` ### Motivation and Context Update the `datasets` pipeline to latest version `2.17.0`. --- tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 64710a982a29d..496b57b417fbd 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -112,7 +112,7 @@ RUN pip install \ cerberus \ sympy \ h5py \ - datasets==1.9.0 \ + datasets==2.17.0 \ requests \ sacrebleu==1.5.1 \ sacremoses \ From d0061d6fb15d40eeb35fa1b40a414cd231d51db9 Mon Sep 17 00:00:00 2001 From: sophies927 <107952697+sophies927@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:03:11 -0800 Subject: [PATCH 002/237] Update stale.yml to use old version as a bug fix (#19532) ### Description Changed the actions/stale version back to v8 from v9. ### Motivation and Context There is a well-documented issue w/ the new actions/stale version (v9.0.0) that causes the following error: "Error delete _state: [403] Resource not accessible by integration". See https://github.com/actions/stale/issues/1133 for more context. This issue is preventing the stale bot from labeling stale issues since the version was updated b/c the action can no longer access the cache and cannot apply labels to all issues due to GH API rate limiting. There are two potential fixes if we continue to use the new version: (1) run the action on all PRs/issues to avoid using the cache or (2) give write access to the endpoints listed in https://docs.github.com/en/rest/authentication/permissions-required-for-fine-grained-personal-access-tokens?apiVersion=2022-11-28#repository-permissions-for-actions. Neither of these options is preferable, so I am going to wait until the bug is fixed. Note: The old version (v8.0.0) uses Node 16, which will be deprecated in Spring 2024, instead of Node 20, so we should keep an eye on [this issue](https://github.com/actions/stale/issues/1133) to see when they make the fix and we can switch back to the new version. --- .github/workflows/stale.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c94e3fa5bcb8c..181f3fb17d332 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9.0.0 + - uses: actions/stale@v8 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression From 4bfa69def85476b33ccfaf68cf070f3fb65d39f7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 15 Feb 2024 20:22:36 -0800 Subject: [PATCH 003/237] Speed Up DecoderMaskedSelfAttentionTest (#19531) ### Description The unit tests take 19 minutes to run (in debug build) because of too many combinations. I reduce the combinations and remain good test coverage. After the change, the test can finish in 51 seconds. Before: [----------] 2 tests from DecoderMaskedSelfAttentionTest [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp32 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp32 (394086 ms) [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp16 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp16 (747035 ms) [----------] 2 tests from DecoderMaskedSelfAttentionTest (1141122 ms total) After: [----------] 2 tests from DecoderMaskedSelfAttentionTest [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp32 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp32 (21057 ms) [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp16 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp16 (30653 ms) [----------] 2 tests from DecoderMaskedSelfAttentionTest (51710 ms total) ### Motivation and Context Reduce test time, and improve build pipeline efficiency. --- ...oder_masked_multihead_attention_op_test.cc | 451 ++++++++++-------- 1 file changed, 242 insertions(+), 209 deletions(-) diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 6afb61bd1f0a1..8ea37ad054ed0 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -640,122 +640,139 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); - - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); - - tester.AddInput("past", past_dims, reordered_kv_cache); - - // Rel - tester.AddOptionalInputEdge(); - - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); - - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); - - auto k_merged = pair.first; - auto k_transpose = pair.second; - - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); - - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); - - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 384}, + {2, 30, 768}, + {3, 31, 1536}, + {4, 512, 384}, + {1, 1024, 768}, + {1, 2046, 1536}, + {2, 2047, 384}, + {3, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); + + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(float); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); + + tester.AddInput("past", past_dims, reordered_kv_cache); + + // Rel + tester.AddOptionalInputEdge(); + + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); + + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); + + auto k_merged = pair.first; + auto k_transpose = pair.second; + + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); + + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); + + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); - - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -766,122 +783,138 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 768}, + {3, 30, 384}, + {8, 31, 1536}, + {4, 256, 384}, + {3, 1024, 768}, + {2, 2046, 1536}, + {1, 2047, 384}, + {2, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(MLFloat16); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); - // Rel - tester.AddOptionalInputEdge(); + // Rel + tester.AddOptionalInputEdge(); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); - auto k_merged = pair.first; - auto k_transpose = pair.second; + auto k_merged = pair.first; + auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -889,4 +922,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From ef0b71308c0e2395d3ea63e627515ff8e624ad45 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 16 Feb 2024 05:34:55 -0800 Subject: [PATCH 004/237] Optimize KahnsTopologicalSort and PriorityNodeCompare (#19475) **Description** 1) During SessionInitialization, KahnsTopologicalSort is a major cause of perf degradation. The main cause of slow down is that the TopologicalSort needs to keep track of nodes to visit in order, and reorder them based on priority (as informed by a comparator). The existing implementation uses a priority_queue that is backed by a std::vector container. However, vectors are not good for insertion and reordering. The appropriate data type for this operation is a linked list. However, linked lists like std::list are not usable as a container for std::priority_queue. This is because std::priority_queue requires random access, which linked lists do not have. However, for this simple implementation, we can leverage a std::list under the hood and perform insertions manually using std::upper_bound. This drastically reduces the time taken by the method, which currently instead causes numerous recopies and a lot of movement inside the graph nodes to visit list. 2) In the comparator, I hide forward and backward attribute checking behind the #ifdef ENABLE_TRAINING macro, as I believe it should only be valid in the training scenario. 3) In noopelimination transformer, I prevent the creation of Initializer (which unpacks tensorproto data) in every node and only create initializers when Add/Sub/Mul/Div op nodes are detected. **Motivation and Context** Session creation time of many models is quite slow. --------- Co-authored-by: Sheil Kumar --- onnxruntime/core/graph/graph.cc | 37 ++++++++-- onnxruntime/core/graph/graph_viewer.cc | 18 +++-- .../core/optimizer/noop_elimination.cc | 73 +++++++++++-------- .../ort_optimizer_api_impl.cc | 2 +- 4 files changed, 85 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 902839bee04ba..305122c56b865 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2843,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index acf7b3a16541f..119d420066a84 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { + return false; + } } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } From b84712151c06f0f59359916be572f71bd36721a4 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 16 Feb 2024 14:36:05 -0800 Subject: [PATCH 005/237] QNN EP: Fuse DQ -> Q sequences into a QNN Convert op (#19511) ### Description Fuses DQ -> Q sequences into a QNN Convert operator if: - Converting from one qtype to another. Ex: Dequantize(uint8 to float) -> Quantize(float to uint16) - The DQ and Q operators are not part of another node unit (i.e., standalone) - The Q operator is the only consumer for the DQ operator. ### Motivation and Context Allows faster execution of QDQ models with mixed activation types by leveraging the QNN Convert operator, which converts between quantization types. For certain models, this results in inference latency speed-ups of up to 2x (depends on the number of DQ -> Q sequences). #### Example for Add node unit with 16-bit I/O: Original: ``` u8 ----> DQ ---> Q ---u16--> Add ---u16--> ^ | u16 --------------------------+ ``` After fusing DQ -> Q: ``` u8 ----> Convert ---u16--> Add ---u16--> ^ | u16 ------------------------+ ``` --- .../optimizer/qdq_transformer/qdq_util.cc | 43 ++++++++ .../core/optimizer/qdq_transformer/qdq_util.h | 12 ++ .../qnn/builder/op_builder_factory.h | 23 ++++ .../builder/opbuilder/convert_op_builder.cc | 103 ++++++++++++++++++ .../core/providers/qnn/builder/qnn_model.cc | 35 +++++- .../providers/qnn/qnn_execution_provider.cc | 88 +++++++++------ .../providers/qnn/qnn_execution_provider.h | 1 - .../test/providers/qnn/simple_op_htp_test.cc | 55 ++++++++++ 8 files changed, 319 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index b1ab641a23256..4e3dff705bd41 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -76,6 +76,49 @@ bool IsQDQPairSupported( } } +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != InputIndex::TOTAL_COUNT || + q_input_defs.size() != InputIndex::TOTAL_COUNT || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + Initializer q_zp(*q_zp_tensor_proto, model_path); + Initializer q_scale(*q_scale_tensor_proto, model_path); + Initializer dq_zp(*dq_zp_tensor_proto, model_path); + Initializer dq_scale(*dq_scale_tensor_proto, model_path); + + return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); +} + bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { bool zero_point_exists = false; if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index bb0bf9438cfcb..8333168b0093f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -38,6 +38,18 @@ bool IsQDQPairSupported( const GetConstantInitializerFn& get_const_initializer, const Path& model_path); +// Check if a DQ -> Q sequence represents a conversion in quantization data type. +// Example of uint8 to uint16: +// Dequantize (uint8 to float) -> Quantize (float to uint16) +// Requires: +// 1. Q/DQ doesn't have optional input. +// 2. scale and zero-point are constant scalars. +// 3. Q and DQ have the same scale *type* and different zero-point *types*. +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path); + // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. // 2. scale and zero point is constant scalar diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index d95e2baa9457f..4a9106f0c06af 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +struct HandleConvertResult { + Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine + // whether a DQ -> Q sequence was successfully merged into a Convert. + const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence. + // Set to nullptr if this node unit could not be merged into a Convert. +}; + +/** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param qnn_model_wrapper The QNN model that is being built. + * \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer + * to the Q node unit that was successfully merged with the provided DQ node unit. + */ +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc new file mode 100644 index 0000000000000..977a9e0b3d9d0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" + +#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values). + +namespace onnxruntime { +namespace qnn { + +class ConvertOpBuilder : public BaseOpBuilder { + public: + ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder); + + Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; +}; + +Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector input_names; + + // Process the input from the DQ node + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names)); + + // Process the output from the Q node. Override the QNN operator type to "Convert". + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {}, + logger, do_op_validation, QNN_OP_CONVERT)); + return Status::OK(); +} + +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Looking for a standalone DQ to start the sequence. + if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + const Node& dq_node = maybe_dq_node_unit.GetNode(); + + // DQ must have a single Q child. DQ must not produce a graph output. + auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); + if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return {}; + } + + const Node& q_node = *children[0]; + const auto q_node_unit_it = node_unit_map.find(&q_node); + + if (q_node_unit_it == node_unit_map.end()) { + return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr}; + } + + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + return {}; + } + + ConvertOpBuilder op_builder; + + LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; + + auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger, + do_op_validation); + return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr}; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 314cab4a36ca9..dc91b9dfa199e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } + std::unordered_set handled_node_units; + // Op builer const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { @@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // Check whether it's part of NodeUnit const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node) + // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) const std::string& op_type = node_unit.OpType(); + + if (node != &node_unit.GetNode()) { + continue; + } + + if (handled_node_units.count(&node_unit) != 0) { + continue; // Already handled. + } + + // Try to convert particular DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + node_unit, + node_unit_map, + logger_, + false /*do_op_validation*/); + ORT_RETURN_IF_ERROR(convert_result.status); + + if (convert_result.q_node_unit) { + // Successfully merged DQ -> Q sequence into a QNN Convert op. + // Mark both of these node units as handled. + handled_node_units.insert(&node_unit); + handled_node_units.insert(convert_result.q_node_unit); + continue; + } + LOGS(logger_, VERBOSE) << " node name: [" << node->Name() << "] node optype: [" << op_type << "] as part of the NodeUnit type: [" << node_unit.OpType() << "] name: [" << node_unit.Name() << "]"; - if (node != &node_unit.GetNode()) { - continue; - } - if (const auto* op_builder = GetOpBuilder(op_type)) { ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); } + + handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b58f6e10df94c..f5a166d36b15a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const { - // If we have visited one of the nodes in the node_unit, use the result directly - const auto it = node_unit_supported_result.find(&node_unit); - if (it != node_unit_supported_result.cend()) { - return it->second; + const std::string& op_type = node_unit.OpType(); + bool supported = false; + const auto* op_builder = qnn::GetOpBuilder(op_type); + if (op_builder == nullptr) { + LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." + << node_unit.OpType() << " node `" << node_unit.Name() + << "` will not be assigned to QNN EP."; } else { - const std::string& op_type = node_unit.OpType(); - - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - supported = (Status::OK() == status); + auto status = op_builder->IsOpSupported(qnn_model_wrapper, + node_unit, logger); + if (Status::OK() != status) { + LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); } - node_unit_supported_result[&node_unit] = supported; - return supported; + supported = (Status::OK() == status); } + return supported; } std::unordered_set @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, if (node != &node_unit->GetNode()) { continue; } - const bool supported = IsNodeSupported(qnn_model_wrapper, - *node_unit, - node_unit_supported_result, - logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; + + if (node_unit_supported_result.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + // Try to convert certain standalone DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + *node_unit, + node_unit_map, + logger, + true /*do_op_validation*/); + if (!convert_result.status.IsOK()) { + LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. " + << "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", " + << "Message: " << convert_result.status.ErrorMessage(); + } + + bool supported = false; + + if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op + supported = true; + + // Mark the Q node unit as handled and supported here so that we don't try to process it again. + node_unit_supported_result.insert({convert_result.q_node_unit, true}); + supported_nodes.insert(&convert_result.q_node_unit->GetNode()); + } else { + supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); + LOGS(logger, VERBOSE) << "Node supported: [" << supported + << "] index: [" << node->Index() + << "] name: [" << node->Name() + << "] Operator type: [" << node->OpType() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + } + if (supported) { // If the node_unit is supported, add all of its nodes to the supported list. for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { supported_nodes.insert(node_in_group); } } + + node_unit_supported_result.insert({node_unit, supported}); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 09bcb24db4dc2..0bcaa39b22f6d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const; std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2f3b0e84a123e..a6422407d79fd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1110,6 +1110,61 @@ TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) { kOnnxDomain, true); } + +static GetTestQDQModelFn BuildQDQConvertAddTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def) { + return [input0_def, input1_def](ModelTestBuilder& builder, std::vector>& output_qparams) { + constexpr bool use_contrib_qdq = true; + + // Input0 -> Quantize(u8) -> Dequantize(u8 to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_u8_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_u8_qparams.scale, + input0_u8_qparams.zero_point, use_contrib_qdq); + + // input0_after_qdq -> Quantize(u16) -> Dequantize(u16 to float) + QuantParams input0_u16_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_convert = AddQDQNodePair(builder, input0_after_qdq, input0_u16_qparams.scale, + input0_u16_qparams.zero_point, use_contrib_qdq); + + // Input1 -> Quantize(u16) -> Dequantize(u16 to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_convert, input1_after_qdq}, {op_output}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Test quantization type conversion (mixed precision) with Add. +// First input is converted from uint8_t to uint16_t. +TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(-20.0f, 20.0f, 8); + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildOpTestCase("Add", {input0_def, input1_def}, {}, {}, kOnnxDomain), + BuildQDQConvertAddTestCase(input0_def, input1_def), + provider_options, + 18, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test From 1dce5e17321d50bf345022b525a937933473415a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Feb 2024 14:41:11 -0800 Subject: [PATCH 006/237] Disable TF32 in Linux_Test stage of Linux GPU CI Pipeline (#19541) ### Description Some test thresholds that previously worked in T4 GPU does not work anymore. The reason is current pipeline uses A10, and TF32 is enabled by default. Disable TF32 in Linux GPU CI Pipeline in testing to avoid such random test failure. ### Motivation and Context Linux Test has random failure at tests: ProviderOptionsTest > testCUDAOptions() FAILED org.opentest4j.AssertionFailedError: array contents differ at index [446], expected: <0.0419757> but was: <0.041948937> at app//org.junit.jupiter.api.AssertionFailureBuilder.build(AssertionFailureBuilder.java:151) at app//org.junit.jupiter.api.AssertionFailureBuilder.buildAndThrow(AssertionFailureBuilder.java:132) at app//org.junit.jupiter.api.AssertArrayEquals.failArraysNotEqual(AssertArrayEquals.java:440) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:290) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:123) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:119) at app//org.junit.jupiter.api.Assertions.assertArrayEquals(Assertions.java:1360) at app//ai.onnxruntime.providers.ProviderOptionsTest.runProvider(ProviderOptionsTest.java:99) at app//ai.onnxruntime.providers.ProviderOptionsTest.testCUDAOptions(ProviderOptionsTest.java:43) org.opentest4j.AssertionFailedError: array contents differ at index [6], expected: <0.0225981> but was: <0.022587791> at app//org.junit.jupiter.api.AssertionFailureBuilder.build(AssertionFailureBuilder.java:151) at app//org.junit.jupiter.api.AssertionFailureBuilder.buildAndThrow(AssertionFailureBuilder.java:132) at app//org.junit.jupiter.api.AssertArrayEquals.failArraysNotEqual(AssertArrayEquals.java:440) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:290) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:123) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:119) at app//org.junit.jupiter.api.Assertions.assertArrayEquals(Assertions.java:1360) at app//ai.onnxruntime.InferenceTest.runProvider(InferenceTest.java:676) at app//ai.onnxruntime.InferenceTest.testCUDA(InferenceTest.java:615) --- tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index b19a8b11db265..24319184dd0b8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -204,6 +204,7 @@ jobs: --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ $(Repository) \ /bin/bash -c " set -ex; \ From 44d8ad93b20efdba921ca80f23485c084b5174d0 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:21:43 -0800 Subject: [PATCH 007/237] Whisper Timestamps and Temperature (#19509) ### Description This PR updates exporting and running the Whisper model with beam search by adding the following. - Adds temperature as a graph input to the exported model - Fixes the token ids by adding them as attributes to `WhisperBeamSearch` - Fixes the timestamps test cases so they pass now - Fixes a bug with invoking `torch.onnx.export` - Cleans up the Whisper scripts and groups the arguments in `convert_to_onnx.py` - Adds a `requirements.txt` file to specify package dependencies - Adds `whisper-large-v3` to list of pretrained models - Fixes a bug with missing cross-attention KV cache inputs in the decoder subgraph ### Motivation and Context - This is a follow-up to [this PR](https://github.com/microsoft/onnxruntime/pull/19188). - The incorrect token ids in the timestamps processor were first noticed during [this PR review](https://github.com/microsoft/onnxruntime/pull/17500#discussion_r1333520007). When they were originally added in [this PR](https://github.com/microsoft/onnxruntime/pull/15853), the offsets were previously constant across the Whisper model sizes. When comparing the new `whisper-large-v3` variant, the English-only variants (e.g. `whisper-tiny.en`), and the original variants (e.g. `whisper-tiny`), both the values and the offsets differ. Therefore, it is easier to set the token ids as attributes to `WhisperBeamSearch` when exporting to ensure the right values are used in the timestamps processor. - The Hugging Face API for returning timestamps and the expected outputs from the PyTorch model have both changed. - The fix for `torch.onnx.export` is a follow-up to [this PR review](https://github.com/microsoft/onnxruntime/pull/17179#issuecomment-1683001470). - The argument grouping is a follow-up to [this PR review](https://github.com/microsoft/onnxruntime/pull/17500#discussion_r1333521721). - Specific package versions are needed to run the Whisper scripts and the `requirements.txt` file ensures that these versions are installed. - The `whisper-large-v3` variant is released and should be in the list of official pretrained models. - After the changes from [this PR](https://github.com/microsoft/onnxruntime/pull/17316), the exported model is not loading in an ORT inference session because the cross-attention KV cache inputs are missing in the decoder subgraph. --- docs/ContribOperators.md | 32 +- .../transformers/beam_search_impl_whisper.h | 4 +- .../transformers/beam_search_parameters.cc | 8 +- .../cpu/transformers/generation_shared.h | 9 +- .../cpu/transformers/logits_processor.h | 81 +++-- .../transformers/generation_device_helper.cc | 12 +- .../core/graph/contrib_ops/contrib_defs.cc | 40 +-- .../transformers/models/whisper/README.md | 46 ++- .../transformers/models/whisper/benchmark.py | 22 +- .../models/whisper/benchmark_all.py | 6 + .../models/whisper/convert_to_onnx.py | 277 ++++++++++-------- .../models/whisper/requirements-cpu.txt | 2 + .../models/whisper/requirements-cuda.txt | 4 + .../models/whisper/requirements.txt | 11 + .../models/whisper/whisper_chain.py | 272 +++++++++-------- .../models/whisper/whisper_decoder.py | 2 +- .../whisper/whisper_encoder_decoder_init.py | 6 +- .../models/whisper/whisper_helper.py | 79 ++--- .../transformers/torch_onnx_export_helper.py | 3 +- .../python/transformers/test_generation.py | 19 +- .../test_whisper_timestamp_processor.py | 4 +- 21 files changed, 560 insertions(+), 379 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements.txt diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e7b537d6894c8..f523e97293427 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
temperature (optional) : T
@@ -5812,11 +5822,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..af0904b7d6e4b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..93837e785b4a4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..b1dd55eb20f34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,7 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 03d4e89ac20fe..231eb17d1a947 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -268,7 +283,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -330,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index bba30805ae1be..7adc2fe0a67ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..e33ce20737f80 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1235,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1322,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1363,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 02100266200f8..7a678f2734ade 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,5 +1,22 @@ # Whisper +## Prerequisites + +Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running Whisper on CPU +- `requirements-cuda.txt` + - For running Whisper on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements.txt` + - Package versions needed in each of the above files + +In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers. + +- Linux: `sudo apt-get install ffmpeg` +- MacOS: `sudo brew install ffmpeg` +- Windows: Download from website + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search (using Whisper tiny as an example). @@ -10,10 +27,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as # From source $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format # From wheel -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format ``` ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) @@ -39,40 +56,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx") Here are some additional examples for exporting Whisper with beam search. +To see all available options +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx --help + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +``` + Export with Forced Decoder Input Ids ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids ``` Export + Optimize for FP32 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 ``` Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer ``` ## Benchmark Whisper diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 759ae6d14f184..e57385aa6db8f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import ast import datetime @@ -54,6 +60,8 @@ def load_via_numpy(): inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) if args.has_logits_processor: inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + if args.has_temperature: + inputs["temperature"] = np.array([args.temperature], dtype=np.float32) # Measure time taken to load audio file logger.info(f"Load audio: {args.audio_path}") @@ -163,6 +171,7 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): warmup_inputs = inputs[0] if type(inputs) is tuple else inputs benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + torch_device = torch.device(args.target_device) # Warm up warmup_range = ( @@ -180,7 +189,7 @@ def time_fn(args, fn, inputs): # Benchmark if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) start_time = time.time() bench_range = ( @@ -192,7 +201,7 @@ def time_fn(args, fn, inputs): fn(benchmark_inputs) if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line @@ -500,7 +509,13 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Whether to use timestamps logits processor or not (0 for false, 1 for true).", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature value for generation.", ) # Args for accessing detailed info @@ -581,6 +596,7 @@ def main(): args.has_audio_stream = "audio_stream" in ort_model_inputs setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010 if args.decoder_input_ids == []: args.decoder_input_ids = [config.decoder_start_token_id] diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d205a2d340721..814b0dd1ef6ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import datetime import json diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index bb697fe1e1506..35211aab272e4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -28,17 +28,25 @@ def parse_arguments(argv=None): parser = argparse.ArgumentParser() - pretrained_models = PRETRAINED_WHISPER_MODELS - parser.add_argument( + conversion_args = parser.add_argument_group("Conversion Process Args") + optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)") + optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)") + quant_args = parser.add_argument_group("INT8 Quantization Args") + + ################################# + # Conversion options for Whisper + ################################# + + conversion_args.add_argument( "-m", "--model_name_or_path", required=False, default=PRETRAINED_WHISPER_MODELS[0], type=str, - help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS), ) - parser.add_argument( + conversion_args.add_argument( "--model_impl", required=False, default="hf", @@ -47,7 +55,7 @@ def parse_arguments(argv=None): help="Select implementation for export of encoder and decoder subgraphs", ) - parser.add_argument( + conversion_args.add_argument( "--cache_dir", required=False, type=str, @@ -55,7 +63,7 @@ def parse_arguments(argv=None): help="Directory to cache pre-trained models", ) - parser.add_argument( + conversion_args.add_argument( "--output", required=False, type=str, @@ -63,19 +71,24 @@ def parse_arguments(argv=None): help="Output directory", ) - parser.add_argument( + conversion_args.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model", ) - parser.set_defaults(optimize_onnx=False) + conversion_args.set_defaults(optimize_onnx=False) - parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") - parser.set_defaults(use_gpu=False) + conversion_args.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for model inference", + ) + conversion_args.set_defaults(use_gpu=False) - parser.add_argument( + conversion_args.add_argument( "-p", "--precision", required=False, @@ -85,221 +98,226 @@ def parse_arguments(argv=None): help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", ) - parser.add_argument("--verbose", required=False, action="store_true") - parser.set_defaults(verbose=False) - - parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") - parser.set_defaults(use_external_data_format=False) - - parser.add_argument( - "-s", - "--use_decoder_start_token", + conversion_args.add_argument( + "--use_int64_inputs", required=False, action="store_true", - help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \ - the encoder-decoder-init subgraph for decoder_input_ids.", + help="Use int64 instead of int32 for input_ids and attention_mask.", ) - parser.set_defaults(use_decoder_start_token=False) + conversion_args.set_defaults(use_int64_inputs=False) - parser.add_argument( - "-f", - "--use_forced_decoder_ids", + conversion_args.add_argument( + "--disable_auto_mixed_precision", required=False, action="store_true", - help="Use decoder_input_ids as an extra graph input to the beam search op", + help="Use pure fp16 instead of mixed precision", ) - parser.set_defaults(use_forced_decoder_ids=False) + conversion_args.set_defaults(disable_auto_mixed_precision=False) - parser.add_argument( - "-l", - "--use_logits_processor", + conversion_args.add_argument( + "-r", + "--provider", required=False, - action="store_true", - help="Use logits_processor as an extra graph input to enable specific logits processing", + type=str, + default="cpu", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CPUExecutionProvider.", ) - parser.set_defaults(use_specific_logits_processor=False) - parser.add_argument( - "-v", - "--use_vocab_mask", + conversion_args.add_argument( + "--verbose", required=False, action="store_true", - help="Use vocab_mask as an extra graph input to enable specific logits processing", + help="Enable verbose logging", ) - parser.set_defaults(use_vocab_mask=False) + conversion_args.set_defaults(verbose=False) - parser.add_argument( - "-u", - "--use_prefix_vocab_mask", + conversion_args.add_argument( + "-e", + "--use_external_data_format", required=False, action="store_true", - help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", + help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.", ) - parser.set_defaults(use_prefix_vocab_mask=False) + conversion_args.set_defaults(use_external_data_format=False) - parser.add_argument( + conversion_args.add_argument( "-w", "--overwrite", required=False, action="store_true", - help="overwrite existing ONNX model", + help="Overwrite existing ONNX model", ) - parser.set_defaults(overwrite=False) + conversion_args.set_defaults(overwrite=False) - parser.add_argument( - "--disable_auto_mixed_precision", + conversion_args.add_argument( + "--separate_encoder_and_decoder_init", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.", ) - parser.set_defaults(disable_auto_mixed_precision=False) + conversion_args.set_defaults(separate_encoder_and_decoder_init=False) - parser.add_argument( - "--separate_encoder_and_decoder_init", + conversion_args.add_argument( + "--no_beam_search_op", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + conversion_args.set_defaults(no_beam_search_op=False) - parser.add_argument( - "--use_int64_inputs", + conversion_args.add_argument( + "--state_dict_path", + type=str, + default="", + help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + + ############################################################# + # Optional inputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_inputs.add_argument( + "-v", + "--use_vocab_mask", required=False, action="store_true", - help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + help="Use vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(use_int64_inputs=False) + optional_inputs.set_defaults(use_vocab_mask=False) - parser.add_argument( - "--chain_model", + optional_inputs.add_argument( + "-u", + "--use_prefix_vocab_mask", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(chain_model=True) + optional_inputs.set_defaults(use_prefix_vocab_mask=False) - parser.add_argument( - "--use_whisper_beamsearch", + optional_inputs.add_argument( + "-f", + "--use_forced_decoder_ids", required=False, action="store_true", - help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \ - It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.", + help="Use decoder_input_ids as an extra graph input to the beam search op", ) - parser.set_defaults(use_whisper_beamsearch=False) + optional_inputs.set_defaults(use_forced_decoder_ids=False) - parser.add_argument( - "--extra_decoding_ids", + optional_inputs.add_argument( + "-l", + "--use_logits_processor", required=False, action="store_true", - help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + help="Use logits_processor as an extra graph input to enable specific logits processing", ) - parser.set_defaults(extra_decoding_ids=False) + optional_inputs.set_defaults(use_specific_logits_processor=False) - parser.add_argument( + optional_inputs.add_argument( "--collect_cross_qk", required=False, action="store_true", help="Beam search model collect stacked cross QK.", ) - parser.set_defaults(collect_cross_qk=False) + optional_inputs.set_defaults(collect_cross_qk=False) - parser.add_argument( - "--output_cross_qk", + optional_inputs.add_argument( + "--extra_decoding_ids", required=False, action="store_true", - help="Beam search model output collected qk as output. Also hint collect_cross_qk", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", ) - parser.set_defaults(output_cross_qk=False) + optional_inputs.set_defaults(extra_decoding_ids=False) - parser.add_argument( - "--no_speech_token_id", - default=50362, + optional_inputs.add_argument( + "-t", + "--use_temperature", + required=False, + action="store_true", + help="Use temperature as an extra graph input for the WhisperBeamSearch op", + ) + optional_inputs.set_defaults(use_temperature=False) + + optional_inputs.add_argument( + "--no_repeat_ngram_size", type=int, - help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \ - Note that default value maybe different between the multilingual and English-only models.", + default=0, + help="default to 0", ) - parser.add_argument( - "--output_no_speech_probs", + ############################################################# + # Optional outputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_outputs.add_argument( + "--output_sequence_scores", required=False, action="store_true", - help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + help="Beam search model output scores for each generated sequence.", ) - parser.set_defaults(output_no_speech_probs=False) + optional_outputs.set_defaults(output_sequence_scores=False) - parser.add_argument( + optional_outputs.add_argument( "--output_scores", required=False, action="store_true", help="Beam search model output scores over vocab per generated token.", ) - parser.set_defaults(output_scores=False) + optional_outputs.set_defaults(output_scores=False) - parser.add_argument( - "--output_sequence_scores", + optional_outputs.add_argument( + "--output_cross_qk", required=False, action="store_true", - help="Beam search model output scores for each generated sequence.", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", ) - parser.set_defaults(output_sequence_scores=False) + optional_outputs.set_defaults(output_cross_qk=False) - parser.add_argument( + optional_outputs.add_argument( "--cross_qk_onnx_model", required=False, type=str, default=None, - help="the model which consume cross_qk.", + help="The model which consumes cross_qk outputs.", ) - parser.add_argument( - "--beam_output_model", - type=str, - default="whisper_beamsearch.onnx", - help="default name is whisper_beamsearch.onnx.", + optional_outputs.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) + optional_outputs.set_defaults(output_no_speech_probs=False) - parser.add_argument( + ################################### + # Quantization options for Whisper + ################################### + + quant_args.add_argument( "--quantize_embedding_layer", required=False, action="store_true", help="Quantize MatMul, GEMM, and Gather.", ) - parser.set_defaults(quantize_embedding_layer=False) + quant_args.set_defaults(quantize_embedding_layer=False) - parser.add_argument( + quant_args.add_argument( "--quantize_per_channel", required=False, action="store_true", help="Quantize weights per each channel.", ) - parser.set_defaults(quantize_per_channel=False) + quant_args.set_defaults(quantize_per_channel=False) - parser.add_argument( + quant_args.add_argument( "--quantize_reduce_range", required=False, action="store_true", help="Quantize weights with 7 bits.", ) - parser.set_defaults(quantize_reduce_range=False) - - parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") - - parser.add_argument( - "--state_dict_path", - type=str, - default="", - help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", - ) - - parser.add_argument( - "-r", - "--provider", - required=False, - type=str, - default="cpu", - choices=list(PROVIDERS.keys()), - help="Provider to benchmark. Default is CPUExecutionProvider.", - ) + quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk @@ -317,7 +335,7 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = False, + use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, @@ -362,7 +380,6 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token, use_int32_inputs=use_int32_inputs, ) else: @@ -406,7 +423,7 @@ def export_onnx_models( extra_options={"MatMulConstBOnly": True}, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") else: output_path = onnx_path @@ -449,7 +466,7 @@ def main(argv=None): args.optimize_onnx, args.precision, args.verbose, - args.use_decoder_start_token, + args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, @@ -462,7 +479,7 @@ def main(argv=None): ) max_diff = 0 - if args.chain_model: + if not args.no_beam_search_op: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt new file mode 100644 index 0000000000000..db2cd95324328 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.17.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt new file mode 100644 index 0000000000000..9bd215de9bc09 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.17.1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt new file mode 100644 index 0000000000000..c307a3665f8a0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -0,0 +1,11 @@ +torch>=1.13.0 +transformers>=4.24.0 +openai-whisper +ffmpeg-python +datasets +soundfile +librosa +optimum +onnxruntime-extensions>=0.9.0 +protobuf==3.20.2 +numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a74666b7af297..14691da4ad643 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import logging import os @@ -9,7 +15,7 @@ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) @@ -23,11 +29,22 @@ def verify_inputs(beam_inputs, graph_inputs): assert graph_input.name in beam_input +def clean_list(arr, remove_all_strings=True): + if remove_all_strings: + # Remove all empty strings in list + return list(filter(lambda elm: elm != "", arr)) + + # Remove empty strings at end of list + while len(arr) > 0: + if arr[-1] == "": + arr.pop() + else: + break + return arr + + def chain_model(args): - # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op - args.use_whisper_beamsearch = ( - args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids - ) + # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" @@ -35,7 +52,10 @@ def chain_model(args): decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + # Create inputs/outputs for WhisperBeamSearch op + temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", @@ -44,38 +64,27 @@ def chain_model(args): "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", - "vocab_mask" if args.use_prefix_vocab_mask else "", + "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", + temperature_name if args.use_temperature else "", ] - beam_outputs = ["sequences"] - if args.output_sequence_scores: - beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") - if args.output_scores: - beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") - - if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 - beam_inputs.extend( - [ - "cross_qk_layer_head" if args.collect_cross_qk else "", - "extra_decoding_ids" if args.extra_decoding_ids else "", - ] - ) - if args.collect_cross_qk: - while len(beam_outputs) < 3: - beam_outputs.extend([""]) - beam_outputs.extend(["cross_qk"]) - if args.output_no_speech_probs: - while len(beam_outputs) < 4: - beam_outputs.extend([""]) - beam_outputs.extend(["no_speech_probs_beam"]) - - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None - output_scores_cast_node = output_sequence_scores_cast_node = None + sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" + scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + beam_outputs = [ + "sequences", + sequence_scores_name if args.output_sequence_scores else "", + scores_name if args.output_scores else "", + "cross_qk" if args.collect_cross_qk else "", + "no_speech_probs_beam" if args.output_no_speech_probs else "", + ] + + graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -98,6 +107,18 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + + if args.use_temperature: + temp_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="temperature_to_fp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temp_cast_node) + if args.output_sequence_scores: output_sequence_scores_cast_node = helper.make_node( "Cast", @@ -106,6 +127,8 @@ def chain_model(args): name="CastOutputSequenceScoresToFp32", to=TensorProto.FLOAT, ) + graph_nodes.append(output_sequence_scores_cast_node) + if args.output_scores: output_scores_cast_node = helper.make_node( "Cast", @@ -114,26 +137,38 @@ def chain_model(args): name="CastScoresToFp32", to=TensorProto.FLOAT, ) - - operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" - node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", config.eos_token_id), - helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + graph_nodes.append(output_scores_cast_node) + + # Create WhisperBeamSearch op + beam_search_attrs = [ + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id + ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) + if args.output_no_speech_probs + else "", + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", + ] + node = helper.make_node( + "WhisperBeamSearch", + inputs=clean_list(beam_inputs, remove_all_strings=False), + outputs=clean_list(beam_outputs, remove_all_strings=False), + name="BeamSearch", + domain="com.microsoft", ) - if args.use_whisper_beamsearch: - if args.collect_cross_qk: - node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) - if args.no_speech_token_id >= 0: - node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) + node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -143,73 +178,63 @@ def chain_model(args): num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] + ) + decoder_input_ids = helper.make_tensor_value_info( + "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - graph_inputs = [ - input_features, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - ] - if args.use_vocab_mask: - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) - graph_inputs.append(vocab_mask) - - if args.use_prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.use_forced_decoder_ids: - decoder_input_ids = helper.make_tensor_value_info( - "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] - ) - graph_inputs.append(decoder_input_ids) - - if args.use_logits_processor: - logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - graph_inputs.append(logits_processor) - - if args.collect_cross_qk: - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) - graph_inputs.append(cross_qk_layer_head) - - if args.extra_decoding_ids: - extra_decoding_ids = helper.make_tensor_value_info( - "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] - ) - graph_inputs.append(extra_decoding_ids) + graph_inputs = clean_list( + [ + input_features, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ] + ) - # graph outputs + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) - graph_outputs = [sequences] - if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): - cross_qk = helper.make_tensor_value_info( - "cross_qk", - TensorProto.FLOAT, - ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], - ) - graph_outputs.extend([cross_qk]) - - if args.output_no_speech_probs: - no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([no_speech_probs]) - - if args.output_sequence_scores: - sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([sequence_scores]) + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - if args.output_scores: - scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([scores]) + graph_outputs = clean_list( + [ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ] + ) + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") @@ -230,19 +255,7 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] - graph_nodes = ( - [ - input_features_cast_node, - len_pen_cast_node, - rep_pen_cast_node, - node, - output_sequence_scores_cast_node, - output_scores_cast_node, - ] - if args.precision == Precision.FLOAT16 - else [node] - ) - graph_nodes = [node for node in graph_nodes if node is not None] + graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", @@ -251,9 +264,16 @@ def chain_model(args): name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) - graph_nodes.extend([prob_cast_node]) - - beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + graph_nodes.append(prob_cast_node) + + # Make graph with WhisperBeamSearch op + beam_graph = helper.make_graph( + graph_nodes, + name="WhisperBeamSearch Graph", + inputs=graph_inputs, + outputs=graph_outputs, + initializer=initializers, + ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] @@ -287,10 +307,12 @@ def chain_model(args): ir_version=decoder_model.ir_version, ) + # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") os.remove(args.beam_model_output_dir) os.remove(args.beam_model_output_dir + ".data") + onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 0d69960a095ac..93fd64c9eb7d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -170,7 +170,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - past_decode_sequence_length, + encode_sequence_length, head_size, ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 351173f525727..832f692e9980d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -75,7 +75,7 @@ def create_dummy( config: WhisperConfig, batch_size: int, encode_sequence_length: int, - use_decoder_input_ids: int, + use_decoder_input_ids: bool, device: torch.device, use_int32_inputs: bool = False, ): # -> WhisperEncoderDecoderInitInputs: @@ -125,7 +125,7 @@ def export_onnx( model.config, batch_size=2, encode_sequence_length=3000, - use_decoder_input_ids=use_decoder_input_ids, + use_decoder_input_ids=True, device=device, use_int32_inputs=use_int32_inputs, ) @@ -159,7 +159,7 @@ def export_onnx( hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.encoder_attention_heads) dynamic_axes = { - "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_input_ids": {0: "batch_size", 1: "feature_size"}, "encoder_hidden_states": { 0: "batch_size", 1: "encode_sequence_length", diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e2dc79ca247ce..1b47b9426d983 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,12 +6,14 @@ import logging import os -import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -21,24 +23,20 @@ from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel -from optimizer import optimize_model - logger = logging.getLogger(__name__) PRETRAINED_WHISPER_MODELS = [ "whisper-tiny", "whisper-tiny.en", + "whisper-base", + "whisper-base.en", "whisper-small", "whisper-small.en", "whisper-medium", "whisper-medium.en", - "whisper-base", - "whisper-base.en", "whisper-large", "whisper-large-v2", + "whisper-large-v3", ] @@ -346,7 +344,12 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -383,43 +386,51 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] - ) + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + parity = ( + pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options + ) + max_diff = 0 - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if max_diff != 0: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index f3e67930adbff..66f24c47f6cdb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import torch +from torch._C._onnx import OperatorExportTypes TrainingMode = torch.onnx.TrainingMode from packaging.version import Version # noqa: E402 @@ -18,7 +19,7 @@ def torch_onnx_export( training=TrainingMode.EVAL, input_names=None, output_names=None, - operator_export_type=None, + operator_export_type=OperatorExportTypes.ONNX, opset_version=None, _retain_param_name=None, do_constant_folding=True, diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 40ea8cf774918..33ec1bd7728fe 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -381,22 +381,23 @@ def test_logits_processor(self): @pytest.mark.slow def test_cross_qk_overall(self): - decoder_input_ids = [ - "--chain_model", - "--collect_cross_qk", - "--output_cross_qk", - "--use_forced_decoder_ids", - "--extra_decoding_ids", - "--output_no_speech_probs", + cross_qk_input_args = [ "--use_vocab_mask", "--use_prefix_vocab_mask", + "--use_forced_decoder_ids", "--use_logits_processor", + "--collect_cross_qk", + "--extra_decoding_ids", ] - self.run_configs(decoder_input_ids) + cross_qk_output_args = [ + "--output_cross_qk", + "--output_no_speech_probs", + ] + self.run_configs(cross_qk_input_args + cross_qk_output_args) @pytest.mark.slow def test_openai_impl_whisper(self): - optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"] + optional_args = ["--model_impl", "openai"] self.run_configs(optional_args) diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 77ce09d7e793b..7892000ae45a0 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -50,7 +50,7 @@ def run_timestamp(self, provider: str): ort_out = sess.run(None, ort_inputs) ort_out_tensor = torch.from_numpy(ort_out[0]) ort_transcription = processor.batch_decode( - ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True ) print(ort_transcription) expected_transcription = [ @@ -58,7 +58,7 @@ def run_timestamp(self, provider: str): "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "timestamp": (0.0, 5.44), } ], From 4874a41008138ecc1f26e9cd17e5d9d7febb29aa Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 16 Feb 2024 16:59:43 -0800 Subject: [PATCH 008/237] [QNN EP] Update default QNN SDK to 2.19.2.240210 (#19546) ### Description Updates the default QNN SDK version to 2.19.2.240210. ### Motivation and Context Build and test the latest version of QNN SDK in our pipelines. --- .../android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 2b181810b0788..d37266a8e96d8 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 0312b70d2b1d5..8fa5bdbf90931 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index b0509467e1689..9a38513d04a79 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 13d4589a67cdc..dc861f7f1ed79 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 6246bb83566e5..534d5c6d6135b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' From 06269a3952fb1759d93235b9d66f9beb10ae8663 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:28:27 -0800 Subject: [PATCH 009/237] [js/webgpu] allow uint8 tensors for webgpu (#19545) ### Description allow uint8 tensors for webgpu --- js/common/lib/tensor-impl.ts | 2 +- js/common/lib/tensor.ts | 2 +- js/web/lib/wasm/wasm-common.ts | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..de18126a9d0ae 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..d5da33640dc7d 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..93910af1f1bf0 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -169,7 +169,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value From dfeda9019cfed2d6df5bcacc54269c7de481bdee Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Sat, 17 Feb 2024 09:19:17 -0800 Subject: [PATCH 010/237] [JS/WebGPU] Add MatMulNBits (#19446) ### Description Add MatMulNBits to support MatMul using 4-bit quantized weights ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + js/web/lib/wasm/jsep/util.ts | 28 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 184 ++ js/web/test/data/ops/matmulnbits.jsonc | 1527 +++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + .../contrib_ops/js/js_contrib_kernels.cc | 16 +- .../js/quantization/matmul_nbits.cc | 25 + .../js/quantization/matmul_nbits.h | 48 + 9 files changed, 1825 insertions(+), 7 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts create mode 100644 js/web/test/data/ops/matmulnbits.jsonc create mode 100644 onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/js/quantization/matmul_nbits.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index b21af8e715db3..4a8c92bb97bfd 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6e..c0517ce363644 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -92,6 +92,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ac08c5fb1f7ab..ba874c8dd0f80 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; @@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 0000000000000..ead7635cf3ac4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + const a = inputs[0]; + const b = inputs[1]; + const scales = inputs[2]; + const aRank = a.dims.length; + const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); + const outputSize = ShapeUtil.size(outputShape); + + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, + {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} + ]; + programUniforms.push(...createTensorShapeVariables(a.dims)); + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); + programUniforms.push(...createTensorShapeVariables(scales.dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); + const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} + ]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const wordPerBlob = blobSize / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ + var result = array<${dataType}, 8>(); + var offset: u32 = 0; + let count: u32 = 4; + for (var i: u32 = 0; i < 8u; i++) { + result[i] = ${dataType}(extractBits(value, offset, count)); + offset += count; + } + return result; + } + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var value: ${dataType} = 0.0; + let output_indices = ${output.offsetToIndices('global_idx')}; + var a_indices: ${a.type.indices} = output_indices; + var n = ${output.indicesGet('output_indices', aRank - 1)}; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : + ''} + var scale_idex = n * ${nBlocksPerCol}; + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'n')}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_idex')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point: ${dataType} = ${ + zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${wordPerBlob}; word++) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_value = ${b.getByIndices('b_indices')}; + let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var i: u32 = 0; i < 8; i++) { + ${a.indicesSet('a_indices', aRank - 1, 'offset')}; + let a_value = ${a.getByIndices('a_indices')}; + let b_quantized_value = b_quantized_values[i]; + let b_dequantized_value = (b_quantized_value - zero_point) * scale; + value += a_value * b_dequantized_value; + offset++; + } + word_offset += 8; + } + scale_idex++; + ${ + zeroPoints ? ` + if (zero_point_offset == 28) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + } else { + zero_point_offset += 4; + }` : + ''} + block_offset += uniforms.block_size; + } + ${output.setByOffset('global_idx', 'value')}; + } + `; + }; + return { + name: 'MatMulNBits', + shaderCache: + {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 0000000000000..c57c431afb3ce --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,1527 @@ +[ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, + 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, + 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, + 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, + 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, + 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, + 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, + 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, + 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, + 360008, 226848, 451256, 292432, 550440 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 55b21283025c2..1c61518ddcdd2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1362,6 +1362,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index bd58dded026a6..25e7567a2e9fc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,13 +8,14 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + SkipLayerNormalization)>}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..888db0fd161f2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..cca2c4757765b --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime From b55260d076da309f3a4634eb5248a0eb541e8ca0 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 19 Feb 2024 10:21:19 +0800 Subject: [PATCH 011/237] Minor fix for cmake (#19552) ### Minor fix for cmake When build on Linux, get a warning saying " CMake Warning at CMakeLists.txt:1603 (message): MPI and NCCL disabled on Win build. " This message is not correct. So have such a fix to avoid any misunderstanding from users. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/848c2d77-a538-4e31-8e0d-4b539233e515) ### Motivation and Context --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ff1c7a84f077f..c9be4aa65d0cc 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1600,7 +1600,7 @@ if (UNIX AND onnxruntime_USE_NCCL) else() set(onnxruntime_USE_NCCL OFF) set(onnxruntime_USE_MPI OFF) -message( WARNING "MPI and NCCL disabled on Win build." ) + message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." ) endif() if (onnxruntime_USE_MPI) From f3e3b531fe4c0d33d70928b101fb5d445e4174a8 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:31:39 +0800 Subject: [PATCH 012/237] Update build directory clean up stage for python package pipeline (#19553) Fix to make clean up stage take effect. If the `SourceFolder ` is empty, the task deletes files from the root folder of the repository as though [$(Build.SourcesDirectory)](https://learn.microsoft.com/en-us/azure/devops/pipelines/build/variables) was specified. --- .../component-governance-component-detection-steps.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index c2ef565a6e9ee..f1418e75bffa2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -5,10 +5,12 @@ parameters: default: 'succeeded' # could be 'ci_only', 'always', 'succeeded' steps: -- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: +- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - task: DeleteFiles@1 inputs: - contents: $(Build.BinariesDirectory)/* + SourceFolder: '$(Build.BinariesDirectory)' + contents: | + **/* displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 From e832562d70685ffeaab7e3bfa20cd5e9aec916a3 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 20 Feb 2024 09:06:03 +0100 Subject: [PATCH 013/237] Fix invalid usage of designated initializers. (#19497) ### Description I've replaces all ocurances of C++ designated initializers in the CUDA NHWC Tests by member initialization. ### Motivation and Context C++ designated initializers have been introduced in C++ 20. Yet GCC accepts designated initializers in C++17 which is the standard used to compile onnxruntime. Yet MSVC is standard conform and accepts this feature starting C++20 which leads to compile failures on Windows without this change. --- .../test/providers/cuda/nhwc/conv_test.cc | 23 +++++++--- .../cuda/nhwc/conv_transpose_test.cc | 40 +++++++++------- .../providers/cuda/nhwc/nhwc_cuda_helper.h | 6 ++- .../test/providers/cuda/nhwc/norm_test.cc | 7 ++- .../test/providers/cuda/nhwc/pool_test.cc | 46 ++++++++++--------- 5 files changed, 72 insertions(+), 50 deletions(-) diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc index 13d4546d669e3..b6a760f7041ad 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -52,20 +52,31 @@ struct ConvOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvNhwcBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .bias = true}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.bias = true; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcGroupNoBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcPadding) { - auto op = - ConvOp{.input_dims = {2, 4, 64, 64}, .kernel_shape = {3, 3}, .channels = 4, .padding = {4, 4, 4, 4}}; + auto op = ConvOp{}; + op.input_dims = {2, 4, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 4; + op.padding = {4, 4, 4, 4}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 6514feadf0ff7..786b2cb4cedc4 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvTransposeOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -60,15 +60,21 @@ struct ConvTransposeOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcGroupNoBias) { - auto op = - ConvTransposeOp{.input_dims = {8, 8, 32, 32}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvTransposeOp{}; + op.input_dims = {8, 8, 32, 32}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { - auto op = - ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 8, 80, 80}; + op.kernel_shape = {5, 5}; + op.channels = 16; + op.bias = true; if (HasCudaEnvironment(800)) { MAKE_PROVIDERS_EPS(1e-2) @@ -78,21 +84,23 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { - auto op = ConvTransposeOp{.input_dims = {1, 16, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .padding = {2, 2, 2, 2}, - .output_padding = {}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 16, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.padding = {2, 2, 2, 2}; + op.output_padding = {}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcOutPad) { - auto op = ConvTransposeOp{.input_dims = {1, 32, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .strides = {2, 2}, - .output_padding = {1, 1, 1, 1}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 32, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.strides = {2, 2}; + op.output_padding = {1, 1, 1, 1}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 2c942bb790096..82b6a286409cd 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -16,11 +16,13 @@ #define MAKE_PROVIDERS_EPS(eps) \ std::vector> execution_providers; \ - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; \ + OrtCUDAProviderOptionsV2 nhwc{}; \ + nhwc.prefer_nhwc = true; \ execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \ \ double error_tolerance = eps; \ - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; \ + OrtCUDAProviderOptionsV2 nchw{}; \ + nchw.prefer_nhwc = false; \ auto source_ep = CudaExecutionProviderWithOptions(&nchw); \ auto test = op.get_test(); \ test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); diff --git a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc index 52da8ba557c2d..40f69e3bd5b4f 100644 --- a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc @@ -9,7 +9,7 @@ namespace test { template struct BatchNormOp { - const std::vector input_dims; + std::vector input_dims; std::unique_ptr get_test() { // create rand inputs @@ -40,9 +40,8 @@ struct BatchNormOp { }; TYPED_TEST(CudaNhwcTypedTest, BatchNormNhwc) { - auto op = BatchNormOp{ - .input_dims = {4, 16, 64, 64}, - }; + auto op = BatchNormOp{}; + op.input_dims = {4, 16, 64, 64}; MAKE_PROVIDERS() } diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index e0d59901da80c..426170b9588f1 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -9,9 +9,9 @@ namespace test { template struct PoolOp { - const std::string pooling_type; - const std::vector input_dims; - const std::vector kernel_shape; + std::string pooling_type; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; std::vector strides = {1, 1}; @@ -41,22 +41,21 @@ struct PoolOp { }; TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwc) { - auto op = PoolOp{ - .pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + MAKE_PROVIDERS() } TYPED_TEST(CudaNhwcTypedTest, MaxPoolNhwc) { - auto op = PoolOp{ - .pooling_type = "MaxPool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "MaxPool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; MAKE_PROVIDERS() } @@ -72,21 +71,24 @@ TYPED_TEST(CudaNhwcTypedTest, GlobalMaxPoolNhwc) { test->AddOutput("Y", output_dims, output_data); std::vector> execution_providers; - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; + OrtCUDAProviderOptionsV2 nhwc{}; + nhwc.prefer_nhwc = true; execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); double error_tolerance = 1e-3; - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; + OrtCUDAProviderOptionsV2 nchw{}; + nchw.prefer_nhwc = false; auto source_ep = CudaExecutionProviderWithOptions(&nchw); test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); } TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwcPad) { - auto op = PoolOp{.pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - .padding = {2, 2, 2, 2}}; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.padding = {2, 2, 2, 2}; MAKE_PROVIDERS() } From 7efb0dbe12cf8736d97dcc3b8f41eb96c5c34719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 20 Feb 2024 17:22:44 +0100 Subject: [PATCH 014/237] add option DefaultTensorType to specify the default tensor type to quantize (#19455) ### Description The current quantization tool relies on shape inference to provide the type of every intermediate tensor, then the tool knows which type it must dequantize into (float32, float16). However, this information is not available if shape inference fails. That happens every time the model include an operator from a custom domain such as com.microsoft. This PR introduces an extra option `DefaultTensorType` as a fall back when the quantizer cannot find the type it needs. ### Motivation and Context This fixes issue #19409. --- .../tools/quantization/onnx_quantizer.py | 25 ++++- .../tools/transformers/quantize_helper.py | 3 +- .../test_quantizer_shape_inference.py | 92 +++++++++++++++++++ 3 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_quantizer_shape_inference.py diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index ecfbaa569ca0a..9450426f12444 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -385,7 +385,7 @@ def add_new_nodes(self, nodes): def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( - "Please check if the model is already quantized." + "Please check if the model is already quantized. " "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." ) @@ -442,6 +442,23 @@ def is_valid_quantize_weight(self, weight_name): return False return self.parent.is_valid_quantize_weight(weight_name) + def _get_default_tensor_type(self, tensor_name): + if "DefaultTensorType" in self.extra_options: + logging.info( + "get_tensor_type returns DefaultTensorType for tensor name %r, use %d", + tensor_name, + self.extra_options["DefaultTensorType"], + ) + return self.extra_options["DefaultTensorType"] + raise RuntimeError( + f"Unable to find data type for weight_name={tensor_name!r}. " + f"shape_inference failed to return a type probably this node is " + f"from a different domain or using an input produced by such an operator. " + f"This may happen if you quantize a model already quantized. " + f"You may use extra_options `DefaultTensorType` to indicate " + f"the default weight type, usually `onnx.TensorProto.FLOAT`." + ) + def get_tensor_type(self, tensor_name, mandatory=False): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -450,11 +467,11 @@ def get_tensor_type(self, tensor_name, mandatory=False): vi = self.value_infos[tensor_name] if vi.type.HasField("tensor_type"): if mandatory and vi.type.tensor_type.elem_type == 0: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return vi.type.tensor_type.elem_type if (not self.enable_subgraph_quantization) or (self.parent is None): if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None otype = self.parent.is_valid_quantize_weight(tensor_name) if otype is not None: @@ -464,7 +481,7 @@ def get_tensor_type(self, tensor_name, mandatory=False): if res is not None: return res if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None def is_float_tensor(self, tensor_name): diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index a449e881ad361..6a25196dbc24c 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -7,7 +7,7 @@ import logging import os -import onnx # noqa: F401 +import onnx import torch from transformers.modeling_utils import Conv1D @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data onnx_model_path, quantized_model_path, use_external_data_format=use_external_data_format, + extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) logger.info(f"quantized model saved to:{quantized_model_path}") # TODO: inlcude external data in total model size. diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py new file mode 100644 index 0000000000000..2b5d1f36070e5 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh + +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType + + +class TestQuantizerShapeInference(unittest.TestCase): + def test_com_microsoft(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("MatMul", ["X", "W1"], ["T1"]), + oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"), + oh.make_node("MatMul", ["T2", "W3"], ["T3"]), + oh.make_node("MatMul", ["T3", "W4"], ["Y"]), + ], + "name", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])], + [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])], + [ + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"), + ], + ), + opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], + ) + model_shaped = onnx.shape_inference.infer_shapes(model) + shaped_results = set(t.name for t in model_shaped.graph.value_info) + # every result after T1 depends on T2 coming from a node com.microsoft, + # shape_inference cannot go beyond this point + self.assertEqual(shaped_results, {"T1"}) + + # first try: checks it raises an exception + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + {"MatMulConstBOnly": True}, # extra_options, + # {'DefaultTensorType': 1, } + ) + + with self.assertRaises(RuntimeError) as e: + quantizer.quantize_model() + self.assertIn("Unable to find data type for weight_name=", str(e)) + + # second try: checks it works + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + { + "MatMulConstBOnly": True, + "DefaultTensorType": 1, + }, + ) + + model = quantizer.quantize_model() + ops = {n.op_type for n in model.graph.node} + self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 1b48054e1b7991ccef664fbedd659ec95d0e7ca7 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Wed, 21 Feb 2024 01:24:34 +0800 Subject: [PATCH 015/237] [js/webgpu] Create Split indices helpers by rank, not by shape (#19554) ### Description This is required to make shape uniforms really work. ### Motivation and Context The bug was unveiled in a model with multiple Split nodes. The later nodes would try to reuse a previous pipeline cache, while the old shapes were hardcoded as constants in cache. --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 14d6f37927590..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; @@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } programUniforms.push( From 3c49aacd5667b320a4e02626a176098f7423d7c0 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Tue, 20 Feb 2024 13:13:40 -0800 Subject: [PATCH 016/237] Disable __cpuid check on arm64 builds as intrinsic is not available (#19574) Disable __cpuid check on arm64 builds as intrinsic is not available Motivation Breaking the arm64 build. Co-authored-by: Sheil Kumar --- winml/lib/Api/HardwareCoreEnumerator.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index fa069c7fb66a7..b6b44690f4f6c 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -84,6 +84,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); +#if !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" int regs_leaf0[4]; int regs_leaf7[4]; @@ -100,6 +101,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores return cores.PhysicalCores - cores.Num2CacheCores; } +#endif return cores.PhysicalCores; } From ec9c8cbdc9686ccda6553674d6aab61cfd245cf0 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 21 Feb 2024 07:40:35 +1000 Subject: [PATCH 017/237] Use xcode parallel build flags to speed up iOS CI that is timing out (#19570) ### Description Provide specific xcodebuild flags instead of depending on cmake to do the right thing. This built in just over an hour with a ccache miss. Previous CIs with a ccache miss were timing out after 150 minutes. ### Motivation and Context --- tools/ci_build/build.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 244bebd81474d..5b715bb29e5a1 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1631,9 +1631,11 @@ def generate_build_tree( [ *temp_cmake_args, f"-DCMAKE_BUILD_TYPE={config}", - f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" - if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) - else "", + ( + f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" + if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) + else "" + ), ], cwd=config_build_dir, cuda_home=cuda_home, @@ -1667,8 +1669,11 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": - # CMake will generate correct build tool args for Xcode - cmd_args += ["--parallel", str(num_parallel_jobs)] + build_tool_args += [ + "-parallelizeTargets", + "-jobs", + str(num_parallel_jobs), + ] else: build_tool_args += [f"-j{num_parallel_jobs}"] From 7a5860e4909387448cb51351d3af50933238ba10 Mon Sep 17 00:00:00 2001 From: Jake Mathern Date: Tue, 20 Feb 2024 13:41:40 -0800 Subject: [PATCH 018/237] Fix cmake function duplicate lib (#19547) ### Description Fixes cmake function definition in winml.cmake to copy link flags. ### Motivation and Context XFGCheck errors in WindowsAI because this function does not transfer linker flags --- cmake/winml.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 268ee3960e75a..57cecd3e66adb 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -827,6 +827,7 @@ if (winml_is_inbox) get_target_property(compile_options ${target} COMPILE_OPTIONS) get_target_property(include_directories ${target} INCLUDE_DIRECTORIES) get_target_property(link_libraries ${target} LINK_LIBRARIES) + get_target_property(link_flags ${target} LINK_FLAGS) get_target_property(link_options ${target} LINK_OPTIONS) add_library(${new_target} SHARED ${sources}) @@ -835,6 +836,7 @@ if (winml_is_inbox) target_compile_options(${new_target} PRIVATE ${compile_options}) target_include_directories(${new_target} PRIVATE ${include_directories}) target_link_libraries(${new_target} PRIVATE ${link_libraries}) + set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}") target_link_options(${new_target} PRIVATE ${link_options}) endfunction() From 97ff17c2cbb6ee6f27c052e9c4302c70a41af485 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:02:11 -0800 Subject: [PATCH 019/237] update script of run CI for external PRs to add "Big Models" (#19576) ### Description update script of run CI for external PRs to add "Big Models" --- tools/python/run_CIs_for_external_pr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 7a77839c4a4e7..df4e70b1e51fe 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,8 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # big models + "Big Models", # not currently required, but running ensures we're hitting all mobile platforms "Android CI Pipeline", "iOS CI Pipeline", From 3fe2c137ee5923ee369062453d528fe0e33bf4bc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:23:01 -0800 Subject: [PATCH 020/237] [js] small fix to workaround formatter (#19400) ### Description Rename shader variable names to snake_case naming and also to avoid formatter behaving inconsistently in win/linux. --- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3f73d9cb7c5bc..d5f97213e49ce 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -85,28 +85,28 @@ const createLayerNormProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} let offset = global_idx * uniforms.norm_size_vectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let invStdDev = - inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${ + sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; From 70567a4b3a8bc74fb0f1a9ed9ea5a5be6b99b378 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:33:21 -0800 Subject: [PATCH 021/237] [js/web] use ApiTensor insteadof onnxjs Tensor in TensorResultValidator (#19358) ### Description use ApiTensor insteadof onnxjs Tensor in TensorResultValidator. Make test runner less depend on onnxjs classes. --- js/web/test/test-runner.ts | 26 +++++++------------ .../unittests/backends/webgl/test-conv-new.ts | 4 ++- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index b01d474788f25..ecc7d4b4a09a5 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; */ const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; -function toInternalTensor(tensor: ort.Tensor): Tensor { - return new Tensor( - tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType); -} function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } @@ -330,6 +326,10 @@ export class TensorResultValidator { } checkTensorResult(actual: Tensor[], expected: Tensor[]): void { + this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor)); + } + + checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { // check output size expect(actual.length, 'size of output tensors').to.equal(expected.length); @@ -347,10 +347,6 @@ export class TensorResultValidator { } } - checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { - this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor)); - } - checkNamedTensorResult(actual: Record, expected: Test.NamedTensor[]): void { // check output size expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length); @@ -364,7 +360,7 @@ export class TensorResultValidator { } // This function check whether 2 tensors should be considered as 'match' or not - areEqual(actual: Tensor, expected: Tensor): boolean { + areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean { if (!actual || !expected) { return false; } @@ -392,13 +388,13 @@ export class TensorResultValidator { switch (actualType) { case 'string': - return this.strictEqual(actual.stringData, expected.stringData); + return this.strictEqual(actual.data, expected.data); case 'float32': case 'float64': return this.floatEqual( - actual.numberData as number[] | Float32Array | Float64Array, - expected.numberData as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array); case 'uint8': case 'int8': @@ -409,10 +405,8 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array, - expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); default: throw new Error('type not implemented or not supported'); diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 8c186b9b36451..014fc57f21558 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -893,7 +893,9 @@ describe('New Conv tests', () => { const expected = cpuConv( inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, testData.strides); - if (!validator.areEqual(actual, expected)) { + try { + validator.checkTensorResult([actual], [expected]); + } catch { console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`); console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`); throw new Error('Expected and Actual did not match'); From 6e04e36e3faf2d8115c0962c85b86a6a8b48ac5b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:33:37 -0800 Subject: [PATCH 022/237] [js/common] upgrade tsc in common from 4.9.5 to 5.2.2 (#19317) ### Description upgrade tsc in common from 4.9.5 to 5.2.2 --- js/common/package-lock.json | 106 +++++++++++++++++------------------ js/common/package.json | 4 +- js/common/test/tsconfig.json | 2 +- 3 files changed, 56 insertions(+), 56 deletions(-) diff --git a/js/common/package-lock.json b/js/common/package-lock.json index a5ada877b916a..3988ac80707e0 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -9,13 +9,13 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index 64ab2736adbe3..cd2612aab4984 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -9,7 +9,7 @@ }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b325..e9068ad837a81 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } From 45e20bf7810689ecf385957c34434c6d2456e32b Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 21 Feb 2024 12:38:37 +1000 Subject: [PATCH 023/237] Use build.py to build in py-win-gpu.yml so parallelization parameters are set (#19578) ### Description build.py sets a few parallelization parameters when building. Using msbuild directly lacks those. https://github.com/microsoft/onnxruntime/blob/7a5860e4909387448cb51351d3af50933238ba10/tools/ci_build/build.py#L1665-L1669 Changed to use build.py. If there's a concern with that we _could_ set the parameters in the yaml, but that will be uglier due to duplicating logic in multiple places. ### Motivation and Context --- .../azure-pipelines/templates/py-win-gpu.yml | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 18368e59cad52..4315eae503ebd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -120,17 +120,17 @@ jobs: $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 + # building with build.py so the parallelization parameters are added to the msbuild command + - task: PythonScript@0 displayName: 'Build' inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: x64 - configuration: RelWithDebInfo - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --parallel --build + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - template: win-esrp-dll.yml @@ -188,7 +188,7 @@ jobs: condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) inputs: GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - template: component-governance-component-detection-steps.yml parameters: From 0c4421cb7867434e1e08b4274f16f6c2f14cb4ce Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 21 Feb 2024 03:39:43 +0100 Subject: [PATCH 024/237] Fix compile warnings (as errors) for functions which miss returning required return value (#19079) Added dummy return values to functions which specify a return value, but do not return an value value. ### Motivation and Context Fix compiler errors with 'warnings as errors' enabled. From 8fadc6c913bc30edff2e89756da515b9bd75d256 Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:41:42 +0800 Subject: [PATCH 025/237] Zhijxu/cleanup cached tensors when oom (#19306) in pytorch, when oom happens at bp, user could decrease the batch size and rerun it without restarting the process. while in ORT, the intermediate tensors are kept even OOM, so decrease batch size still fail. this is torch run, we can see after oom failure, torch will release tensor before next step ![image](https://github.com/microsoft/onnxruntime/assets/43435212/92b8a2e3-454b-448a-a223-17cb91d463c2) this is from ort, we can see ort not release its tensors after OOM failure. ![image](https://github.com/microsoft/onnxruntime/assets/43435212/bb6a3882-8e14-4f37-8079-e7f70fc2546b) ort with the PR, we can see memory is released, **the 4GB memory is not own by ort, and will be released by torch at the end**. ![image](https://github.com/microsoft/onnxruntime/assets/43435212/7f39d711-4e36-47d5-aecf-3805433a6d01) --- onnxruntime/core/framework/execution_frame.cc | 21 +++++++++++++++ onnxruntime/core/framework/execution_frame.h | 2 ++ .../training/ortmodule/_training_manager.py | 26 ++++++++++--------- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8c08152986cf6..32a5f749af084 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684faa..18d210ffd48f7 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cc533e549db92..73c32a2f51e41 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -196,18 +196,20 @@ def backward(ctx, *grad_outputs): # Run and get results backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state - - # Fast version: all backward_outputs are converted first. - # This version only works if backward_outputs is an OrtValueVector. - transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - - self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) - - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + try: + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. + + # Fast version: all backward_outputs are converted first. + # This version only works if backward_outputs is an OrtValueVector. + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + return res + finally: + del ctx.run_info.state return _ORTModuleFunction From 6226c5f62f3d16b9702d5c40993ee9bf1cbd119c Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:08:48 +0800 Subject: [PATCH 026/237] [ROCm] Add SkipGroupNorm for ROCm EP (#19303) Add SkipGroupNorm for ROCm EP. --------- Co-authored-by: Peixuan Zuo --- cmake/onnxruntime_rocm_hipify.cmake | 5 - .../contrib_ops/rocm/diffusion/group_norm.cc | 152 ------------- .../rocm/diffusion/group_norm_ck.cuh | 35 +-- .../diffusion/group_norm_ck_impl/impl.cuh | 10 +- .../diffusion/group_norm_ck_impl/impl_fp16.cu | 8 +- .../diffusion/group_norm_ck_impl/impl_fp32.cu | 8 +- .../rocm/diffusion/group_norm_common.h | 125 +++------- .../rocm/diffusion/group_norm_impl.cu | 47 ++-- .../rocm/diffusion/group_norm_impl.h | 47 ---- .../rocm/diffusion/group_norm_impl_kernel.cuh | 213 ------------------ .../rocm/diffusion/group_norm_triton.cuh | 29 +-- .../rocm/diffusion/group_norm_triton.py | 16 +- .../rocm/diffusion/group_norm_tunable_op.h | 153 +++++++------ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 2 + .../kernel_explorer/kernels/groupnorm_test.py | 136 ++++++++--- .../kernels/rocm/group_norm.cu | 112 +++++---- .../contrib_ops/skip_group_norm_op_test.cc | 14 +- tools/ci_build/amd_hipify.py | 2 + 18 files changed, 382 insertions(+), 732 deletions(-) delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d485abe6bb1a6..85a9bf50460d3 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -44,12 +44,7 @@ set(contrib_ops_excluded_files "bert/packed_multihead_attention.cc" "bert/packed_multihead_attention_impl.h" "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" - "diffusion/group_norm_impl.h" - "diffusion/group_norm_impl_kernel.cuh" - "diffusion/group_norm_common_base.h" - "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc deleted file mode 100644 index e82e15a304f4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define GROUP_NORM_TYPES float, MLFloat16 - -ONNX_OPERATOR_KERNEL_EX( - GroupNorm, kMSDomain, 1, kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); - -using namespace ONNX_NAMESPACE; - -namespace { -template -struct DispatchGroupNorm { - Status operator()(RocmTuningContext* tuning_ctx, - Stream* stream, - Tensor* output, - const Tensor* input, - const Tensor* gamma, - const Tensor* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_swish_activation) { - typedef typename ToHipType::MappedType HipT; - return LaunchGroupNormKernel( - tuning_ctx, - stream, - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - workspace, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_swish_activation); - } -}; - -} // namespace - -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { - epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(epsilon_ >= 0); - - int64_t num_groups; - ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); - ORT_ENFORCE(num_groups >= 0); - num_groups_ = static_cast(num_groups); - - int64_t activation; - ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish - use_swish_activation_ = (activation == 1); - - channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); -} - -Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - return Status::OK(); -} - -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* gamma = context->Input(1); - const Tensor* beta = context->Input(2); - Tensor* output = context->Output(0, input->Shape()); - - if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "only the channels_last layout is supported"); - } - - const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 4 dimensions, got ", input_dims.size()); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in gamma and input does not match"); - } - - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); - } - - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisible by num_groups"); - } - - if (context->GetUseDeterministicCompute()) { - static std::once_flag log_warning; - std::call_once(log_warning, []() { - LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic."; - }); - } - - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - - utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(GetTuningContext(), context->GetComputeStream(), - output, input, gamma, beta, workspace.get(), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index fb7091592c16e..d0a0d09fcbae3 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -26,13 +26,18 @@ namespace rocm { using onnxruntime::rocm::CKDataTypeAdaptor; -using Swish = ck::tensor_operation::element_wise::Swish; +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 5; constexpr int NumReduceDim = 3; -template +template auto GetCKGroupNormNHWCTypeStringAndOps() { using XDataType = typename CKDataTypeAdaptor::type; using YDataType = typename CKDataTypeAdaptor::type; @@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { using GammaDataType = float; using BetaDataType = float; - using Activation = std::conditional_t; + using Activation = std::conditional_t; - std::vector>>> ret; + std::vector>>> ret; for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { - if constexpr (WithSwish) { + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->withSwish, "Swish version only support groupnorm with swish"); + !params->use_silu, "Silu version only support groupnorm with silu"); } else { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->withSwish, "Pass version only support groupnorm without swish"); + params->use_silu, "Pass version only support groupnorm without silu"); } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; auto activation = Activation{}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 19b081881dcec..4cb371fdcf960 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -18,7 +18,7 @@ namespace internal { using F16 = ck::half_t; using F32 = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() { template <> std::vector>> + F16, F32, F32, F16, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Silu, 5, 3>(); template <> std::vector std::vector>> + F32, F32, F32, F32, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Swish, 5, 3>(); + F32, F32, F32, F32, F32, Silu, 5, 3>(); template <> std::vector -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f16_instances{}); + device_normalization_f16_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 9b0ccab17b4c1..ceb53ed442abc 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -11,12 +11,12 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f32_instances{}); + device_normalization_f32_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 008ae20b0561f..7cff640db2f34 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -8,110 +8,47 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::CeilDiv; - -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; - } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; - } - } - } - return maxDivisor; -} - template -struct GroupNormNHWCParams : OpParams { - GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma, - const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish) - : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) { - int32_t maxBlocksPerHW = 1024; - switch (c) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; - } - - hw = h * w; - const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW); - hwPerBlock = CeilDiv(hw, blocksPerHW); - cPerGroup = c / groups; - hwc = hw * c; - invHWC = 1.F / (float)(hw * cPerGroup); - groupsPerBlock = cPerBlock / cPerGroup; - } +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} std::string Signature() const override { - std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; return sig; } - - // The output buffer. Layout NHWC. - T* dst; - // The input buffer. Layout NHWC. - T const* src; - // The gamma scaling factor. - float const* gamma; - // The beta term to add in GN. - float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - float epsilon; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h; - int32_t w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index dbd5009e63676..142aaf14e8d2d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -15,9 +15,12 @@ namespace rocm { template Status LaunchGroupNormKernel( RocmTuningContext* tuning_ctx, - Stream* stream, + Stream* ort_stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -27,19 +30,26 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, - batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); if (tuning_ctx->IsTunableOpEnabled()) { static GroupNormNHWCTunableOp op; @@ -50,14 +60,17 @@ Status LaunchGroupNormKernel( } template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h deleted file mode 100644 index a0f7e0aca5def..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::RocmTuningContext; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; -} - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh deleted file mode 100644 index d6322a12a9363..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCm kernel is modified from TensorRT 8.5. -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sumSq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = static_cast(input_v.val[i]); - sum += val; - sumSq += val * val; - } -} - -template -__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw, - int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) { - // The object in charge of doing the sums for the different blocks. - typedef hipcub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[ThreadsPerBlock]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (ILP channels per thread). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - if (ci < c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * hwc + static_cast(hwi) * c + ci; - UpdateSum(src, offset, sum, sumSq); - } - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * ILP / cPerGroup; - int32_t cj = threadIdx.x * ILP - cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == cPerGroup - ILP) { // ILP channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= groupsPerBlock || gj >= groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x); - atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y); -} - -template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev, - const U* gamma_v, const U* beta_v, bool swish) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - VecT output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - U val = static_cast(input_v.val[i]); - val = (val - mean) * invStdDev; - val = gamma_v[i] * val + beta_v[i]; - - if (swish) { - val = val * sigmoid(val); - } - output_v.val[i] = static_cast(val); - } - *(reinterpret_cast(dst + offset)) = output_v; -} - -template -__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock, - int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) { - // The channel loaded by that thread (ILP channels per thread for F16x2). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - if (ci >= c) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < groups) { - sum = redBuffer[(2 * ni + 0) * groups + gi]; - sumSq = redBuffer[(2 * ni + 1) * groups + gi]; - } - - using VecF = onnxruntime::rocm::aligned_vector; - - const VecF gamma_v = *reinterpret_cast(gamma + ci); - const VecF beta_v = *reinterpret_cast(beta + ci); - - // Compute the mean. - float mean = sum * invHWC; - // Compute the variance. - float var = sumSq * invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * hwc + hwi * c + ci; - - // Fetch ILP channels per thread. - computeGroupNorm(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b7b9441ac997d..b3d3e92209b39 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -20,21 +20,21 @@ namespace rocm { namespace { -template +template std::string GetGroupNormTritonGroupName() { std::string ret = "GroupNormTriton_"; - std::string swish_suffix = WithSwish ? "Swish_" : "Pass_"; - ret += swish_suffix; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; ret += GetDataTypeName(); return ret; } } // namespace -template +template auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); auto* kernel_list = GetOrtTritonKernelByGroup(group_name); if (kernel_list == nullptr) { return ret; @@ -45,16 +45,19 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto* metadata = GetOrtTritonKernelMetadata(i); auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCParams* params) -> Status { + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSwish) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); } // Construct args for launch kernel struct { @@ -73,7 +76,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { (const void*)params->beta, params->hw, params->c, - params->cPerGroup, + params->channels_per_group, params->epsilon}; // Grid dim is (batch_count, groups, 1) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 56b3a030b289e..5368cb1cf635b 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -21,7 +21,7 @@ def group_norm_kernel( eps, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, - ACTIVATION_SWISH: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, ): row_x = tl.program_id(0) row_y = tl.program_id(1) @@ -62,7 +62,7 @@ def group_norm_kernel( x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta - if ACTIVATION_SWISH: + if ACTIVATION_SILU: y *= tl.sigmoid(y) tl.store(y_ptr + offsets, y, mask=mask) @@ -71,7 +71,7 @@ def group_norm_kernel( # blocks = [16, 32, 64, 128, 256, 512] # hw_sizes = [8, 16, 32, 64, 128, 256, 512] # but this will result in too many functions and slow down the compilation. -with_swish = [True, False] +with_silu = [True, False] dtypes = ["fp32", "fp16"] blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] @@ -84,14 +84,14 @@ def group_norm_kernel( def get_function_table(): func_table = [] - for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks): - swish_suffix = "Swish" if swish else "Pass" - name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(swish_suffix, dtype) + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) sig = sig_pattern.format(dtype, dtype) kwargs = { "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)}, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, } func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} func_table.append(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 25d820f7ed326..e6831f764b418 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -20,115 +20,117 @@ namespace rocm { using onnxruntime::rocm::GPU_WARP_SIZE; template -void groupNormNHWCSum(const GroupNormNHWCParams* params) { - // Make sure the values are as we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - groupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->src, params->redBuffer, params->cPerBlock, \ - params->hwPerBlock, params->hw, params->hwc, params->c, \ - params->cPerGroup, params->groups, params->groupsPerBlock); \ +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SUM(256, 2) - case 480: - LAUNCH_GROUPNORM_SUM(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SUM(128, 2) + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SUM(64, 2) + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCSumOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCSumKernel + GroupNormNHWCSumKernel <<StreamHandle()>>>( - params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock, - params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock); + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); return HIP_CALL(hipGetLastError()); } template -void groupNormNHWCScale(const GroupNormNHWCParams* params) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - groupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->gamma, params->beta, \ - params->redBuffer, params->epsilon, params->c, params->cPerBlock, \ - params->cPerGroup, params->groups, params->hwc, params->invHWC, \ - params->hw, params->hwPerBlock, params->withSwish); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SCALE(256, 2) - case 480: - LAUNCH_GROUPNORM_SCALE(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SCALE(128, 2) + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SCALE(64, 2) + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCScaleOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCScaleKernel + GroupNormNHWCScaleKernel <<StreamHandle()>>>( - params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock, - params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish); + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); return HIP_CALL(hipGetLastError()); } template class GroupNormNHWCOp { public: - Status operator()(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -138,29 +140,30 @@ class GroupNormNHWCOp { return Status::OK(); } - Status IsSupported(const GroupNormNHWCParams* params) { + Status IsSupported(const GroupNormNHWCTunableParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), - "The value of attributes don't meet the requirements."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); return Status::OK(); } }; template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); - groupNormNHWCSum(params); +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); - groupNormNHWCScale(params); + GroupNormNHWCScale(params); HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); } @@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template -class GroupNormNHWCTunableOp : public TunableOp> { +class GroupNormNHWCTunableOp : public TunableOp> { public: GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 55cd6a1d112f5..382a3951f3a83 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index e32cb032798fc..8334d20e47c86 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -35,7 +35,11 @@ def sigmoid_function(x): return 1.0 / (1.0 + np.exp(-x)) -def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): +def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip): + add_output = None + if has_skip: + input_x = input_x + skip_x + bias_x + add_output = input_x n, h, w, c = input_x.shape input_x = input_x.transpose([0, 3, 1, 2]) assert c % num_groups == 0 @@ -45,46 +49,70 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): x = x.transpose([0, 2, 3, 1]) x = x * gamma + beta - if with_swish: + if with_silu: x = x * sigmoid_function(x) - return x + return x, add_output -def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func): +def run_group_norm( + batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func +): np.random.seed(0) width = height input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) # the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18 - workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32) + workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32) epsilon = 1e-05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish - host_x = input_x.astype(dtype) - input_d = ke.DeviceArray(host_x) + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(np.float32) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization + + input_d = ke.DeviceArray(input_x.astype(dtype)) + skip_d = ke.DeviceArray(skip_x.astype(dtype)) + bias_d = ke.DeviceArray(bias_x.astype(dtype)) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) - y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype) + y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip) + y_ref = y_ref.astype(dtype) for impl in my_op.ListOps(): if not my_op.SelectOp(impl): @@ -95,6 +123,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: y_d.UpdateHostNumpyArray() np.testing.assert_allclose(y_ref, output_y, atol=1e-02) + if has_skip: + y_add_d_ref = y_add_d_ref.astype(dtype) + y_add_d.UpdateHostNumpyArray() + np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02) dtypes = ["float32", "float16"] @@ -102,19 +134,21 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm(sd_sizes, dtype, swish): +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [True, False]) +def test_group_norm(sd_sizes, dtype, silu, has_skip): for func in dtype_to_funcs(dtype): - run_group_norm(*sd_sizes, dtype, swish, func) + run_group_norm(*sd_sizes, dtype, silu, has_skip, func) @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm_ck(sd_sizes, dtype, swish): - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - run_group_norm(*sd_sizes, dtype, swish, ck_f_name) +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [False]) +def test_group_norm_ck(sd_sizes, dtype, silu, has_skip): + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name) @dataclass @@ -136,37 +170,67 @@ def report(self): def profile_group_norm_func( - batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func + batch_size: int, + height: int, + width: int, + num_channels: int, + num_groups: int, + dtype: str, + silu: bool, + has_skip: bool, + func, ): np.random.seed(0) input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) - workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32) + workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32) epsilon = 0.05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish + + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x) + skip_d = ke.DeviceArray(skip_x) + bias_d = ke.DeviceArray(bias_x) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) for impl in my_op.ListOps(): duration_ms = -1 @@ -181,14 +245,14 @@ def profile_group_norm_func( ) -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True): +def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func) # ck function - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name) + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name) sd_profile_sizes = [ @@ -227,7 +291,8 @@ def profile(): group.add_argument("num_channels", type=int) group.add_argument("num_groups", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--swish", action="store_true") + group.add_argument("--silu", action="store_true") + group.add_argument("--has_skip", action="store_true") group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: @@ -241,6 +306,7 @@ def profile(): args.num_channels, args.num_groups, args.dtype, - args.swish, + args.silu, + args.has_skip, args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu index 0bd47b2c0387e..6af163ab94b10 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu @@ -12,17 +12,21 @@ #include "python/tools/kernel_explorer/kernel_explorer_interface.h" namespace py = pybind11; - +using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes; namespace onnxruntime { template class GroupNormNHWC : public IKernelExplorer { public: - GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, + DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, + bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); } @@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCOp op_{}; std::string type_string_{}; @@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer { template class GroupNormNHWCStaticSelection : public IKernelExplorer { public: - GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWCStaticSelection"; } @@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; std::string type_string_{}; }; @@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { template class GroupNormNHWCTunable : public IKernelExplorer { public: - GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { params_.TuningContext()->EnableTunableOpAndTuning(); } @@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCTunableOp op_{}; }; #ifdef USE_COMPOSABLE_KERNEL -template +template class CKGroupNormNHWC : public IKernelExplorer { public: - CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { + CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer { #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL -template +template class GroupNormNHWCTriton : public IKernelExplorer { public: - GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { + GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { name_strings_.emplace_back(name); ops_.emplace_back(std::move(op)); } @@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP(name, type, threads_per_block, vec_size) \ py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &name::SetRepeats) \ .def("Profile", &name::Profile) \ .def("Run", &name::Run) \ @@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_COMMON(name, type, ...) \ py::class_>(m, name) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ .def("Profile", &type<__VA_ARGS__>::Profile) \ .def("Run", &type<__VA_ARGS__>::Run) \ @@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP_TYPED(name, type) \ REGISTER_COMMON(#name "_" #type, name, type) -#define REGISTER_CK(type, with_swish, swish_suffix) \ - REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish) +#define REGISTER_CK(type, with_silu, silu_suffix) \ + REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu) -#define REGISTER_TRITON(type, with_swish, swish_suffix) \ - REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish) +#define REGISTER_TRITON(type, with_silu, silu_suffix) \ + REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu) KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half); @@ -248,16 +270,16 @@ KE_REGISTER(m) { #ifdef USE_COMPOSABLE_KERNEL REGISTER_CK(half, false, "Pass"); - REGISTER_CK(half, true, "Swish"); + REGISTER_CK(half, true, "Silu"); REGISTER_CK(float, false, "Pass"); - REGISTER_CK(float, true, "Swish"); + REGISTER_CK(float, true, "Silu"); #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL REGISTER_TRITON(half, false, "Pass"); - REGISTER_TRITON(half, true, "Swish"); + REGISTER_TRITON(half, true, "Silu"); REGISTER_TRITON(float, false, "Pass"); - REGISTER_TRITON(float, true, "Swish"); + REGISTER_TRITON(float, true, "Silu"); #endif } diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc index fefd5722054de..ea8537f243f5d 100644 --- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array channels_last_values = {-1, 1}; for (const int channels_last : channels_last_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; @@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array has_add_out_values = {true, false}; std::array skip_dims = {2, 4}; @@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { constexpr int channels_last = 1; for (const int skip_dim : skip_dims) { for (const bool has_add_out : has_add_out_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e286236ba6447..f1d3702e3245e 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -181,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("rocm_device_prop_", "cuda_device_prop_") s = s.replace("rocm_device_arch_", "cuda_device_arch_") + s = s.replace("HipTuningContext", "RocmTuningContext") + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names # And we do this last, undoing or fixing hipify mistakes. if "fft" in src_file_path: From 124bde985ae883566c44f5cd84d351612006100c Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 20 Feb 2024 19:20:42 -0800 Subject: [PATCH 027/237] Bring QAT POC back to a functional state (#19290) --- .../test/python/qat_poc_example/README.md | 2 +- .../test/python/qat_poc_example/model.py | 56 +++++++------------ .../test/python/qat_poc_example/qat.py | 2 +- .../test/python/qat_poc_example/train.py | 18 ++---- 4 files changed, 27 insertions(+), 51 deletions(-) diff --git a/orttraining/orttraining/test/python/qat_poc_example/README.md b/orttraining/orttraining/test/python/qat_poc_example/README.md index 6840e98bd9c86..05072b410b730 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/README.md +++ b/orttraining/orttraining/test/python/qat_poc_example/README.md @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t > **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC. -> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True` +> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True` > **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False` diff --git a/orttraining/orttraining/test/python/qat_poc_example/model.py b/orttraining/orttraining/test/python/qat_poc_example/model.py index 91d7ccd7294f5..601362a59e379 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/model.py +++ b/orttraining/orttraining/test/python/qat_poc_example/model.py @@ -5,7 +5,7 @@ import onnx import torch -import onnxruntime.training.onnxblock as onnxblock +from onnxruntime.training import artifacts class MNIST(torch.nn.Module): @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix): 4. The checkpoint file """ - class MNISTWithLoss(onnxblock.TrainingModel): - def __init__(self): - super().__init__() - self.loss = onnxblock.loss.CrossEntropyLoss() - - def build(self, output_name): - return self.loss(output_name) - - mnist_with_loss = MNISTWithLoss() - onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None - - # Build the training and eval graphs - logging.info("Using onnxblock to create the training artifacts.") - with onnxblock.onnx_model(onnx_model) as model_accessor: - _ = mnist_with_loss(onnx_model.graph.output[0].name) - eval_model = model_accessor.eval_model - - # Build the optimizer graph - optimizer = onnxblock.optim.AdamW() - with onnxblock.onnx_model() as accessor: - _ = optimizer(mnist_with_loss.parameters()) - optimizer_model = accessor.model + onnx_model = onnx.load(model_path) + + requires_grad = [ + param.name + for param in onnx_model.graph.initializer + if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point")) + ] + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=artifacts_dir, + prefix=model_prefix, + ) # Create the training artifacts - train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx") - logging.info(f"Saving the training model to {train_model_path}.") - onnx.save(onnx_model, train_model_path) - eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx") - logging.info(f"Saving the eval model to {eval_model_path}.") - onnx.save(eval_model, eval_model_path) - optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx") - logging.info(f"Saving the optimizer model to {optimizer_model_path}.") - onnx.save(optimizer_model, optimizer_model_path) - trainable_params, non_trainable_params = mnist_with_loss.parameters() - checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt") - logging.info(f"Saving the checkpoint to {checkpoint_path}.") - onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path) + train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx") + eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx") + optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx") + checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint") return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path diff --git a/orttraining/orttraining/test/python/qat_poc_example/qat.py b/orttraining/orttraining/test/python/qat_poc_example/qat.py index 51a15475ee911..dcc9e116fda7d 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/qat.py +++ b/orttraining/orttraining/test/python/qat_poc_example/qat.py @@ -46,7 +46,7 @@ ) logging.info("Preparing the training artifacts for QAT.") - training_model_name = "mnist_qat" + training_model_name = "mnist_qat_" artifacts_dir = os.path.join(model_dir, "training_artifacts") utils.makedir(artifacts_dir) training_artifacts = create_training_artifacts( diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py index 9a429d2adc6f1..a25c071c58a48 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/train.py +++ b/orttraining/orttraining/test/python/qat_poc_example/train.py @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader): model.train() cumulative_loss = 0 for data, target in train_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - train_loss = model(forward_inputs) + train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) optimizer.step() model.lazy_reset_grad() - cumulative_loss += train_loss[0] + cumulative_loss += train_loss return cumulative_loss / len(train_loader) @@ -43,12 +39,8 @@ def _eval(model, test_loader): model.eval() cumulative_loss = 0 for data, target in test_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - test_loss = model(forward_inputs) - cumulative_loss += test_loss[0] + test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) + cumulative_loss += test_loss return cumulative_loss / len(test_loader) @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp train_loader, test_loader = _get_dataloaders("data", batch_size) # Load the checkpoint state. - state = orttraining.CheckpointState(qat_checkpoint) + state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint) # Create the training module. model = orttraining.Module(qat_train_model, state, qat_eval_model) From 8092a89688f92dee83d1d0111acaa1e1d2dfdb85 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Tue, 20 Feb 2024 21:18:54 -0800 Subject: [PATCH 028/237] Changed command line argpasrse to process '--symmetric [True|False]'. (#19577) ### Description Accept the command line option --symmetric and its optional value correctly. If the optional value matches uncased to 'True' then set symmetric to True else set symmetric to False. Asymmetric quantization will generate zero_point input. ``` usage: matmul_4bits_quantizer.py [-h] --input_model INPUT_MODEL --output_model OUTPUT_MODEL [--block_size BLOCK_SIZE] [--symmetric [{True,False}]] [--accuracy_level ACCURACY_LEVEL] [-v] [--nodes_to_exclude NODES_TO_EXCLUDE [NODES_TO_EXCLUDE ...]] ``` ### Motivation and Context --- .../python/tools/quantization/matmul_4bits_quantizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3e9f9a6544a71..eb7bbec997d59 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -349,6 +349,10 @@ def process(self): self.int4_quant_algo() +def ort_convert_str_to_bool(value): + return value.lower() in ("true", "1") + + def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise int4 quantization for MatMul 2D weight matrices. @@ -366,7 +370,10 @@ def parse_args(): "--symmetric", required=False, default=True, - type=bool, + const=True, + nargs="?", + type=ort_convert_str_to_bool, + choices=[True, False], help="Indicate whether to quantize the model symmetrically", ) parser.add_argument( From 58f4921686bf0a5b0442fb6df92d1b1972a118cc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 21 Feb 2024 00:31:06 -0800 Subject: [PATCH 029/237] [js] changes to allow Float16Array if any polyfill is available (#19305) ### Description This change adds only necessary code to enable ort-web works with any Float16Array polyfill. Unlike #19302, in this PR, ort-web does not include any specific polyfill; instead, it's user's choice for how to use a polyfill. ORT-web uses Float16Array if it's available; otherwise, fallback to use Uint16Array. ```js // case 1: user does not use polyfill: import * as ort from 'onnxruntime-web'; const myF16Data = new Uint16Array(...); // need to use Uint16Array const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` ```js // case 2: user use polyfill: import * as ort from 'onnxruntime-web'; import { Float16Array, isFloat16Array, isTypedArray, getFloat16, setFloat16, f16round, } from "@petamoriken/float16"; globalThis.Float16Array = Float16Array; // ort-web will pick the global Float16Array const myF16Data = new Float16Array(...); // Use the polyfilled Float16Array type const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` --- js/common/lib/tensor-impl-type-mapping.ts | 34 +++++++++++++++-------- js/common/lib/tensor-impl.ts | 10 ++++--- js/web/lib/wasm/wasm-common.ts | 9 +++++- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea1..b29cb8cbd6d35 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index de18126a9d0ae..56682ef98e117 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 93910af1f1bf0..54eaf5e0c43cc 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': From 57d6819212464f49b30db047528be0f409dadc67 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 22 Feb 2024 00:08:47 +0800 Subject: [PATCH 030/237] [js/web] Fix fused-conv is not included in npm test (#19581) BUG: https://github.com/microsoft/onnxruntime/issues/18855 ### Description ### Motivation and Context --- js/web/test/suite-test-list.jsonc | 1 + 1 file changed, 1 insertion(+) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 1c61518ddcdd2..b43b1ac37e37d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1354,6 +1354,7 @@ "expand.jsonc", "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", From e5ce81ae847d0b347a3dfe95abfc9e407e2f0469 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 21 Feb 2024 15:24:41 -0500 Subject: [PATCH 031/237] [java] Adding ML program flag for CoreML (#19551) ### Description Adds the new CoreML enum flags to enable ML Program support in Java. ### Motivation and Context Adds support for #19347 to the Java API. --- .../ai/onnxruntime/providers/CoreMLFlags.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f3..cec3fadf446ca 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; From 3afb38cfb7d4263f262dea33bcfa16d35c67fede Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Feb 2024 12:46:16 -0800 Subject: [PATCH 032/237] [CUDA] Add use_tf32 cuda provider option (for FP32 Conv) (#19426) Follow up of https://github.com/microsoft/onnxruntime/pull/19357 to apply the use_tf32 option on fp32 cuDNN convolution. When use_tf32 = 0, we will disable TF32 in cuDNN convolution for FP32 inputs. https://docs.nvidia.com/deeplearning/cudnn/api/cudnn-graph-library.html#cudnnmathtype-t **CUDNN_FMA_MATH** - Restricted to only kernels that use FMA instructions. - On pre-NVIDIA A100 GPU devices, CUDNN_DEFAULT_MATH and CUDNN_FMA_MATH have the same behavior: Tensor Core kernels will not be selected. - With NVIDIA Ampere architecture and CUDA toolkit 11, CUDNN_DEFAULT_MATH permits TF32 Tensor Core operation and CUDNN_FMA_MATH does not. - The TF32 behavior for CUDNN_DEFAULT_MATH and the other Tensor Core math types can be explicitly disabled by the environment variable NVIDIA_TF32_OVERRIDE=0. --- onnxruntime/core/providers/cuda/nn/conv.cc | 17 ++++++++++++++--- onnxruntime/core/providers/cuda/nn/conv.h | 3 ++- .../core/providers/cuda/nn/conv_transpose.cc | 10 ++++++++-- .../training_ops/cuda/nn/conv_grad.cc | 3 ++- .../training_ops/cuda/nn/conv_shared.cc | 6 ++++-- .../training_ops/cuda/nn/conv_shared.h | 2 +- .../training_ops/cuda/nn/conv_transpose_grad.cc | 6 ++++-- 7 files changed, 35 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 82f3503919237..a417be5a86c32 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -326,7 +326,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -351,8 +352,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; @@ -399,6 +405,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); if (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; } else { perf.mathType = CUDNN_DEFAULT_MATH; } @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set( const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type) { + cudnnDataType_t data_type, + bool use_tf32) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set( CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index bcaa4d855b81e..181fbc99fd8e9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final { const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type); + cudnnDataType_t data_type, + bool use_tf32); operator cudnnConvolutionDescriptor_t() const { return desc_; } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 55dceaa2698e8..939b9959af818 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -167,7 +167,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (has_bias) { const auto& b_shape = p.B->Shape(); @@ -187,8 +188,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f6c58445c0a5d..fc5d9b65d0f89 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -114,7 +114,8 @@ Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args_.params.data_type)); + args_.params.data_type, + UseTF32())); if (dB) { const TensorShape& db_shape = dB->Shape(); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index 5dc16c68f6210..d23905496c9bb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const } template -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32) { perf_results.resize(1); perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; if (args.params.data_type == CUDNN_DATA_HALF) { perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) { + perf_results[0].mathType = CUDNN_FMA_MATH; } else { perf_results[0].mathType = CUDNN_DEFAULT_MATH; } @@ -256,7 +258,7 @@ Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const std::vector perf_results; ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) + ? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32()) : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); for (auto& algo_perf : perf_results) { if (f(algo_perf) == Status::OK()) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h index a2d4bf3bdc006..3fdb4306bfbbb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -75,7 +75,7 @@ class AlgoIterator { Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f); - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32); private: const ConvArgs& args_; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index 5f7206fc121ec..d3f5a89434a48 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -182,7 +182,8 @@ Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tenso ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); } return Status::OK(); @@ -287,7 +288,8 @@ Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, cons ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); if (dB) { const auto& b_shape = dB->Shape(); From ebd220b0730f9898aaa0275ef0d8195ce70057d0 Mon Sep 17 00:00:00 2001 From: Matttttt <18152455+martholomew@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:38:18 +0000 Subject: [PATCH 033/237] Misspelling in README.md (#19433) Fixed a misspelling. --- js/web/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da28..906c78a1b7ec4 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: From 38c34323939bac03b9648b2e59dbbe8de0bd7092 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:58:53 -0800 Subject: [PATCH 034/237] Bump ip from 1.1.8 to 1.1.9 in /js/react_native (#19582) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 4dca90d7415cf..bbb0c4f3d1e22 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" From 5197db19802a39e47d19ac829cd08a94bacbdfbb Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 21 Feb 2024 15:45:44 -0800 Subject: [PATCH 035/237] Diable __cpuid call for ARM64EC (#19592) Diable __cpuid call for ARM64EC Co-authored-by: Sheil Kumar --- winml/lib/Api/HardwareCoreEnumerator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index b6b44690f4f6c..d04e276347170 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -84,7 +84,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); -#if !defined(_M_ARM64) && !defined(__aarch64__) +#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" int regs_leaf0[4]; int regs_leaf7[4]; From 3d88487c96bf467c4b83dff179c9e282602e2d64 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 22 Feb 2024 10:35:26 +0800 Subject: [PATCH 036/237] Minor Triton Fix (#19589) Including removing a unnecessary assert, and add support of passing string attribute from ONNX node attribute to python functoin kwargs (mainly for passing debug info from graph to python for now). --- .../orttraining/core/framework/triton/triton_op_executor.cc | 2 ++ orttraining/orttraining/python/training/ort_triton/_utils.py | 3 ++- orttraining/orttraining/training_ops/cpu/triton/triton_op.h | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc index 092ab89d5d760..f30d6ddee253a 100644 --- a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc +++ b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc @@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first))); } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first))); + } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str())); } else { ORT_THROW("Unsupported kwargs data type: ", kv.second.second); } diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index 95e6703be8783..877eacc0b775f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl def next_power_of_2(n: int) -> int: - assert n <= 2**32, "32-bit only" + """Return the smallest power of 2 greater than or equal to n""" n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 + n |= n >> 32 n += 1 return n diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index f226db76f7ed7..db8e8558ab884 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -25,12 +25,15 @@ class TritonOp final : public OpKernel { attr.first == "onnx_string") { continue; } - // Support int64 and float only for now, skip other types. + // Support int64, float and string only for now, skip other types. if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); } else if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) { + kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}}); } } } From 8354329086ebb190db9ea0cb6a3fa72f53f8f881 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:45 +0800 Subject: [PATCH 037/237] [ROCm] SkipGroupNorm triton (#19408) Change GroupNorm triton to support SkipGroupNorm --- .../rocm/diffusion/group_norm_triton.cuh | 23 ++++++++--- .../rocm/diffusion/group_norm_triton.py | 39 +++++++++++++++++-- .../kernel_explorer/kernels/groupnorm_test.py | 12 ++++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b3d3e92209b39..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, params->channels_per_group, - params->epsilon}; + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 5368cb1cf635b..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,13 +12,19 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -36,14 +42,35 @@ def group_norm_kernel( offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,9 +84,13 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta if ACTIVATION_SILU: @@ -77,7 +108,7 @@ def group_norm_kernel( hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" @@ -88,7 +119,7 @@ def get_function_table(): silu_suffix = "Silu" if silu else "Pass" name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 8334d20e47c86..400a9d8a7a187 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -80,6 +80,18 @@ def run_group_norm( ) use_silu = silu broadcast_skip = False + if has_skip: + skip_x_shape = skip_x.shape + b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels + b4 = ( + len(skip_x_shape) == 4 + and skip_x_shape[0] == batch_size + and skip_x_shape[1] == 1 + and skip_x_shape[2] == 1 + and skip_x_shape[3] == num_channels + ) + if b2 or b4: + broadcast_skip = True channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x.astype(dtype)) From 05ed89f46980b7e5a5328bc20af8b32ca9f1f715 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:55 +0800 Subject: [PATCH 038/237] [ROCm] Add excluded libs for ROCm python package (#19586) The rocm lib version has changed in rocm 6.0 Using libs packaged in whl might cause errors. For example, `libamdhip64.so.6` packaged in whl will cause compute error when training gpt2 model. The root cause still in investigating. --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 03e1cb75ba581..9a5fc29dd5e02 100644 --- a/setup.py +++ b/setup.py @@ -205,18 +205,23 @@ def run(self): rocm_dependencies = [ "libamd_comgr.so.2", "libamdhip64.so.5", + "libamdhip64.so.6", "libdrm.so.2", "libdrm_amdgpu.so.1", "libelf.so.1", "libhipfft.so.0", "libhiprtc.so.5", + "libhiprtc.so.6", "libhsa-runtime64.so.1", "libMIOpen.so.1", "libnuma.so.1", "librccl.so.1", "librocblas.so.3", + "librocblas.so.4", "librocfft.so.0", + "libroctx64.so.4", "librocm_smi64.so.5", + "librocm_smi64.so.6", "libroctracer64.so.4", "libtinfo.so.6", "libmigraphx_c.so.3", From 6b73ab3e3e72a9f2008e8d0e221b0be77d2993b1 Mon Sep 17 00:00:00 2001 From: cao lei Date: Thu, 22 Feb 2024 10:19:08 -0800 Subject: [PATCH 039/237] Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream (#19515) ### Description Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream. So if a NodeArg is an input of several Ops across different streams and reuses other NodeArg, the reused NodeArg won't be involved when computing the second stream's reuse plan. ### Motivation and Context This is to fix https://github.com/microsoft/onnxruntime/issues/19480, which is a crash for the scenario mentioned above. --------- Co-authored-by: Lei Cao --- .../core/framework/allocation_planner.cc | 44 ++++++------ .../test/framework/allocation_planner_test.cc | 68 ++++++++++++++++++ .../multi_stream_models/issue_19480.onnx | Bin 0 -> 760 bytes 3 files changed, 91 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ea7a6432a7507..158ab8ed610f4 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -182,7 +182,6 @@ class PlannerImpl { // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -295,7 +294,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -530,6 +529,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i; } bool HasExternalOutputs(const Node& node) const { @@ -1065,7 +1065,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1078,10 +1079,10 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1138,8 +1139,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1168,8 +1169,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1187,11 +1188,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1266,7 +1267,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1277,9 +1278,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1342,8 +1343,9 @@ class PlannerImpl { ort_value_usecount.reserve(ort_value_info_.size()); #endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough! ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1693,8 +1695,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1889,7 +1891,7 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index d7b1de5c930c5..3e0d94e94e48c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1974,6 +1974,74 @@ TEST_F(PlannerTest, TestCpuIf) { ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep); } } + +// model looks like: +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// Shape ----------------> Reshape --> Shape ------------------> Reshape +// ^ ^ +// InstanceNormalization ----| InstanceNormalization ------| +// +// Python script to create this model: +// def CreateModelFor19480(): +// #shape->reshape->shape->reshape, 4 gather +// graphNodes = [] +// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8'])) +// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output'])) +// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293'])) +// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4'])) +// g = h.make_graph(graphNodes, 'issue_19480', +// [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), +// h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale0', tp.FLOAT, [32]), +// h.make_tensor_value_info('B0', tp.FLOAT, [32]), +// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale1', tp.FLOAT, [32]), +// h.make_tensor_value_info('B1', tp.FLOAT, [32]), +// h.make_tensor_value_info('indices1', tp.INT32, []), +// h.make_tensor_value_info('indices2', tp.INT32, []), +// h.make_tensor_value_info('indices3', tp.INT32, []), +// h.make_tensor_value_info('indices4', tp.INT32, [])], +// [h.make_tensor_value_info('output0', tp.FLOAT, None), +// h.make_tensor_value_info('output1', tp.INT64, None), +// h.make_tensor_value_info('output2', tp.INT64, None), +// h.make_tensor_value_info('output3', tp.INT64, None), +// h.make_tensor_value_info('output4', tp.INT64, None)]) +// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name') +// onnx.save(model, 'issue_19480.onnx') +// +TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { + SessionOptions sess_opt; + sess_opt.graph_optimization_level = TransformerLevel::Default; + + InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + status = sess.Load(); + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()) << "No crash"; + const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan(); + ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; + + int gather_count = 0; + for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { + if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); + if (node->OpType() == "Gather") + gather_count++; + else + FAIL() << "CPU stream should contain only gather ops"; + } + } + ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream"; +} #endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc7d39206dd49f4ef6daf65b7d58c5b456ecf331 GIT binary patch literal 760 zcmaixKTm@|7>Bw3f%9#m_0_6_F_p#1gab^#v5RqW(2b?J(o1>?g{IKO$uH`6@t{2^ z#m0q%=Y9D7j(aJ^(>&%f;j=_M&c!l&{_evy4DtnEiK$Fin*vE__dm*aU~nQ+XN$p9 zA11aA3GP#64zs z+VGAUzBYVq;6Ud2Mod}g2Tt_Ry!IQoq685v?9X@+FQ7}m2pC{Q_TCzB1Q$v>tF;at zE9X-02LY%OdZ2hTthTjJs;u4R{+GpCSxtg_7ivO}nrK8dbFt05KbWuCO#ReubJg)l W4Oj)N8n}nRI|Tj~OnP7p&wl_GGP8{U literal 0 HcmV?d00001 From 3bdb10d5ca4f258ec444863bcd5e839eeac5c238 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Thu, 22 Feb 2024 10:56:25 -0800 Subject: [PATCH 040/237] Move import to when needed to avoid circular dependency error (#19579) ### Description Move import to when needed to avoid circular dependency error ### Motivation and Context Fixes dependency error described here: https://github.com/microsoft/DeepSpeed/issues/5140 --------- Co-authored-by: Thiago Crepaldi --- .../python/training/ortmodule/_graph_execution_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 779b6bfe50422..fda6e345da235 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -20,7 +20,6 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype -from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._fallback import ( @@ -143,6 +142,9 @@ def __init__( self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: + # Move import to here to avoid circular dependency error + from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 # type: ignore[import] + # Cannot toggle feature enabling/disabling after the first time enabled. configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) From fe82fccf1a4d7ea6c24c8448d7264df36605c370 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 23 Feb 2024 05:09:28 +0800 Subject: [PATCH 041/237] [js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596) This is used in sam-h-decoder-f16. ### Description ### Motivation and Context --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index b5b6a2a15cd8c..11c8778b72335 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo = const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); inputVariables.push(bias); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'pads', type: 'i32', length: pads.length} ]; appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}`; }; From 09622418c45b265977a8f1f17581e15719357423 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 Feb 2024 13:15:13 -0800 Subject: [PATCH 042/237] Add special handling if there is only 1 graph inside the cached QNN context binary (#19594) Add special handling if there is only 1 graph inside the cached QNN context binary. No need to make the EPContext node name match the QNN graph name. This is for better backward compatibility in case the QNN context model is generated before the PR for QNN context binary model support multi-partition. --- .../qnn/builder/onnx_ctx_model_helper.cc | 6 +- .../qnn/builder/onnx_ctx_model_helper.h | 3 +- .../qnn/builder/qnn_backend_manager.cc | 15 ++-- .../providers/qnn/qnn_execution_provider.cc | 3 +- .../test/providers/qnn/qnn_ep_context_test.cc | 83 ++++++++++++++++++- 5 files changed, 99 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index c2e71081b898e..2d8ec295d613b 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -151,12 +151,14 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models) { + std::unordered_map>& qnn_models, + const logging::Logger& logger) { Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); + LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage()); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index b1360b4e576fa..7d56b45a1dbcd 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -56,7 +56,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + const logging::Logger& logger); Status CreateEPContextNodes(Model* model, unsigned char* buffer, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5f0b87c7cb9d7..ca34a1efa6ca7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -573,11 +573,16 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // More work to support multiple partition, how to map the graph name in compile to qnn graph name // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile - for (uint32_t i = 0; i < graph_count; ++i) { - std::string graph_name(graphs_info[i].graphInfoV1.graphName); - auto qnn_model_pos = qnn_models.find(graph_name); - ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); - ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + if (1 == graph_count) { + auto qnn_model_pose = qnn_models.begin(); + ORT_RETURN_IF_ERROR(qnn_model_pose->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + } else { + for (uint32_t i = 0; i < graph_count; ++i) { + std::string graph_name(graphs_info[i].graphInfoV1.graphName); + auto qnn_model_pos = qnn_models.find(graph_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + } } qnn_sys_interface_.systemContextFree(sys_ctx_handle); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f5a166d36b15a..9a6540a3efea5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -670,7 +670,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, context_cache_path, qnn_backend_manager_.get(), - qnn_models)); + qnn_models, + logger)); for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index b1f3b52e77553..eaef6f6315157 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -463,7 +463,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { InferenceSessionWrapper session_object{so, GetEnvironment()}; - std::string provider_type = kCpuExecutionProvider; ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); // Verify the return status with code INVALID_GRAPH @@ -486,7 +485,6 @@ std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { auto* graph_output = helper.MakeOutput(shape); Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); ep_context_node.AddAttribute("embed_mode", static_cast(0)); - // The .. in the path will cause INVALID_GRAPH ep_context_node.AddAttribute("ep_cache_context", external_bin_path); ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); ep_context_node.AddAttribute("source", "QNN"); @@ -651,6 +649,87 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +// Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node +// Create another Onnx model which also reference to the bin file, +// but the node name is not same with the QNN graph name inside the bin file. +// This is to support backward compitable for the models generated before the PR that +// make context generation support multi-partition +TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphNameInCtx) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_bin)); + + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {1, 2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, false, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + ep_context_node.AddAttribute("ep_cache_context", context_bin.string()); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNNExecutionProvider"); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test From 76a2a487a12c7ec579f453a36932429164494ef6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:58:17 -0800 Subject: [PATCH 043/237] Bump ip from 1.1.8 to 1.1.9 in /js/react_native/e2e (#19583) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index 9e20a286c4e27..6f05faf046098 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -3351,9 +3351,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-accessor-descriptor@^0.1.6: version "0.1.6" From 5e5c36f6df95dfbb25787ea385f733f8c9ef691e Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Fri, 23 Feb 2024 09:03:56 +0800 Subject: [PATCH 044/237] Fix citation author name issue (#19597) Use `name` rather than `given-names` to set author name. ### Motivation and Context The old CITATION.cff uses `given-names` to set author names, which won't be rendered properly with some bibtex style of LaTeX: image The problem is that **the `"ONNX Runtime developers"` is regarded as a human name**. How to fix: by using `name` to set author name, the generated Bibtex entry will use `{}` to enclose the `"ONNX Runtime developers"`. Then it is displayed literally: image --- CITATION.cff | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 82bcac5a7b750..10b7290022aef 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,8 +3,7 @@ title: ONNX Runtime message: "Please use this information to cite ONNX Runtime in research or other publications." authors: - - affiliation: Microsoft Corporation - given-names: ONNX Runtime developers + - name: ONNX Runtime developers date-released: 2018-11-29 url: "https://onnxruntime.ai" repository-code: "https://github.com/microsoft/onnxruntime" From 4ab497603e915ca992b96ef1ec25bfcf8b9a2ad5 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 Feb 2024 17:04:59 -0800 Subject: [PATCH 045/237] Enable user to set QNN HTP performance mode for every session run (#19521) ### Description Currently, the QNN HTP performance mode is set during session creation, there's no way to change it afterwards. There's requirement to set it high performance mode for high priority request and set it back to low performance mode later to save the power when the incoming request is idle for example. Now, still keeps the performance mode at the session level in QNN EP options which is used at the default one. Ort QNN EP will set it once if user set it. And there are setting (qnn.htp_perf_mode and qnn.htp_perf_mode_post_run) in run option to change the performance mode before and after session run. There's recommended scenario that user set the mode to high performance mode before the the inference sun so that user can get the result back ASAP. And set the mode to low performance mode after the inference to save the power. --- .../core/framework/execution_provider.h | 10 +- .../onnxruntime_run_options_config_keys.h | 12 + .../framework/stream_execution_context.cc | 4 +- .../providers/cann/cann_execution_provider.cc | 2 +- .../providers/cann/cann_execution_provider.h | 2 +- .../providers/cuda/cuda_execution_provider.cc | 4 +- .../providers/cuda/cuda_execution_provider.h | 5 +- .../src/ExecutionProvider.h | 4 +- .../providers/js/js_execution_provider.cc | 4 +- .../core/providers/js/js_execution_provider.h | 4 +- .../migraphx/migraphx_execution_provider.cc | 4 +- .../migraphx/migraphx_execution_provider.h | 4 +- .../qnn/builder/qnn_backend_manager.cc | 75 +++--- .../qnn/builder/qnn_backend_manager.h | 19 +- .../providers/qnn/qnn_execution_provider.cc | 198 +++++++++++++++- .../providers/qnn/qnn_execution_provider.h | 73 +++++- .../providers/rocm/rocm_execution_provider.cc | 4 +- .../providers/rocm/rocm_execution_provider.h | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 4 +- .../tensorrt/tensorrt_execution_provider.h | 4 +- onnxruntime/core/session/inference_session.cc | 12 +- .../cuda_execution_provider_test.cc | 13 +- .../test/providers/qnn/qnn_basic_test.cc | 217 ++++++++++++++++-- 23 files changed, 577 insertions(+), 105 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 31c988f500779..c1cc69edc17d8 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -33,6 +33,8 @@ class Node; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtRunOptions; + namespace onnxruntime { /** @@ -51,6 +53,8 @@ struct NodeComputeInfo { DestroyFunctionStateFunc release_state_func; }; +using RunOptions = OrtRunOptions; + enum class DataLayout { NCHW, NHWC, @@ -184,7 +188,7 @@ class IExecutionProvider { Run may not be finished on device This function should be regarded as the point after which a new Run would start to submit commands from CPU */ - virtual common::Status OnRunStart() { return Status::OK(); } + virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } /** Called when InferenceSession::Run ended @@ -192,7 +196,9 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { + return Status::OK(); + } /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 1f5fcd50e185c..b0a17e175fef3 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor // Per default it will be set to '0' // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index 875e7f395bfa8..dd7f4d35b34bd 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess } #ifdef USE_CANN + // Leave it to CANN EP to fill the gap if they want to use run_options + static onnxruntime::RunOptions run_options; // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, // which is different from CUDA Runtime API, but similar to CUDA Driver API. auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); for (auto& xp : execution_providers) { - auto status = xp->OnRunStart(); + auto status = xp->OnRunStart(run_options); if (!status.IsOK()) { ctx.SetStatus(status); return; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 752b742805a7c..9a242919665bb 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() { } // All threads share the same context and stream -Status CANNExecutionProvider::OnRunStart() { +Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id)); return Status::OK(); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 63ae980869c65..d83bd88d6958f 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider { explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info); virtual ~CANNExecutionProvider(); - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; template Status Fill(Tensor* y, void* addr, aclrtStream stream) const { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 48a952e6dd98f..0dd568c5ecc05 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -386,7 +386,7 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart() { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -396,7 +396,7 @@ Status CUDAExecutionProvider::OnRunStart() { return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) { +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 55f0b5570e0ee..5f62f313b86a2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; DataLayout GetPreferredLayout() const override; @@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); cublasHandle_t CublasHandle() const { return cublas_handle_; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 5617bc7bdcac6..841d6244a983e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -270,7 +270,7 @@ namespace Dml return m_impl->OnSessionInitializationEnd(); } - virtual onnxruntime::Status Sync() const final override + onnxruntime::Status Sync() const final override { // Completely wait until the device has completed all preceding tasks. // The application could have called SynchronizeBoundOutputs(). @@ -278,7 +278,7 @@ namespace Dml return Status::OK(); } - virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override + onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override { // Flush any pending work to the GPU, but don't block for completion, permitting it // to overlap other work. diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 799d4172f2b64..62c3981682cfc 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -756,7 +756,7 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } -Status JsExecutionProvider::OnRunStart() { +Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; EM_ASM({ Module.jsepCaptureBegin(); }); @@ -764,7 +764,7 @@ Status JsExecutionProvider::OnRunStart() { return Status::OK(); } -Status JsExecutionProvider::OnRunEnd(bool sync_stream) { +Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 91a3256ec2bd5..b4518c67d1e60 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -59,8 +59,8 @@ class JsExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 40e76a0a67782..50782569ee80a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1383,11 +1383,11 @@ Status MIGraphXExecutionProvider::Sync() const { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunStart() { +Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunEnd(bool) { +Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { auto status = hipStreamQuery(stream_); if (status != hipSuccess) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d582338c7e067..c3617f409e72c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -56,9 +56,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { #ifdef MIGRAPHX_STREAM_SYNC Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; #endif std::vector> diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index ca34a1efa6ca7..e354bf6562722 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -634,11 +634,6 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "CreateContext succeed."; } - if (htp_performance_mode_ != HtpPerformanceMode::kHtpDefault) { - ORT_RETURN_IF_ERROR(SetHtpPowerConfig()); - LOGS(logger, VERBOSE) << "SetHtpPowerConfig succeed."; - } - LOGS(logger, VERBOSE) << "QNN SetupBackend succeed"; backend_setup_completed_ = true; @@ -646,7 +641,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ return Status::OK(); } -Status QnnBackendManager::SetHtpPowerConfig() { +Status QnnBackendManager::CreateHtpPowerCfgId(uint32_t device_id, uint32_t core_id, uint32_t& htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -656,23 +651,37 @@ Status QnnBackendManager::SetHtpPowerConfig() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; // Get power client id - status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_); + status = htp_perf_infra.createPowerConfigId(device_id, core_id, &htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed."); + return Status::OK(); +} + +Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + constexpr const int kNumConfigs = 1; std::vector power_configs( kNumConfigs); QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0]; dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3; QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config; - dcvs_v3.contextId = htp_power_config_client_id_; + dcvs_v3.contextId = htp_power_config_client_id; dcvs_v3.setSleepDisable = 0; dcvs_v3.sleepDisable = 0; dcvs_v3.setDcvsEnable = 1; dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; // choose performance mode - switch (htp_performance_mode_) { + switch (htp_performance_mode) { case HtpPerformanceMode::kHtpBurst: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMinLatency; @@ -771,25 +780,40 @@ Status QnnBackendManager::SetHtpPowerConfig() { dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; break; default: - ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode_)); + ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode)); break; } std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode."); - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - if (rpc_control_latency_ != 0) { + return Status::OK(); +} + +Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency) { + if (rpc_control_latency != 0) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + + // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. constexpr int kNumRpcPollingPowerConfigs = 2; std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency = rpc_power_configs[0]; + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; // v68 doesn't support this. QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_; - perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + std::vector perf_power_configs_ptr = + ObtainNullTermPtrVector(rpc_power_configs); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency."); } @@ -810,11 +834,7 @@ void QnnBackendManager::Split(std::vector& split_string, } } -Status QnnBackendManager::DestroyHTPPowerConfigID() { - if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) { - return Status::OK(); - } - +Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -824,7 +844,7 @@ Status QnnBackendManager::DestroyHTPPowerConfigID() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_); + Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed."); return Status::OK(); } @@ -834,12 +854,7 @@ void QnnBackendManager::ReleaseResources() { return; } - auto result = DestroyHTPPowerConfigID(); - if (Status::OK() != result) { - ORT_THROW("Failed to DestroyHTPPowerConfigID."); - } - - result = ReleaseContext(); + auto result = ReleaseContext(); if (Status::OK() != result) { ORT_THROW("Failed to ReleaseContext."); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 36375522b5a0a..ff97c4c3a991c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -33,8 +33,6 @@ class QnnBackendManager { public: QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level, - uint32_t rpc_control_latency, - HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, std::string&& qnn_saver_path, uint32_t device_id, @@ -42,8 +40,6 @@ class QnnBackendManager { uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), - rpc_control_latency_(rpc_control_latency), - htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), qnn_saver_path_(qnn_saver_path), device_id_(device_id), @@ -92,7 +88,13 @@ class QnnBackendManager { Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); - Status SetHtpPowerConfig(); + Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); + + Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode); + + Status SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -141,6 +143,8 @@ class QnnBackendManager { const std::string& GetSdkVersion() { return sdk_build_version_; } + Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -150,8 +154,6 @@ class QnnBackendManager { Status UnloadLib(void* handle); - Status DestroyHTPPowerConfigID(); - void* LibFunction(void* handle, const char* symbol, std::string& error_msg); template @@ -232,15 +234,12 @@ class QnnBackendManager { QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; std::vector op_package_paths_; - uint32_t rpc_control_latency_ = 0; - HtpPerformanceMode htp_performance_mode_; ContextPriority context_priority_; std::string sdk_build_version_ = ""; #ifdef _WIN32 std::set mod_handles_; #endif const std::string qnn_saver_path_; - uint32_t htp_power_config_client_id_ = 0; uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 9a6540a3efea5..3d9cfd92b7922 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -7,6 +7,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" #include "core/platform/env.h" @@ -18,11 +19,36 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/framework/run_options.h" namespace onnxruntime { constexpr const char* QNN = "QNN"; +static std::unique_ptr>> s_run_on_unload_; + +void RunOnUnload(std::function function) { + OrtMutex mutex; + std::lock_guard guard(mutex); + if (!s_run_on_unload_) { + s_run_on_unload_ = std::make_unique>>(); + } + s_run_on_unload_->push_back(std::move(function)); +} + +struct OnUnload { + ~OnUnload() { + if (!s_run_on_unload_) + return; + + for (auto& function : *s_run_on_unload_) + function(); + + s_run_on_unload_.reset(); + } + +} g_on_unload; + static void ParseProfilingLevel(std::string profiling_level_string, qnn::ProfilingLevel& profiling_level) { std::transform(profiling_level_string.begin(), @@ -193,18 +219,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency"; - uint32_t rpc_control_latency = 0; auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY); if (latency_pos != provider_options_map.end()) { - rpc_control_latency = static_cast(std::stoul(latency_pos->second)); - LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + default_rpc_control_latency_ = static_cast(std::stoul(latency_pos->second)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << default_rpc_control_latency_; } - qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + // default_htp_performance_mode from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode"; auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE); if (htp_performance_mode_pos != provider_options_map.end()) { - ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode); + ParseHtpPerformanceMode(htp_performance_mode_pos->second, default_htp_performance_mode_); } htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; @@ -241,15 +267,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_DEVICE_ID = "device_id"; - uint32_t device_id = 0; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { int value = std::stoi(dev_id_pos->second); if (value < 0) { LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value - << "', only >= 0 allowed. Set to " << device_id << "."; + << "', only >= 0 allowed. Set to " << device_id_ << "."; } else { - device_id = static_cast(value); + device_id_ = static_cast(value); } } @@ -276,15 +301,23 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, - rpc_control_latency, - htp_performance_mode, context_priority, std::move(qnn_saver_path), - device_id, + device_id_, htp_arch, soc_model); } +QNNExecutionProvider::~QNNExecutionProvider() { + // clean up thread local context caches + std::lock_guard lock(context_state_.mutex); + for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { + const auto cache = cache_weak.lock(); + if (!cache) continue; + ORT_IGNORE_RETURN_VALUE(cache->erase(this)); + } +} + bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { const std::string& op_type = node_unit.OpType(); @@ -725,4 +758,147 @@ const InlinedVector QNNExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } + +QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, + uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency) + : qnn_backend_manager_(qnn_backend_manager) { + Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); + is_htp_power_config_id_valid_ = rt.IsOK(); + // default_htp_performance_mode and default_rpc_control_latency are from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run + if (is_htp_power_config_id_valid_) { + if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, + default_htp_performance_mode)); + } + if (default_rpc_control_latency > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, + default_rpc_control_latency)); + } + } +} + +QNNExecutionProvider::PerThreadContext::~PerThreadContext() { + if (is_htp_power_config_id_valid_) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_)); + } +} + +QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + // try to use cached context + auto cached_context_it = per_thread_context_cache->find(this); + if (cached_context_it != per_thread_context_cache->end()) { + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + return *cached_context; + } + + // get context and update cache + std::shared_ptr context; + { + std::lock_guard lock(context_state_.mutex); + + // get or create a context + if (context_state_.retired_context_pool.empty()) { + uint32_t core_id = 0; + context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, + default_htp_performance_mode_, default_rpc_control_latency_); + } else { + context = context_state_.retired_context_pool.back(); + context_state_.retired_context_pool.pop_back(); + } + + // insert into active_contexts, should not already be present + const auto active_contexts_insert_result = context_state_.active_contexts.insert(context); + ORT_ENFORCE(active_contexts_insert_result.second); + + // insert into caches_to_update_on_destruction, may already be present + ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache)); + } + + per_thread_context_cache->insert(std::make_pair(this, context)); + + return *context; +} + +void QNNExecutionProvider::ReleasePerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + auto cached_context_it = per_thread_context_cache->find(this); + ORT_ENFORCE(cached_context_it != per_thread_context_cache->end()); + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + + { + std::lock_guard lock(context_state_.mutex); + context_state_.active_contexts.erase(cached_context); + context_state_.retired_context_pool.push_back(cached_context); + } + + per_thread_context_cache->erase(cached_context_it); +} + +Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + std::string rpc_latency = ""; + uint32_t rpc_control_latency = 0; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) { + rpc_control_latency = static_cast(std::stoul(rpc_latency)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + } + + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + if (rpc_control_latency > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency)); + } + } + + return Status::OK(); +} + +Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + if (!GetPerThreadContext().IsHtpPowerConfigIdValid()) { + return Status::OK(); + } + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 0bcaa39b22f6d..43b5e7bff827e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -12,14 +12,19 @@ #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "HTP/QnnHtpGraph.h" +#include +#include +#include namespace onnxruntime { +void RunOnUnload(std::function function); + // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options); - virtual ~QNNExecutionProvider() = default; + virtual ~QNNExecutionProvider(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider); // we implement the Compile that takes FusedNodeAndGraph instances @@ -40,6 +45,10 @@ class QNNExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const; @@ -72,6 +81,68 @@ class QNNExecutionProvider : public IExecutionProvider { int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; + uint32_t device_id_ = 0; + qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; + uint32_t default_rpc_control_latency_ = 0; + + class PerThreadContext final { + public: + PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency); + ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); + + bool IsHtpPowerConfigIdValid() { return is_htp_power_config_id_valid_; } + + uint32_t GetHtpPowerConfigId() { return htp_power_config_id_; } + + private: + bool is_htp_power_config_id_valid_ = false; + uint32_t htp_power_config_id_ = 0; + qnn::QnnBackendManager* qnn_backend_manager_; + }; + + using PerThreadContextMap = std::unordered_map>; + + struct ContextCacheHolder { + ContextCacheHolder() { + RunOnUnload([&, weak_p_ = std::weak_ptr(p)] { + if (auto lock = weak_p_.lock()) + p.reset(); + }); + } + + std::shared_ptr p = std::make_shared(); + }; + + static const std::shared_ptr& PerThreadContextCache() { + thread_local const ContextCacheHolder per_thread_context_cache; + return per_thread_context_cache.p; + } + + struct PerThreadContextState { + // contexts that are currently active + std::set, std::owner_less>> active_contexts; + // contexts available for reuse + std::vector> retired_context_pool; + // weak references to thread local caches from which this QNNExecutionProvider instance's entry should be removed + // upon destruction + std::set, std::owner_less>> + caches_to_update_on_destruction; + // synchronizes access to PerThreadContextState members + OrtMutex mutex; + }; + + // The execution provider maintains the PerThreadContexts in this structure. + // Synchronization is required to update the contained structures. + // On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time, + // so synchronization is not required for that. + mutable PerThreadContextState context_state_; + + PerThreadContext& GetPerThreadContext() const; + void ReleasePerThreadContext() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index ee3578326ac6d..3fd5423681b81 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -353,7 +353,7 @@ Status ROCMExecutionProvider::Sync() const { return Status::OK(); } -Status ROCMExecutionProvider::OnRunStart() { +Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -363,7 +363,7 @@ Status ROCMExecutionProvider::OnRunStart() { return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 37d5f7b42210f..da671d9e863bb 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -28,9 +28,9 @@ class ROCMExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; const void* GetExecutionHandle() const noexcept override { // The ROCM interface does not return anything interesting. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c0bf29e486c88..81346671f2aad 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1818,11 +1818,11 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(); } -Status TensorrtExecutionProvider::OnRunStart() { +Status TensorrtExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { +Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index e86f997b6597a..26f6b2dcc3020 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -233,8 +233,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; ProviderOptions GetProviderOptions() const override { return TensorrtExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b045f30a59797..efd7db4ea7629 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2289,8 +2289,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2326,7 +2326,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(/*sync_stream*/ false); + auto status = xp->OnRunEnd(/*sync_stream*/ false, run_options); ORT_CHECK_AND_SET_RETVAL(status); } @@ -2448,8 +2448,8 @@ Status InferenceSession::Run(const RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, &run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2490,7 +2490,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; - auto status = xp->OnRunEnd(synchronize_execution_providers); + auto status = xp->OnRunEnd(synchronize_execution_providers, run_options); ORT_CHECK_AND_SET_RETVAL(status); } diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index a70e439cdf755..5505d689381c9 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -22,6 +22,8 @@ TEST(TestDeferredRelease, WithArena) { CUDAExecutionProvider ep(info); AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0]; + RunOptions run_opts; + run_opts.run_tag = "log1"; // Allocator for call cudaMallocHost and cudaFreeHost // For details, see CUDAPinnedAllocator in cuda_allocator.cc. AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; @@ -31,7 +33,7 @@ TEST(TestDeferredRelease, WithArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cpu_pinned_alloc, n_bytes); @@ -44,7 +46,7 @@ TEST(TestDeferredRelease, WithArena) { cpu_pinned_alloc->GetStats(&stats); ASSERT_EQ(stats.num_allocs, n_allocs); ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } TEST(TestDeferredRelease, WithoutArena) { @@ -52,6 +54,9 @@ TEST(TestDeferredRelease, WithoutArena) { CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); + RunOptions run_opts; + run_opts.run_tag = "log1"; + OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; // Create allocator without BFCArena AllocatorCreationInfo pinned_memory_info( @@ -70,7 +75,7 @@ TEST(TestDeferredRelease, WithoutArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cuda_pinned_alloc, n_bytes); @@ -79,7 +84,7 @@ TEST(TestDeferredRelease, WithoutArena) { } ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 4e1aef2c40b2b..8f07c2ce77e77 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -7,6 +7,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #include "core/session/inference_session.h" @@ -332,19 +333,23 @@ static void CreateModelInMemory(std::unique_ptr& result, static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& output_shapes, - const std::vector>& expected_values) { - std::vector fetches; - auto status = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); - - for (size_t i = 0; i < fetches.size(); i++) { - auto& tensor = fetches[i].Get(); - TensorShape expected_shape(output_shapes[i]); - ASSERT_EQ(expected_shape, tensor.Shape()); - - gsl::span actual = tensor.DataAsSpan(); - gsl::span expected(expected_values[i].data(), expected_values[i].size()); - ASSERT_EQ(expected, actual); + const std::vector>& expected_values, + int loop_count = 10) { + // Let it run for a while + for (int it = 0; it < loop_count; ++it) { + std::vector fetches; + auto status = session.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + for (size_t i = 0; i < fetches.size(); i++) { + auto& tensor = fetches[i].Get(); + TensorShape expected_shape(output_shapes[i]); + ASSERT_EQ(expected_shape, tensor.Shape()); + + gsl::span actual = tensor.DataAsSpan(); + gsl::span expected(expected_values[i].data(), expected_values[i].size()); + ASSERT_EQ(expected, actual); + } } } @@ -404,11 +409,11 @@ TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; - + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); } for (auto& th : threads) { @@ -484,11 +489,191 @@ TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with run option to set power config +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with EP option to set default power config +TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with +// EP option to set default power config + run option to set power config for each run +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); } for (auto& th : threads) { From 29b1106033e291947debb49c3fd03feb479c4b1b Mon Sep 17 00:00:00 2001 From: Segev Finer Date: Fri, 23 Feb 2024 04:53:50 +0200 Subject: [PATCH 046/237] [node] Switch to setImmediate to avoid starving the Node.js event loop (#19610) ### Description Switch to setImmediate to avoid starving the Node.js event loop There should really be a true async version though, running computationally intensive things on the event loop will stop everything else from happening while it is running, e.g. a web server from answering requests. This can be done by wrapping `RunAsync` behind a [`napi::Promise`](https://github.com/nodejs/node-addon-api/blob/main/doc/promises.md) to run on the onnxruntime thread pool or [`AsyncWorker`]( https://github.com/nodejs/node-addon-api/blob/main/doc/async_worker.md) for the Node.js/libuv thread pool. ### Motivation and Context Without this, if you run inference in a tight loop, without anything else in between that is async/deferred, `process.nextTick` will lead to starving the event loop and not letting anything else run, `setImmediate` at least lets the event loop spin between calls to `run`. See https://dev.to/ynmanware/setimmediate-settimeout-and-process-nexttick-3mfd Contributed on behalf of [Swimm](https://swimm.io/) --- js/node/lib/backend.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index e8eb0e9babf5a..927953b4f1dd6 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(this.#inferenceSession.run(feeds, fetches, options)); } catch (e) { @@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend { async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {})); } catch (e) { From ae92d593c0e2b06decbea64797f9145bc10f34af Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 23 Feb 2024 11:05:16 +0800 Subject: [PATCH 047/237] ONNX Gelu Op in Opset 20 (#19560) ### ONNX Gelu Op in Opset 20 Refactor code to support MSDomain Gelu and ONNX Gelu-opset20 Op 1. Move CPU-GELU implmentation from `onnxruntime/contrib_ops/cpu/activations.h/cc` to `onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation for approximate attribute to be 'none'. 2. Dumplicate some logic from `onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc` to `onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation for approximate attribute to be 'tanh'. 3. Register ONNX domain Gelu CPU kernel from opset 20 in `onnxruntime/core/providers/cpu/cpu_execution_provider.cc`. 4. Move `onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h/cu` to `onnxruntime/core/providers/cuda/tensor/gelu_impl.h` and `onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu` respectively, as the implementation for approximate attribute to be 'tanh'. 5. Implement the logic for approximate attribute to be 'none' in `onnxruntime/core/providers/cuda/tensor/gelu_impl.cu`. 6. Register ONNX domain Gelu CUDA kernel from opset 20 in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. 7. ROCM ep related changes. 8. Enrich the tests for ONNX domain Gelu in `onnxruntime/test/providers/cpu/activation/activation_op_test.cc`. --- cmake/onnxruntime_rocm_hipify.cmake | 4 - .../InferenceTest.netcore.cs | 2 +- docs/OperatorKernels.md | 2 + .../core/providers/cuda/cuda_resource.h | 2 +- onnxruntime/contrib_ops/cpu/activations.cc | 10 +- onnxruntime/contrib_ops/cpu/activations.h | 41 ------- .../cuda/activation/activations.cc | 1 - .../contrib_ops/cuda/activation/activations.h | 11 -- .../cuda/activation/activations_impl.cu | 14 --- .../cuda/activation/activations_impl.h | 2 - .../contrib_ops/cuda/bert/fast_gelu.cc | 20 +++- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 2 +- .../contrib_ops/rocm/bert/fast_gelu.cc | 59 ---------- onnxruntime/contrib_ops/rocm/bert/fast_gelu.h | 24 ---- .../providers/cpu/cpu_execution_provider.cc | 2 + onnxruntime/core/providers/cpu/tensor/gelu.cc | 108 ++++++++++++++++++ onnxruntime/core/providers/cpu/tensor/gelu.h | 18 +++ .../providers/cuda/cuda_execution_provider.cc | 10 ++ .../core/providers/cuda/tensor/gelu.cc | 89 +++++++++++++++ onnxruntime/core/providers/cuda/tensor/gelu.h | 28 +++++ .../cuda/tensor/gelu_approximate_impl.cu} | 17 ++- .../core/providers/cuda/tensor/gelu_impl.cu | 48 ++++++++ .../providers/cuda/tensor/gelu_impl.h} | 7 +- .../test/contrib_ops/activation_op_test.cc | 13 ++- .../test/onnx/microbenchmark/activation.cc | 3 +- .../cpu/activation/activation_op_test.cc | 48 ++++++-- .../cpu/activation/activation_op_test.h | 7 +- 27 files changed, 395 insertions(+), 197 deletions(-) delete mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu.h create mode 100644 onnxruntime/core/providers/cpu/tensor/gelu.cc create mode 100644 onnxruntime/core/providers/cpu/tensor/gelu.h create mode 100644 onnxruntime/core/providers/cuda/tensor/gelu.cc create mode 100644 onnxruntime/core/providers/cuda/tensor/gelu.h rename onnxruntime/{contrib_ops/cuda/bert/fast_gelu_impl.cu => core/providers/cuda/tensor/gelu_approximate_impl.cu} (88%) create mode 100644 onnxruntime/core/providers/cuda/tensor/gelu_impl.cu rename onnxruntime/{contrib_ops/cuda/bert/fast_gelu_impl.h => core/providers/cuda/tensor/gelu_impl.h} (80%) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 85a9bf50460d3..1bb70e9c2ed27 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files "bert/fastertransformer_decoder_attention/*" "bert/multihead_attention.cc" "bert/multihead_attention.h" - "bert/fast_gelu_impl.cu" - "bert/fast_gelu_impl.h" - "bert/fast_gelu.cc" - "bert/fast_gelu.h" "bert/relative_attn_bias.cc" "bert/relative_attn_bias.h" "bert/relative_attn_bias_impl.cu" diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 715aed7e1d64f..7f3d5d6624b07 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions() private void CanRunInferenceOnAModelWithTensorRT() { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); - + int deviceId = 0; string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8ff2135c6b1f6..46149c577a106 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -127,6 +127,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[9, 10]|**T** = tensor(double), tensor(float)| @@ -606,6 +607,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index 1fef077860be3..00e7dec5727d1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -19,4 +19,4 @@ enum CudaResource : int { enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, use_tf32_t, -}; \ No newline at end of file +}; diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index 556699192d2eb..3e0533dd8b9e5 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/activation/activations.h" -#include "activations.h" +#include "contrib_ops/cpu/activations.h" namespace onnxruntime { namespace contrib { @@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); -ONNX_OPERATOR_KERNEL_EX( - Gelu, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); - ONNX_OPERATOR_KERNEL_EX( QuickGelu, kMSDomain, diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index aed4c2229215d..7e64235d3fc3d 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -54,47 +54,6 @@ namespace contrib { DEFINE_ELE_KERNEL(ScaledTanh); DEFINE_ELE_KERNEL(ParametricSoftplus); -template -class Gelu : public OpKernel { - public: - Gelu(const OpKernelInfo& info) : OpKernel(info) { - } - - Status Compute(OpKernelContext* context) const override { - const Tensor* input = context->Input(0); - const T* input_data = input->Data(); - - Tensor* output = context->Output(0, input->Shape()); - T* output_data = output->MutableData(); - - concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - concurrency::ThreadPool::TryBatchParallelFor( - tp, static_cast(task_count), - [&](ptrdiff_t task_idx) { - const auto start = task_idx * length_per_task; - const T* p_input = input_data + start; - T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); - - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } - }, - 0); - return Status::OK(); - } -}; - // Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call // MlasComputeLogistic instead of using Eigen for better perf. template diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index 1a86c5dbece5a..6303858b9bd48 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -49,7 +49,6 @@ namespace cuda { UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain); -UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain); UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain); REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16) diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h index ab339f276c2bd..fc9a71b0b7fa1 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations.h @@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise { float beta_; }; -template -class Gelu final : public UnaryElementwise { - public: - Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - MAKE_FUNC_CTX_NULL() -}; - template class QuickGelu final : public UnaryElementwise { public: diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 0c856815fd437..36f33fbb24c18 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh { } }; -template -struct OP_Gelu : public CtxGelu { - __device__ __inline__ T operator()(const T& a) const { - return _Gelu(a); - } -}; - -template <> -struct OP_Gelu : public CtxGelu { - __device__ __inline__ half operator()(const half& a) const { - return static_cast(_Gelu(static_cast(a))); - } -}; - template struct OP_QuickGelu : public CtxQuickGelu { __device__ __inline__ T operator()(const T& a) const { diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h index 5d18283a395e3..782d4bf59a5ad 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h @@ -11,14 +11,12 @@ namespace cuda { typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine; typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus; typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh; -typedef onnxruntime::cuda::CtxNull CtxGelu; typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu; #define UNARY_CONTRIB_ACTIVATION_OPS() \ UNARY_ACTIVATION_OP_NAME(ScaledTanh) \ UNARY_ACTIVATION_OP_NAME(Affine) \ UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \ - UNARY_ACTIVATION_OP_NAME(Gelu) \ UNARY_ACTIVATION_OP_NAME(QuickGelu) #define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 892f5c181a607..e8974a29476b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -4,9 +4,14 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cudnn_common.h" #include "fast_gelu.h" -#include "fast_gelu_impl.h" +#include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "transformer_common.h" +#ifdef USE_ROCM +#include "contrib_ops/rocm/bert/elementwise.h" +#endif +#ifdef USE_CUDA +#include "contrib_ops/cuda/bert/transformer_common.h" +#endif namespace onnxruntime { namespace contrib { @@ -31,8 +36,10 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +#ifdef USE_CUDA const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); +#endif } template @@ -50,6 +57,14 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; +#ifdef USE_ROCM + return LaunchElementwiseKernel( + GetTuningContext(), context->GetComputeStream(), + reinterpret_cast(input->Data()), static_cast(input_length), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), + reinterpret_cast(output->MutableData())); +#endif +#ifdef USE_CUDA return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), @@ -58,6 +73,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(output->MutableData()), use_half2_); +#endif } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 3e642a70afef5..d563556593e6e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; + bool use_half2_; // Only applicable to CUDA kernel (not ROCM). }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc deleted file mode 100644 index 9cb414e4e8980..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/fast_gelu.h" - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/miopen_common.h" -#include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -using namespace ONNX_NAMESPACE; - -template -Status FastGelu::ComputeInternal(OpKernelContext* context) const { - ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); - - const Tensor* input = context->Input(0); - const Tensor* bias = context->Input(1); - Tensor* output = context->Output(0, input->Shape()); - - int64_t input_length = input->Shape().Size(); - if (input_length == 0) { - return Status::OK(); - } - int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - const HipT* input_buffer = reinterpret_cast(input->Data()); - const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr; - return LaunchElementwiseKernel( - GetTuningContext(), context->GetComputeStream(), - input_buffer, static_cast(input_length), - bias_buffer, static_cast(bias_length), - reinterpret_cast(output->MutableData())); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h deleted file mode 100644 index 42bfe5a0b0246..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class FastGelu final : public RocmKernel { - public: - FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 813fdc54ecd0d..48e4617b33b4d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1035,6 +1035,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); @@ -2562,6 +2563,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc new file mode 100644 index 0000000000000..d55973eda180f --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" + +#include "core/platform/threadpool.h" +#include +#include "core/providers/cpu/element_wise_ranged_transform.h" +#include "core/providers/cpu/tensor/gelu.h" + +using onnxruntime::narrow; +using namespace onnxruntime::common; + +namespace onnxruntime { + +// May revisit the implementations to support inplace computation, if needed. + +ONNX_CPU_OPERATOR_KERNEL( + Gelu, + 20, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +ONNX_OPERATOR_KERNEL_EX( + Gelu, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); +} +#endif + +template +Status Gelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const T* input_data = input->Data(); + + Tensor* output = context->Output(0, input->Shape()); + T* output_data = output->MutableData(); + + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + int64_t elem_count = input->Shape().Size(); + constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. + int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + + if (approximation_algorithm_ == "tanh") { + // FastGelu allows optional bias. Here we split input data into chunks. Each chunk + // has N elements (except the last chunk), and use thread pool to parallel chunks. + // N = 4096 is selected based on performance test results on input shape 1x128x768. + // FastGelu uses approximation for Gelu. The formula is 0.5 * (1 + Tanh(x * (C * x * x + B))) * x. + static constexpr float B = 0.7978845608028654f; // sqrt(2.0 / M_PI) + static constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0 / M_PI) + + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const T* p_input = input_data + start; + T* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + + for (int64_t i = 0; i < count; i++) { + T value = p_input[i]; + p_output[i] = value * (static_cast(C) * value * value + static_cast(B)); + } + + MlasComputeTanh(p_output, p_output, narrow(count)); + + for (int64_t i = 0; i < count; i++) { + p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); + } + }, + 0); + return Status::OK(); + } else if (approximation_algorithm_ == "none") { + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const T* p_input = input_data + start; + T* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + + for (int64_t i = 0; i < count; i++) { + T value = p_input[i]; + p_output[i] = value * static_cast(M_SQRT1_2); + } + + MlasComputeErf(p_output, p_output, narrow(count)); + + for (int64_t i = 0; i < count; i++) { + p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); + } + }, + 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h new file mode 100644 index 0000000000000..13238028d878a --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { + +template +class Gelu final : public OpKernel { + public: + explicit Gelu(const OpKernelInfo& info) : OpKernel(info) { + approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + } + Status Compute(OpKernelContext* ctx) const override; + + private: + std::string approximation_algorithm_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0dd568c5ecc05..be2530aec49fa 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1329,6 +1329,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape); #endif +// Opset 20 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -2222,6 +2227,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.cc b/onnxruntime/core/providers/cuda/tensor/gelu.cc new file mode 100644 index 0000000000000..67b2fad373a7f --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/tensor/gelu.h" +#include "core/providers/cuda/tensor/gelu_impl.h" + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gelu, \ + kOnnxDomain, \ + 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .MayInplace(0, 0), \ + Gelu); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(double) + +template +Status Gelu::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); + } + + Tensor* output = context->Output(0, input->Shape()); + + int64_t input_length = input->Shape().Size(); + if (input_length == 0) { + return Status::OK(); + } + + typedef typename ToCudaType::MappedType CudaT; + + if (approximation_algorithm_ == "tanh") { + return LaunchFastGeluKernel(GetDeviceProp(), + Stream(context), + static_cast(input_length), + 0 /* no bias */, + reinterpret_cast(input->Data()), + nullptr /* no bias */, + reinterpret_cast(output->MutableData()), + use_half2_); + } else if (approximation_algorithm_ == "none") { + return LaunchGeluKernel(Stream(context), + reinterpret_cast(input->Data()), + reinterpret_cast(output->MutableData()), + static_cast(input_length)); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); +} + +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib::cuda { +#define REGISTER_CONTRIB_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gelu, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .MayInplace(0, 0), \ + onnxruntime::cuda::Gelu); + +REGISTER_CONTRIB_KERNEL_TYPED(float) +REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16) +REGISTER_CONTRIB_KERNEL_TYPED(double) + +#undef REGISTER_CONTRIB_KERNEL_TYPED +} // namespace contrib::cuda +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.h b/onnxruntime/core/providers/cuda/tensor/gelu.h new file mode 100644 index 0000000000000..1c8189ab24121 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/math/unary_elementwise_ops.h" + +namespace onnxruntime { +namespace cuda { + +template +class Gelu final : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) : UnaryElementwise(info) { + approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + } + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + const bool use_half2_{true}; + + std::string approximation_algorithm_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu similarity index 88% rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu rename to onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu index c9498eb1bcd7b..3292650584de8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu @@ -24,12 +24,9 @@ limitations under the License. #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_call.h" -#include "contrib_ops/cuda/bert/fast_gelu_impl.h" - -using namespace onnxruntime::cuda; +#include "core/providers/cuda/tensor/gelu_impl.h" namespace onnxruntime { -namespace contrib { namespace cuda { // constants for approximating the normal cdf @@ -75,6 +72,17 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int return CUDA_CALL(cudaGetLastError()); } +template <> +Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, + const double* input, const double* bias, double* output, bool /*use_half2*/) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + FastGeluKernel<<>>(A, B, C, input_length, bias_length, + input, bias, output); + + return CUDA_CALL(cudaGetLastError()); +} + template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const half* input, const half* bias, half* output, bool use_half2) { @@ -114,5 +122,4 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } } // namespace cuda -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu new file mode 100644 index 0000000000000..3f96da38b37bb --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/tensor/gelu_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh" + +namespace onnxruntime { +namespace cuda { + +template +struct OP_Gelu { + __device__ __inline__ T operator()(const T& a) const { + return _Gelu(a); + } +}; + +template <> +struct OP_Gelu { + __device__ __inline__ half operator()(const half& a) const { + return static_cast(_Gelu(static_cast(a))); + } +}; + +template +Status LaunchGeluKernel( + cudaStream_t stream, + const T* input_data, + T* output_data, + size_t count) { + UnaryElementWiseImpl(stream, input_data, output_data, OP_Gelu(), count); + + return CUDA_CALL(cudaGetLastError()); +} + +#define SPECIALIZED_GELU_IMPL(T) \ + template Status LaunchGeluKernel(cudaStream_t stream, const T* input_data, T* output_data, \ + size_t count); + +SPECIALIZED_GELU_IMPL(float); +SPECIALIZED_GELU_IMPL(half); +SPECIALIZED_GELU_IMPL(double); + +#undef SPECIALIZED_GELU_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h similarity index 80% rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h rename to onnxruntime/core/providers/cuda/tensor/gelu_impl.h index ba78310f5dfc2..2ea0d3441fda3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #pragma once + #include "core/common/common.h" namespace onnxruntime { -namespace contrib { namespace cuda { +template +Status LaunchGeluKernel(cudaStream_t stream, const T* input, T* output, size_t count); + template Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const T* input, const T* bias, T* output, bool use_half2); } // namespace cuda -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/activation_op_test.cc b/onnxruntime/test/contrib_ops/activation_op_test.cc index b1e54ec605a39..2a56991ec5af4 100644 --- a/onnxruntime/test/contrib_ops/activation_op_test.cc +++ b/onnxruntime/test/contrib_ops/activation_op_test.cc @@ -22,7 +22,8 @@ namespace test { TEST_F(ActivationOpTest, ThresholdedRelu_version_1_to_9) { float alpha = 0.1f; TestActivationOp( - "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, true, 1); + "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, {}, + true, 1); } TEST_F(ActivationOpTest, ScaledTanh) { @@ -46,13 +47,13 @@ TEST_F(ActivationOpTest, ParametricSoftplus) { else return alpha * logf(expf(bx) + 1); }, - {{"alpha", alpha}, {"beta", beta}}, false); // Disable TensorRT due to result mismatch + {{"alpha", alpha}, {"beta", beta}}, {}, false); // Disable TensorRT due to result mismatch } TEST_F(ActivationOpTest, Gelu) { TestActivationOp( "Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast(M_SQRT1_2))); }, {}, - false, 1, kMSDomain); + {}, false, 1, kMSDomain); } #if defined(USE_DNNL) @@ -115,7 +116,7 @@ TEST_F(ActivationOpTest, QuickGelu) { y = tmp >= 0 ? y : 1 - y; return x * y; }, - {{"alpha", alpha}}, false, 1, kMSDomain); + {{"alpha", alpha}}, {}, false, 1, kMSDomain); } // Silu = x*sigmoid(x), i.e., alpha = 1.0f. @@ -129,7 +130,7 @@ TEST_F(ActivationOpTest, QuickGelu) { y = tmp >= 0 ? y : 1 - y; return x * y; }, - {{"alpha", alpha}}, false, 1, kMSDomain); + {{"alpha", alpha}}, {}, false, 1, kMSDomain); } // Negative alpha. @@ -143,7 +144,7 @@ TEST_F(ActivationOpTest, QuickGelu) { y = tmp >= 0 ? y : 1 - y; return x * y; }, - {{"alpha", alpha}}, false, 1, kMSDomain); + {{"alpha", alpha}}, {}, false, 1, kMSDomain); } } diff --git a/onnxruntime/test/onnx/microbenchmark/activation.cc b/onnxruntime/test/onnx/microbenchmark/activation.cc index cf859facf4765..69ee72996365e 100644 --- a/onnxruntime/test/onnx/microbenchmark/activation.cc +++ b/onnxruntime/test/onnx/microbenchmark/activation.cc @@ -11,6 +11,7 @@ #include "core/framework/node_index_info.h" #include "core/framework/execution_frame.h" #include "contrib_ops/cpu/activations.h" +#include "core/providers/cpu/tensor/gelu.h" #include "core/providers/cpu/activation/activations.h" #include #include @@ -182,7 +183,7 @@ static void RunSingleNode(const std::string& op_name, const std::string& domain, } static void BM_GeluCompute(benchmark::State& state) { - RunSingleNode>("Gelu", kMSDomain, {}, state); + RunSingleNode>("Gelu", kMSDomain, {}, state); } BENCHMARK(BM_GeluCompute) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index ddb0a6620619c..acd513172f95d 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -116,13 +116,13 @@ TEST_F(ActivationOpTest, Relu) { "Relu", input_values_double, [](double x) { return std::max(x, 0.0); }, - {}, + {}, {}, /*is_tensorrt_supported=*/false); TestActivationOp( "Relu", input_values_int8, [](int8_t x) { return std::max(x, static_cast(0)); }, - {}, + {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -133,7 +133,7 @@ TEST_F(ActivationOpTest, Relu) { if (x.ToFloat() > 0.0f) return x; return MLFloat16(); }, - {}, + {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 11); #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -402,7 +402,7 @@ TEST_F(ActivationOpTest, Celu) { // TODO: Investigate why gcc 4 fails to compile without the explicit cast [alpha](float x) { return std::max(0.0f, x) + std::min(0.0f, alpha * (static_cast(exp(x / alpha)) - 1)); }, // Disable on TensorRT as it seems like it doesn't yet support Celu - {{"alpha", alpha}}, false, 12); + {{"alpha", alpha}}, {}, false, 12); } TEST_F(ActivationOpTest, LeakyRelu) { @@ -410,7 +410,7 @@ TEST_F(ActivationOpTest, LeakyRelu) { TestActivationOp("LeakyRelu", input_values, [alpha](float x) { return (x >= 0) ? x : alpha * x; }, - {{"alpha", alpha}}); + {{"alpha", alpha}}, {}); } #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -442,7 +442,7 @@ TEST_F(ActivationOpTest, ThresholdedRelu) { "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, - {{"alpha", alpha}}, true, 10); + {{"alpha", alpha}}, {}, true, 10); } TEST_F(ActivationOpTest, Selu) { @@ -452,7 +452,7 @@ TEST_F(ActivationOpTest, Selu) { TestActivationOp("Selu", input_values, [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + {{"alpha", alpha}, {"gamma", gamma}}, {}); } TEST_F(ActivationOpTest, Selu_Attributes) { @@ -462,7 +462,7 @@ TEST_F(ActivationOpTest, Selu_Attributes) { TestActivationOp("Selu", input_values, [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + {{"alpha", alpha}, {"gamma", gamma}}, {}); } TEST_F(ActivationOpTest, Selu_GH10726) { @@ -472,7 +472,7 @@ TEST_F(ActivationOpTest, Selu_GH10726) { TestActivationOp("Selu", {{1.f, -1.f}}, [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + {{"alpha", alpha}, {"gamma", gamma}}, {}); } TEST_F(ActivationOpTest, PRelu) { @@ -625,7 +625,7 @@ TEST_F(ActivationOpNoInfTest, Softsign) { return result; }, - {}, false); // Disable TensorRT because result mismatches + {}, {}, false); // Disable TensorRT because result mismatches } #if defined(ENABLE_TRAINING_OPS) @@ -695,5 +695,33 @@ TEST(LeakyReluGradInferenceTest, Basic) { } #endif +// Remove DNNL from running this test because DNNL Gelu op seems not check domain for kernel implementation. +// It will run the DNNL Gelu op which only be part of standard of Gelu-20 op. +#if !defined(USE_DNNL) && !defined(USE_QNN) +TEST_F(ActivationOpTest, ONNX_Gelu) { + TestActivationOp( + "Gelu", + input_values, + [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, {}, + {{"approximate", "none"}}, true, 20); + + TestActivationOp( + "Gelu", + input_values, + [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, + {}, + {/*default value of approximate attribute is none */}, true, 20); + + TestActivationOp( + "Gelu", + input_values, + [](float x) { + return 0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * x * x * x))); + }, + {}, + {{"approximate", "tanh"}}, true, 20); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index b5ec1402584fb..984b8f4437a3b 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -17,13 +17,16 @@ namespace test { template inline void TestActivationOp(const char* szOp, const std::vector>& input_vals_vec, std::function expected_func, - const std::unordered_map attribs = {}, + const std::unordered_map float_attribs = {}, + const std::unordered_map string_attribs = {}, bool is_tensorrt_supported = true, int opset_version = 7, const char* domain = kOnnxDomain) { for (const std::vector& input_vals : input_vals_vec) { OpTester test(szOp, opset_version, domain); - for (auto attr : attribs) test.AddAttribute(attr.first, attr.second); + for (auto attr : float_attribs) test.AddAttribute(attr.first, attr.second); + for (auto attr : string_attribs) test.AddAttribute(attr.first, attr.second); + std::vector dims{(int64_t)input_vals.size()}; std::vector expected_vals; From 5e432a3ae69dbbed603420493c52ba48b3726471 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 23 Feb 2024 04:47:15 +0100 Subject: [PATCH 048/237] Add support for NHWC GridSample in the CUDA EP and enable grid_sample_test for all EPs (#19562) I've added NHWC GridSample support to the CUDA EP to reduce the number of layout transforms. Also I've enabled the full set of GridSampleTests for all EPs. I've also added the GridSample OpSet 16 to the registered kernels. ### Motivation and Context This is the first PR is a series of enhancements of the CUDA EP improving NHWC support to avoid costly layout transforms between NWHC and NCHW nodes which are layout sensitive. Also testing was quite rudimentary for the CUDA EP while it was great for the CPU path. I've regenerated grid_sample_test.cc enabling tests for other platforms as well. Those tests resurfaced #10607 again which is fixed as well. --- docs/OperatorKernels.md | 1 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 7 + onnxruntime/contrib_ops/cuda/grid_sample.cc | 35 ++-- onnxruntime/contrib_ops/cuda/grid_sample.h | 2 +- .../contrib_ops/cuda/grid_sample_impl.cu | 101 ++++++---- .../contrib_ops/cuda/grid_sample_impl.h | 2 +- .../layout_transformation.cc | 2 + .../providers/cuda/cuda_execution_provider.cc | 2 + .../providers/cuda/shared_inc/cuda_utils.h | 26 +++ .../providers/cpu/tensor/grid_sample_test.cc | 172 ++++++++---------- .../cpu/tensor/grid_sample_test_gen.py | 2 +- onnxruntime/test/util/default_providers.cc | 16 ++ .../test/util/include/default_providers.h | 3 + 13 files changed, 223 insertions(+), 148 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 46149c577a106..b0ed68d595c42 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -619,6 +619,7 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index be8c0dc86c135..57e951d3a68ff 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -203,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif +#ifdef ENABLE_CUDA_NHWC_OPS +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -408,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif +#ifdef ENABLE_CUDA_NHWC_OPS + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 4c2999c279e0a..2500de39d3536 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -9,22 +9,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GridSample, \ - kMSDomain, \ - 1, \ + DOMAIN, \ + VERSION, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); + onnxruntime::contrib::cuda::GridSample); -REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain) -template -GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { +template +GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); @@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { } } -template -Status GridSample::ComputeInternal(OpKernelContext* context) const { +template +Status GridSample::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); const auto& dims_input = X->Shape().GetDims(); const Tensor* Grid = context->Input(1); @@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]); ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2"); + using Ch = Channels; + TensorShapeVector dims_output(4); - dims_output[0] = dims_input[0]; - dims_output[1] = dims_input[1]; - dims_output[2] = dims_grid[1]; - dims_output[3] = dims_grid[2]; + dims_output[Ch::N] = dims_input[Ch::N]; + dims_output[Ch::C] = dims_input[Ch::C]; + dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; Tensor* Y = context->Output(0, dims_output); // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { @@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; CudaT* Y_data = reinterpret_cast(Y->MutableData()); - GridSampleImpl( + GridSampleImpl( Stream(context), reinterpret_cast(X->Data()), reinterpret_cast(Grid->Data()), @@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } } // namespace cuda } // namespace contrib + +namespace cuda { +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +} // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h index 08ca58c7cc458..16581bfe77482 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample.h @@ -12,7 +12,7 @@ namespace cuda { using namespace onnxruntime::cuda; -template +template class GridSample final : public CudaKernel { public: explicit GridSample(const OpKernelInfo& info); diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index 8a391eca7e86a..b23da635bc83d 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) { return static_cast(fx); } -template +template __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { + int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; + + auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { + return Layout == LAYOUT_NCHW + ? (bIdx * C * H * W + cIdx * H * W + y * W + x) + : (bIdx * H * W * C + y * W * C + x * C + cIdx); + }; + if (padding_mode == 0) { // zeros if (x >= 0 && x < W && y >= 0 && y < H) { - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } - } else if (padding_mode == 1) { //border + } else if (padding_mode == 1) { // border x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x)); y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } else { // Reflection - x = (int64_t) GsReflect(x, border[0], border[2]); - y = (int64_t) GsReflect(y, border[1], border[3]); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + x = (int64_t)GsReflect(x, border[0], border[2]); + y = (int64_t)GsReflect(y, border[1], border[3]); + pixel = input_data[PixelOffset(x, y)]; } return pixel; } -__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) -{ +__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) { float cubic_alpha = -0.75f; x = abs(x); coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha); @@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) { return pixel; } -template +template __global__ void _GridSampleKernel( const T* input_data, const T* grid_data, @@ -110,16 +116,32 @@ __global__ void _GridSampleKernel( { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); // extract batch index, channel index, y index, x index for current thread - int BIdx = idx / (C * H_out * W_out ); - int tmpBCnt = BIdx * (C * H_out * W_out); + int BIdx, yIdx, xIdx, cIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * H_out * W_out); + int tmpBCnt = BIdx * (C * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); - int cIdx = (idx - tmpBCnt) / (H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); + yIdx = (idx - tmpCCnt) / W_out; + int tmpHCnt = tmpCCnt + yIdx * W_out; - int yIdx = (idx - tmpCCnt) / W_out; - int tmpHCnt = tmpCCnt + yIdx * W_out; + xIdx = (idx - tmpHCnt); + } else { + static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); - int xIdx = (idx - tmpHCnt); + BIdx = idx / (H_out * W_out * C); + int tmpBCnt = BIdx * (H_out * W_out * C); + + yIdx = (idx - tmpBCnt) / (W_out * C); + int tmpHCnt = tmpBCnt + yIdx * (W_out * C); + + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; + + cIdx = (idx - tmpWCnt); + } int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; T grid_X = grid_data[grid_idx * 2 + 0]; @@ -147,8 +169,9 @@ __global__ void _GridSampleKernel( if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound if (padding_mode == 1) { // border - grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + // Clamping must not be done here, see #10607 + // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); } else if (padding_mode == 2) { // reflection grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); @@ -175,10 +198,10 @@ __global__ void _GridSampleKernel( w_lb = w_b * w_l; w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; output_data[outIdx] = interpoV; return; @@ -186,7 +209,8 @@ __global__ void _GridSampleKernel( if (mode == 1) { // nearest int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; - output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + output_data[outIdx] = + PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -195,7 +219,8 @@ __global__ void _GridSampleKernel( T p[4][4] = {}; // [H][W] for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + p[h][w] = + PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -204,7 +229,7 @@ __global__ void _GridSampleKernel( } } -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, @@ -216,17 +241,23 @@ void GridSampleImpl( const int64_t H_out, const int64_t W_out, T* output_data) { - int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock)); - _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data); + using Ch = Channels; + + int blocksPerGrid = static_cast( + ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + _GridSampleKernel<<>>( + input_data, grid_data, mode, padding_mode, align_corners, + dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], + H_out, W_out, output_data); } -#define SPECIALIZED_IMPL(T) \ - template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ - const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); +#define SPECIALIZED_IMPL(T, IsNHWC) \ + template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ + const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); -SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(float, false) // NCHW +SPECIALIZED_IMPL(float, true) // NHWC } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h index 6df86ce161908..62cd66a48fa84 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 4505d4afdf1e0..a8717b99a8750 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a } #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +// TODO(mtavenrath) generate list from registered kernels using nhwc domain const std::unordered_set& GetCUDALayoutSensitiveOps() { static std::unordered_set cuda_nhwc_ops = []() { return std::unordered_set{ @@ -41,6 +42,7 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() { "MaxPool", "GlobalAveragePool", "AveragePool", + "GridSample", }; }(); return cuda_nhwc_ops; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index be2530aec49fa..00783bcbc2665 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1256,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -2148,6 +2149,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index fa987866c002f..54c024793ff0b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -168,5 +168,31 @@ struct NumericLimits { } }; +// TODO Where to put this? good places might be +// core/framework/tensor_shape.h +// core/util/matrix_layout.h + +constexpr bool LAYOUT_NCHW = false; +constexpr bool LAYOUT_NHWC = true; + +template +struct Channels; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t H = 1; + static constexpr size_t W = 2; + static constexpr size_t C = 3; +}; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t C = 1; + static constexpr size_t H = 2; + static constexpr size_t W = 3; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 0f097622abff0..5c89d6ea7bd75 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -6,6 +6,33 @@ namespace onnxruntime { namespace test { + +std::vector> GetExecutionProviders(int opset_version) { + ORT_UNUSED_PARAMETER(opset_version); + + std::vector> execution_providers; + + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + if (opset_version < 20) { + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#ifdef ENABLE_CUDA_NHWC_OPS + execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); +#endif + } + +#endif + return execution_providers; +} + +template +void RunTests(T& test, std::vector>&& execution_providers) { + for (size_t idx = 0; idx < execution_providers.size(); ++idx) { + test.ConfigEp(std::move(execution_providers[idx])).RunWithConfig(); + } + execution_providers.clear(); +} + // DO NOT edit following tests. They are generated by: // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { @@ -25,8 +52,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { @@ -46,8 +72,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { @@ -67,8 +92,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { @@ -88,8 +112,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { @@ -109,8 +132,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { @@ -130,8 +152,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { @@ -151,8 +172,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { @@ -172,8 +192,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { @@ -193,8 +212,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { @@ -214,8 +232,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { @@ -235,8 +252,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { @@ -256,8 +272,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { @@ -277,8 +292,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { @@ -298,8 +312,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { @@ -319,8 +332,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { @@ -340,8 +352,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { @@ -361,8 +372,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { @@ -382,8 +392,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { @@ -403,8 +412,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { @@ -424,8 +432,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { @@ -445,8 +452,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { @@ -466,8 +472,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { @@ -487,8 +492,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { @@ -508,8 +512,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { @@ -529,8 +532,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { @@ -550,8 +552,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { @@ -571,8 +572,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { @@ -592,8 +592,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { @@ -613,8 +612,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { @@ -634,8 +632,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { @@ -655,8 +652,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { @@ -676,8 +672,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { @@ -697,8 +692,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { @@ -718,8 +712,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { @@ -739,8 +732,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { @@ -760,8 +752,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { @@ -781,8 +772,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { @@ -802,8 +792,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { @@ -823,8 +812,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { @@ -844,8 +832,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { @@ -865,8 +852,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { @@ -886,8 +872,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { @@ -907,8 +892,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { @@ -928,8 +912,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { @@ -949,8 +932,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { @@ -970,8 +952,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { @@ -991,8 +972,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { @@ -1012,8 +992,8 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index e4d58e79243ef..c60e55617774f 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -76,6 +76,6 @@ print('test.AddAttribute("padding_mode", padding_mode);') print('test.AddAttribute("align_corners", align_corners);') print('test.AddOutput("Y", Y_shape, Y_data);') - print("test.Run();") + print(f"RunTests(test, GetExecutionProviders({opset_version}));") print("}") print("\n") diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 40b40136af1af..b404c12db3582 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -8,6 +8,9 @@ #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif +#if defined(ENABLE_CUDA_NHWC_OPS) +#include +#endif #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/session_options.h" @@ -118,6 +121,19 @@ std::unique_ptr DefaultCudaExecutionProvider() { return nullptr; } +#ifdef ENABLE_CUDA_NHWC_OPS +std::unique_ptr DefaultCudaNHWCExecutionProvider() { +#if defined(USE_CUDA) + OrtCUDAProviderOptionsV2 provider_options{}; + provider_options.do_copy_in_default_stream = true; + provider_options.prefer_nhwc = true; + if (auto factory = CudaProviderFactoryCreator::Create(&provider_options)) + return factory->CreateProvider(); +#endif + return nullptr; +} +#endif + std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options) { #ifdef USE_CUDA if (auto factory = CudaProviderFactoryCreator::Create(provider_options)) diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 9f78e0a0d4eb2..738fc66d775c6 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -35,6 +35,9 @@ namespace test { // unique_ptr providers with default values for session registration std::unique_ptr DefaultCpuExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultCudaExecutionProvider(); +#ifdef ENABLE_CUDA_NHWC_OPS +std::unique_ptr DefaultCudaNHWCExecutionProvider(); +#endif std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options); std::unique_ptr DefaultDnnlExecutionProvider(); std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDnnlProviderOptions* provider_options); From ae3d73c9818c34af42c785ff2bd9558007ba315f Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Fri, 23 Feb 2024 00:21:15 -0800 Subject: [PATCH 049/237] [JS/WebGPU] Fix Split and Where to handle corner cases. (#19613) ### Description 1. Fix Where operator to handle Boolean input less than 4 bytes. 2. Fix JSEP test harness to use tensor names consistently. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 3 ++- js/web/test/data/ops/where.jsonc | 34 ++++++++++++++++++++++++ js/web/test/test-runner.ts | 4 +-- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index cfee07a9239d7..a6375847fc42f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -27,7 +27,7 @@ const createWhereOpProgramShader = const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -38,6 +38,7 @@ const createWhereOpProgramShader = let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; + let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc index 047fd6fd7511b..990120dd3708e 100644 --- a/js/web/test/data/ops/where.jsonc +++ b/js/web/test/data/ops/where.jsonc @@ -168,5 +168,39 @@ ] } ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1", + "inputs": [ + { + "data": [true, false], + "dims": [1, 1, 2, 1], + "type": "bool" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 4], + "type": "float32" + }, + { + "data": [5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index ecc7d4b4a09a5..a4adf5c4ce144 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -627,8 +627,8 @@ export async function runModelTestSet( try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); + testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); const [start, end, outputs] = await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { From f4306004321efe9a0e65a19a707bf2266ffd7b16 Mon Sep 17 00:00:00 2001 From: cao lei Date: Fri, 23 Feb 2024 06:02:05 -0800 Subject: [PATCH 050/237] Enable streams for DML EP. This change is to revert PR 19481 since the bug 19480 is fixed by PR 19515 (#19609) ### Description Enable streams for DML EP. This change is to revert PR 19481 since the bug 19480 is fixed by PR 19515 ### Motivation and Context Enable streams for DML EP. This change is to revert PR 19481 since the bug 19480 is fixed by PR 19515 --- cmake/adjust_global_compile_flags.cmake | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index a56864ebf4644..8161ea574b8cc 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,13 +92,8 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# Enable stream for all the non-minimal build, except for DML. There's currently a bug -# in the allocation planner when reusing buffers and more than one streams are used that -# make it possible (although rarely) to reach a reference count of 0 for a buffer that is -# still being used. Since DML doesn't benefit from multiple streams, disabling it is the -# safest option for now. -# https://github.com/microsoft/onnxruntime/issues/19480 -if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML) +# Enable stream for all the non-minimal build +if (NOT onnxruntime_MINIMAL_BUILD) add_compile_definitions(ORT_ENABLE_STREAM) endif() From efbe2b84556c195e7d7f3353321eb3f410a1e645 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 23 Feb 2024 17:45:17 +0100 Subject: [PATCH 051/237] Fix cuDNN v9 build by replacing removed cuDNN v6 RNN API usage by cuDNN v8 RNN API and reenable RNN tests for CUDA EP (#19419) Replace deprecated cuDNN RNN based API by cuDNN v8 RNN API and re-enable RNN tests for the CUDA EP. ### Motivation and Context The deprecated cuDNN RNN API might vanish soon and in addition for the current CUDA EP RNN implementation all RNN tests are disabled due to failures. With this change the deprecated API has been removed and the new updated implemented doesn't fail the tests anymore. --- .../core/providers/cuda/cudnn_common.h | 4 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 350 +++++++++--------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 55 +-- onnxruntime/core/providers/cuda/rnn/rnn.cc | 3 +- onnxruntime/core/providers/cuda/rnn/rnn.h | 1 + .../core/providers/cuda/rnn/rnn_impl.cu | 91 +---- .../core/providers/cuda/rnn/rnn_impl.h | 14 +- .../test/providers/cpu/rnn/rnn_op_test.cc | 24 +- 8 files changed, 240 insertions(+), 302 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index fdd14dedad47e..2cbeb13696270 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -24,12 +24,12 @@ class CudnnTensor final { operator cudnnTensorDescriptor_t() const { return tensor_; } + Status CreateTensorIfNeeded(); + template static cudnnDataType_t GetDataType(); private: - Status CreateTensorIfNeeded(); - cudnnTensorDescriptor_t tensor_; }; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 99c1f48e21c74..b61b104790fe5 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -9,40 +9,49 @@ namespace onnxruntime { namespace cuda { template -void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* reorganized_w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const { +Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t reorganized_w_data_size, + const void* reorganized_w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const { int numDims; - std::vector matDims(3); + std::array matDims; + std::array strideA; cudnnDataType_t dt; - cudnnTensorFormat_t tf; T* mem_offset; - if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } + CudnnTensor tensor_desc_matrix, tensor_desc_bias; + ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded()); + ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded()); - cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); + T *mem_offset_matrix, *mem_offset_bias; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams( + handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data, + lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias)); + CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor( + is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data())); + + mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias; int count = matDims[0] * matDims[1] * matDims[2]; + + if (strideA[0] != count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed"); + } CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); + offset += count; + + return Status::OK(); } template Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t reorganized_w_data_size, void* reorganized_w_data, const T* W_data, const T* R_data, @@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, int w_offset = 0; int r_offset = 0; int bias_offset = 0; - CudnnFilterDescriptor filter_desc; for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias( + cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } } @@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { @@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons TensorShapeVector dims_w({w_size, 1, 1}); ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType())); - TensorShapeVector fake_dims_x({1, input_size, 1}); - CudnnTensor fake_x_desc; - ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType())); - // Prepare the weight data - reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream); + reorganized_w_data_size_in_bytes = w_size * sizeof(T); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); // In many cases, this allocation is bigger than needed, leaving part of - // the buffer unintialized. non-zero garbage data leads to wrong result + // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); @@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons auto* ort_cuda_stream = dynamic_cast(ort_stream); cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc, - reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, + reorganized_w_data_size_in_bytes, reorganized_w_data.get(), + W_data, R_data, B_data, cuda_stream)); return Status::OK(); } @@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R); bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); + bool has_bias = B != nullptr; + if (get_W && get_R) { CudnnRNN tmp_rnn_desc; - ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(), + auto proj_size = hidden_size_; + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } cudaStreamSynchronize(nullptr); + weight_cached_ = true; } @@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(nullptr != X); // optional inputs - const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size] - const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] + // [batch_size] + const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); + // initial hidden. [num_directions_, batch_size, hidden_size_] + const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); const Tensor* initial_c(nullptr); if (rnn_mode_ == CUDNN_LSTM) { - initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] + // initial cell. [num_directions_, batch_size, hidden_size_] + initial_c = ctx->Input(RNN_Input_Index::initial_c); } + size_t proj_size = hidden_size_; int64_t seq_length = X->Shape()[0]; int64_t batch_size = X->Shape()[1]; int64_t input_size = X->Shape()[2]; + // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]? + std::vector sequence_lengths_temp; + if (!sequence_lens) { + sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length)); + } + + const int32_t* sequence_lens_data = (sequence_lens == nullptr) + ? sequence_lengths_temp.data() + : sequence_lens->Data(); + + // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + int64_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + + CudaAsyncBuffer sequence_lens_buffer(this, batch_size); + int32_t* seq_len_array = sequence_lens_buffer.CpuPtr(); + + // 0-len sequences are not supported by cuDNN. + // Replace them by sequences of len 1 and mask them out with SetZeroSequences + for (int i = 0; i < batch_size; ++i) { + if (0 == sequence_lens_data[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } else { + seq_len_array[i] = sequence_lens_data[i]; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache.resize(zero_seq_count * num_directions_); + for (int64_t i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[static_cast(zero_seq_count) + i] = + static_cast(batch_size + zero_seq_index_cache[i]); + } + zero_seq_count *= num_directions_; + } + + // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must + // be copied to the GPU always. + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must + // be copied to the GPU only for the ReverseBySequence kernels. + // if (reverse_) { + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // } + // optional outputs TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); @@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc); - std::vector dims_x({batch_size, input_size, 1}); - std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1}); - - CudnnTensor x_desc_temp; - ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType())); - CudnnTensor y_desc_temp; - ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType())); - std::vector x_desc(seq_length, x_desc_temp); - std::vector y_desc(seq_length, y_desc_temp); - - CudnnTensor hx_desc; - CudnnTensor cx_desc; - CudnnTensor y_h_desc; - CudnnTensor y_c_desc; - ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->Data(); if (reverse_) { @@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(input_size), reinterpret_cast(x_data), @@ -226,115 +284,82 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { y_data = y_alloc_data.get(); } - const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data(); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + bool has_bias = B != nullptr; CudnnRNN rnn_desc; - ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx), + ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size, hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); // Prepare the weight data + size_t w_data_size_in_bytes = 0; IAllocatorUniquePtr w_data; CudnnFilterDescriptor w_desc; if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); const Tensor* B = ctx->Input(RNN_Input_Index::B); - ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, + rnn_desc, ctx->GetComputeStream())); } - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + CudnnDataTensor x_desc1; + ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + input_size, seq_len_array)); + CudnnDataTensor y_desc1; + ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_, + seq_len_array)); - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - int64_t zero_seq_count = 0; - std::vector zero_seq_index_cache(batch_size, 0); - int64_t zero_seq_index_cache_size = 0; - - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(GetCudnnHandle(ctx), - rnn_desc, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); - } else { - // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 - // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence - std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); - for (int i = 0; i < batch_size; ++i) { - if (0 == seq_len_array[i]) { - seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; - ++zero_seq_count; - } - } + CudnnTensor cx_desc; + ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Calculate the zero position cache for reverse direction if it's bidirectional - // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since - // we hacked the 0 sequence to 1 - if (zero_seq_count && num_directions_ > 1) { - zero_seq_index_cache_size = zero_seq_count * num_directions_; - zero_seq_index_cache.resize(zero_seq_index_cache_size); - for (int64_t i = 0; i < zero_seq_count; ++i) { - zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); - } - } + CudnnTensor hx_desc; + ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); + + // reserveSpaceSize is not required cudnnRNNForward, but returned by cudnnGetRNNTempSpaceSizes + size_t workspace_bytes, reservespace_bytes; - CudnnDataTensor x_desc1; - ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); - CudnnDataTensor y_desc1; - ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(GetCudnnHandle(ctx), - rnn_desc, - x_desc1, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc1, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_cuda.get(), - workspace_bytes)); - - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { + CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, + x_desc1, &workspace_bytes, &reservespace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), + rnn_desc, + CUDNN_FWD_MODE_INFERENCE, + sequence_lens_buffer.GpuPtr(), // should be zero starting with cudnn 8.9.1 + x_desc1, + x_data_input, + y_desc1, + y_data, // output + hx_desc, + hx_data, // input + y_h_data, // output + cx_desc, cx_data, y_c_data, + weight_cached_ ? w_data_cache_size_in_bytes_ : w_data_size_in_bytes, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + workspace_bytes, + workspace_cuda.get(), + reservespace_bytes, + reservespace_cuda.get())); + + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, + // no need the following code to retrieve Y_h from Y data. + if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); - } - return Status::OK(); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } + return Status::OK(); } IAllocatorUniquePtr y_reorganized_data; @@ -345,6 +370,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // reverse output data ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), @@ -361,8 +387,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } if (Y != nullptr) { - // User specified this optional output, so need to copy the reversed data to orignial place - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); + // User specified this optional output, so need to copy the reversed data to original place + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), + cudaMemcpyDeviceToDevice, Stream(ctx))); } else { y_data = y_reorganized_data.get(); } @@ -370,23 +397,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } - if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - CudaAsyncBuffer sequence_lens_buffer(this, batch_size); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); - RnnMaskImpl(Stream(ctx), - gsl::narrow_cast(num_directions_), - gsl::narrow_cast(seq_length), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), - sequence_lens_buffer.GpuPtr(), - reinterpret_cast(y_data), - reinterpret_cast(y_h_data), - output_size); - } return Status::OK(); } @@ -399,7 +412,8 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, onnxruntime::Stream* ort_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); - memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), + zero_seq_index_cache_size * sizeof(int32_t)); ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; MaskZeroSequences(cuda_stream, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 1c9483b2afd38..0fa01d3486e99 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -38,26 +38,28 @@ class CudnnRNN { } } - Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, + Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, - cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) { + cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) { if (!cudnn_rnn_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); - CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle, - cudnn_rnn_desc_, + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_, + CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC + rnn_mode, + has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS, + cudnn_direction_model, + CUDNN_LINEAR_INPUT, + dataType, + dataType, + dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH, + gsl::narrow_cast(input_size), gsl::narrow_cast(hidden_size), + gsl::narrow_cast(proj_size), // projected size num_layers, cudnn_dropout_desc, - CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation - cudnn_direction_model, - rnn_mode, - CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC - dataType)); - - if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) { - cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH); - } + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RNN_PADDED_IO_ENABLED)); return Status::OK(); } @@ -119,8 +121,7 @@ class CudnnRnnBase : public CudaKernel { private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t w_data_size, void* w_data, const T* W_data, const T* R_data, @@ -128,23 +129,22 @@ class CudnnRnnBase : public CudaKernel { cudaStream_t cuda_stream) const; Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& target_w_data_size_in_bytes, IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const; - void SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const; + Status SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t w_data_size, + const void* w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const; void SetZeroSequences(const int64_t zero_seq_index_cache_size, const std::vector zero_seq_index_cache, @@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel { cudnnRNNMode_t rnn_mode_; // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; + size_t w_data_cache_size_in_bytes_; IAllocatorUniquePtr w_data_cache_; bool weight_cached_; int64_t layout_; diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index 4bd22340ef2bb..ed8be63679707 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include "rnn.h" + +#include "core/providers/shared_library/provider_api.h" #include "rnn_impl.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index e4e50046b3725..6221afb003b22 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -4,6 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" + #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index d485855ddb417..94c8036be6cdf 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -8,22 +8,32 @@ namespace onnxruntime { namespace cuda { template -__global__ void _ReverseBySequenceKernel(const int32_t seq_length, +__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t block_size, const fast_divmod div_batch_block, + const fast_divmod div_input_or_hidden_size, const T* data, T* reversed_data, const CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); int seq_id, offset; div_batch_block.divmod(id, seq_id, offset); - int org_id = (seq_length - seq_id - 1) * block_size + offset; - reversed_data[id] = data[org_id]; + int batch, batch_offset; + div_input_or_hidden_size.divmod(offset, batch, batch_offset); + int seq_id_org = seq_lengths[batch] - seq_id - 1; + if (seq_id_org >= 0) { + int org_id = seq_id_org * block_size + offset; + reversed_data[id] = data[org_id]; + } else { + reversed_data[id] = T{}; + } } template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t *seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream, // kerneral int32_t block_size = batch_size * input_or_hidden_size; fast_divmod div_batch_block(block_size); + fast_divmod div_input_or_hidden_size(input_or_hidden_size); int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); _ReverseBySequenceKernel<<>>( - seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N); + max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N); } template @@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, data, reordered_data, (CUDA_LONG)N); } -template -__global__ void _RnnMaskKernel(const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - const fast_divmod div_seq_block, - const fast_divmod div_dir_block, - const fast_divmod div_batch_block, - T* y_output_data, - T* y_h_output_data, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - - int seq_id, direction_id, batch_id, offset; - div_seq_block.divmod(id, seq_id, offset); - div_dir_block.divmod(offset, direction_id, offset); - div_batch_block.divmod(offset, batch_id, offset); - int32_t batch_seq_length = sequence_lens[batch_id]; - - if (batch_id >= batch_size || batch_seq_length == seq_length) { - return; - } - - if (seq_id >= batch_seq_length) { - y_output_data[id] = 0; - return; - } - - if ((y_h_output_data != nullptr) && - ((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) { - int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset; - y_h_output_data[hy_idx] = y_output_data[id]; - } -} - -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N) { - fast_divmod div_seq_block(batch_size * hidden_size * num_directions); - fast_divmod div_dir_block(batch_size * hidden_size); - fast_divmod div_batch_block(hidden_size); - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _RnnMaskKernel<<>>( - seq_length, batch_size, hidden_size, sequence_lens, div_seq_block, - div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); -} - template __global__ void _MaskZeroSequences(const int32_t hidden_size, T* y_output_data, @@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream, } #define SPECIALIZED_RNN_IMPL(T) \ - template void RnnMaskImpl(cudaStream_t stream, \ - const int32_t num_directions, \ - const int32_t seq_length, \ - const int32_t batch_size, \ - const int32_t hidden_size, \ - const int32_t* sequence_lens, \ - T* y_output_data, \ - T* y_h_output_data, \ - const size_t N); \ - template void ReverseBySequence(cudaStream_t stream, \ - const int32_t seq_length, \ + template void ReverseBySequence(cudaStream_t stream, \ + const int32_t max_seq_length, \ + const int32_t* seq_lengths, \ const int32_t batch_size, \ const int32_t hidden_size, \ const T* data, \ @@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream, const T* data, \ T* reordered_data, \ const size_t N); \ -template void MaskZeroSequences(cudaStream_t stream, \ +template void MaskZeroSequences(cudaStream_t stream, \ const int32_t hidden_size, \ T* y_output_data, \ T* y_h_output_data, \ diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index 9844e04ff6ec5..ba876011f6b67 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -10,7 +10,8 @@ namespace cuda { template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, T* reordered_data, const size_t N); -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N); - template void MaskZeroSequences(cudaStream_t stream, const int32_t hidden_size, diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index b9875b9553a55..1a31743e2f7e7 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // TensorRT failed on RNN tests - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) { -#else TEST(RNNTest, RNN_bidirectional_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5; @@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) { std::vector Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F}); test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) { -#else TEST(RNNTest, RNN_forward_direction_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5; @@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_0) { -#else TEST(RNNTest, RNN_bidirectional_0) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; @@ -424,12 +411,7 @@ TEST(RNNTest, RNN_bidirectional_0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_1) { -#else TEST(RNNTest, RNN_bidirectional_1) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1; @@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) { } } -TEST(RNNTest, DISABLED_RNN_reverse_direction) { +TEST(RNNTest, RNN_reverse_direction) { int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; // In case of useDefault, attributes, inputs or outputs are not set. From aec2389ad0463d218b8cf3b1e245d4c34e98364a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:52:47 -0800 Subject: [PATCH 052/237] [js/webgpu] allows a ProgramInfo's RunData to use zero sized output (#19614) ### Description This PR allows zero-sized output. To make the implementation simple, it does not support partial zero-sized tensor. Which means, either all outputs are zero-sized, or an error will be reported. added 2 tests: - op test of `Add` with input T[2,0] T[2,1], and - test_split_zero_size_splits --- js/web/lib/wasm/jsep/backend-webgpu.ts | 32 ++++++++++++++++++++++---- js/web/lib/wasm/jsep/init.ts | 3 ++- js/web/lib/wasm/jsep/util.ts | 11 ++++++++- js/web/test/data/ops/add.jsonc | 22 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 +- js/web/test/test-runner.ts | 10 ++++++-- 6 files changed, 71 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 98990a6fe477b..3e3a191ec3ead 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -385,11 +385,16 @@ export class WebGpuBackend { // create info for inputs const inputDatas: GpuData[] = []; for (let i = 0; i < inputTensorViews.length; ++i) { - const gpuData = this.gpuDataManager.get(inputTensorViews[i].data); + const data = inputTensorViews[i].data; + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (data === 0) { + continue; + } + const gpuData = this.gpuDataManager.get(data); if (!gpuData) { - throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`); + throw new Error(`no GPU data for input: ${data}`); } - inputDatas[i] = gpuData; + inputDatas.push(gpuData); } const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); @@ -419,6 +424,11 @@ export class WebGpuBackend { const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + outputTensorViews.push(tensorView); + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (tensorView.data === 0) { + continue; + } const gpuData = this.gpuDataManager.get(tensorView.data); if (!gpuData) { throw new Error(`no GPU data for output: ${tensorView.data}`); @@ -434,10 +444,24 @@ export class WebGpuBackend { } persistentData.push(gpuData); } - outputTensorViews.push(tensorView); outputDatas.push(gpuData); } + // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are + // zero-sized tensors. + if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) { + // if all outputs are zero-sized tensors, there is no need to run the program. + if (outputDatas.length === 0) { + TRACE_FUNC_END(program.name); + return outputTensorViews; + } + // if some outputs are zero-sized tensors, report an error. + // + // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors. + // If we see such use case, we need to make a change here to support it. + throw new Error( + `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`); + } // load uniforms // TODO: add cache for uniform (is it necessary?) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 786ae41646554..b64abf9cc5424 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -104,7 +104,8 @@ class ComputeContextImpl implements ComputeContext { throw new Error(`Unsupported data type: ${dataType}`); } const bufferSize = elementSize * ShapeUtil.size(dims); - return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims); + const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; + return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput); } diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index c0517ce363644..9a1d5463f7843 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -56,7 +56,16 @@ export class BroadcastUtil { if (aLen !== bLen && aLen > 1 && bLen > 1) { return undefined; } - cdims[crank - i] = Math.max(aLen, bLen); + const max = Math.max(aLen, bLen); + if (aLen && bLen) { + cdims[crank - i] = Math.max(aLen, bLen); + } else { + // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable. + if (max > 1) { + return undefined; + } + cdims[crank - i] = 0; + } } return cdims; diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc index e5b4ff2b53148..dd15134861ef0 100644 --- a/js/web/test/data/ops/add.jsonc +++ b/js/web/test/data/ops/add.jsonc @@ -157,6 +157,28 @@ "type": "float32" } ] + }, + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] } ] } diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index b43b1ac37e37d..88555a27be82e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1231,7 +1231,7 @@ "test_split_variable_parts_1d", "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", - // // "test_split_zero_size_splits", + "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", "test_squeeze_negative_axes", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index a4adf5c4ce144..7c03e5b915fd7 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -573,7 +573,9 @@ export async function sessionRun(options: { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (feeds[name].size > 0) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -582,7 +584,11 @@ export async function sessionRun(options: { for (const name in options.outputsMetaInfo) { if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { const {type, dims} = options.outputsMetaInfo[name]; - fetches[name] = createGpuTensorForOutput(type, dims); + if (dims.some(d => d === 0)) { + fetches[name] = new ort.Tensor(type, [], dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } From bb43a0f1338b05e93fcbbe5c5cb53ebf017625ba Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 23 Feb 2024 15:45:30 -0800 Subject: [PATCH 053/237] [js/webgpu] minor fixes to make tinyllama work (#19564) --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 4 +++- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b06c9fb496d15..b142a82e551a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -154,7 +154,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + // 0 length tensors are valid for concat, remove them + const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5c31e6dd86c00..d48bb909f7f8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) { From 46c4d7fe4ad457d517fe92db7681c38849c51beb Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 18:20:22 -0800 Subject: [PATCH 054/237] Disable gemm activation for non-float data types (#19612) ### Description Disable gemm activation for non-float data types ### Motivation and Context When a float16 model contains a Gemm+Relu subgraph, the gemm_activation_fusion will kick in and cause the two nodes to be eliminated and replaced with a FusedGemm. This however is only registered for the float data type. This causes model load failures. Disable the fusion for non-float data types. --------- Co-authored-by: Sheil Kumar --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c62887da09fdc..50be2cbd48f7b 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -56,6 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } + NodeArg* node_output = node.MutableOutputDefs()[0]; + auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // FusedGemm is only registered for float data type in fused_gemm.cc! + continue; + } + const Node& next_node = *(node.OutputNodesBegin()); if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; From c12a20bef95df5437189687b94e7ba2f1bad1505 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 24 Feb 2024 14:06:30 +1000 Subject: [PATCH 055/237] Add helper to run CIs for a branch using `az pipelines`. (#16843) ### Description Add helper to run CIs for a branch using `az pipelines`. This can be used to easily kick off multiple CIs for a branch prior to creating a PR. Update run_CIs_for_external_pr.py so the CI list can be shared. Request json output from `gh pr view` so the current state is more easily parsed. ### Motivation and Context --- tools/python/run_CIs_for_branch.py | 116 +++++++++++++++++++++++ tools/python/run_CIs_for_external_pr.py | 120 +++++++++++++----------- 2 files changed, 181 insertions(+), 55 deletions(-) create mode 100644 tools/python/run_CIs_for_branch.py diff --git a/tools/python/run_CIs_for_branch.py b/tools/python/run_CIs_for_branch.py new file mode 100644 index 0000000000000..c507cae0d9f43 --- /dev/null +++ b/tools/python/run_CIs_for_branch.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import json +import os +import subprocess +import sys +import typing + +from run_CIs_for_external_pr import get_pipeline_names +from util.platform_helpers import is_windows + + +def _parse_args(): + parser = argparse.ArgumentParser( + os.path.basename(__file__), + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Run the CIs used to validate PRs for the specified branch. + + If specified, the `--include` filter is applied first, followed by any `--exclude` filter. + + Requires the Azure CLI with DevOps extension to be installed. + Azure CLI: https://learn.microsoft.com/en-us/cli/azure/install-azure-cli + DevOps extension: https://github.com/Azure/azure-devops-cli-extension + + Configuration: + Login:`az login` + Configure ORT repo as default: + `az devops configure --defaults organization=https://dev.azure.com/onnxruntime project=onnxruntime` + + Example usage: + List all CIs + `python run_CIs_for_branch.py --dry-run my/BranchName` + Run all CIs + `python run_CIs_for_branch.py my/BranchName` + Run only Linux CIs + `python run_CIs_for_branch.py --include linux my/BranchName` + Exclude training CIs + `python run_CIs_for_branch.py --exclude training my/BranchName` + Run non-training Linux CIs + `python run_CIs_for_branch.py --include linux --exclude training my/BranchName` + """, + ) + + parser.add_argument("-i", "--include", type=str, help="Include CIs that match this string. Case insensitive.") + parser.add_argument("-e", "--exclude", type=str, help="Exclude CIs that match this string. Case insensitive.") + parser.add_argument("--dry-run", action="store_true", help="Print selected CIs but do not run them.") + parser.add_argument("branch", type=str, help="Specify the branch to run.") + + args = parser.parse_args() + return args + + +def _run_az_pipelines_command(command: typing.List[str]): + try: + az = "az.cmd" if is_windows() else "az" + az_output = subprocess.run([az, "pipelines", *command], capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as cpe: + print(cpe) + print(cpe.stderr) + sys.exit(-1) + + return az_output + + +def main(): + args = _parse_args() + branch = args.branch + + # To debug available pipelines: + # az_out = az_pipelines = _run_az_pipelines_command(["list"]) + # pipeline_info = json.loads(az_out.stdout) + # print(pipeline_info) + + pipelines = get_pipeline_names() + pipelines_to_run = [] + if args.include: + value = args.include.lower().strip() + for p in pipelines: + if value in p.lower(): + print(f"Including {p}") + pipelines_to_run.append(p) + else: + pipelines_to_run = pipelines + + if args.exclude: + value = args.exclude.lower().strip() + cur_pipelines = pipelines_to_run + pipelines_to_run = [] + for p in cur_pipelines: + if value in p.lower(): + print(f"Excluding {p}") + else: + pipelines_to_run.append(p) + + print("Pipelines to run:") + for p in pipelines_to_run: + print(f"\t{p}") + + if args.dry_run: + sys.exit(0) + + for pipeline in pipelines_to_run: + az_out = _run_az_pipelines_command(["run", "--branch", branch, "--name", pipeline]) + run_output = json.loads(az_out.stdout) + if "id" in run_output: + build_url = f"https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId={run_output['id']}" + print(f"{pipeline} build results: {build_url}&view=results") + else: + raise ValueError("Build id was not found in az output:\n" + run_output) + + +if __name__ == "__main__": + main() diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index df4e70b1e51fe..dcafe898b3bdf 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -3,13 +3,54 @@ # Licensed under the MIT License. import argparse +import json import os import subprocess import sys import typing -def parse_args(): +def get_pipeline_names(): + # Current pipelines. These change semi-frequently and may need updating. + # There is no easy way to get the list of "required" pipelines using `azp` before they are run, + # so we need to maintain this list manually. + # NOTE: This list is also used by run_CIs_for_branch.py + pipelines = [ + # windows + "Windows ARM64 QNN CI Pipeline", + "Windows x64 QNN CI Pipeline", + "Windows CPU CI Pipeline", + "Windows GPU CI Pipeline", + "Windows GPU TensorRT CI Pipeline", + "ONNX Runtime Web CI Pipeline", + # linux + "Linux CPU CI Pipeline", + "Linux CPU Minimal Build E2E CI Pipeline", + "Linux GPU CI Pipeline", + "Linux GPU TensorRT CI Pipeline", + "Linux OpenVINO CI Pipeline", + "Linux QNN CI Pipeline", + # mac + "MacOS CI Pipeline", + # training + "orttraining-amd-gpu-ci-pipeline", + "orttraining-linux-ci-pipeline", + "orttraining-linux-gpu-ci-pipeline", + "orttraining-ortmodule-distributed", + # checks + "onnxruntime-binary-size-checks-ci-pipeline", + # big models + "Big Models", + # not currently required, but running ensures we're hitting all mobile platforms + "Android CI Pipeline", + "iOS CI Pipeline", + "ONNX Runtime React Native CI Pipeline", + ] + + return pipelines + + +def _parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), formatter_class=argparse.RawDescriptionHelpFormatter, @@ -25,7 +66,7 @@ def parse_args(): return args -def run_gh_pr_command(command: typing.List[str], check=True): +def run_gh_pr_command(command: typing.List[str], check: bool = True): try: return subprocess.run(["gh", "pr", *command], capture_output=True, text=True, check=check) except subprocess.CalledProcessError as cpe: @@ -35,23 +76,25 @@ def run_gh_pr_command(command: typing.List[str], check=True): def main(): - args = parse_args() + args = _parse_args() pr_id = args.pr # validate PR - gh_out = run_gh_pr_command(["view", pr_id]) - info = gh_out.stdout.split("\n") - for line in info: - pieces = line.split("\t") - if len(pieces) != 2: - continue - - if pieces[0] == "state:": - if pieces[1] != "OPEN": - print(f"PR {pr_id} is not OPEN. Currently in state {pieces[1]}.") - sys.exit(-1) - - print("Check passed pipelines") + print("Checking PR is open") + gh_out = run_gh_pr_command(["view", "--json", "state", pr_id]) + info = json.loads(gh_out.stdout) + if "state" not in info: + print(f"Could not get current state from `gh pr view` response of\n{gh_out.stdout}") + sys.exit(-1) + + if info["state"] != "OPEN": + print(f"PR {pr_id} is not OPEN. Currently in state {info['state']}.") + sys.exit(0) + + # This will return CIs that have run previously but not passed. We filter the CIs to run based on this, so it's + # fine for the initial response to have no info in it. + # `gh pr checks` exits with non-zero exit code when failures in pipeline exist, so we set `check` to False. + print("Checking for pipelines that have passed.") gh_out = run_gh_pr_command(["checks", pr_id, "--required"], check=False) # output format is a tab separated list of columns: # (pipeline name) "\t" (status) "\t" (ran time) "\t" (url) @@ -61,54 +104,21 @@ def main(): if len(columns) == 4 and columns[1] == "pass" ] - print("Adding azp run commands") - - # Current pipelines. These change semi-frequently and may need updating. - # - # Note: there is no easy way to get the list for azp "required" pipelines before they starts. - # we need to maintain this list manually. - # - pipelines = [ - # windows - "Windows ARM64 QNN CI Pipeline", - "Windows x64 QNN CI Pipeline", - "Windows CPU CI Pipeline", - "Windows GPU CI Pipeline", - "Windows GPU TensorRT CI Pipeline", - "ONNX Runtime Web CI Pipeline", - # linux - "Linux CPU CI Pipeline", - "Linux CPU Minimal Build E2E CI Pipeline", - "Linux GPU CI Pipeline", - "Linux GPU TensorRT CI Pipeline", - "Linux OpenVINO CI Pipeline", - "Linux QNN CI Pipeline", - # mac - "MacOS CI Pipeline", - # training - "orttraining-amd-gpu-ci-pipeline", - "orttraining-linux-ci-pipeline", - "orttraining-linux-gpu-ci-pipeline", - "orttraining-ortmodule-distributed", - # checks - "onnxruntime-python-checks-ci-pipeline", - "onnxruntime-binary-size-checks-ci-pipeline", - # big models - "Big Models", - # not currently required, but running ensures we're hitting all mobile platforms - "Android CI Pipeline", - "iOS CI Pipeline", - "ONNX Runtime React Native CI Pipeline", - ] + pipelines = get_pipeline_names() # remove pipelines that have already run successfully pipelines = [p for p in pipelines if p not in checked_pipelines] + print("Pipelines to run:") + for p in pipelines: + print("\t" + p) + # azp run is limited to 10 pipelines at a time max_pipelines_per_comment = 10 start = 0 num_pipelines = len(pipelines) + print("Adding azp run commands") while start < num_pipelines: end = start + max_pipelines_per_comment if end > num_pipelines: From 9ccdc4961ad76355289ed3a36ccb8307e8dc7789 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 23 Feb 2024 22:31:57 -0800 Subject: [PATCH 056/237] Stop using apiset in OneCore build: use onecoreuap.lib instead of onecoreuap_apiset.lib (#19632) ### Description Stop using apiset in OneCore build: use onecoreuap.lib instead of onecoreuap_apiset.lib in onecore build. ### Motivation and Context 1. Now all Windows Editions come with Reverse Forwarders. We should just use the normal onecore libs. 2. Many new Windows APIs are only available in [windows umbrella libraries](https://learn.microsoft.com/en-us/windows/win32/apiindex/windows-umbrella-libraries). So these libraries are not specific for Windows CoreOS or Onecore. 3. Going forward we should use "IsApiSetImplemented" to guard our API usages: https://learn.microsoft.com/en-us/windows/win32/apiindex/detect-api-set-availability . After this change, our built binaries can pass apivalidator's check. ``` C:\local\apivalidator>apivalidator.exe -BinaryPath:C:\src\onnxruntime\b\Debug\Debug\onnxruntime.dll -SupportedApiXmlFiles:onecoreuap_DDIs.xml ApiValidation: Summary: "C:\src\onnxruntime\b\Debug\Debug\onnxruntime.dll" is Universal ApiValidation: All binaries are Universal ``` So it will give an easy way to test ONNX Runtime's compatibility to Windows versions. --- cmake/CMakeLists.txt | 6 ++---- cmake/wcos_rules_override.cmake | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c9be4aa65d0cc..ed9043f2adc4a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1729,14 +1729,12 @@ if(onnxruntime_BUILD_KERNEL_EXPLORER) endif() # When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions). -if (WIN32 AND NOT GDK_PLATFORM) +if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) # On onecore, link to the onecore build of the MSVC runtime get_filename_component(msvc_path "${CMAKE_C_COMPILER}/../../../.." ABSOLUTE) link_directories(BEFORE "${msvc_path}/lib/onecore/${onnxruntime_target_platform}") - # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, which in turn links to reverse forwarders. - # We ignore that entry and use onecore_apiset.lib instead, since system components must not rely on reverse forwarders. - add_link_options("/NODEFAULTLIB:onecore.lib") + # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, but it shold not cause any conflict with onecoreuap.lib endif() endif() diff --git a/cmake/wcos_rules_override.cmake b/cmake/wcos_rules_override.cmake index f3d8093629a42..ec2303b073d5e 100644 --- a/cmake/wcos_rules_override.cmake +++ b/cmake/wcos_rules_override.cmake @@ -1,2 +1,2 @@ -set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) -set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) +set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap.lib) +set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap.lib) From 0edb03580823c9d9e97ba1a6ea941fcd70a2500b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 24 Feb 2024 10:09:07 -0800 Subject: [PATCH 057/237] [js/web] fix suite test list for zero sized tensor (#19638) ### Description Fixes build break brought by #19614 Currently WebGL backend does not support zero sized tensor. This change split test data into 2 parts, and only enable zero sized tensor tests for WebGPU. --- js/web/test/data/ops/add.jsonc | 22 - js/web/test/data/ops/add_zero-sized.jsonc | 31 + js/web/test/data/ops/concat_zero-sized.jsonc | 561 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 + 4 files changed, 594 insertions(+), 22 deletions(-) create mode 100644 js/web/test/data/ops/add_zero-sized.jsonc create mode 100644 js/web/test/data/ops/concat_zero-sized.jsonc diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc index dd15134861ef0..e5b4ff2b53148 100644 --- a/js/web/test/data/ops/add.jsonc +++ b/js/web/test/data/ops/add.jsonc @@ -157,28 +157,6 @@ "type": "float32" } ] - }, - { - "name": "T[2,0] T[2,1]", - "inputs": [ - { - "data": [], - "dims": [2, 0], - "type": "float32" - }, - { - "data": [1, 2], - "dims": [2, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [2, 0], - "type": "float32" - } - ] } ] } diff --git a/js/web/test/data/ops/add_zero-sized.jsonc b/js/web/test/data/ops/add_zero-sized.jsonc new file mode 100644 index 0000000000000..37e08cd7f20ac --- /dev/null +++ b/js/web/test/data/ops/add_zero-sized.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Add with no attributes", + "operator": "Add", + "attributes": [], + "cases": [ + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc new file mode 100644 index 0000000000000..7be8e8c1cc602 --- /dev/null +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -0,0 +1,561 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": -2, "type": "int" }], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [], + "dims": [1, 4, 0, 64], + "type": "float32" + }, + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 88555a27be82e..e96a0aa045bc8 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1334,6 +1334,7 @@ "acos.jsonc", "add.jsonc", "add_int32.jsonc", + "add_zero-sized.jsonc", //"and.jsonc", "asin.jsonc", "attention.jsonc", @@ -1343,6 +1344,7 @@ "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", + "concat_zero-sized.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", From c980149c857facc2463668a11944af3c6c12365b Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sun, 25 Feb 2024 05:00:53 +0800 Subject: [PATCH 058/237] Add log for random exception in Linux GPU Test Stage. (#19569) ### Description 1. check GPU status in docker 2. use stages to make test stage can leverage existing building artifacts ### Motivation and Context To investigate the root cause of the random exception `CUDA failure 100: no CUDA-capable device is detected` --- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 351 ++++++++++-------- 1 file changed, 198 insertions(+), 153 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 24319184dd0b8..822bc559d992d 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -34,6 +34,17 @@ parameters: values: - 11.8 - 12.2 + + - name: SpecificArtifact + displayName: Use Specific Artifact + type: boolean + default: false + + - name: BuildId + displayName: Specific Artifact's BuildId + type: string + default: '0' + resources: repositories: - repository: manylinux @@ -61,163 +72,197 @@ variables: ${{ if eq(parameters.CudaVersion, '12.2') }}: value: 'onnxruntimecuda12build' -jobs: -- job: Linux_Build - timeoutInMinutes: 120 - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: " - --network=host - --build-arg BASEIMAGE=$(docker_base_image) - --build-arg TRT_VERSION=$(linux_trt_version) - --build-arg BUILD_UID=$( id -u ) - " - Repository: $(Repository) - - - task: Cache@2 - inputs: - key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - restoreKeys: | - "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)" - "ccache" - cacheHitVar: CACHE_RESTORED - displayName: Cach Task - - - script: | - sudo mkdir -p $(Pipeline.Workspace)/ccache - condition: ne(variables.CACHE_RESTORED, 'true') - displayName: Create Cache Dir - - - script: | - set -e -x - mkdir -p $HOME/.onnx - docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume $(Pipeline.Workspace)/ccache:/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - $(Repository) \ - /bin/bash -c " - set -ex; \ - env; \ - ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ - --config Release --update --build \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \ - --enable_cuda_profiling --enable_cuda_nhwc_ops \ - --enable_pybind --build_java \ - --use_cache \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ - ccache -sv; \ - ccache -z" - workingDirectory: $(Build.SourcesDirectory) - displayName: Build Onnxruntime - - - task: CmdLine@2 - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - rm -f $(Build.BinariesDirectory)/Release/models - find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete - cd $(Build.BinariesDirectory)/Release - find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test - timeoutInMinutes: 180 - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: onnxruntime-Linux-GPU-A10 - dependsOn: - - Linux_Build - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - buildType: 'current' - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - checkout: self - clean: true - submodules: none - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: " - --network=host - --build-arg BASEIMAGE=$(docker_base_image) - --build-arg TRT_VERSION=$(linux_trt_version) - --build-arg BUILD_UID=$( id -u ) - " - Repository: $(Repository) - - - task: CmdLine@2 - inputs: - script: | +stages: +- stage: Linux_Build + jobs: + - job: Linux_Build + timeoutInMinutes: 120 + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Ubuntu2204-AMD-CPU + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: $(Repository) + + - task: Cache@2 + inputs: + key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + restoreKeys: | + "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)" + "ccache" + cacheHitVar: CACHE_RESTORED + displayName: Cach Task + + - script: | + sudo mkdir -p $(Pipeline.Workspace)/ccache + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - script: | set -e -x mkdir -p $HOME/.onnx - docker run --gpus all --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory)/Release:/build/Release \ + docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume /data/onnx:/data/onnx \ - -e NVIDIA_TF32_OVERRIDE=0 \ + --volume $(Pipeline.Workspace)/ccache:/cache \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + -e CCACHE_DIR=/cache \ $(Repository) \ /bin/bash -c " set -ex; \ - cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \ - ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ - /tmp/python3 -m pip install -r /tmp/requirements.txt; \ - /tmp/python3 -m pip install /build/Release/dist/*.whl; \ - cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \ - cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \ - cd /tmp; \ - /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \ - --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ - --enable_pybind --build_java --ctest_path '' " - - - template: templates/clean-agent-build-directory-step.yml + env; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release --update --build \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel --use_binskim_compliant_compile_flags \ + --build_wheel \ + --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \ + --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_pybind --build_java \ + --use_cache \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ + ccache -sv; \ + ccache -z" + workingDirectory: $(Build.SourcesDirectory) + displayName: Build Onnxruntime + + - task: CmdLine@2 + inputs: + script: | + rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 + rm -f $(Build.BinariesDirectory)/Release/models + find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete + cd $(Build.BinariesDirectory)/Release + find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline Artifact' + inputs: + artifactName: 'drop-linux' + targetPath: '$(Build.BinariesDirectory)/Release' + + - template: templates/explicitly-defined-final-tasks.yml + +- stage: Linux_Test + dependsOn: + - Linux_Build + jobs: + - job: Linux_Test + timeoutInMinutes: 180 + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: onnxruntime-Linux-GPU-A10 + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + ArtifactName: 'drop-linux' + StepName: 'Download Pipeline Artifact - Linux Build' + TargetPath: '$(Build.BinariesDirectory)/Release' + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: $(Repository) + + - task: CmdLine@2 + inputs: + script: | + set -e -x + mkdir -p $HOME/.onnx + docker run --gpus all --rm \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory)/Release:/build/Release \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ + $(Repository) \ + /bin/bash -c ' + nvidia-smi; \ + /sbin/ldconfig -N -v $(sed "s/:/ /" <<< $LD_LIBRARY_PATH) 2>/dev/null | grep -E "libcudart.so|libcudnn.so|libnvinfer.so"; \ + cat /usr/local/cuda/include/cuda.h | grep -m1 CUDA_VERSION; \ + cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -m1 -A 2; \ + ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ + /tmp/python3 -m pip install /build/Release/dist/*.whl; \ + /tmp/python3 -u -c "from onnxruntime.capi._pybind_state import (OrtDevice as C_OrtDevice) ; \ + ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0); \ + print(ort_device); print(ort_device.device_type(), C_OrtDevice.cuda()); \ + assert(ort_device.device_type()==1); assert(C_OrtDevice.cuda()==1);" \ + ' + displayName: 'Check GPU' + + - task: CmdLine@2 + inputs: + script: | + set -e -x + mkdir -p $HOME/.onnx + docker run --gpus all --rm \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory)/Release:/build/Release \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ + $(Repository) \ + /bin/bash -c ' + set -ex; \ + cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \ + ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ + /tmp/python3 -m pip install -r /tmp/requirements.txt; \ + /tmp/python3 -m pip install /build/Release/dist/*.whl; \ + cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \ + cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \ + cd /tmp; \ + /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \ + --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ + --enable_pybind --build_java --ctest_path "" ; \ + ' + displayName: 'Run Tests' + + - template: templates/clean-agent-build-directory-step.yml From 0fcc6fb7601893bd1e2b53baea4436a7a51b7f8d Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sun, 25 Feb 2024 14:04:22 +0800 Subject: [PATCH 059/237] Add Whisper model in CI (#19604) ### Description Add Whisper Conversion and E2E into Big Models pipeline ### Motivation and Context --------- Co-authored-by: Your Name Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- .../tools/transformers/benchmark_helper.py | 4 +- .../transformers/models/whisper/benchmark.py | 3 +- .../models/whisper/requirements.txt | 5 +- .../models/whisper/test/1272-141231-0002.mp3 | Bin 0 -> 92124 bytes .../whisper/test/whisper_ort_output.txt | 1 + .../azure-pipelines/bigmodels-ci-pipeline.yml | 101 +++++++++++++++++- .../docker/Dockerfile.package_ubuntu_2004_gpu | 9 +- 7 files changed, 115 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 create mode 100644 onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index c7d93470a729e..c9c815f01e053 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -589,7 +589,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): if max_usage is None: return None - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") + logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}") if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage): # When there are multiple GPUs, we will check the one with maximum usage. max_used = 0 @@ -620,7 +620,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): monitor.keep_measuring = False max_usage = mem_thread.result() - print(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB") + logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB") return max_usage - memory_before_test diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index e57385aa6db8f..11e596cadc2cb 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -410,7 +410,8 @@ def handle_output(output): actual_output = handle_output(ort_outputs[0][0]) logger.info(f"Generated token length: {len(actual_output)} tokens") transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0] - logger.info(f"Transcription: {transcription}") + # print to stdout as the output for comparison + print(f"{transcription}") measure_fn(args, generate_fn, ort_inputs) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index c307a3665f8a0..956922dc83d51 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -8,4 +8,7 @@ librosa optimum onnxruntime-extensions>=0.9.0 protobuf==3.20.2 -numpy==1.23.3 \ No newline at end of file +numpy==1.23.3 +onnx>=1.15.0 +psutil +py3nvml diff --git a/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 b/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..6d220f5ede6a7c54893b1dda32b7876c31059fcf GIT binary patch literal 92124 zcmce-^;?wh^FF-0G%QH7^wJGWr%E?ScP!lqh@{}sEse{9G)Q*{NQbm^igdTUKndZI z_xJex3(x+tdF(ycTr=mK*UWi8UPku={_kPr#qnOa)e*gHD8x_fxN@%0Z34hf5hj*b72oR)#g z&MPebTvl1r(A3)A(cL%jZFpjGdTw#~$J*xh{?Y0Af7ds+_YeQBr=>2Vp&-c14}+m= z{hujh*vx223;^KzlO2C2YX<-Ci~paWAD{d~AdTs4@$bUGGEhrU%fcZx#YZn(Qjbp# z+bpz%_%8`9uqmFMXfc%TZQ{JN&JHT@?Rj>YE%-S9$np5dYx(#WhuLU$!KHdZL^`in zyZ7C+3^Y_PWm1iq^;?G6W-ByH+mMe;DMxXFqe@^?hqK5s!Nv4mr2g7ZxES5E0 z?NEjUr%XYGKhD))YNq4gr?urAXd82^MyNG4|}JMwWF8?Ce8$IF;>*A9Wv zBO*fem^5K%ByHKu9aDCr$v!t;j=LgvtUE{iet)?e_E=bS2%`Oe0prw0UGy!!{WV;B zyj1RArUxzCTd2f-)hDcZ8DI>^qE~qN`n8ljx37i7Y4b_mEw@6<+fN)MIlGA#cW%h< zZJI;zj_I8E<@(=FPnRDf{{AfYcHfykYFaS)YAdf)toqyCTm`*Wncy$PW&Klu_Ju6S z9HQ+wp*+YhCcy;_9bD3%R)mr<>fgPJNi88u0cSx};5So^i^B{gPxfrE9NC*#_$OT% zXE+f>NnwZW4#l?K)rTf7cYX5(W&ZbEqo@P2=X1CW6}4GTZQ9*`A22K5f59jEMlV#G zcvg2Vr5!R-GT-F(_WUUZ**_LR*A+WH$fP27Ic1t<2!`*Zz|myZ2o9R0)=7QI^U(!T_}CV5q4btAjC zuP@0_U1rYNF5i09lC3{NF6c9AdXX=vK3~y-TAs?+XdwI2IeFZDaFMExP7}Mo;f^1( zzs$o+>x^xj*%E3|4osEbvdpiKNk}s~bax8QkKS8MCi@b|sZ`m>=H?wl<4}FU>$W3S z&@h46;EjJKCrI_m9@D`K97Ts?S=l`PdrjF^U+6jJZfqtCx+7DgnhM%Mv0>tiDx#i( zqiPwW+qv`RtyhwA11;unkN&!n+x^ax2$fL#6j~f6m!RZen4Lu`f3; zZl{&#eIwQM`>ws_efC?PK=&H|Y9szYXqP~>mGdgml`zk1_qfbAJa>^^GZ7AEcTGZi z7FpU_E4Rnf?Sq*71lR8?ncYEO-cX%Q&IIe}wST%nIX2t3U_P`XVzG~^nqRg29QyLX zPq0(FE!oE?rSz@Ql6rl3kTky@)-5}?L{T(h5v?HWsQAZi_beq5e$kHwI^I1^wp@;S zMbwNTar)7$8d6U-AG2?6T3c9u+xZ#Cp(4xuIgTOAqg97w@|6;$M4ECQL*j4&>3><# zK8JN}J~6KQmu1~5@jSC5tSnd(SFR#DzEMzxKS+yOwp^l{l5oTNW-X zS+!uOnZFe`QV|%=*owMdTs+N#*<$>dP*{K=6>uLJ8daXWW_0Q*{_lRUJtu?M=4Sp~)O-w=`*#AW=f^I>F;DS)#-UqoZ4scgs;#wUTf$?K znk*T=Oxip+2OS`d89L)w&g=U1U=ubtuUex4>Z9vXvXNy%X) zd-ytmqbr3!NN%iMglZ8bV%6tS zGbW}|y`amq<3F%Y67g=D*8VW~b@%U8_sdG&U8XVW352GP#mrz>6-=5O(_Ro?(!#_r`V}aogb{R zs^+i0jJD3Yo*drjRl0E(k5m2UzZpp;k3mAnAxhjuvAo()9F@a}!ziuFPN>P&)WTKX zhV}vVjn;O=(-W*R+*o@U*_g8&RZ|xYZj0jgy|)Cv?fYZL^z7~2XwKZ2O6@7#+%wX~ zJ-k_d)Wxg9Twz$pveM^Hzu<+wTpCN!PIq0@aPfA+M{;@9I4?EkWua*L4bs8E@Y;$c zoUu6WXS%ZiCwEMEPEWY8z;icm|KR`8>17nIP+<87?~1D{UF!fv+qcNnv%uejX4tW^94k-Il{ zX!tF5UTw~{+7euj?UR=A{2%UpRg3AX?{Rb9Rbdbz{WuPG{rV--x%OsjWzg&;eR{UM z_XstPjo~HvONp+X~O=C&qi+;yhsO#>KYtR4|wEx>UA3@1;%qwl1+)ox`x8)&RJPs ztu6V{AO#AgiS&JgnD5P78KEqiRQ-o18OQ37-7&U3U(uYH#1^@Nr^++}~t zs%IC8wYD^sGs>Co=T@&r?=t9z!}IQJOPMJ3%)%Y(1a>tv#TZz6eQ291pKRXdKs{Jk z*bx5RzA|jPFKD+qcdJM#sxZ;+4_XLH%qWkHZ`7@k4Q=07=oi-fQ72WM0{d@QpVJNW zyZm@YTWUaeMcp92&U>yOpI<5Q6J}q$@|}xM@T&de)E~%;vWfzW6Uu@eg_#_E-b^~H z=r{{WV=a-m3>EW)6t-?@5`AYH9S!+>xN8Wg&iJq^qPMKWU5bT9rxiuHIrfSgFT?CA zJX+mW>$h-cf1Uit58HSD&*nF$*1-W2)xwf+75m=~_DFNem`T9?kJmkp99 zUH)v`MRqV`zg_nvGcgXjA4&lE2go9auP~j^!&j&|M?p zIklXgM^Biib(MG%F%TVr8{!!PM`IeNcG1BMrjod|9|S?^;H&^=(spB;R(G^qtN7A&-}({ikok$#pDI z=Kv@Wfenxk(U3F-L<5%F4bTCkmEE)G{*<{HLFoUHPr7@_@XbtrOeQxf@%Vg*jWDnc z=KisB#cU5;XKbIG?Iy$s=?1{TAOHeD0E|$@2c)UOZ~^&nYyzMMA(_td#1Dhyzk^k; z3KKTCW)8@*q87vJ+zZ-8h=CqZ5XCx_5RBISjs|@L@Bml-_g63w771Ds1q>Tt1CZvV z!$<|ylEHkjXYjkl_*DrTFre!Eg1BgajcxvRtVX z*3a6}3=C^P9!$rH_~~D@}_TH+%CwK0UiJGsjkom?N!>Q=&}0 zz)o|bdk1PY%Y;bU@pE{5d~~LY6eA81(-GHB9$EYEeg3D=jOj=ie_WZ`+)gPfLNH^8 z*qPZOmf%F55M?qq;$9{N&@^Zo$*y_B5kfBXl&7t1Ma1>Nz4AXgeYyDAIhE{xC?2O& zX~XLY^}1t8Fk^}2dZUnz?_fsh=0*B3GOr!TQf;SPLI;$hS%UON$B?K-@E=- zmcm>O&M+y`fVp+YXk025T}U>tuPu^T#^*&aXN-inh{kw7NT>IFF_dN;Qy0B*9QpvWax<7Ac@uJIWF-6Kcd8 z683ocrT-*oj690|(>b=X1yVDdy_60@4-5&b1w%MP8NvWSBt~d!Xdd08%U&h7*L(aR zQy&|j%b>MJCA()#(QrMCT|l=U0CNU_Acs8z^3 ztpFzL-XH2OR<{k0A2&m&1ET)xu4Eh2K!<#<{SZ%q^;7K?-qNxUFL0ukNW0lI{0km>!j41;PO z2LD039H$dIx7ohWC9J&d=tR1~Yt(07=W_jqy?NWsz6T)p4i~Hj1IvW`?}SxKc*C}a zj}{y*KMS51uE%^py`HHasn0Kby=z)h*}R}Rlhno2VD>nX^Y_+kMeon+@Xxfb_H{o| zpD*=a^!o(2v5||@o~@-k*sW!(b+qh$T>Tph zvX7i1hI$w?2PNi@30M?NV1ieiD7T2g121}%n(FgspT(^DryN}8^Cq5n9D@UD&jo9n zG&rz$u05gBp-^NtI+#=WfVGv{mt{c}d!D`Qo&0*YmfrHgf1;-W;w9RJDM@{}boMeE z$@%m4hWS-*Sl-(T#2!89K#Z5mw{_~})fs6W7Kd&F?cnF~5<2aaId=aE?4~CT3hvm= zWg3lU?xUQm=XIWOMkzrC1)A`!%j;A*E6eyp062Xg{|1NdG}iS*C9)4S5QUKQ8r4~> zB+AzEv)%T&w~{wHJp`e++T-x?6A~E*iL;ibwIvl;n;iNQ3Hyr@As?P{@Ptze)N9eJ zdNp0(R1pPd;Y<*x3LBGB5!O&}_=d&o4@Q$={ z?=8>tj@p~-lV9~NG@8jWf9awEP3HStZoR+r2l))}=JYny*lM_(?65V9)PLfsHjwh0 zpU*0`)OdKzkSN~QsAm~(WUnX11OWkZAy6YMBmj;B!3&Xupd&~iI5I$_3`>7DM(rgb zA^@(8@PXrmxWoTvjbHMLLHm1Xp3L$ zilo`o6ZT{He`LamT(^o1xwYiBN&iia6ubozpjZ`hP#PzM30ZnpO^^Z{^Q&a~76`m&(c#gyVO23w58P zLm@C_BJZ&vylxpIGf@~RZ`W70&3Vw~XQsG9{C}{3(82s=;=?MuGe?f=A zxXFmZOYne^Km;DPH2|BGP@RAa!4feMOF51v0j|SS=heo~RgY&Tjjt9S+lbKcB?LQ@fLU$g zGeGelDaVhb$GO}G5n;A^sX=`ybt&?SocOW*p|Bxa11j);7)T9Z8#6>U#6Jwlin2zF zK}b`eqvJ>uA!&dNk$-YC9nfz=k^n@2VT3d>5-Wrw&O;k=7NPA(J=7Bc_YYmyDkXJ0 z7IPB#&n}M`g4Q%JMXrLGbX;x2HAI33XaI%ivdW5hEbvDQC(2V87c9X%2w{tg0yKQE z$}av{M%UsBVqj0VJU`rs2ms*43&hL3aRX=7g1@4QI^Dm-xS~h`*{4hVg=> zJ$Jl0lGREsE!Gc&STcyOhZa$MQvVu+f?9fyO$#}%DO#~XS~Gkhl2D* zhVshrh%G#YgL`#6T8xddd^dmH`Uk+dm2~J6?mKq8-=w_K5p*goJYxv4OC>!Bu$+H! z58G1-6J$|_0s6}H$`A+?N!cA&J;a@QgKR+~<* zm)&{AL7^k@FImCj*fRN^p&R@O1s9=t{$Y9f5d|9|vK*lRnXsPBuz`&@#*MgP$^<9R zsPWcL{F`({W>G}G@dBEbHLGz~+qF4R_#9avCc_Mha~`%c98@&au5eqGG2n_Vw?|^a zDhwd#*!ag3?3o~Dyu2MPh7bpw&AKl~V_DO zRfTO5UdkNak3+y0%IVsfl4VPn#>FgE7}Qh^W$4DHjFzw*2L&f;2yaRLixg~uRiOhx z)?L=m$uDi)Cierf_-xX9l5yhB~6Bg$#n*0xEQ zjOA@naY0dvfU+}Ofx!2DWtnwM31+RE+c2;Ppl*7oye{r~47ewMsuxFst^{RWOVL0b zuVZz+FCJObyIK$lQ+F$`0gs*}IdOMG_QEoZ3PJwanLnTkT$85Wcw#WN<8QI)hQ8%3 z`OtCN7>xdc-u`IlkJ55Vkt)YwQn}i*&l`7E45N~%{z^h_M<22TXZ)TK3YGK3)yD=qMER1LdJwA>Z|aOfe=$pHo^V2g2rqF9p?0=%>gA z*aqILZx8jzJ$b`7+fp{WUvEB%b7eoW>~^`N z!6yzfK1{I|DD&TNX0@(${o~TqfhTUorBusBB;#^+uSebqau^wAYwFK*91@NI9uT8c zBnY^N|G2dDMj*G6Ajh-5f1J&+f0bWwKIUNtUxTmc1zehvke2|_!%U8gdAD_L+w zfs|%_FO{(^s#Eb$#361V_k$@>$YOQhoybI|JQ4#?KgYU6)sTo)`+iA^71&>^Hf8+VKSqgwm|GN@|hp zHKkghr`SuJp1cZWffRHbJXPBx&y(^$$~ND9;|XBAzmkqO=-xW`jl8dzXC%LxZer}u z*RGd-xwNwqxR>?*_Hm(x@|c6B-i~t3?q5&XuYc=>Hz@k&){Osumn*{MUg(S=5G?ru zTQ@d&FvrVT!VhlqF1ptQdS`XGg3)cStolGD!CY+l;p{?m+r-N1{!FBAO>xm2?9a6Z zmrvb)z4Y}M|G1|&sn=GwUT&$w1ztw98knjI86=-gjaI2%JmlG1ZY@8;tSygT`Xk1# ziglNFC+|Ogq%J;-o|`Z8Tlr~QlK(AvTlT>WM=s|eIkR?t6!FiV*1jmdb(IdkPxX>W?JyHu;XgiuD|>y)*GdBtwCxJ4^cG^AD7kmW9E z+3Kkf$|$f&s>B7gV7@jiWy4v1KKqtUm5VT|q&8Vo^--tA8U=Cj60MV7xF3q{D%NHDx0*+U2_}7FRsG?ZHJ^eRZG_P zduel}NG2;iKAo4>lx)eHe)f9&;NlG%e$EjC8@3f`hdQJqzok$Q4G~33rZ`L3yOg#^C%P(`co0n9@Kq7Z=P&-L ze0MVqrplwr%`i2c_upbGsr$MqpZLYYpbZ^E7=bUViSj{-jjMkCG`|3 z3kSbOQ(YUrS7v_>Hp|ZXN*)+ev{vUU+j7_=?j+RT)|f(CR;bf>ZOq!%sPg=6siD+L zd627t`tQa|ei0jjOS4v;`TAY9I$9->pN=8euq+B)3qiaT2z64du|Esx<{j&mJi7OyhnObrt z`l=GpVlw8N2zZHpPgvv*F;%g7=|Vse`Z;?124RM`Z7~vL#I7bx9!ET>g{!YbtU8oj zb(W?KX4SMT6GLxkwhttpzZ8_()O~!co8ea|d||;^?(qAnI9bu6%h`*4hUGVVy*STi z%j4ri#bfTH{^PHe!>$L1uQ$0rgB;ETR)inkr~ma&)2vn;wARflEt{NXH{ij?!LY7_ zD?^C^)(FTxm>utxOd$>m>O>IP3k4S?88}*RnOnN&~#BK+N87%@-|Sc`@Y>;oY1833rWyI*&GwSuz~Kg!f{ zF?EL2f!$Jy=!i*9m@dX8I*f-5EfwH_aZ15E>qyXWr@rz1%Sn&7N~S!c7~CDN1b{%= z;NoD4$zxl&g8CVr_|K|;{J&mY_Bpm0H#;j3ZUSfK;U&e#P}ugawK*L)=)C&Dw!CHL zXZNh{eMRENhv-oxG9)Jck2z8@EIFKz1387AAL@Y<28O2x*A;hcCW=Gw;uDR;J5UqS{bOlyxdR3EkwEi3%2h|6S@P zSMKvHC)_u4h+I;c@dJxI{*WPM62Igl@5nqWQL!k)$cub!H#%xgh6b)^)axYit)^ib zLa=sCjAWKZl}dg4H&o?E%I>B|dhb^~rkwhfnW&HH;(8e>4fOy(m{p#MXV~dS&AkI8w^9hNie~r?428!{ZWgmQL|uXKV|m7i7mb+mfk-iKK5mF z4nT>X?j`1#+H#!kZ?u?-pUTm<70ipq-fRykN`TwkE97fODOC+yjNiFAut~NvXB!>? zq_q1X_>2h?V(R0o=Ci*tgMQeZ9GU3@0NYM{8(EmxP@7Ol{;S@gjInPKROk(TQACuw z!n#HVWBu9GdX077`|9$S3L_AW6 zZyDwmFU`lmzXS_L({q^wrN`y6YZiC%b1SzYzs@51H2eOLGxS7Q>-t({ixXod>_sw> zyU8VU|bi7~84PDo7-& z(^`L1E{#y1JcR9=J45oVZf|SG#s$4N)g+=AOQx?&?k(!P>RgvobsAL~-j7UIjdQ1( zH2v3~^^kgFNY}XhP55OWU-^8~?8~*5fj&H1jdSkZ8{HlS{f^crtgGbI!<{f)lm;?= zxbazDV~x^>8)^+Rak}D@4ZL3Yk0p{CEkw{R`G6msgr+0aXCa0spNjcR^*8Wq7+6&p z7rW-;$NC+wL6&WY59SR{RX6^<0Hnu;w-=GD-ByZA0cw70V@D{tSu23I=Z~0v#}6L) z1GAtX^9#aKzrw;#1>DTjb#`^;l256e2$by#hBBp(t*gUaZQ5tKg7o_iR$KNr4&$Gt z>G>V1#-daX>^@|Dl%aTtkd&sh1Lh0s6%&Is#Q69z^JI%v_}2rWb98x zGTO?e(a3hOLY>OFvO8Ai7^td>H5|&RhI%PS(81W!O5b4!j#QpHg>m$pv(xq+ch82nchjF9lP;xAoiNOT(Pbp9 zEI+LX>E)_F0B|Ms9V-dISyz6o`&AwKRk!8JB%=7&2o>|^UqY#Zwfrl@t=$>I48{~Q zu}>Jj$XSW`<)K&Wf?7%{9H|^!IFF$hDwAxK7m3Tcaye_aXR!N=Doj^ZwQ5==GI_g@ zd+nQc9#ahHSdbFRg-1A<85u3Jxt--;%1kE6oceCIDsElNgp$s#)UK499iBhiGuKM= z_pxj5SvU>L0sW)UeKc))ldg$kdAam(Uw=tOl7iU_XFd@^dD}zQ!@GcyAIn&?Txixa zQ-8K&G4?)+%oQeY^~0yh%Bp}udr8aQ|EoJ(S+}h@0ZO#ZO^hoCGc~#9myOI#nSLX` zTa8`N(CE#)>l<%g3p8l{T+1G`H0sz9%s8;ewLEEP%DUnbsiZeuE0Nz?R8Q7F!tJ}i z6yVLRXK-TW*`!1NgR{8{=BM_#7AAWZXT#RXVKf)+Mch|qksHK2vT6Kxo~^D&u;TV+ z<)SjHv2y*wdZY&4`+M$DP*A8?!0S`h{n<=)-FKs1ZdN+QWUH_j9gN|dCY54O{BTTo zjCjPtW@%eY%&WiY0)9@)!ZsHh>k#St-q;<%jUHvi-mLOmoL_?F#Z>xyHnQVDO~5Nq z=|f`G8&Vea><;dnsms}^-oqvSLhb1agJMDr^~CDorAE~3|Y?l1YB>&U!>v+My;CW|c z{UM3Vw*@x|>lc*?eDs>vBm$W}2?k55rCyq+XZEW>Zf(8krb^?lg${~MuDlC94F+{` zv6)5WRi=A)$d#4?j}}{O|FfqozcD5AvOVH6YJ{*x8u$eU*{In+<<6Ff6nPhoHGTZn z68j6Iqz(Uffw-G)cSzbA|Ke|P*_@nn^FixhU}All55LJ-?EHrAEIMi5CBioqn7Qk^ z-p9#(7<_nPH!dp7g#?R6L}_G>rQLm;?^-zRlKlKgJ;4as+fX3bRdG4dq;dS>&trOE{Z==kI0gL=RiQq!y;5-9gK%ab$_a-TgJBJAL&uO0JBhI zJ8pNyKQdZw*9nxb^+Y;K-knzrc_pWg=P;$ioe;39fPo!|KNYxVzrRm`C^ zQOsfbI9tp&Th$OCkzUN9COV>%RxG%~p{n~5Iml}R8RS=`B}gmgP`^zp<{B}0+!f7_ zw^9H1`Qsx}ijE)ygkz%+{r>UJ3A?w+c1lZjHlhFP{L37U6BRQ|*QEkG3TNg`t@F}^ z`adEp(Mbno8up&xe+KM-GJ*w(*$c-Aa)ymmhyD+rncv*j(2ubL>X~^4>`{v1Z{LR4 zL`IUde@S<4CJZgy6Ih8*cCoitJc|uFx()C!&yjL%p{*$I(M(!kpA+j( zu|5s!c&EiwglZ~fgO^5h4~Ns<9gv{ia! zO5LEVmtBwx5G&HNgN{uq5T|k=8F}jjp#ps&R47J*+odBe^}N{Von>rh{ykGc3o(T! z&C&_`C6xH~)6zW{Dt z9<<;R8NR?Co*^#P8S*x)Cy>il0iP2!!V7ha;UQv99C-ae5fP%|Q`b;c`z9y@{EyCi(g;j@ieW~n@ zHGxP3p7M`q^_c%xONEK2&ie0YvAW|{xR76*~#IYu*!px!JSCdd}RRH z9%lQRR4Kc2CE$TTf2K?WuZvV4K4CQ*S_rdNbbYQ?7D0|aPvGudpPI^G> z-tq8Y76YPnpXP1bGecf@G{H3;)3MJCNLM|>qwrD8PkvtDDEMXZPph`%obmp3vj<}H zpQh)fX|UqsbX-+XCQF*r$!fw7rRLwg`VUg`Bj4{w@7*gVLf|&1r!59F*JrikM|XaE z7454&lDR>Hn7&BU6~4-!VhuGWqvB{4f&%s~>aVGFRn^wIG8Xy#KU-W1wfv--UlOkT zrL-lz?E{q#iMy_NTb}GsikO&Q;W6ss zwOZ_B$xx!6{wS@#wGzs=O&VMT=9xtHOA9 zQM&BLwO02_!k`5-2U%GqlbJh^(C>8rI@VT(_)0-BWYi|B92W2_R*uZh=QIr;H^=$= zHqOdrbLZbU3;cju~u|+nb|$pwnF$Lp#^pLZda3a;d|M z%r6P1I;J$It_kXtJ~4~a@D&)m#2*jqx-M^a4kRJvThyz#W+9E}gT*YpGxWsXV;N^f zWwED;otb5fx@v3AwlfAXenl{GEjHqyBZ>bAU-R}d;6s}>2Y9oj2#D+{A>ar^xH+$` ze%qjVWaJzqt)o?8RUyh=J-MwoMO?6)6P_e=?u%LNU@J`sq(F2Dh0w**!1hCc&lU7> zNd)AuAu!oV3^V~5G?ev^uJAA%Ds_Oq4+Ng?9}C3f#xA{qabPD^<$vx|ck;|whiie1 zdA(Cv(WT)@Hg&42=32GF=-`MF{<=vnc3qv5;g2StLJ*tM@GlIj<^@pp^LsIZ*lLMP z-r~h!`T?QJv%+nDkcd4Ix36V@QfN`gOe%+4Y~re59eX5%`K6np^*P$Z@k_N&?yXGS zMw{^hL%4v@l{**gjXeAq1mdYzF#3*}+}%HO9e{uX*s)15@IX-XY*!Vb&JaK$8iouM z0M(lOC%pNB9)Sx;9mL{Okp@zYN;5-Xi>?@$0(n|<>zg6n(FD1;*+@|N5TsKM@OoXE zc0HpN*Jxr47Y_>)48w#?Mza$N^0W8bPx2oFh++F4-O_}F3@>9qjA!hOreDW8NFMf7VhBPzg+2F z(9*sTF%1PT6<(F+hIm-yAOMP{TEfT#01ZNUsfOO!XeXIjYl*{Zv$aUW`yl8#dD8N* zQuzU!2KvbE0cmLXw|q1_tT7>60cWXk#;syr0zq@Z+xvdd7IS)N5fd)?n2;6!B=>~u z3^)NSL81}s=vZE!m3zeIg2#i+C7&eki8ef+VrK0+`Wo8#uACxxW^9~mwN|`AG5P-A z2u0iv#)H}vqL_3XE7qH@(O-!8YQY!g^4i$$p*gfT@Uf|@>MgQ0?Cbr4QIPXv!h+6= zg-46dMQg7xvR3mKmhRv4iqz9Y>@GY7lj?s}f2ax!C{R{Q(4ZC&Q%qKHT1*`;*O!ga zSo>-c_f>s**?Xwf`^3&dMV&g_nIJIJQ;#)ZUu zTaH?0i@@Sx7qnko?inN68WJ_cU)vqRy)NYW=he!PntnEp~O}j`S-8oX*%Rt(% zf{jo_zhc$u*uWJ~=E!E3fv39@>stapCRIoLd{iVP&xqFg8;Qdg>Y&#_iissnlK8A3 z^(V?XS#q`m!HeD}AL;unqfYo!BS>(%dvZ)zr?R@x+Np(m*|KbYlKyuG73FPR8%HN2 z4jDyndX@ZW?X;b-T=Icg>(X$mILYs=poE#zuYctio^bVt!7O#jPu8TG$Rv zA9AWUJGH(l?8x9BRyzO>&|)2m0|yue_>Q|?`mW!RAGtpi6PJ*#x8aT)Rg5gW9<^uY zE}#j@QQ~QnzoClmA!mNCWMSWF5Ipy9go!dHEo1M69ls>#Gs|v~fyV1t`K0vohj;u= zonG_&F~n+dQQd4&!{U1w6!L4#lyEk?%vMR*iT&47&te|!D! z`l$3GAs@b*O`Vk6@)!(Uu|(7hr>LZkZ+QVnGb7&C^+4fYy}Cc+wnoRDulvjoKDDPD zXnUo4lEy8(-71*UXLyTU??zTABqwe!=5_VL{`%ww<-X!ed!KV_67c1!y7@8RuR~~resU|*Q~az5GZR<;U4jF=2D5>$B-KfwR9O7$ zDMR>psBoD`WO^7ByMF**96B^L9Lq8s3omd|dFg4*vMm~_m5o*8&!Cm8Ux+lYc+L}T zZHD#jLabwSZiD&2a|YXrp!c)DNl>Gj*v75(H0(v~zJF2vXr(|3W21hG3q-$4AdYUc zKu!A)wtJAl_aM>grBrG0=G}au$PVLe=X0+%F6T|FFBzFJ(zgpW|2pV=8&$3Pznr-% zjD|<7ZCs4g$W@V|cOXEt>2XFothfV2X%r7rmN*845j?GHiudlY{9%Rgq-3Fb*D9XM z1EbpBg`3yS%G`pvU?q@mlEbY1q8+cG2rcc zk-GXvlvkrx2>Ir82GBcKZvRO>^Al=wkS-_UyGv@Re@9dcd4$%a4bDH#|v2 zS>3?dK|SY@+(FFTUsV(9_Mc}C3zd56S1dd}ltq3ZcMY4;AWzdbIrU9^j16Y3*4b)r z{al$w5N?FZAq&oFjf|sEGnQ4QH%yFB`iXUi`+N}4>LqO&V;?AQr2OIs>bX=Kg?*G{qj_dv}uZ7zjH4n z>;H3t4r}7j6e0U*)6#&#RyptDS$`RVky*f${Wg8zr2OFCt);&`+_BR2QhGwL^2I;} z)ns-Pra)^9>QZq~xa;8!>hbHY5H{)euTr|AxAV^hyMW)RI$B~aZ-1h$Ixen0ovvK} ziNMd8v>3;|iLA1|3H$A6P~F7#ctm}>T9v@UWP# zC;$L$tb0*Km{C)dI6{s#@p3EM`YFZ^g3RiliOx~^7ahM4)RoFvJdL(w#(KU67!rE*7n)V8sa_;_=riPM8d zzB@yAfTgb86KkVVj9}n1utEP`m4+HG%({-SR*w<<1_%Jj`e~u{t^=&EOb#VYxPQ>FjMgh20GW0h*Pf5c=IG6ZoD3NeNl zM)ZWRGeJav#$nP5pW9te6?3DWIPG*WH+g|&&N<;Gzdb=0KzH6=$0#gqA)Y!!zS|Lh z3_GNXE1nqO*sX^dPa$p8t<0m2-@{hhiuOp|D{wMM4{-~cE$A}osJ0=FXsI3!V>d@a zLgD8c5PYDYn6x?w@iiJB*ljIbUK(Lz^!$@+{a;#w=T`8sj{UN_l7ED&XRB%UNB!+)m4G?2SbzTy`GH8W z430!&vdmwZ`PGyl$~C&OaR2d4K*B|`lTD^QeN5bB2jrU3xGJO*DT+o^`sYU`IE1Ji za-8Xj(5_PbOjaa|CeAuVA~`=i{KqMuivPr$#+#6o0m{jETHhE2MKrzOoWKwO>CBeG z8>f}=E$t9g*jryz`YT=KyFExVnIemYDIw1UWjA{Etu2{#lxGj#{kRDU>D3S-EN>wG zrNa2Z)}kst!!`)*fV=yv+gnPy$GCM^UX|@TP`~S;gjoF9@LAl$OzR2bv?Y;T4fDN=z>goGxHgkLZ{3d)b7N>zsKIZ z9H$LGwK!s-elF*ERV-ht5vGd(0|!egsHv&xQj||c+AVo;_I3T2z!4f|f@iv7~5L?>X`J#G8M*nSZtQWvZ^Jo5uUc z$8V44$v>_CG7VExlcWj@3%f-jE0Z|^nAd^tn_Tuf*PN@|(2D5A;cteF1tv*e1Ld}` zb%7WV5@n93b$#b7oJf0VYy|KaNBhS}m+`)-#G567nXLgK2tdGa4l&uN7a9(%fZPmS zu6k+#l_jU`tX;6lk`%$R8B~5=ig7Jo5vAup9yK=$7$ED*sI;%e_)SiBlVx^%qsPco z*!iQ6B(d@8XWRT&FF6Yqwsh@&_R6mK9j??1F#x5o~y69CHy^5wsmy&?MBp3s{?Qu3B1o z{5NS*4pa%BQw>m?J9s*SoV^kvHnGea0_=Vcbox`=Cf9TC63_5=4msCe2IaJa!^3gs+|l);58sN@rzi#ie_Ndt zTQM4PeA&=MDOYA-Bq$GHDLWnoU}zzBrCdx|#0E@vdjs))Ib)AvYC5#_CN)BViU{cu zS5kG$h-6NBv#h;&IxOach(F#)4`9V_^**2={@sv|$YtOn4sQj07|2gu{3MV#lHtkv9}ZFd#xY1>vpcmA3t@zXnWdpfz=m z7~>z@?`b3lEB5Zh)M6#{B&bxhzV652aJg`B>?LK0N0({rSqJ}5$M`;t{Da{?9P0sO z!`X3J(ek9=q}X?#^FJ7ST9!;naIm%lIx~JNiLhg`qUXVJ&;)@nTw;s{z&axla0vl@ zjuidZ>2Sgy{?a$`pCT42!$IgodDpsT-1NWppoUDCAdW5&o()sP(#1$(gT29E0%*Cx z^nmN@C9vmc9x!qwI>-qTO{z@)n-@C2G}?dNRx8N8Wf=(_*5DH?6V~9bEFa0oENP!rJiW0 zJbOwRs(ME~%}|%?T>59&-z6 za7KEZczi)~1AnU8iZ&?rs5D>HA9&p`Zr-(L+2nH?wBNfjkahjiJPz{cM(G!(MiZ_-o(#_m*+I?m>4cQX(Jn|#=$E?*q?!zcFv+oDW5*s8{3N4h-iRsO#M2O0%1+R&ZFq+d0w$J zR2Zc1%paUJGjhLsJVFt2wOp0`M6gO?M22&TSx1FYg+00uYa+>M%vjaMLN61)ir4&- z)3>$6JGj)k(?vq`_LLglcmFex@VM#BewMd~X83!!r%RxrkI~H4`I*%8cKIh-jp`@<;p4hz=$3Cs)H?bb2(fwN8;7oHK5St>fgNZ zjiY70Jcw$h;#*(m-;9Ovn~zCMV@U?SeKbG0QEgecnf2PwOD7%ya9*8h*8#WoV+l-;an(?CLRFL0Vu#zZ0d%K z+%uYWoZ$A@A7A+}W!9BU-&%&F05+A7K@zz@G%gvUhPEKPOL%i!5FN({SI26^N`47% z#P?c%2B!xs5^%Og+6PrctZ>CB@X7SpVUJORT6?2u;zh&cc;o=ja-?q7s$RUx?eaEq zBrNq96F~X)u`!R)qJH-TVvSLEcR%Fknb1#EN?mzB!F_@B@rqbvvkj>t&@g8gb%z3# zGi>XWg|lkS?r3sN{BBrJG9(_>V;86tPuHW^Ll1-ypme$ou&dVE^bJDg-sIO`>h)WrG){*|3o;!FnSSf;pNmYE&wn}f&hcSvqpDI!`|cI(v`KN zpK9Ru37GX3H~L&ujtHR_OO^xviPL4mpMcN7lj3kCun&G6H{2H&$;AkJj4c$|O$MvO zO+nQeU>qXUA_7$~rV-!6F*q2@@RPgisBfc>?xYjDFz5aUFII<@4gyOh{l_iG4^mSf zJI)3V3Lp%e)6&o{anO!vZH_r^ZTtTWOF6BVy%WUl?JJSt@I2?3}R z1%k%zq5g1}!9k&|zI*H)0EYn(!UfQE)>QMXk2gT1qY#RsSYYLMT-KD*GyM8AsYAly zSWIwI0u)_&$|f8PCkL=h$vQki3aH`@LV!j{0G`H8HyA*+VIh;i&R$+1ivxsrW5{O~ z3vk1`VdbR5a{9NdeTx!6>xp&>a z%L!fiKcUmd!qkJGW{URf7Cy293!dPj{&;HQA>4Xk`1eHGz#dhGjbS?z0h+UemCN?x z6zOP9AVB)26Q(7hgEa;N`1V4WAw+#leWLgf7%HSYnJgl#S%F1x7i!kS5nA*jR+ryM z;bo@b&Wq9F8vmmyloQ~IE2dG0usc!#AcYhFuP2jXgZ7iDhFFLX{at;1v@&wT+bwe3 zj=s;MUp$?hRVF?>WGd-+HPtU6JBE+aW=k9dpg52+Kzo1#(LEt{9SaTvh9MweEONjk z6arAZD(v8SWvYm#geVXZo#lEi*zzNF7uW-Y>AglbD%URVi*0prpku(H<_H7~=ML z;`w{nstG6+06?n7?`0@i10mtnCI?BNG_{O{?#_8yDLnt%n^P9hYY;7{1x z@f7enC7?|4wD3)AJ^Iz*5!)bHMsa~R@~ULyk;M7ph$BOD;z6y>;+xVxvx?V`RztO1 zb=7>e2XfTxhDTDr`nVqd`t#?XM{N7;6=c!>sK|JB_2dbO$+v;i3qcI#Pn6NpY`_KC z{j>b=%TC()Aub-&wy_qt!S^lbUKb*;zlxupIKyhm7{J(N*YTs~vY-1NLB|-97Sb;l z%XevvbSP_Zt7FDyMw}wB1QLm*pA)Luh1gV=e?|7O^|KXlcz3#Q?^3{1Dt}|cLAuI| z&O0wuy;WM%yef-OY!C|E!5Gk(c|4|*^sO^*^ke7GZlAti-hP~3JQR2SKxrHd40HdA zT4m3hR1xDG2oR##LEFCPjkh|`orv;DL(jE{q9g0);@_K2nYz;R>JCuO(bRKiWG10q z`{7uOtKzB-ufIA5xV{@~C~lYhXk7?U#SKR!r`@)e_r*~uBFeHSd`3H+)T-g`izy5JcJ6;`x zBi7M!ez<6YWAd6ETBh*WH^lpgqwe6fb3kU6;!^+2t0AsUSoA$tF|dC^s+g)~ZG|!9 zQ;tSrOs{fMZQ1)no?14+N3~36=(AAxrguznppK?HKJAogjUO`y$1}H&Y?rWz zVHgve0V2n4;5x_O@M#^|(p(_G=l9!MJdQMhtqY?onG1tX|2=Flle(3nL-OXke67j-OIsQas1afsWAmC4sBdS7@TM z55~vmvvy?T9pz~<&)0Gp6(J=}G9shrC=*U(EI$r%)p~RkqI2lL{+W{c`J>qvUBCJq zXbV|}FN8ZnZU)_B6>7jz+DO3UX^J~BWoS|YX4t7Eju7U#K+>_f5kt`0r*?`nmu?Z z1uU*>`)BvWL+53`bz|$1unLc-dV7EG^QJ!SBYH2z*$O?RpG{GlhA@3}m3+bX{@A%V zdY!uRV5RcDS$2}bqUFt{H1FyI-5KDdH}t)_g^TVMh&If%EzxK{c(dt)yLpfLCyVCA z`sG;{4Ux*q)Q_5eA=U0T>NPJ#LO#z{u9Yu9rk?u7j}6$~xKw<6y?1x)gyBDV{|kU` zFk^ADVBFtIsf=o7OXFB$4jx~BrZQ&=8G3UU2b;0gE^U5&cAUL(upjf}lcSc7a#7g_Er5U{n`91Dn= z|M{iJ{B-T#Riq-z?~2y@9bUZU-x~P^F1aEAGlB*S2DABD;I zOvVo4GE5V#T7aT)fN){Mha$l+k5E zp1KXs80$tHTZbX^INAP$cmIh5f{n@H2rM|1#J6h1UGI6vlK@)n^Gt`iC0kX;%E6J9 zhS?WM90Bxrdgj%GfUvO^Iz-GX1fucD!eRHT?r9F|X06|G11{bD2b??d&2zrzo%()hDew4<64QN z`$FsRwlm$(sKhyFzU!gx$}-hmRl0yrKK@UOmNAj-tx=3{0Xt(BDE-@6Mb8{<0h3_* z^O(PxuO4Ovjhwz@8KH{^Kw0H;O!`PMspOv{_K9AE!uYHIA~g5{BRMEYu~6=ysQ(~B z#3VTE$vxN#3?M}kpl&4($A`PMTeJ-Q*B%#VjtCv6z~AmAJX9rmvBj7SU4GoTW=*%+KZmz4y-$pZv*g8Wg>KroRHX%B7*&2` z4^T|^uIqe_{Se`x@K~BK0ky=X$WJHnJuM>#KUgsg5s#0Wc?bE_RH)O}oiq6uMuBhs zLHP0wErMaF+2$cMYL&Lf={Ze@%BG8N$F)uHqvHM+y%8v2&{6a+7|9s{w{fI19h z7-B<0eH*85HYtnmDpmXn&2&8@JcXd&mCTwob#eV?*e||oZ6@ml>3#KFl^r*Xv40d? z?#CU!;Vf(K(X^Gg6p7GGiMshCduf}@*FalX;IU}8Ba)e*3Z>9#+9Q=iv$&7u4Wi^Z zR%s^r78iWj-t1P{Mqd04jIC&od(C$~`u)cw9m)CO+E?+x4DINDXPUf1Q@}MRbT8`WOTDwx6+{-S%psaKiSK8d0MRibN zA5n%E>$iP<*_zT<%GH)Kpxe6*mXEc=e)1Mu0W3!@EfEZV9GS=|>WpWMe#+CzjIJ9^ z{`G`XK1#87-2ZW5n;{zpRUT}+w7J5f#rAy*R2MV>^Q?-VV^S8Gd^xIbAa zBnVS*Sz8}+LMyi-l||PC0^o*j_HvgzJ=-1r_oCuuAKqK$MCC|WLAHuP%9Ux6UBUog zZ^}!vD>PhAjfsj+sicvFkt6Nwmr^#_xtGLL1hQ$nK#r5Az}n&Kvf??p)&~VYP0IR+;%`x&3Aw#a-{4q7fvLkse=G(HBdq?I#tAF1Kmcwx3z9?O*l{DIXPV zV~TR`Jdr3f@a1jE6&}m2ngmI)Xf#7zZ5KQ)+MR-NCxE0(r=_uuh_D9Mg|_Lbp--f6 zzt1#mEFc`{ucF0+8_%z=mCKlsf8@y<9nAd72&(7yXtV#T_j#~+pPLQ$K}WH9EO#mV zNYa>ElB%SLz9X|%kxq4RY`jp;?~TN9RPiRCRxtQUiP0Z-=CTg>IA zUud2j&lp6ra6bo&X9tuVRLn>jI6QU2R7kO<6bz~8oQ3yV+vkAsCbJC&vWoKkSGAug|hlF9uIG=N3S^qEnkb)K85 znvE%R&(Xi>!R7;Atjl}-OuS|v&vDSt?5B@ysMrr?jHKzPb|?=Kt3Sl;hfAehi5hON z<9-PLx8$?P#H#sE!?yA7pPl81TlSHl@&k(|s2sSAES}&Zz!w+B?%B)bB`7ewwEp)G zIX7vhtM}QWiSMrldIkB&g+BN9|%Y%CywaQL`j5F3wxm+Rg zy6Ok3@5?0SMz1!dtelLA))QhU#oi#lzJ)N%rHHJKHQS{zn-4Kj(ZampF$4Z#5{Ec7 z6sxg5dQa}C$rG(?j8>5Sd$E|uLBo#DFuw9{n1;RVsZ3z!eZM9ub7FBbq<;7E?q6tr z_J-{Ba>aq&FH(Gxeqg!0Miz61sJVU$w-R10dAtwo6)StB(Lr`fIYQzw<504s;edX2 zblJl_T@&V`vaz%!rB?)F%(bH=iNu}*v%4i*>M!i0up7J#jm6u}0t~(}m@O8);r#*$ ztMqgL0V~BOsaIa_c`TQd#N`NR#T!vBOEFSQ$?KpT7~c)MbD2lj`mMN>1Zgj)`s*-c? zrwC~KaOmCb(SqvQye-HFZ%81+U-FAsOkFQYp1~CiBC3vlce9e-Dt|G!z3|M2d*;n_ zwMU6x8Ewh82fx;Z6bbHiIG>ICuBNKXC6Y(Ek)=YBPdOO-`&nnQSX1Bha_Wh?4Pt-6 zQdQ)eK2_i0oR%D)(bam%+mf8pj~#*f$;7M6@vAUnud3redsV5e=O;eHUuV&}QY?K! zfYLs~l)@MOIr*Lbub< ztAd1RF7oQF2XZGjX1#{^V(nbseOs*ZlE5l4wG<6>p=8};=*m z|3u?SdmoTD={@|<(>yuEZyBspob^gi-HnA)RWFriDzWRONsZ5a|j2eAkrg|cnW9p~e=3`7n z^CHay`kLm5L{?8pS26HqeE$5RBtHf*GniR`2~g zH@=))`(7-_w1j3peR%v`7Un~K@XE)|@vAGA3!zVUYfRnA(odrHFB!BQLuEoUwcux0 zV;0@l`Z{Su*-hF}1EX>i9O|LAyhTDSIq@mzWrksVavM`e(v~=v(EFz0mX{q|mNr1O zPrIF&;GpnHirJA^rWYFH@&*&2JX3+dO@ZC~2TWVth9fn32YV0eUW1_;abPOb(-oC? z>d&mZqnt#4HC}#D5GXgk_&eE7!SihNfKWi<&Z@;*Dl;ebi|_9Tj?W?IC*pS$HaK;I z5jwxUY1NW#m0c4Vc8W>8z|UEM^Op0S_Fov(N`(zg4aYbsF=)Vfta?6W*!zYaP`*JCK@}K+$&9wQO+X;w%8QHY z|DA1Y*vTkp>Kk|0$uTa1zYIrSy?A94c0KkG8x!2V#MulsIB7NJ^PRmJi;Gh#RC|{S z?hRzqb*6fZo5pTe18zA_*8fBV!E^)m+a(e;F&5J zr5GhQwP>6klUos!>2#E8fZLNITmT`Z&_(jGfZ*yeu#6b_o`z~gEs8j(9>h=e!RATs z9I|`{>$tQSzyFq;RmZGM`0T$QeaRt4(WF<>Hu=>gNEB%{RFD}Ie&?3&T{Wp^{6>~` zW4SNW>T^74K0QoVdh>YxiV}-ey5fWLDU1*CQU$uW#t{nh2n-2u)wNFF(oXK!n&@#Qq>Uo2{uh(NdpQM=w>4yyl)C_n?~*@h^$NT4f5ykT#=3Uzc+=qA4Tm z%St2rv>J}&U{=|9*7q&zIr6Ww@vd9*<#pSdqx=_Nr6xid9d*k0=5aAD^y*ov23cIY z5sd>tSe};M;fEht%Ip-?VTuK| zPTtqBLL(|#6E0LbwyOPox!T#XSe_;wU4U!}(IkH}nX~=7dWRx^vAS;Tqu$G7IiEL} zJlc)QCrAHq1!+A0#(T_N&-)Ru5k7`L*u2lfyJl%=9knkJqI$m9m6$Jj&LB#l`N6TP z{rBmQ&}g@R%0t;#(K*`8NsLwB)~h`|lla6XkH7UL#kxL~zDjaT5q!*o2vSKy?Y;c# zdUto9qBZ(Dh16|+E!>%m(y@-b+A`_<*glcq@T-LT*~`EBYyXlNo&{d7UAZ4F1g-g4 z95mN9OfY^MC84#;3P@pb2FxYPiWA47ix7)WWMFLnkZ+aY88qK?Lq>=1Jg@5Qb# z8?M8hW*~5>g3YDB>4$TrCvVr74grjq#+7AOoO zjR-;|)5FMt8m<71Ehy`34;FyMp%66$he5HN!;ml(00Y7IU0`xD&%jC;m8#|RE@c6n z1CW#yIAY64S?@%Z|Q?2En&tPJ1Etmc^sZWWuIuc;r zR@DarP{KfyOnCR=jr(=MiGUQ;IwK4e=^UP{kr);fj=;kZ0foXBxiEA%Qz!(^*i`r3 z3l*{ey5O0H%vUO}?jmL$eGA%HG%q4dX**2ytIlv=2(hkInP*GHDZrp1NEKL*Dvzdv z5`1!qb}VTc+3SoWpTl#Ak3vekpE^KddXaR)P!(HSO|y zk8`F;R0Goe%KVM}L)j8V$yb@a!di`WM&px6VMB9c@tOdby}I(JG#S%F7((tPWte@P zT2^Yb-&J_?*mc}>S#{gKSNi_rUC=%5 z=k7eANP|&1%FldVf_->`#?Fehkdx3Pj z++Hkgrc54W<1a=Igt44`w@LvlE7~VQszb$r`r!`l@U5J+Gp@ja?6ZWjZ^)jMS;uHB znV!)%bg$GLu!(=XGg}qQ-VGp<&-K*8m-KV?P?(&f+xzpvtJ<{Ns28=p_Fh>y9Jw#S z1wgXUK(|#$l*)C&!zMR=7$1z2X=GG-B|!&2{b8Q|IN4}G!J0MFNB$x7Usrq0Dvx+* z6+izah04y__YIzfuD`eE_3=}oM5nX14q}xr(N)D^7u}-KF~2gom?QdwUuSU*ly3PI zKHR}YA$qEo*Q1kPmua=g23vpCu^P|vq*K`G_Ls=erV_}LC2~o@d=l9V{_uO3b9;4z;R^_I>V-tjcQRIuPXKH}pS(J)YdPf}|^ z``dMw_AoulSZd||nV$2jFv`qQOKVT>{o#>Qq!z8Ae!PmBZrBa;1TubW!oz-|=1C~gS=IA$mKko#Tr)=S;eQV_yGf{+8ZAXt;t*z|Yv$2D;mAp!G5 z=WtI>f=Ev`a!r~(WKSd=#3AYh(DlPF@2ot&9=fb6#7hI`IA~Ag^D(bf$Yj$5L@04_ z)I1te9$igesXeo5vyJ_4lpF3)G5`I;?})2=PH}7h=Bpd3;LyCEkXUK7E0&ZsNOVv! zjP@Rm31uT8uJb&HapU0M4=m*LNk5-mO6-2ysS=9g-}*uAdaj>o{Kq;PY66P_evgXY z!mtt2Bx4}J=%c=B_IQ@C;_i5MfLyQ4qf!8q+4$Yzo$Tv2vlmHwd9Oa!`k2QnBK`fX z>9I<&)&s36fTj3h2yE;=VjTe76-xxvO@xEp6V@ttUzMoE?Mq&pKO7$ zPq1u2Oc(CFZKo65YVb?3w#%5>8yuT6yV+nz$ag7wh`Dlg*sgu`4Tsx^CNeD#{l?zO zesE|2d87qx)wEXSL87*ix(D8}Qf_L=HposXN*W$Inxg{-cS^cND%y-b%mF)1GGFGw z+wk-{-C#`u*b4v#!H5H31%&CeV#A2vlg7z9Pp7}p+tvAMGVDH8QCOWTY8+{@9uU`U zb!pH?SIOJw{&z|D{~Q=F`c#J!>WMi|!W%C}*yc_Y<0nQ~+zU%#_lXa)q{!US|9hnz$v7e*?l)Vu`%=$$4W34KYqB@62rZ$#v)fMW^~>|Az# zN^)GLV0xwQG>uFiU-QG)PX-Y*u^gJPd%WHD*rI}RJVIvY5hb|gesh;fucQ`5Tmb3E zBZaE%kUZ;Co$PIph`}SxEaYHcN`ts!VW4QmsYDi5E@Kw4WMfiOCb9zHv>gD31oebZ z{+`5bFwA#% z?c^1k1WobdJ*q~nW)B=@nDjBcZI6Dgmfdc-f3Gyj#X^C^WSQ`t60!YyG9iPYH*N#J zo+8M-3;l6*QLS`-HMT@rJi4|tj#czdEIH`ia|6?=gaqQ?s^3}TI=H~Bl^_qX|gnyoKQIX|m-48333{f;V&)fbNL8lH}aYJG!zT&~+=fY$FL z>&v=ch2HhIXlb2(9!#&XaH|Ry{`T+d1TSnfZz#GG`o18WSVaBN8xjjRn* z6`38ZZF8H8O8m>+mFYhVA&VdSrem(StWo-2T{u-Dam1c+)@7@^jgFMG=o~sl*OZ6x z?jP;*Dw7Xffi>J5chZ_4h*3YMn#iL8{-Xq4xW6DwXM2r3uRS5>^v^FBVE@HHDLl6?jqbTdFgRmFxWu8TVYIV9EsJHD-#Gf1+ahFN$GIfN)G2 zqRBuWPf1|o2|7$yW<6TUl{=apmJ96r+5K5=A88>>|JK926?otpomc#v0j21&v+IzN1^cp~lxXBS`PhoNgT^Y(dOVSzHoa zB)HThcULpUp_g;e4u%WHao{AJYIWbuP%EKy*Vfj3{#7A<>Wg<-fy%BVQ8(1Qh#Z<3 z3gLm_L`C4pM-brHgd_Uv*?Qmzh&pUvkcI*dlOo5s;bOd*Ffc+`AgkfrU*}49eFfda zv$9x{yf;Gmp~a$YE*6qL1o)2ZY4xqI@}53^_x_d!{;3^@s1y%|8^njkFrXnYvejqh8bJY6I{-Sk@y!c!DNNYizbU{dQCb|nD5-uF2rE%w@B8KZXSzDc(Sxa;Z#sTg*dt2x zIInnTAI^O*w3w=$G3O_A0fG2S4G056K?t4ha;z~>H&_)MkJk<5!-BG2_A>^bUJSLu zW++-qb4LtG>>*Re)yLAaOx@2lQ0wY&B3uRFPZH&G033n~&ny)JnRtP_Cyf_$(6EI5FTut^yV0YDNJ z2*m(UCMYNi42q%&gA>D`fcmg|;c^*lD1r@<*`8*&bH&q5GGIo%#P;`vL@dkE$6vGa zUHtQHp|Z9~tj*=S_v%W?DMIh@6+n|>LPkW5SaMK!6RtjQPqxtcqb&kpkccui7D-C0 zP$2DF$j?jRr+Ny2ZJ^LRB!ZFrmuNTFxj@PSFHgD%N-m@H+utRZ;toNGDAOG1MPw_+ zP<^l_(e>`JX|s6>6cQeGGg!iAOSk?t_|5*~?JD@RoN;~PCT?%~e|mZJ#BHMc$X%xY zHeKE4HfDc$dQN)WOAOwBZO-#JKJBimLphdJ^^@_G=P%lX$9msgT+#WiH1{K3wRq3Q zSB~O0eqt9P<6a9&t1AYIt`o}vcZ_$u&q?SO@XBtqAX4c?Cm%TY=E4qeT3zyw?c(l6q4t)4y&ExJx8Bx^^JVfJ-&j2IX_=iKa?}iW4%L5F%Lu?nz8H`o zzOto|ogTN6%v%kX=^+csGMBxJySLAEDGCYVNS{FLxjpXwNCvdTe|9-;)Z^=o5jnD1 z^0ocPpr7Ao#kbH=?44`K@uHAb4CSwTl~-9Kc^y6nyMFl;aH3uzY8T&fSzwKo-6pX14Y90FCh09? zrtHo8UU|j69&8ryI`4Yd>Z`>Gm;$`~=b!n_saL@M%Puv1aCEw}f@X9Lr!TW$koV+< z|35GY4;yD-h%u(`dIUs9B0 z;eY_GUmL~Gl`Adpw+*+P@n~|z?~hG}NmJXO&Q){uG0PS_D(*JXT2T!U@gn)p+AiVT_TN;2nf zJ1!-aWa^^y>D5D6tfz!O2kM+QmmdG}vbPcJ<$_QJ#t8h8m~2?p5zG7vw@X_PaTfE7 zHcAZAHvxVy>wEow0Pp?Plfln$oOaGa`vx%-S+; zaC32$0Jc~hP)51`+#3Z?q%#|f%QTHP)s(_PBh`K*0=ph}|CZf5qRJ?fA3ym1MKpxC z;q@rFP}cI7A;goezu%w67>%?U+D>)6HFS4ba{PY&??m{j{;S(5ZZ1w+vN<}BU-XBr zRjG0WuqE6n zEhL#r+qYesPVzTt-zrtUC5IEX#?LO^DAC|3X52)b!bvTbx^+v2@cmb?$fLHb@4^fP z@JZ#M8}4G$W6c_aqx}3A_`~nASi9||#y;--5tRDbe;S)aZ@kF(8FiZa{B+&yN!@;2 zD~WjaP>$ypCV9`d8$Dqu65G|Ty3IVy6QuRc`CdNmV*OQ@pA8;-Ka_T(YU}J}grR!G zSS*TMO7AS}2!;m5HHaub3s-%;RsVAFgqW8+xNwE1ZGt3kHGTHDrYWqtZ7%QmbhDz- zvn~zQW@;Vlu`l00(nQdPnm+eB#`$%o5=AaR*-ul$e;Pc{dq5_pM0sAb+1g1j>o^<0 zBVl?PU0pbw<6R~8F(yyAdB3jQ00kq$F2*(9jHCrZb(Ts>pHkZ{=ll@)dLkT0bc37c zo=_;8Ld|HD|Gm6=G+9+r5T&SM`#V`Obt!U#DZqBqRDY{4A};BB{K!`F&9<9(vsFe# zkkj1=!NKPU0|T-(M!XchtDdsq*?A0b4k{(JqJX1g4fcQ_GMPP$A5$bTdlk0<;^A`$Y*MG48XoA#@JXr27WzckLFjD&Cgv? zN)io6Ru(0z4if78)Xzd*zHJ~?Y7SlUj#=gLueb9#wl2r>K6{#T6(gEs6PJ>Tql+Jw zLqj=BSDxbR{`t1b4XUG~AfvuIC`%;RA*7sAHLmMly>eQydH!E^EoQOQ=L8~Dfoqmd zO~YsYx^Aa2>Su?Fz^A6*dIJ*1N9=5gB?g=#PmdRVn2-_WR1?YPl=~aW$phafvk(i= zkf?xXw-utpgO|M~S*(K)O-;YoccdPCUzSj$YKPrxets769?Far_^+e=zjqeb&hpz* zOYrq^wmz@{8Cb%)%*f6(=K!6jmKOq`MY_??tQ*XNAP^oAQ25%QY4Ca=jja>7a=-6IjHEuqB-+OW4#YL(|GBd&-}H;VN`FMwxH?&w z?4)dHsgT64h&2hK#( z&Ou7}REV6;LqaD?b6OGAgHzfoikbB~H+qZjr{M4cK(IWVCyd4NU)sBiQ7f-1{b%|! zQIAinsq}IXTq(VdhEUC5`pE9Jy*Ra#G2MuWp~9Y_kg>0~yhy48DsK+Q$$Ts_pSHGaQme87_^;=cN~ z1|5$|Rv87op+`Ee6GR=tfqj1Ac)XIrdQc>(`^$u(7?)}hLDDeHq7QC@lbOxKv>j0&sc!m@xr!qk2M{aHS!-ey%?7TK^Jfq#YtlQVnMX+s2Qy?QH($UqPI% zvNl+?@?{Z`5Qy3<6fp504LqE>)kmSgK{GmLT7L0McVC?82F_ZzVi2M5WP`lKy)I2) z|GGp7RUpH%N^)J4xKWq0nL4W^N%_2bW)Iwco9GPx?&I9klR zay5hI(}!}9k=0XmcRC5ra}us6oecO~dHPsfW6LC8up8=FJkd6gsPd@iZUK$=7^MwTk(Fw?^uMFHVRMWtgb5iTUi$ zW?$kTCvkK4Ep(>=QbG1cKduiSJ$dpwhW~?C7C_?d2R(CQRJM8mHs7CzGd&gwAxrx= zN<^_q=OH?%8Lh2rb}&0NZzT}j%16n&eeKE@7FMp<(NGmZ3D-Yq?65x!`d$+v_Zw4& z{-&~3i1c-e(H!}esyOZIIlq_{uGXuo)Jf_?#O6%v#-={5Nb~4=#q8=H3{xuUdu+U? z?{>#k(r)-Hf9&sV)Apyop-K5y`TWaOV~sDKdm1>kl(2PuXkdzqax+=Y{)KLIvmt*~ zRA6okbZ2k5yJlQukR*ln(LzzsRu)?YG<8DRMM;<;CTtjEvGWFy3ume~93V*+`9Ejm zKp?`U^xyECpwZJyiJY3s{E63#19iOV+MI|y$gWFWsw4OVsUzt~xkw_qku%do0lRU8`-wT ztLy~B*B_ar=Zr;YW9vmmbz-dUZNC3judwuHsqD<;Qp+IC=E0j>vVWEHn+N}C;s{w` zG;XnEYis1mS<(bUDm)gLt;v|7AO$UQ(P_0zdxElw(h6Tsy}z1rmGFS65~^G8k^&?a z)a3GaOj$h8}FLI8pB??*8EuB@1CEwb~;KS1sFG zzqK!P@gyEnVA8(V2Lj5NeIa~I$Kq;s2zv9kPP62Qx<9Ngek!a_Z@8=tcWg%>TQG1pp344~m&B{}2R!YPx&akNcEt0i4Jq1lQ-k}gxPva3 z#?~ij=T#rRqD?BAF+rQ=UjHM3Nw<5b7WGdm1r^B8L}!;4XxK%EqsNTvFCFbY3C zn^a&d5-XDXd2IFVTUC?hEEi^d7UmzmaSb_~RQ5tbHqLTIeU8cT5d`$3Yjwb(@Ca=J zaC(PMMw?{1(laK#a9KBXak=pLmTD=Uf-mZ6pNPaVYw?7E4pqzx`kWGjq7YSMYX`Xr z8k|T?8PZv>j6H*JIB6D$=1s;aVfMN{lMX3U!v=t+qi|?|?~&X$NigA)J+wf>d}5cU z`-CvG<5JE|Q6N%uJ!vKO(Z%onuZFiZ<^G>+IK3a4YF!Fai~;G~#&hR97XNnWvke4g zG(~a4%5{p1jBG%N0$gvKSos3c1vaYaVL&&aJ9C&S-9Wz@!o z-_n80dGSA^j?{yte3KR)KLsVFFrA)xGz#dKOI4u9h z*eHfa6W;s*v_lhq;@GW!O;X>mmDNn)dG`JgPgLlCBwYhTC0rA}*=^QV8*bLkZnJIM zwry{AZ8kRBwrjI(ZM*TS_xlGkXU>_KGk6|-fb+iT7yo;GLF7Xi!IYJCc3d%kBJhCA zpnHQ#WO8yra)VMvQ$&CP$CJ>^KH#COg&&*^3)#SuW{5s-M*oAUtTz}``o&8z9sdh? z3B6~HdW4(pS0+vfArb>zCVluI;%cno1E(@!zPLy&zovBGHP7jXL>al zDH0Uv{VnF{e8q&+ab~IxXj zgS#V!l0nc(zyH|QnS@f$WZCw0h*-|4q>GN7CdB><`)yUpfw;2FO*a*8WIluAeP_)b zkEB2jSj}(ag@la%)uS(jf{M!2@+Jk6uJ64FAcAmOtydwK~WNbednH~OvPJdBO8IKGc0Kw?r+(drwNrrMiPus&H{3+7bbWUCEZ)F6Brb5Z$H?PoGvQNyM*4y45>umvj>iw;T}r8?xQT)4l4AO{EL&`gZE8)p?P$?a8K{= z>vN7FqG((^YOBhrVM@|k6n-i}FaG!aEKAMHuA%&2LO|$>T>gGY#*0O`4?O&jI7shw znxQ+7az7K7NXf{vIU8St_ygX)rA*b^GF-0N>X#;+=#40JK(ptpn%9=}VhXWf-Jh=r z)XnPsXl+&|L(sa6nNl`xAti|&9FMW_^!+oz*urc#`{}I#8s)>AGzX8Tk*4_6{3(;s z@x<-QEEX{#us~|VZE};^=VG!jI0q;TNgf4{Iz-u%_cu8l`S2vtOzua+NXmkhnzK;t zStgk$J=8k~c6dsqoF?KhM=Wescs!ole$uJ|0b}EKH-q*=Mza%u0S${sH`JeHO5fyd z#!tPD*_;k@-sWcseDMU3R`@v_;Jq(Yu#k16qFm(#^wrwkveEj=gzR7`gCj5*NUxy& zJgF+9V(nAA@2*ap!+K~))Kgi=sVqIlELko~EBjbCPfNP-C?TZXuVR_RPhPHRoOSS2 zvHtXI`U4Una1rmfUGdQRMvC(*H#d};Jb6SbId`c64?Z~(k#NQOBs@SLS0v;o|I(;^-<|{0FX7l= z+);Lv^4NF0PrX(cGV3Y`$7ez6e9eE$kL?@n~_^wOPQrmztQNapdUNsxmgb)dh5`c54$bW=of>){uZgrc|Z_xHWM z3@P7)GhLE)JM#2XhL1_xvnH$ZP`l&(nyl;X>%5TLuv7wLrfMqoDlifG&2P(eZ!a|c z7RAOlP$KxCa>Feysd#e?(ABVd8c-Q4w2(=xDrIFXHx-eD^*@>ExQ9|c@G2PxZE5Gx zU>W$2Mm3twgGl_=f)O#vFoWzECA|>!h9F8k;PJDHo)i|C*SJO71*4S1=s&EXp zKtxJ)lyawPDM11kcmhn2tyfl3B-@W|n`5pWv#1Sd%`2ZSF4UeCrw*cOOE$}#u+^o_ zlQ2_~NnA-bAsX{~=q|Hzs3NP9E%6B@?E=-KoIVGfb~m-KgGilKAoh`}u-MjBVw!lZJ-54QP9?m?-L z^+>(_Oo&j4c#IGqPCuAK-7Kc!S{s2u9oiz$sA~!W9b^Jv2Ge%}g+V#N=re(mkkAo9 zo1mW%T8zIxWA5sI7$zSxCn5jTXhrsp|Ad}CEW6^htCue!VuMNds|!j4drRqIhG6_t z^zp$11ELKjS)d_$Gbw#62+XV;S7_f0Dez{V~j)O+$?I8P?X$jeOM^>y%e?6OC?VQ-nJD3;?$MgFrv^_k*`0O z6r(o4!D<62=}-bz0_+jf5CcK{-_lS5W`j9+alSy}`up>y5d#DQTd2}t0>J{)7%;$m zuZNM{csH_#c`uh2It_=u8D%eDRLk~Xk9TW4IZ|K&T_XGi#Cur~rNMhxdhy{WU;;3D zUI0+wS$#9#5E=XI2b77TN%?;d@L*5yLr(O$Lq(H|b_^ZUMzaqY_jM#iSBU#pCR-9U zCaJ0r)=dH%!9;@xM(uQoan#HI+(W868Of!OuxbfGX`WOxv?vQ=H*@?4l7ig6jxJizt3yY{^G_s@rmY5ZWkJA};WyrYfNZqCppY;3n|tahrf z!|Z&l`SJc(QcNXn`S&6C=75Ix;X&D_&dox59N>o#L_uY>ds}BbH!gYfjx-;_)T+On~Wn}r?Gm9GfHve z9~d<4v^-!(ukCF?+Y3vDnp4k;&=`|iLEvj(ezU04Tx;7^e2mH&Bv6W;VRfc=wLEJ{ zV0G$k$j5(H3K!hBn`;n*qF96d{=U{WRj-&cL_LSN;oVt|!eTf@QwpdOs82sW%=hOe3QU|?NZ;nAZyk8(A;sS2H z(oCdZo`7LMAe=ukkGrptl=2I6XE#R6LB}MCelz7-LSY^wheGLM#BO!h$N9(lmX}L@ z6_Qc^Q+U|COZ;;aN0Z)SXs~vJ;|8*(Rseswpg{2lwzN1s6vYLA307Z(XRD#<>h@oc zwXQ&ZMbC8cQ>j~A>$SSR5(1v7C($X2(nS$8)|b9h^DtDw5(7Ydu?VezdwyVFF>U&RB2nR3_IV|!vMi$2BSj!GGhESg}uhZ<6 zS9kR(2>T{eJ2%DS0wV`A%>KQF61S0es#dbI2`lnp+hh$WFgDmze9ek-rraO!l#g9t zYJ=2Qvl+y^y&r5c_dfXG_b6X<4@5e7;v~xSsqk5%j&Q``8OX16IRC?VKsULPtvt@GJnT%MB~2F zb-!_JeYWZP5E-JlX@ECo=!tyx=x6)>^FxEn;qP`&_Nmb;G{%~$;lK80p}g=!@Rl1n zRAgZ7BWipUo47#$cvF+zKMv1^fO6RIUlF9-3llz6tIm?hO-JEg zgAU8!4c?x*q0yDvcirdJjcO05GM~0Mr0svJ<< zEjsidbq^PUt^&J9>$%iZ_NXcG!TXC-o8L=+f3#6p_w+G4&?srH%6r&efwa`59u{cc%7rqYxr%aM+Aq9}7&({$BRFmh1= zHfQ|)pn?Et+)3BBSI@!Wp6AmTEj(ff{gFt;^1+uHy}_xOdK-Bs3>E_t1EVUOU5oD@ z8hD0jmdcM-sn?T#ue5@3m5G!{YYcds>)|%1C2xSK%_7Oj*DY&Ai zsdEUT^xg>TbGY~99|}B}gW(>MHe~b?$W!e^PgPRL%59Wr#T-<3g?S^N7T6UIdt~0uzxXs7}G^{Yd!-7;?aTNov7TZii3XSsHM-AiyFNJaf ztJvFpKFy@enei$`ronNPV>f)Q2Js*F=)P8RFU}kzm>~Q9W=K8#f^eAbLp-w`_eKps zD6R3)MrM(VqSE!U=zP}bJME-0LQA}*KT;4l#GmewBPEQkUXrVNkHc<#mb0_F3f8vqv%f>m(}&}lm;fa8R#mgm& zWFFdG2NFKyC__A%j>p%TAA5fvm|()uSS5s>_Ae`GNkdzYuOl9GGX0LUbelJ~eS`0JL|$u}Lx+bPjH5MiH?CKWS9ArL~UsS&{#XbVkA zbER1YDAbO!TGy9Uj4%4BabuU1gu6Mu6uBB=X1B!Bc_ehw79=Ict~Jvp^nd1B2j2&< znhEu9!cjKV;H1*BG@gC;{t%A>@0q3FM~^PVN-p%;es)%jq!UA==Yq5qk>mMYtuZ2h zOk5Iv(fMJ3P1+q9#%_iNn=-A86M~tuNT{13Jx| zav~1h@Zhy>rDF*%vekKF*o#t*Xs^N{>^*Vya(LEIaQ5_KA9!&eyn&q7ST#8MTh{X| z?L=AgbP9WU^r(8YY@KuMU2EkNQSZS9;*E;%1c_NV=!l#2e0>8HJ8j8_M^MeS;tr_K zGJc{iD9?6jSpmuQAu3_Cv(EOxd0;{d-y?t-!kCH?#az&m^9q&~oP%a5gs}Qcq^rx8 zrz%QZ*fsATyVCtbXf{d1mRZs=DpKyj(=s2_?Tps6JWEvnqV7FHOKQS2Ukf&|EEM1E z7FCO?y)wz#nX|F8p!zh;@gaEPMFbUg%Tp`U-7VDkg5=-qI|0Xy{_Doxq~(G;lGX9c zE4W~W?PEZP76=1}0LmE~o|+2Vj-9}PpdM(6|Ngm~RLaQRc!o$m$|@1a+A3DvR!LYE zZ29!KOAm8PLW%J7>8FJ0DwcRVY5E_mp_(o@j*%N8%epKo8kq-;Yz(0f5yeRrs>#J& zbRL?kZ))#Ltu-n&8CU9!$vPQU>JL9wsz+HWfOxT#$|@M6mW-O|_EDMDS!kh#sK(d- z^mE+PyrCiku_O^3du~_N(drO?e371BqW7?p(9NmfO)*U!+o9-+>c|Sz&^;@^z^wSN zm-@gK1^UxZ97B1JNRyJBntrq^ulBN|7X5t@7L4e+kPLcX7(d5u?AwaIKjM)Ck&$C# zbBks7PMJ465{yLOMzRt;qm%P`Nh!$z0g7>R-sL&ZEK0-{2f@!UrmU&L>{4<(aD zG}r{aGs2201vYC;%NJg#THZ4yW=`#Cp0ph3Zc@f6EY3lLd~%3FWC7C02ZkZ8B16<8 zm23+)tk;U?BX2IYs1t_f+a`)*5?9`D=d~M|NDGsRcrN!-^^Ttu{6G?n#ydf&r6(gvjy}5wBFl}6{2-j4V^~pAwEvrU4o!~s7$pphh zJ+5IBGGM#pnvNW1o`4gTbFG>{h8eQ6(%LgQW*ECU+#@u<(7cxazzI7#`msrtcZ*NLSs7y%a{ z5h^3)62#1($p5^sOr}7oPD5o>?yXtrAjT*gm@t6Ai9}hZGZDJHB-*f6*`OonWQA?1 zvf;i7A-5phC_SySexZbwpkn>H&JkVfSv}7qBE0A}EZQivSx}|06w0DnwLy!YZZ`z% z)t_0tdP_#spnbgRepZi0ib&MQjSLgWg;$SWWl@9(5gEusYaj}D|ArhwBDvWlJO{2U zlNnF)-`vFvg;KFbm#4gs<3#f4C7;cz>O>=}kuWsMZ9%lj!I;nWN?NqijRw4}0Dfq` z!Uo<1CE5xFqVrX+;yt7Md^y^W;z-NJEEYyw?qWN&!l&%F()ej^EtBv^%q?`CW+5wX z23Oi~PD)}j$K08hfuuwR0itGBag-}$gfZwvsK1#_D=y2TzU>hNa2yK*>C`Smk$Kw8 zTm=hJ{VMmO(yQobYbZ&VYtse##;2j#Q4B1f=e$Ctv}vd-33JkcLL%9lv*}fV_-wBj z_ad>t*|o-XP=mR$+rk#L75nz~*n8)eF(IlfrSTx&@3Vn0RFU5Hhi)H7h|>HQ)0B_cn5*PLn`5TVD)(6K2W`a zS_aicNI(RND)^jAIw71;$bP(Wt}}Q%lN7cBm17X%NnDhdNocM_qGPRuz7?6F88o^fchADVF4(xpXVJeK zl^YR6iq*s@GjzL?Y426lcL{JUEQ2vvtv+pWKzqBSt7|lG(Dgb_)VcaT@%OdZ#ExFI zK9yQ;rt~k3-dFg-it4%gKgkmX%y(m%>bs%%p3&y9%553#)fO6(YAC`mD8cwMjcIc>wYpKd&H2BFTE&u4F-6VP1!z5%cO?UA#&y$LpbEGvK3Hi@lh*z@r;6jJ205EJ zDi$*5|GLE#3YWKd9vmwoh*&b%6S|@w6d9IgasGzy6iI9UP;xOPCB+v02WqK${!`}P z&WOGh=22_GU=`!Z;jH!X8=TRZizL_PQRC-Cw4@r zg~iZ8>n@Vn{8GtmUPz;ZBfuB_@0=+IK5%IaVSW{O3ZC6sIb(!1QgCXAHL{|ZALZ%S z^{RlVTv!&X?5nVGugiz}${+-=hi5wm)-vn3FCQ^x9sA=qpOGS8l=rKj2j-Wpl>8PZ z-2a#Y#dMBz_ehz}%*VS_3sm?;*KW_Omj>j%h;!L%BwGqp=}O9RYaA#JMMx~M0T&$< zT*h{#e(AvJ3A^6Il`H+Z1|dzcG)U#gY64AEsnA3aa~T`iN@#4h4nAYb35ArQsjs1Y zCnn!4ZN@~`VCJ1qo}fz%u)Q?r!e$M3=?7CbJ{>H`T;9i_H|drsF~{BXl)c6F)>8PE z&C=D%G^Ev?yTq%Sw*+SzTJqeLS$rKqCntUSEu3L6s~mPcV_BzZOH5r5T)-PeSB21j zCM#4qI2FAj#{p879%Q3!fNyi>77#Krt;1RWd4^LIXQ7(92Jtf0^crxU+E(T5Ru39r z3BN1McI&NpkuVnz>uK5bd^89ZLaWVR$qAMP4U3kXOT(swDr_cyo$1(lk4<)pDy=Uy z7w0p|FZ_Elnz}z-G@2r-tug$$$lQ8MP|B6w`-eC6@mJYlJz1xpGa6%Ys>IM=T~qsi zvd{f}*Aap%mH;vQd7p=-p|namV?p6;Vr@V5h?03sS#?c*Yu_)_l^ebK=S}DrUGg8e zVd?{m)Je!6cXhbdmlP%>3bwp}R(<9{2a%9o+$)Q76wfk2Jr@fP1LXgdt1QAg+!rOh za#{<#Q-a_Qh$e>0Hxa|Qjb%xXMe&x3O_fMMF-SHRH&-qvFS@we3wB>_VX{m)+r8ch z&zzh@XD+yU{m3ONs*22wCLS5xq>X5rtI)uc=gyb6rV(4nL{H<6M|qSv@S(wbXQ`euH1(wU!PvC@<(Xc~TPuiRkcp)*jY1*J zVVni*0`>K4o;}M?s#ED4gaWUKk82OkQBq=PYs1^<9gkh4b_Cb)*3d!0>MB1Ck*p<) z+crXQRU)>}c&s{A;?mwII*`bns0ymvH~3$uM`9jBQj1=&rq-KWJ^4)`{fF!bvIA&^ zY>)|jk2c}7)RXvvdWRj8SVGOs6SSO{>5N9vL4i3*j^B!NA|*Ad=$#ZlFjdk9W>jgh zH2pG10UdCPhP=!{nQridg_8X&id*%r8d72yWfB$a*Gss6E?sIAPZR8ToDLt#fAN)2 z5h9IUA9No?K5RYG4aX{UZyns5pf~RAS6V4LPea)Lc;(M{8(zSEJT_zA`glR8k_*EG^G+2 zF=!+)9?F}(N8Htv6l?+|L7uu=MLQ3oC%Xjz(jw7_DR;yE07ylo>=-lv6P`t)FC@EV z==JRF!J;zT5PX7<<^M$ljZvkfRx?(*m9r>GZByFw*8`;cGC&6Y<55W&RRsP++w7K~Sf+C(%4ivWS#$gp z5Pf+qzHHzmuK+<>ph%lSV;Ow5;_$*Rb4B)o(_bBvxc_{qNKuQC|5hTIXMM$bTzp>C z$o!VguvbRimrFn!ObH-@f}dO|zIp|oS^qGC7!aHRH^d4THhbD~%WlKVD5;>d-pQBV zQ6jX*!WEStbpIZ543d-4yC<&EIt8QPdPV6 zfh`%bxD>AV{QX)EaV!f|4A3a5wJ%YEohIF2d(0qy7>Q2$7icatE0%H>J!0i=h(*gg zM@XJ3#Rb+-O5?uw2gk*GXPbky3KDDF0Y{!%cv$tnx3{zRoYRu{M89vSyjY9&I!HJi zM3&o&;Qm?f-iA-Mdiyj_!&R20Eg8T2yV#ih-tE$Ncx{)z>;OzNjWV6K-v3?j#X52W=i-wEIWfx% zeqU5|S`NnJ0^HEr{3)4|sZC^LzC%`FD|-D`^Q13$E7qPz+_=By`$nTYT@Rj!eypZm zT_vPFwmDP=Q8TNSCr%nkUWhajhNRM#R~H?f$ItIh4aQCj-N+j>j8}SdTE6Ipzu9%?0tMth?`q$0urdU!OcDe<39jtTpP!8=488N6;g%j4n4} zJW5S%i6=cXp*DJCew!f6)*jM6+={(;#_Gs{Fv9ASO+p?U$ZO;m!$FIb*Y>6AorozP z@%;AUM*f4Eg9$6CRXsiY+$KBpvz9W?D*aJM>W7Jo3|eG`4i%03S-E!R&^MvbON?lW zK!i7J5Uk8vDO>9#Ysm10vCP91AX>Rx&mM?TWYv!}nh)!2YzYDAPRU(7&t8; zfCc5pNOi{_Jic^}B1|rP(0h!JgwNb6+4J|oTLKqn$#mXN=${vVh8=!Qhb0Xx0-CVj zTA#_1`FT6e!2XA)cg~D~)lE1j7WatL`8T5`nvP6-0E47RT{_ab)(PaH-uK;7PSii0 zoi~hW!Dz0b9k5p59=>3i4$CIv*SW+T(}i<=ST@_9*FN9-DU~vKRo9&S&5fF)Lwl^shSa5--}^?#_>f@;>Ma%p zoE)3prreVF5K>))f+9d$Uw2?+$p6}T=O_-FQZO=!K2X9Kj*4)yyX*b&yBmi#^ZMfg znHKCES75Xd#&y<}-@p00h}JGks!XVD(ao}OQjx&GaTZS#Opc0e1=cznD)BLDA1O`J zMW`Tk_Apbm(3I4Q#E<2BMid*S2X^HVYGf)BmXE%xKh};m4w@j~Q|2?!0&?~@O>r<& zAdNg4xGrcopcgoTBMMusdWyE*)K!>*tIc(V>W8l56%DByM)jmyiXc`yD<4S&LKCNo zP9JvA@ah@E!MoAJxF17q-hOibHcKT|Ne7kIzB3#MNbSFwT9xGrD_OpUTbY!GE+lsuep* zmd0i8*+(7VP0C<>$!7KYQc2PB-e&4MM)JgvJZv?Ev8^)CG1I8h&KR3OP|5}-w5I>DGNtk5B03I1j|5*LFr7=OKuJ9{@ zm$0@RtCL6=6FzMJmvZEw^Nd)}zwdvRPv3V!iuW*qB`C;ErxIVpWGC28i4D#|533D- z=M_&lfqpqkG*)>pF2CDv)s#Qti(zApvtsi#fnaxwkJTmsxQm|}!HZGD{ zpq(Ox&sgxm3%u0z`hijZ5@67rmaMW<{qMJ|S>c5$VfNE2;@^?vKzDN??1tJa z?0mwo?)Pts$&*Sw_l0zEmOm^8CrVH-d>UQepZe=`+|a9H^&cD|iSa{7Y^>f)yTjMK zg5T^Cb=$qL9Z#%U-6Ap!1}mI!5!pO@V0Rci@Fc8<1cYKi-#CcNQ&Qycbn;OlLVYqT zUNEgP%$j z7DQx&WeW%mcb-6QhAj;$emLSCKxB7lf`wh6r*zA>XbmMi$~B-Y@Ig_8lriPq=?-3c zos_MJw_mMRMr$nKQCx^avl;f__9)yVP2^p{i`Wh(M5}EQQ0q};p_sjddNpri+QeG} zav+T@xqPdC&N?d)I8(-W?rswft?}+<2FFEHyGr_~w0`zwrg*WBqdjAk$)WDLOR?-L7XcGS4g_QPc3Z?De<1fglhxQmcVkOEe0L; zYf8LZs_C;Z83WblxrJ2I=~ESll}DS*D@u0hA;IU&4p+x!t${3;3Y_vnwL^J5N=haa?h0bRS7rx4Nzj5fgK8qd zf=}9o`)NRS?-ADHeaA=TW}-|N}>6X&HIYWJ6NoZ zBglR8P4tT?g$*EvD#aZGLP^982Sxd`ErWN73b4^Z&HnxKXg~F-& zL=TjtQ3NQWBqa7AWc>8wnEck5h#ZIg={3%&#`_=bxW{ zA}*mFSln~!E4;v-TRvfzXJ062qx9XD$R?DO?!IZq7U4zTD%3F#F-UpOk%~evl_M|< z1)KEYx0)&`oFPPyA3eo_ok#I3NEit?zJ|blpE)etzK;+APu483_6u?1{+T}Mx?ZL_ zU}wYu1f>x#sopUv?ij=IX$z zz%j^q5!IPa+o)(S_sjp1V0|^+GOVQ$)JWpWl>nc8Pl7Bt#8^6TKdQeh*s&CsWo5Qin>>@*K^1WY-$xCtxRF&SIQkK%t`>d% zyb|o_`gseUR+}!1tML#nM=F&f1(WvA4&9*5F=ZQ9;H=&*StHPRIWmwVguuG zN)|1qOIHMMQT+~!c!yTNlE_4x*REN}aw}Xb8d{=MRrwKrKXpIn5&r1Md42P)OO-1> zY6>sWY^}LN$d-xP#v)oEa-Q#;=ZV*kuAOZ`^dB*lN>}&-FvwA*7ODe!yZ4Ad;vwsZ zBsIF`TS86L_aIAV1aDK`Qu&DwB+OHdZ(( z1!vA*9qJwjKQ$Z1B2cK)iRYA)jiE3ChSokX65R(*`S!7zamLZjZ%|ptIt8^Z*XN{<=kFQG$_VXtsRk_)>%1FWspWfgh^n%uK~9a zn4pK?HZ{>Mp#bRv0L%&B0cglzkFB&okqQ_XP$WPy%hJ6Rr^P|tP&5CaI-o$+T1a_i zcZ5Xos7D8ZF9Pl-PNonf?ALSH5h#Tw1y=fmg32yyRu+n?&=65Gq|CPRG zxt7%TO2cJ|uV4PK`s>`tZCBuIRG+s%4h~4#I$z9fbXa2=5m@i-!$=QNCW)l|LJdhl zI13J_+(=Dh|IA5t7?|gB_~vA2p*JK@IMB#bb8~-~)uSV7hjUM_ss&hzP;*FSzlS}b z`h1H$6nbwa`b?|F4Ha({Ay}Uzx6B4k9{{s3)Llv8w>`V$x-tU zCR?Z}M+rAKZ1_v*nZnCRK}jm@BXI^InBUVfMWx_&Ac|1{IwiE8er`NUYh^^$-@dYR zQqD9d&KS@KV*N9<9Ncf~&piAJPXG)(HQZXNmGi^yCOrQ%;F<#xjC^uo}tn3=Sr}vAu+Nzu0?%pXl00ZmY%xGR!cLO~C@RI1$77HJLhh*JN`D`3| z1O7b#9maG@f}yFhIJi;V$0m6SP8@mfQJLy?aO77hyb>$=i}zOgh~{)LEh0>)yOgP6 zorJk~KKxF4_)<)N`%TYZ$BcqKX?_gcY2z1`Y7B-#A;{CEcwXz<9Xij%;Wu8U5gLT7 zXrj=&6)i0P%u_0G!yta_8zi@VLz_#$I1l1Z0Be~a2;_u@|A7Wew9bv(haKMi_V(`k z_T2$d+d7H^_Fo?4%JWM22A9 zb7(|Fmw=3WA~V=(GbsAG_d3^TAV!^j-j~F`)0;jtCT`mB9FO`CVXo)YZc=QY~=uURxR9&9&6@ zbDk_AP(n0ge?HVoKAez`h2bfr2_fd}i3XA7EuH2bd(P#IY2nmHVU0g4>w7&mJfYF1 zFuYj&5o)jT6qqo`@w+Aa>87YBk-~!3OdiQ^^f?>H2w@{8e@GRcX^<~$Dw zC!&_ZX#q+!bL{Tyg2zs?#@AMFQ;oaZ{dXtdc1J;nuoX!uA3=uoC^@wv4gO2Oj~C*f zkMAuc^~z}^yE9Dq1z<=WN>^6FRRxJ^+58KBj?Ir?DKJ9+YQr}nYJ0>$gtoYaF-7g%5_NoRd zA4d>J%1-(kMmA3|SC2-H%n5`Fph~< zCJB~_$_)G$raaK221+^4d0A48Dl1Ur6&6;)9}E}`*ApI}zGa)PzpmQ%W6!ssY%m?s zOJGRsLSa0Lj#W(IMa2|>w>1DoykcUtWL!=mE6jnfwl03YcHn7v*DP&agnOaFX%LM6 zuQ-uv4r*>x-o*K$i}?+WVMu`=yQ0d{RJ3($TE^M?ycH&!;p>@&rrxC`TYE+L+fJcs zt|qdnlI?D%U3R&Z)uln$7Dc%5kLoAyeCNE!wgwfVWFxcisI0fw=L(l_FdF|k$Ou+& zB_dR3`_)fd%+Mj!75k2(iX2%Ai`R`f^c!yj4AH*}4RfjFIEn%)o4U>Pzd`*`yQ@aV z#)5*__uU>Ll6MTOuR2$dq5UgC357< zDvuaVR{{7p5Q^&Z-__ZBs`I>f`F0SE)5hX<@S8YDez4h1_|eTx?#*P5@nC#%DT8~w zkj`@FZ%a6TUZ`pugpZx114J`L?H1Ob_VU`YM5s7z%HvBgHTGaA7gG0V;1>pnu6%~F zGJ)e73fy3CCw9p>5660010BR9=)wsehDy#J`mTQE4NkHat1# zeiZ4$`U9Zl2a{o;wqw6Xnuo%n=nQgb#5G^=!-3&axeXa~RyfU# zs`sqEb!gwI;j4?wQp8VWp=dnZIf5LePHZH&Pb-_TcK|}k_#CtrHz2ZJi5&ARXz*AZ za`N#Bi6%WTraS;3PrkQS)jr&~@~AtXNy^4b!1g^?3>*E)6 zfb)jtTtkD!QYOLD6*8E@pIChim*RGseCi0`Ot6atDK-UX_DEn(dt!wCfx0~vWg{k~ z(Y$yC)w}{(wBpt)bFe`_@uh78lflkjRR~hf%?b$cP;|4@fADUtH#oG zIX9Fmb~~1E_2r8rQ<=H*VowZIS%UaoNl@_fjg-y|totQZ?P*N2UsXbx(RGnNr*)1S z+ArA+CfjsFfdC4vQhhAhvHW9pAId^xkzxw*MkX9~x=)|5PV?0@mv0>9Bt;H+xvuyK znOmu^*QFq6FqWV@{%(e?`OmAH;rU{OD4q5R%TS6!Qnh>284U9Yc`R&|$#&62<%+Ah zjOER>%^j(2CVZt1OQmIDhqPtIBc9S~-6KWK^E)S%E$BW^th6F5CBn6kU{S=0v3qWP zeqhkKTd{e*|2C)?W>Fr>=xLciQvEs)-6URjn#EvTSn1>%qXEG1MUkt=i2XhrmA2Z; zeGf79Q|2Y=39Kx>d6bqf>t1>;?rrfIv?cOomi0w(D)8}EccPXVg7LOoCEx?Opa#zD z&CiCUbjF0K5&{7Y|=f8X4`8%$9>E5LlX>9+Xy!{=&W)NITz zvwOdB-rH@h>=~}xgc|BRKw>ZW{0)XoA>ndc>{SC|q; zu)qffM~K*`(VAs%4#2j_3j{EWKX(l!2s$p3HMu*a;Me0`)<(D2Z1VXu-talyXkW1p zvbxP_ijZ>o=yDJCjv|KJi~VqyYME{J+v9$Cd#d4V<{U+~+`gSPr{vLo#{v>T19;;L zV7{WzN}|mxlX1knvX{0i+U!QO;*rkk=bBWA3Xl_i`h?${Uo+AP=jY3QwQti-CDgv# zJ^lHiLWG*M)Zbh=tfjgVKS;T@qYYaLD)2D*5>x8zByQ5C>)y#e9CbK;QhW75!NiQ+ ztm>FZD7`h*An~4}{84Q7m{58ro#Xj+p<_^*7(-o{gv;2=kHi>=9RW3$w;qKSU@}sC z)z|;QlGtnXrS3dLAu=xF4Ya$J8t>*=9gjvN9;T3rv5cW1XMwr6^)jd}bLB$Md}x}v zlw>E!zT6?ExX-PF!1}{=3f4IP0{o0g_OB*($c%6fC8_&5!{(CE!s$skZ!oyCBu+9| z0L0?2Z?))hbU%P6(L?Y*@a=k`CdHROZMw_k^i-@nhhDp>ev)cDLJ9j7P$^cVx3ZTH z2{@Zt570^!C@e4oEvRNx)RWX`?rWH3`<>wqSx&$Vzt(}7e zn8b4S*Yc+D9c{U$s)h!}58HY_azm-{!bn3LKzW2LqoVQte`d%(XtCQTY$`QNL8oDu z>@_w9ow2$5t0`DhpDRo7ku|2yd!2oitlJCE!znjlC%~Aw%m=FmgB*McT&SNWQp0ksY+z64>=US^a z-U0-Ei#|;nqO43MKiHS)%J2qj9KijLqpNUe`un0A9RtP)M|aogmeH+rcb7;vGP=7) zN;gQSUnEC&hajnhgkS;64}YJ(;Jv-O^X`53+;arNI6rsxiqFjk_dlF{zUy^@QKZQd zNxdyN{><^^JNqV0kCfp6`YClA=0#@@IlMKz%YD)C@sIFiALj|)$P2+IEAt9C zLF~yN4ePzBE*@}H(EnqM&)cQ-&7`s~ub#xVUdd;yRO31rN4<_UUFgC$<;~SWg@OsU zKz(B4=PW{G6uxKeaTkYuTjzj<`?{41gWrvKn>$?is;fj;6T-yAzclK|J5P1yf)Xq$ z;+>~Ly2j&c0w=u^rs$&#RD+T~8W_0MahD2mFlHJ$7Zc}@6_Q09Tt^6q^bm`&t{QO2 zxWunlU`{d(a)LD}#BKCQUREXl#P}hMx%J9#g7+Ea*>8buTEt5aW~K&Xfra(uaWv5fH;U%!h-vp8Gh{M*;1m>yy8UsT3QxYcx#7n{wb!!DOK4rPSXbudmi zKAcD|pB`M(3BY`o2~W`>K?jOvJs{eT%h}@=I>37BlLG7zhf>gP#E`37DUfi9l!4~O z{#?75!n9L@9lJBRn$FMcPLzmouDo9BWK3 z^Lj1aWybn9??=SDF*kLFyT7Lr*(|RsEWCm`&C<&c9j9x8Y70a1TBX{A=Q7-HYi(Ah ze7S->+xTKxBg@$9%sjsTWNe;r`N$c%{PHJ>e{qiD2?^W?Sqwz8DoDtShJ8CW9r4^dNV?vbE5Z=GD zn1#KERP+y#q-lB@2^bq*Y`s5TnmFKYTtElc%Z+D013b$BjwipYGF@eEI33z=V0M11 z$iBIRL5fC`iiRPC;mZB0CSB-KXx~aUToke0?QG_#?^o;8jGh9Q*L+0fBH;1v@4MhPv{#Rot?U-} zO)f*Kzia9jiU_Bs_D&TJZpfW!&a*ueVsW|OGAC!ZTN354dL_MZ*zxZS_FCQk>wH*m zE)nSUgku3CtD+&~QM8j3X(Cu-Rz5$*cCjvvuqcw7hA12v4!28jG>GkJ`{l%TaZGvG z5fsC`8ye7v85KG55AU3!HglWuaT2noD*JP@8tHpa2Zj0ER3W{_k!&sr`~=z!#yZD_ zx(lCOg0l3n+ml6l?;+7P4egEkv*NhSUP!ORw1CEu!=yu-NGs+u2KpYDbTOT=g|mC3 zb~_7plDN~8A5p>6$X&*+Q^$`ihfVhw8qTEMOq`Ke5ie~61r}v=B1Kk3dUNc$RcYNs zdg+J#qPV+l=wp+@{gDxCXU;P9jY9&2^@5LP(wR~v=kzcc;J zzYK|@T_gzbf0156mlTI+$7%!(mpJ1p{4HUKXk-ttVu*X5Ml2eF!q+sO;M_2TouxIA zo&8q$t=M*CIc&Kvt42GvftGd7P9F}Oa;l!fK|i;z^=Hdaf9uC5K&w-DoyAto+tAZk z>rTcY`i61uV%73obCa}qe!iY~DtPqC$4{biX86gXzN2?UoO?&VQDkjPgh$?~)WNM% z*YWBFFRhd^&Tf7|f??RXa~&7DEZR``rpfz~IwdfB321y&6=cOx-tG;Sg|&2?kBnIH zPDlL1@Ke66m;l=2v`|_Ya8>D1r}lj@ACPT}&Ei^kyz)Y-RhBPy?>JK0ey1o)GT;B= z=%u(pdyqTTTbH+_{u^dG#!g&V-OS`_qx;%EPx^SJ^v`IMu?CO~vbAY+akh8-3_%I~ zfo#}#c!*GH9zxd|D~Au13d7j=r1g@L-GBXJPNb&;Q4|_2E8q{Voqw8+|1iF-boh{y z!C~eRE9@0e|Ai3Zs0@^+bTcc! zGpfl=g3xe)J<==rA^98On8ltE+%jHZ;+ixK^Ad1=G@v;^4uOi$L6;R)FZ3|Olii3Q zPA<_eO6|E)=WxVg7hG<@tu8Bf)o)~A;HMU)gk*XMa@c*rN2f(XC^`+Wr%4C)hJrVD zEqY_hJB7aI`%cz2toE6PxiuAMHH|vF^k|@?lTS1KmezQZ@3wf)&t~Wtc5^TJyrw^= z;Ml}B_h>XKVsw*^w|TJG_51x;K+V#wq3`K;nnw0)SI1jRNq-CpX(0r;fVlr*JKJT^ z!6Sgszx47WI@z1MTKm>7|Ke~+Z6vyqG9INv@xyA2lAPnQdZEm8K*?=`QUH&`sEJ2 znQ|-qctwK^AFggM^u!OBCO&F5bMI>(3Nv*-(=OQV5?-WR*wB%&rshn6*%o72=@L$Fy@o0qF3sL*4xU*5@P9a;%+kM=^gE!mMPcHl*MkI!?=yKqwLmfF4L zrXgKX77EdK5x1#3P-Ec^dq$F6KSH&zN((K~i*&f@T+2?`n@#(u+F~`>G)A%>0{57?_ zL?K*HteV1ceW#kNo|0aCzpkE985)`JNkwfwhB%>QLEAl}i-Y&`(J^{tDZ*f8Td(W`N=N;hQ3quq$fkn6{fokLVS|{RPmy$RJtFRXRU8pY6B6} zvF?V21SW(}BuFd+@k|ns2$)o?fa9A7%SI~Z5UeKxfaNJn;r7Mmlg+Z)4TMl5P80&= zmc;@0$N^NK^dr0NC^7&B#?h!Oepx>A5a(wzN?<<2tKE)R)_?u(>ZRXAts%nTas+Q}nZXmA%2{zLyVb7PqbH5OF)vtR`W?aqHio^6* z`y2+BRX1F%G`q1}*2u_7^NphiH{2ljce~>TAj?}IbmZseXzY>+VDk#qz zLw`veM^4SjkPr2mB7aBo_na{~IMRXZ=nJj$HK9|-;STIk5X0>y!&{F29wL5@B5JH6vZ zsqW%f>E})ALp5Q?)+&tJymo_~pPT!kS?NRSJSCB2{slGZ7NZ8)#v0#lMrkMA>a;?7 z#Jh&l=;;yAlDrbnSDF3o;nU&q`X+22Lz*gaWX0kbhXiA^<|10DyTt{I!=OkXm4xtQ zeRouI>Z842y#6vSfmN?sGOo%2Ou&amUTge|-Rv?YG$+p<7n^`Y-@i-2xvC&Mg5IIA z0_hn*)5B^F7MtCO^y7MlYs3A2Tz%rMR6So*))awf5a)P9Ps6g8cZ%zmdhAq1o`<^s zI-YT6UHo%IX#{| zzkhrzI*>%QH0ogF&Az^Fbhvhw_El|ZSfbqRX9`n?PmmQ=0sshv7e3&cSrHAP3J2FX z*(8H9lxSnlE8Z*VM`BTW08q<--;Nr_OltPU&U{HB zBS!mtT+x8P6>s1mlare`ADpDBnxa~S+wP2D>qb_Al!PxF-V*4aWAeQiW=lEpI{y5Nu4ELAe!ewn>tydTbp{hi(ok~8N){Ntf5 z=u1IT!ipy7-R4lt=XsA2#`iI&>K{4u)w5D&gj`Z2S;J~8x06l7tOan$Q;J<=&>Z{t z)hX1=gcAIjb9;T#jeUQ)U{oa=XIUl=d^iVuUy~;r2YqrbHLYJ?u75o7j^*u;Qv%$itv8Fyb&7KAhMEveF_Xe^#oC4#$#@qjAK6VwB8dR2XxDej-H(8KTmV zNMCH7s2D&n+=wN{Nv-J*lqYg8;+G1N6JQ)3z?A}^uEn@KJq@Gn_Jh`0Nnl_R0mQ7- z@JHd4Fh<}6x)Gcz;#~%C&SS$g{Q{VHWrP&N2PW{pfX>G5#nvPvl)NxzlDVy+#s7Q^cp?A-T|@^TklSNPdJ0OE@5EsD^x+cu)_4mPCYL z9%N^v`;Bt08P+jU#Uyb2hV6HF-7~p3d3gw)J@P!Uu%sNyCQD)i1VC|FF#sx7P|Rce z@O3K_f*6W$R%>pOW1LbOTwp8U$;pQR9nx$9qe8#X{j4(@Z~mT7{i7=Dp~se4?tq8% zbiHGp{*QyBqO`d~D&(QRsVH=QQmrLE(K5q2))P z2e1U%u?REYuGhPeIGEr|*t==e$QRCA2_@ms(|-^8xYLPokE-D?J(A>eHOOlI!t zu`Xfp2fY7CJdg8&S1mgNe(N+K|NXy)_1(yG!V1SR&|YnLxs1M2*lG&~%{;$*GHM=W z{> zM>)mD+w8uKE(U$XfS{u7ay^&VNuDNinW>-)k32Z$D637&Zn|mY!|9&QVO#1bi&U>D z%YTKP%fml4Hv|8uANcJC1x#Cr9PL4m01*IzPXNH6!cvn2NGeK%Bx~62pf_COc3xoo zeAN*)#L)~IlpSoJK!N0tc1m{M;5Y|NT_jSq42`me;Y=tG0!)q1Ux$N2_G10>iPh#chYQSi+(7=#|mli8bOG+gT&dHshdc~iLII|E?6r`oD!WZW`~-0fxIIzWytk>o|;f7{?Fp)(ykb?+3RY_uhVWmrqz z4M}pBs?CZ1H*17qbeR4buVXU?4Z99^=^k$@BiB|VOT!40NA|!E&h97(ZOG?S^S6j{ zuw1$F$+y+#byi&pA#4vmIAe{`tpCk?$Gb=dc8pj&y2 zuxGC%^(>wA%=+=4OOUGO_O636RlKe?=8*coo8bu75zFI zCERo*&9TNl{Q2-G^GJC1G?i%I;_mx+@zat-TXI#+VpU*owzIi#v!5^j@z<{&SL?CH zS)Y!M)1qy+6Gba6c%vr*j5@wUF6S~+I(Sqr*`cFFO}4i?`Bl|1;gBf zjy~aL(|&gRU2uc9i^Emap1%NwMoOyS?WV>!AU^l+t~}?snJs<0ZUj`^5)t5Ooh&a! zPmv^2#&uyyb!4lNK36jSSwJ_UNhDQtlEC|u0)4FtL$5{WPF^PI$Fv?ZHGkV;#X`}- zJNf66cbqtpUNu^FS{iPNEUq+9DqaTX_L_V&RQ#lz!nB;sY0M%LvFvA{A9NA`CR?@j zqunBUzij^Ifw=j}dWME+Jv;6Qoqiz7TtzERa14t(YvTe}*Ydk$4nd-ENHFi~nnFT` z_~d_UKI1yOh<>w?Dh@^cglH_KVwRVWIJ?h~)f~&*E&4)O6YMe56{TjFs{-5pYthVtLJ8hDDZi1a#Y2g3aQGSr1#67sjm zpR$st@0Wkak4CyR8m9NPLX?S=w5wdi?Afa4E}HLrN-Iuf9yX_B)qiTgg|)K@e2$6| zT9MMuRi#*oDQsgb*U5Ayz+Mm7h_-m*Sv~2uaC7x1SS;ki>N1%GT3eMQ;;GZX?Nhq0 zj9eyoeKvW8PDd%rt9Rj7+^2>K7iwhU%|ccq+VsgzSB*8JP=O>d8NvxN=_uL-W^ciE zV)}v3*{oawNqi}BiK*RIet{O3QYSH|{o3K0<; zj(f9dLHPAv_*cHr z)!-=l212sQu(XsW_)%}i%$(^}Xryx2^13Y;eNFWZT1$w)Hol!(^Lwv9(zopUQQ!Tt zB-M%IxlqU^%Z-1Kb01rof0ZMjQ|$O)cIBbv%=+G!e4Q_R6l%5?rdi&r2k%m}$^9UR|O$S5_r>P{h z|6f=3f~$}`zIY;CY<$RwHHkT#eW2S&gA=Yd;Dj?xhh{bOg)WJFKxIBc!$W1iWvHZz zciDBnx*Nl%4skflcrKvgU8yDwN6+396|!|{5S=LRipSjIA^2K~J+4-6O{VKIH5`i*>-#zHls@hJ8|^QB%>pD+SH^ zMISFEn(OPk$A}GPv7b@skpO(B(u=zXEfIL1CDKdnVcWD>Ufau3)xp=wiQFqk;v^Ly zkl?>x>?Oz@Vv5qnluu^xlgwuE=&4ibd!#JLtx%A7w~Wg#3y6p049Uj13E3fl`%D00 zeoqe^;s`7(Fy_$oy0MsiAe;er8DZ_gEsXVETM-dUo9N*KXs2MBTw_Vp^nKr= zPofdUVXh6P2eZ%jqk$ta05Di5!s-*tuz}6KPVO{4$c{#r5D$Zp2t(99qw0JeOl>VI z3&hv!a}Ro*Ycv{> zWTt+Sv*jbdc&`G-0Y=M7{SvDFp=DC5oBSq^Mpl*xV%oT2Q8dWjBrz-5 zQ)s7(MPVyW2?FG^8-DmD24liD^N793yQ;7|!^j1B^u3CCNz}oj7x7WC?2z{}#c}pX zn-=RkDuj;btiIh_NA0|y`x#yxp#y1$p<~Bab5yt4HW|M|cwok8Xu5DG?PwfJ`UF;D z+<||2%cV62dALl$JPzDh0rOP-YR;epb=+xo?9E!1Aw{}?4@(FPPMWlXR2wMs}}=vt-HRz=8O(8q;EFD#2aBQ{9fFj+Zxx%ih`G+4#$^Yo42P zL3dj{kV@3@y`TN!hX4Ge8%IwLHM7bDB?!kZgFI_4S15e?ocN2wPM|82o7h#UFk$k- zex0qE^9^=*f%2HMLx)43$x#(09!JX=`6sKMCfU9CQic}46z~$P8$3RwO>c}CVk|+& zM4OhP1>-5Av4bEQgtO~dYQnj8V$|mh8IpL&IJ7?j=_(&^g3YK7+!Q?t|qb*e2F%}G!<@aC4v|Aof1SJ zfku!2I`!s!i5AtzYxDA6z1l}zq4DzG8Kc|X295Z9Dlu?6Px#GrCPhB`5Nj;jzq|$I zv`y|xQMf>_Hyy9;DsSjE+&tB~C)1$v{sm;Y*x^#WaGVlNPN0pDnjL2VuZ`maZ?46E zXxpB-y1fZAt-Z}v0=`c6#G;#GYZQSF#vJz3S`nN0RG7?wA`wQpy@8Kt6Ol+{=O;TQ zgDzso(&5 ztMq;N{#ETs4+R#GzF7#js3hqF7QiUU%&13kS|w`E@GHFMW)ChqOqyX`@&y1ng3I@# zakk?|K=AmOvTmd}Y{73xaVVii*2PdUU}(MVHw4ai!4%UTs<##zUl{FLz7)b;RkK?u z)+#!$VmzAi(a*1*r^RM(_)*6*)ruVr+ix4-Of3*QdyTJ9iKcWnb4%LPf4E$BFgWXe zUp~8y8*5LzX8z<;Q!_GZYet?*Rj>f~yobtF+{oK@8MoKw_?3H!=WKB=H@7$L`6HtE z7@CiH*bag!_*xIAAe$Lwlff0tc^RMV8t=N68U#b)o#WE+v(tDD8$6t%h@G9Ri*>3J zw}|?`{FXRv5Qr(`k#{K!AdZIGQBn^l5MEMP#|SallhKO>M)MLX*%RZl69e!g1jZ=w z_LS|mIRR!BtN?&kK=*8#*t$)22mnkyql*KE>f$f+lKF8!j#|-{G7wrqtl$U$zQCs% zSu{CB5LeFv1b z2pq@5^um*HI5CX2)YLoap&=s}R2W%TID4PO=6_O{6)roY^zH-%U5=i|PzvTbI(0%> zIhF_Gu1@ekf4SFxKG|uFgwj&mwf$N;qx3>mWeSeX+FEqSI~lZdsRo^Cm%mhf^Z4J- znJcPejb)c2)Bl2vf?(C@;v&pu<3wfG&6(#@g?>li2vOq2dosFarDNrrbm3h70v;q|^k)-mS zj7+yz5nsg;->A{DS9_+sV|Jb5q1J81q!B4~a$U$Oqnf_O&FNQADiRdykwDwF5e*!Y z3#oxu$0SJeDLe zjG+>RN_kElGqL$f50mi40=)-BEw1Q&!_H3sWYM2%1dZZ&MBJ}&^?$kFH17!c2Wpk2 z@fXmpKu$EQfvGTRohYq0046el>cr@Kced z^R6c2GZdC)YP>4{RM`Wyd&_e5)$4S;yyf)6m3&~vABj-1BC zUt&>G*9&%^0~+m<^0FeG_H-#7s+vo*SFWI*Za}aqr#Qu zEmJvaJ(xK!w3=@r^zn?P>@&L@#i$Fa$xCgsMakwnB!#FyzP+P%-`{T?PleMApAe2I$B@ zFW_&e<}toaSg)eNzkSUM+M-Yv!^FAdEj`` zvHIPwZU@s#hf}khXs;+7o_JYBlbYNU$NyJ;e{`JFb%?|j-K$Nyk6Cm~%T$jkHwus) zB81)gsDy$ezK@d$tGmzVM74A?B$OQUeL8$p?63{G&I`NvS!npLVR}_U3PRXAabDmY@tq$5Y6f1$(giM? z%~AK8OP1;(1Gl$p$Jl**Xg}y|en>g;J-fKXGZ1{xRx~f0|7DAZh^)C4q)@&utbbbe z+sAK6);rn!^mi+d%g3mBa9z7vqfPmxpT zpPW`U*KgRqjuc|rNUY-8&+V^{dlp9f;Txv@?{edB>-={K{ieq!A&Khpp_`4R1qw>?! z(%LiE{eAV!jGnx>l)Xx=Pfe`!8F-B8T8(q|E48YA)_WKz1w2KLO={2q^@&|b7DD?K zK}3tSE_P1WQ~Vf>EOHj?)X`3XzW{5JE)x)Ndwlhn%5%Nt%$fix(x-<23QBCwoXqxe z4-&BuW3I>R&U7JPn&sV@J}3hgoc@ErjlVYyulC9`gJK~KrfjpGo7OS^MT&{S4pne} zpwIILdto!PkIy`MJXXH%{CS+@dZ7E5ZO~z9`SbBHbd0Kx(&$?5g%KsU3V!_X@L)-Y zlvp2CD8ooT_d%TvcxrR^AfJT11Ehd{MiIskaVysI)a!rZ((o*kH(5-9mpl1!h$lW* z+FyKcy1{eJJBO4p5kxDr7N?Wwmx?nv^>CcM&Px(C)l%>!5wz|#Ao;p;z8v39tHQu; zD_plS)0SH&iv9dJRWInv_rQTum@NL61OjzO^E(wy5)HZaTJ4rD0Rz8ZVG{FYHz)(-nI-iU?)s|3=g;Y-S6rDb1LMvl3|W%Jgdra?eD)FU{68_tJPx2 z;!oax*B5`Ksh>0F3J2zgo$@1|#BC52J1Ow&VY9#H&AgfvT^*eN>5_BSxyc6;#5jz9 zGfq*NZZQ1M_31suDE%F-!~t$GHI7cRa&qn*li;7n$NL)*`9SiQuTyz=g3bm5iO%#s z1O`aXlp3JhThwpRkFF31Z@75F{CtYI83@+OEO~5fx>TgIp8BDjmG+m#?mkz-Oy=4R zl-pR-dt?2~YegI6+g%b=$wO`uT4-|l{?Y_FO^aRjF3(ww3BFL)`ID1jqB%5;AN18U z)?fWc%4e&uk9k6~<;Z%~*F1CXbvnkA`SCu)|M<<_+1ityL%_&q32O2-+9wZC#TBLC z#Q0j*eelV8nmC1!88=-dT1x*L17W?Q=;6p^15u~%-_DdrBP}z%a8PE>7^*|OR$P`* zFEDuUBpH}E9zKV;R}W*N4)-#7bbjf2hrRKgasFBv1;zuW4bPbUfkL*V%)u&5$fu!k z3$Fn;1sZq)P^0ttb4p+-Mw9?w2_#`~@ZeNclW%OLvpFEwi^;4C+SG99X_3G=QcB#i%wtjZqIv1?%exb!5$5s97VU~9%f_#n_xJ?x zCTnm&7t%2y><$k;RF;Vm1=3R!j{LQs7)kIip1%r95kAe;=25_ASdE456@J)>p@V?3 z=aH((nFlIy3m^v^bXPZ8mE8uF%#ErKG+Lr>_@I6WS?WGINC*EninB)4*A3kidG?oH zB>s)N#vtj{dh#GF*HoLW7vDuG8@_hAW_P_b=>9R&*`~K+p!?mI|7J30NMgQyKeKIV z&Fh#e!ER-2n*^}`E28AL58If^YGHS%ZI=6TkcB3sI0!rH>U&!j-yc&U^WqC(IvVLa zn-1UKRU>+%YZ#;eq*9>QS*00r!e^%bS%~g`IWDR0tKR?4^gyzZupD}fG(HntojVhr zG0!#lwgQY0Qt9~)%GLgS;bEun)V7I4m5M@dRjfZz6v$%z9wr#%=zSatd; z%%qx5D`+?r1ve4)i=759At?uuLtK1mR$st|nkTG&lLRI%W4ft~F2}8uzPnnR7nQdM zkFEJ{wrd2%ty0xDF8`=+L6rqa8(E2`?x zkSfTRd*=#0iU-PJ?|43R80Lp<@Ju>=W)6j@k*`BE`u-G{Zrom>UJ1JQldez5kpX}S z@sD!qyR1H7h~Js*0-=yqGqdBMril&izcW4HUnWd=*q?r&5LbqQ)j}D#(FnUeI+$b~;pTy9JbmZexk!jep^hrjjJp^!iW5RA%4No)fLpu#`j6;%q5xg7 z8@3fa{_v)M+}pGM?ivNgClV8D6bI>a3@-{dCkgxAviYU=#}f$I3AY#N3iBWf4wGtl z&#&8rev7cj*bg~-E|rHy*9j4&0~%s_Pl1hhGb*o*M_OYL4G7owdM*En?wR(oEW;9O z4A$(2cmx`5W!>b!K->zc%`4(}-UdCqEiO=w*8G3>#W63_PK3#(HYl9&RRpShN|^bu z?C^Mk)RTG`1>0k+QOz@B&-#YxD`N8PPX%roA5s73x7#9u<_mgCj8Ur13_d#Z ze)ew`ZyCOfmSDE&@u#^0wL^;Bn(I7OV!yO(Q5AK`7-bEVhfD?Ks2ydW8Ba;IJ#fD6 zvj6eU|M|rZF&SZZtljTSIisQHoTIG2rv<7{S!#GyN<~>lbIsAOgJg&8@X`8Kd#GLc)a-xv)cCm@N%{*SXVskeD2L_J9GCR5JP8a zLmB(_+2)4~{2Q&&xFP~-_Xr$CjBIT(UcK6j$@DxwMwcg&9=o*x*N>s!B9qj|4G#Hz zbe-@KgWk;6d9$B;T0&wGV$w+tkiGx(>3&T$_s}BoG>PU{{OJ7rnqB`=`_rz~wZ?F{ z_1wi6qXD7y&8=&Fm!UaA-ER~LpNn_=qU8t2Fkk=seGU{GAbH!vum=%3q2{y3UH^|i z{3WS$+xEn@sM4Of&arnPS_RlG(ps8`xY z#zv+jBg;IwfX3 zO4%*0&c7W-RiAm6t;{Br2r1SbRa*aMXT*R{d^yyVW{In+uD6l!5J!h0TQynV&u9IV z@CR4OhWx{^T{t&+RX`9$g_?-jnH}-`c9w}Gccz9y=5pLS-=>_`mDVhf))Q5us4y}t z*JGLjlO>MqEwwu9QP@~Cky)ts@#paW ztv~tDcsG>4?1-FC*M>0Rr8u#3O0JVpZr2~(-Q5&$2ucmV&Z#aeRB2_O3!G@vE_|ow zNpDrAsy>CvPkrUkbc5Wy8v|;UBBBLdjJrgVzD5I$#=gV#k&h2~{WR>a&OHlWPvkcO zUn!KnD<+%AC*L=q;Y5rTGtxa$CvzHW3We|s(5P*UQ|f(s4zLNa<9zdqk^mkr{}0Dz zIDO@j=G<7E^5b-%HRwEWdSdp2=QvW4RRktp^0R3h8DwFXQx>Hz0 z`ouH+yH62x_coZV3Jk9#(fopt!~GQr4g-bT(-Wyc3urZ^$P4w(BwNs^HaQtK1i^wt zFCUxbFfg!WXo@E?HLACNJofF*TsGf-6bph!-SEG)9o{%8O#b2J2mr6wuvNW?Bk;@6 zLotbIjg&+tVxi$j!5p$$s+$gjG}wBLeUXDCFV*|kjJRA0402^U>m)QFybmK*)qm5k z5Aj3);n)TD19|70Aq_L^9Kk+ThqD&G0>?iBE1dPBb|jT1UqZ9=|5kZyNLs`ks^b#% z;So;R&IqQU-FTK(yihT4rIL@-cF;^7cez4$a3vt%QeXh6XNDjGt6#5HFjPN`PBw&6K0P;Z& z4r+?8qp(#olHOnhHVi4kQ7CAV4K{G2sw@V|Zo97#UG*lx8V!IJ`BD}psGWswhQ<#s z=im`sx{M2zG`NT0bbLiJI3QP-ytcSE+fS3~nEe*UH`UkYpWO+`=lm7vbSQQuJ&zfGj>~ zvN130Ms27TMO;e4?IIE@j2GmSvQC^v#?v?%{>W_jGuvW}+)dk-PoJh786702tFIGP z->2fKCy?*>z2Z#DE~{SJRc+;5h|BW91vvnjp{!TnW`_!;EQmMwTQBJ zk?EtUBth8T{==UQu@J&O${`4SgyeO;+ZIUGs@yo_96QUWeVt8Mx5@q^4;zeU)Mkl@8ctVPff$p%U>>Ot zBzb1g^gPeyxwEh9RA=#2oBed-L9r?KpmWr1mPx%<@5%=MVJ#V*kr!TZ)OcM0%k#YN zwH9_y(i%5B&x7o;zVGH3j20Mbq^~fKY^^?qBLMrTiJ>nwzqM37{P z;-xr^96$NTwLt^(q((IpN2{%Fir;J2hP&c@SqY|7=qTW6;3iM83SDKXaD@i?H`;H0J6+3l#hX^!)Mu@&DSVNFS!fR;;)RrQro|{H+M1t9 zy|?pewD~ewYjz*%MQO+&9eK3pSrf|{iD?y25LzUf|JO4=fh8Whn`>ABG12+!Ztn8-sk0S%oMQm=)K<&pN|pnM^NSJy3weZ4 zt42ZMr^$*20s$rcRHe5HeJ`WATvStRx6xTyOX95CIhFW_-+!H~S70;Jxy<+3i^Lm< zgzci=kd+w1Hn5qfO3HfNIUfv~=W>VPZy0`743FT0WFtKWj9?mU2(2MAaSg^n`~Oe} zKV7VM6JE>>?fhJ1v{eCOVKA!LB1BQNgCkZCbjPrN_%jnWP44hmd$XyvSVjK8(wnouc(OC~!>7dKg>)GL8Sb>jVFpxax z%Kf*!ke4N2uEQTCI%O}Q&+grQUWcrM%Q>&P#ch9jHdGcXFAmy+{ET1-9}t2@U^^$m zpCL8E2jd3_kO)len3_7IC>BLjwx~Q94>!3~wo4kEZsE5g#O# zpu(W`3{P*!BkLueDodkvgL$_0o)&l$Ul$BXK^p~BO>S3Y@mR<@v1=(S-y)~r@=3#{ z2Xaw9!YMm38Q`Dlr@$<_`74naL?n|9Q1a{|B49AkPvG zIy0gs|5Pa_lvu!3mPed6MOyy(v} zvpFM9vFW}x3^+L$hOIeu1|R>ndmlPGr}12$cbo5T{o{|vd*v!YtXHO&Lcc@F_8X5X zA|~F4YK+mQ(JJ@*#(B0hJpB3m`1fy{$w-IGNy4s{|GV$+zV|qPXc+AZmlUzuXu596 z-w6vn(=kwg@(>9aX%;BB`{ykPd{<}FJIeEZ%Lo;=9qf>clKI_uUHI9vE}Kdi0l@p9 zj{0o|UiWe(AR{A;%v3UF7_Sq=#~!2J zLbuk{Ob$NG&(I_zgSx~rwl692xNC}zKCHbN6>Sp?o~-<~`*=P*z1uD?TvS`@Gv;|V zS@*6yjJb17O8IIn*xQiv{kxCbb&vP1j}Mk#-KIzJ^5ah4aC9g9@pQsL7!#n%O41Zm zgh5gC9+(aMueNn1M=<#1LB|(}9H}-9Q50h23bs~Y3IhQJHj%(6>qxLjrL`k1VW-Z& zd~BITu(j3lB^>w{Is4?tsUnFKN89f+aVs+AlvLaPf|Lk1fJVFK#lheD!%O|B<6 z&LpQ;_U$!;&N~}z(E(W47^>m_=bQw!lgB*aAt{U)7lS{tSw05yI%1pZ~g8tAGPpR-9N$nDpaTC_u@RyXZ4CEwJDJslW8- z35WnvtPhV;?6qtM^*Y5Cj}o_ZkN^F+rSi>IX-(Djgv9_bqHr_-2s1Ps2W)_YRUs}G zNkkAPP9}M?9_%T_E-{0d07_)OL`6n7{x(MQ_sRDZoweZjpZ5b?wJ(Ahj{}O0U9~R zz^_J4IULf71xt-11jceM83;u(e}MU4!lGu3K}cwkWt(c=^7gg@xWtXyNnn$*xv(aa zybvWY7+I#`Ir`y6_O>lfJeQM%tajho#-8%65-pybOcO}W$e@F}&cbt^M}#-(r}TWZ z{LD9&el4jbY5Q@dgL=+7S{{mDc<4+}n?b2{yTK%tH zKGwV?H?Yhb0KD1!Q?tCxm>}afTV}fb{8#JGkG~TSu(wzE*lHW5su+lQPVsW#(y7L| zvOV6tjX^t&j#;g`OtHt8M=$w`?)8f%O#(s7IuXVAF_nvx&%Vh9#^QMR5-61oS|u@a zlf{YC<8py0NnkZ@zsA8f8_Z-9kZBb{VF~DDEKVO7Mh_k=)itODc{$joxJ8j-0zJ^p zVz8{3pFCn6um{SHCd-NPmnZ>_T0Q^}E*vz7gE>tCe?L%xRdsZ^ZRwhwEC2fY)c`~- z0zrodoIs@Ex)_hOQss!hZ173LQUi!{P6H8ogHe)<*TI)P*OwFY zfDMoWp9|o!oFF}W-ur0oA!-bvZjss%0TM_PwD1TI?1V`EaIks`U?bd!0LlZ@Mj(ix zY~eY7IsWW)UQT^&>!P)|5LEuSk-J=dZb1=Vc3CGbmdYOyM*t;7Fd)9r;-ewxX;l$` z04xY^Q!0R+cm^=ozFTsbOM?vS?GBu-l!9;<-2?wSj4d2O(WKLKe zY?lRFVm;4e%3TS&NVZYkz0<`+%ZXq)@72Y2QfyMQ$VH@3KJfz zNp3Mh;~;Giattwrv5}-d`Y`#56x2RVy1emRZ6MwETUO9IO?t)wym?(0xTEH#&g}F_ zP)TUE9*vTLbrD#)OkqM@-Em&*O2D7ptHt$)ffyrvKb47CagG$S)2f;zblTGb-lC1J zS%B@NU0(q+!b#s3z3bh0REb_mYO<1x-N~8V06NSFVem#6n#?XwWH&f7`bn zt@_$j{|nxw@xfT}hsKU9-Zvs|Lvh`RG^Q_Z7wt2%)B1h=sY)Z(r}W8al&;Tu+P8Gi zJ}OoaeTlsIWTz+`_Mi|v8mCCzqHWI%FMY=Z4o_(kX8QHi<-6Odd#zCPSxP%qOq{5t zS%2tnwrQgqm%ctV-h42+oNT)X7cn8F$A-$l{h92RY#kg2%?tL-1NhRn1mLAMDHfil zmr@8PhN=8{Zv3QS`_ymZWqH~FyIQSUw>jtzgqTRNi{j~M8WGy}nnsKWLE@q7KDr3)wKhz(T zB#q=DmvoPb!oiITqffeY02$R;T7yEiQcJd^0THDagk zc>KTmGq*u^KGOujN6j^#u8mM~xk4|VfgKXs5-opl)*DV%wb6r4?Q%hggDbt2rkR8k zi>f4iCP5Vwo#M+|>CwWZEEDFIW7|P{*xHnykJhwIg7Wd$I|R(mX-`L}AEX7&b+ScJacAt<#+ znc`1Lh|@R<+kO=|S3M38th7-g%Q`s5ps&@3YEQLXN7`(bv@AX@e4%urqbHuQVV^X3 z`doB+sqN~napza{M^=rNg3Pl2(z5Q&k?Fb3=bfrVS5`9YHEpF)`^3x=%>uAa23eE{ z&i9^L48zi=SkDK}qxtuO87;~dhY8B^SXP|@1}~K-E-CJ1G;k|KeRfTAocI4>-sbb8 zihR^p|M5q7+}zCY^+oniaWxzM+olO^Srs_-@NN`cr&2aX8QQNPYzdJsvbiL{<)@=# z-iMFahwYcDdVM8-_x@Vv*Y6sYy|ew!9%e;1EsQZbTx8{bEL8{-59F9#yBTW=+rFH(J>{gO5&#|~5zb}i?|RR%E%-c@LLOJf0Vxsxg(Q6atlhoN z_VYtTdsx_Wh$>OfJ8;O}_cJufQiTzU#sVX>s$Nr6fUbF0daKAwz8GwN$Z-G~I|WFYI^$HB4R+O(7>nZ1d85LMt=-g zB~y+2$b`K0^H{t~s0K>w=gmDOnL1iKI2nY7tEVr4yJI)xhwRELl!>IC)Nq)h!kx)N zT$EQ47d5Uz3*j$>}enq!2*h+yCRD&xUDIX#$qo?*Mwm z>FKXxXlc*okU>p{EWkL7WXgCz9C0=dDrYu(TCNW>5;IbHN0$O4BViQuRZfCoiXg63 zQwB1QkpgjZa>B#11q=%Ts7N?);E-m(Y#a@ZS4|v>=0{{oCf$$b?8LE2yKF<<0bnGw zii`joh(S%$o5YQl=uIe+a2rw5eA}S-726AkM7t>}OBMl0Lubfm)6-@E^J!>zAohR+ z6I~M`fUN_o2CK$Cmx#TBL%5ku?BH}snKh)eDmjH1y)|1vNhgGi!p%2XO-iPlgY+R^ z2+`~QxWrG0V`^ej>1C1tV}={WwcRn$_I2{V4^TI8cPO8W3DI%A3KuaG{k?H zjeTc*zm6z$T8xdgJLG7NRid2VO0@;cQ$9H9`0OvOv#JJrn=3IP?+lgZUvbIs!nfh+ z!vo4GZ{CrMy0m|KdNRLs;A|B_WRzbp6{%$L?jk3>gnSueBr$;MD2UmqW!%Os!A4)oMLr>Z~UR`#;1WQf4#6!4CyNc9^hcVm0&i`>pC& zCI{b!zwbAxXZ6c1);yjx;9{V$fnR3|b=$f-M6Ybt<~OW&)uf){Vz^NcYYF70zv$IH zvKt%Q-Y$$?!euVkc?T51V1c9{Sw{D!nVHC9o^boT?^o@eadF2)ZW9!`WG}nQu;(eV zgBTD2%<;(W-U)tXyk^XUqbm0BCmd=fE?lx47qM;mtKD>3G);4LTF_-^_428!<3%3wN>{3!+@DS$cV|m4s#O<6naS?#Xu(n)E zD&|xapD_>$xA3vm8ON~U0Ju|caA0#X4XPus>m_WX;_o1Ar- zb&QXGTMqmq{VpEfaV)GTxQ;;32l2yNyB)>eO=vxXqk4j3K&PAgQ$7ib3T{B2pmWWTrVaN(XDH zqn!vF$2G{2NGZE&Qg1X%8%995ul!nH_t)^&V_P-=OT}*r=4rj^AZBB5&8iptI#h~e zV)PA!;LU1Cfv<^osxXo6zg%W7LX*noUyeWcc1Pfco9pxsMo5vUJvc4qLd-R=SBBxBR zh96=_8|@^)U5GH5fxiIJVQTlPyMg6%9bqEvrY=jy@T#g)69EIN=^-ZHV49CadimBO zw0L=@xhB>5pxD0TRcJmk_9hwqfOd#iOZ9@%3`OdtL*{9005?(Txah$+u|FPpp&jC8 z$~NH(xgn1F>!N~d=(Y(VvnR<}-~Up0Jo~jR)L3v$1Wd@_rR6RRj-i>%2WOM!SierB zh8L*C(-ZQrRFi7a(7egWg=f$}q>(ccjr7V@-c|aAH1D!A${{K=rwMANShWQWNpa$F z%Ge@FE3f7lD;MTKgcu^;!<30SLOTHU=+rcz2C4-5lhfsclbiF6wx27mN*WH{VFXsB z22vxxs!ioA{fMS?9s=IhPcV>k)`TEpzrG?<*buGcD4m+J%GKtW-pt|?m8Vj2Xqnbv zJLzr}?oB7283VddRbKa(!YYpP+BsL1VXh18KZ||XiN-RyeM`=a!ff6|Z*4ub4Z^3oomlr0iDC1^ILzWiL?5yyzg=vLpwp%l!_G1|7xCd z29s0~%mt$Wjn7)@yc2YH0+r70geEWKbPZ;{PpfQ+Xxp`i&?mfGE*-dKl@N2%Nz5%X!DfN#V|J(P0v_;OPZPRxumFX%%U+UOzrgu74{MRbYyyhlt z;Ewix+|IH?`H^+4z9M;>Q2-?t`E)FGx*+xiV^5>?`y#a{3lAz>5Z*Ul#n?sy@FyF? zjs}ky;jD$2GqJ2n2fBe#+wmfGGOSWp5boD{Q>w!$G%+?I4yr!rx~0ahz+ZN>Qq757 zlO#&d+V9(lz9d;%%|A{u&xUxYqvRxym_G%1K8_lXeIsIA^T~NO4!EcoI>rs2{F(&Y7Vq`v> zlmm80O8G9TH=cl!v!h;xgzJleWZ?Fo!Y_0`|8(2p@fA69HTghw|GUyLRTC3guG)Xw zcizUk_V8eA_Y^+fQSBGtwjBqr8krRNl@**2A~01tgcJ8_A!7>t4=XxF*Gku`}m7Z}N zuD-^!T@-0HLDks)Zzgn(Tbj7wKKO<14)5)K%+OKuQN3~E4-+=AqtXa-zm3R1Dn4nZ zaxCj+_M|uFPoK?NoIgLaZ0d&?Zj-%h{IBP}(yqB^azO6JUXQ;}6fsLgNu=e@O|v%y zmgBQj-SOyn_2LoopE6S2alRuxG2&W^5L7HFDIUX`5yGKufz*dg{}-vWb?{+2uoVD# z?w<+PtB&6&XP0Z1N6b<15$9%Lh_`9$kEpfBwOn zUVkc1#&TuE7-}g}5g^}Zf^1gcuP*I;_>JD#uQQHLFi=c?8vlACH+E<4zXuUelONuQ3=9D(ig*4Fb~+FDA3Gz8ObQLv9jBiadRoQ4XcD>L z>5iyM$7-nIA|{g-VdV+tkjojYd5xu@U;Q>VW2JtMt9^oCLH{{ntre`Rv%Qm~=E{^} zw%>u$SJ&0|k76?08pokWA*m8?0$G)!*@Z1@p6dU5pOOE{y;swksie?h&;WBaedUzG zCSzy9JH6FiJnXim)QnSfmR1CD_sXM%(7N~-_@?G{5f2YnKaKqq(+2UOza4}a)POKu z+ZWMvc(p2t@#{*YU5G7P`8jVSR}w(sYa$b#V=C?>f#OPE_7z!`YC^3$N1R^EH)QS9 zk1$!&_BjN;*|54OXZpMU?K^t~PLdTOe)cwp>*%t>Z@wrQKWb>qz*qL-uIKuCb9b{V zKAIQWO3rlcN(!Zbdqd}6u55B5`4zqIJ%+o0LLEl7!NEFqMu#@^BMal8HZ{jZhi z*aU2GeqPpxXSiI%GPwkpiJ1B4?&B988HBhS0ab?matAa;(%)tl=X?ZRWRb$g?e`6} zLYPJM5lG)s1BN(Uj7xk-7ERR!L#4?$M~eSsO+_(NGzc;)iU4V zwvuJf!(SvUyUrz?V2Y|*b0a?|;*W79bB&RCK^AFtHEp^2+Pn$RI-itxs}_4AzJStQ z+=n=)=lss;4U(S%<vb zNUrppx^84%9lWD4bnLbE!r#U_@tIUaaVJAwE_1Nz_ay=v%6X?^Dnc`AE&oT|RFlCS z?)7^Is}Rk_HNMw_%MzX;`1hCI5+qcoere%>91{+eHET-C*X z2^}CR@^bf8`(l$zs@9-4^>=@Gudf4EroASMX9iNBKYmx=!82F{Gd%2G*vlYF>G)qg zp9qPVm*Sr$y$t4JC)$t6U5~hbzdxDcj> z9OHiCORK(`>O(&@p}WnA;_e@p!VuQ>{lqyXn}Yykc1_-uJyHI9A+~%SIoz%{#&b5W z;@W*H4|YPP$53L$@g(Vi+-7jxi0#<8f3T3io0^)=i&$yDW;+2uS?vGA%nu@PdBp7h zUSX^V1<>KaSp4YN;~BMUjQy)`Q|wg^y`KSs@@tW-h$+GOotZq@Xo98$`#c^^t^(<;&h2MPlUUNzk%79gAY1l%rIG%T%}Ksn zbX{3|eq`=1*GakZxO&N&66nY<$iA)#NPNfmIx2rMGM_M%`E`t| zqVGl1o*iV#Frjy`s5h!<=9*$ngl#~VV!Bj%ARRmckt3=9&L@JB8&JZGi77VvYM_Fm zO(C$?<}iVH;o0~}>Md~OzqJ{e)Z6}ngQqZJ*=qtVna2Ql-3&yLG7KcT>YDTpK)6ke zKeRX>#0eZ#i5y1niTJWqhxnu-)ERsPXu+Ml2f?AA67uI?Ej=5_yx4x->C!{JKI*+& z)%=&c)-l41H$vx4=-H3s9`(nWtc1Z?tvfOgtFxC2w0q14U=G`-t*xIIetxWSb56N% z|D5ewz;xxwD;+HKw*`7~{QTiFw(ak98rI%~E*2+?dSaraAw@PmUh-sW#hD9Y87DI) z@3u~TMv(LMSie`bzI@6RiH9_i$W6)D?F4>Q1ttbLj0LYGiWXLgF^}Cu0Q;^?et4Nj z0ye-x`avP{fF=m+La}^!92aIvFozU`$z!tuOJm@uWCJFJAh@-ju$&GY!~!Pbpr7M3 zp>)EBRS+u3M=E)D&IRX3-Js-P627oOUN{ItNOv7J%JkSS!9vAh0##^9t6b@-v?v`Q zI9%JvAPEO&S+}S;$Mjz|C>YXN#JV)y_D_Ky?8rU8$hrpM2>xk9zveCVa5*|y|6H01 z2DtWnPlIe)a`Ghxf22)<(-)o7Ot;ef@W{_p%ZYa&!5YTKW>W|`Qe4P^?aQl9j*9ls z%afudobc$#9b%mDZogDh0Z=Mvi$M)n!TI(M7SiFSW#$hxOhHtccQ~uoCc9~%H=J+V zC@o94q*NTz{(Sq^p>Eb^;a80nNs1-^^Uyg&yuLB6aafw zSjluS5OFT*t36N!2jf*DJ7G4aV5b6LB#B8Hi5WWp#OQJ2P_>bZgQ4d{IC67TQhx6s zimEN0s)$bmsQ5%Afk`GI0TAIf zoLLkGG3?;R6#6QE2E%z_Fko0M>BL3M2RPn*qa!9U4cdSjOqQSS9QOhjU$(KA0v|60 zz+NiZQoTr6moqwV!|?a4)xtD3w{55u+ia~Nq<8m9pOzqniS8|Hs;oSB{ zHCfGZ^pFT@=HRbHPFxn8_N!Omj{^iqLdp;}?TxMV{DZuPYR*E=in8l=oLy&r!ySvGUl9sXB0$m~4%hX+Gf! zv{>>7bvI24A*ZQ1_{%Dg-*DnmP%#d_(yJV{&$hMP6J>of@Yl6Jq<$U-yIar3RScl) zE(EILxbIz}JRRhrnGOpQCjavHspof8zRblxz*8zN+rN&I!t1*QMcF<1uQuQ?mS}GI zQDJVnOX`^%SO-|oMg{(r*Ona!6^bRl@i^8AOR3bZkL(X{-XmKLr*vQ!sCSX#gY^>Z z;H0|>uW^)9LO#Q(2ryYl=i#lm!bEa>A_Vk1*kp`C4U}0Bxkz`$y15Ov`MGI>@Ml^v z;xR9{pId!sXw|4&tJwW8^z+%IXMgsPi44dRbF|k2I7!kE7%0&~DT3``;&D=#ayAtp z6@v;M8xNc+!d_iu2#>^ZXxE$lfb)tRgTZqBu8f+$J`YbC59*lLE*~6E4pdxfT^)3- z-`YIZPh@z?+?Og31PS<@to>BtVx|4rpWxiT_~5}8Ih=P%9P~$*qud4>WXs@4?~~-2 z=Ct2;U}lTnR`%)D`9-npSwvao6nhelJ$uL4975H!8 z2cH*FnE+R304b1UH5_DLLJUphxUK|ETj`^8CJqXS@lLm7A{Nd*DM2no9GocW&-yIxD z;i{9sNYBBM@+20_06-*GLTylJXcTGGgxnmc1`LTw!Q`VF%&{(UX!_b)!5$C7)xo7? z2buO7h`0zxrFh%@2Jr}C32>Qs@o9Vkbkj6g&jDR*6CLUAd0_uJui^CY82Z4z=bG#$yMJ{g+p4 z!4Dii6F@aS7_T`_r zXSfxC^U^mooz?jo-oUa|JbxD!C_aX0*F1G-1Znc9$}%Nld|c{!ioN z_Sz8(Lk;2zcbIbf9@&hM^%Puj6j{ZjlNgl1$9;Di{JnX8n;CE!-Tbk(+!2jI`2@8L529 zHgR~x-r1GS4|?R#7j;=;wI7!s$e-wA*t%#(MwZ+EeEqbGd3e5rr=!L-WUJ6$fB^QwGeOgYfwF_mOa|0zzE+l+`?qE=G1#vBC?(o#8 z_(3J?a=|6`cv+JEsoFYsEx)C%f^SO2?>~K#`g_~`>HChNpwgwC!u8eVGF%J0oG@pr z(DI>CN%JW;aM(hr_q))_e(55#S$wPR=NK(dt*;}h&T)>#;{ID}8jT`Ofox(H?i`j5 zVTqUDr2N(YIIdee?yb4vnOwjztD z0%2K!g=12*PXp#j9UU6|Y}2zl`td!TXZ!XlZL39%)xbO+>cA5ZQ}E9Q~D!6CAxguqJ02 z!BEgPbWz8QEnHiwXnib}f1%vX!5&=P81lYN&%h{L>_Txq+_}(1&j49qT%~JkQTtbQ zN~!kLr>TwM$1-W)fUwTVT4T?p)e!8d3oMH8nWyO54SOQ!fA6&-w1Cm_{Lz;Rt?OT| zl)_=!OH@+zN>&#egCqe1HnW*+GVWd(t(c~Dh7no(-7RU#(((_V`}n;GH4Uk|XY$kO z*t%2(@=Az2Zhd%<;^N>=ssK8ah!&&2cjep;U&-5^PFioU?DRFo6}kCTyMd6XSk?Pr z;3k#p(;gzt1L7tp=0`cqX40Hb;VhH0|2>uwjsgq9OOI9P7j`5<`9D)Q-Iq;y(0aE( zNa`x$dlDaNzb}qZCeKEf&^grUbYEx$vfu4(N@`~|I-qqaBZ8E${B^PYJ=R)Dezg1&Fn z@z+b?>qIchp5m2bY)hvqe$JMmmSr%sT<>^;R4-&ahyP;2l=B*-qEKor>v+RIqISXS zY%c8OBk8h(Em8aG_@~ZIUJPsMr;xdFm*p5n;-C$~b_#}aow!h#?X2I&lu}touDLzG zGQAQXqof}3Kh96IBJ7RkV+#R5kv^V;{F`t5u2qj* z>ql+|ziWSg_`uA$JPF4aOljlKFWaI{QSqS_6I^~^Z_t#x1lRuj;b{}zpgw2zYv~^j zI(*#^XZ!H)%3&W*wLgA&XRe^6(@YwI_bQy^V{X&nmvB$dq~d}DEvDt!A03$GQIZDU zM>^{VflP<__0xXeCn(*+*6pQE827#hMeKi5OW!Q{2Fr)S||H6#2Q7Q+g*xKR<@eXR)NWP zpFA$P4tdxsT&kr>U7Xk}7;7wn)J0C-+ zUXTp`qs~?nljF-?y`|OZ`3-+^ZMZv7p3eL0A}mTpxuNG_;a=2v&yIx&a816kq}N9- zyMVkSrTHeFAJ`TtPww*K-+rCK&0Qv#=CqvoK7+q~hG}XHxbc2?PsL*>qMHP;Q+AR} z#8493ZgEh3oYj6I`Kdw3?S;~9eySE&ZhP}6Fl63u<3`$J?=r@ub_A4UNa}Opd~KYR zqwF!>5Gc;7tE^@;!K)?iXl|VBTD#~f>9AO=f4g2*SHr(a2@mGaG~=2~&8^L~GN)y; z6kvOsw;6a}cCy>LU){E>Y^FlPdur0si}1P(`CWKv=!gkh05B0UAzn=OMUaH8P)8hU z=&K#?`Aaq-dL{e3^(?n9qE%i4SPGa}K#&9-1iMs}_B^4xMU^&Ds=1T)HUGNVSt`i> zuNy+s&TUqQ*L`%;*<$eq`@`20Hy|jbREpyTN6XjFE7Hcu`@!F9st4y8 zMR@YqM5$MNZ&(6vHnnBVXDu0BLl)>q7K-lehbcMv1V_rw+ewgw$PlGlR$xo*Lg5GMXIGJckzkGXhXVwQRDE0(l)`S znhg0wOQw2OWy|5yXUoQ^#aW`t_EC0YP6hM@ZEE`-qc6D{v>%l!9yuPpqREd5-U=Eh zv9~a9H*e6Ib)c1@#h+=~Ov~GQMYhPSIQRWsjXtb2J~@u^#EUEkUn6)&u|-+r&|i4& zh1$ZkX^wJfb;wK%VS_PdsTY(XzNA8?K8J0euh%Vzo75+*Ev{K#UiW$37lLwCVaBu+ zibmu#8Z{qkSVWPad?at+Dn=QB#eVI^P!-rXjRiV!mFsVBH8Ke$qmY=$3i9W~!W@HDPt$x?PrwBP_u7wQy4 z2>Xfr8xRV6U42BM%=4Wv_vAcw3@dP;$=b`=UTk*&pG88#y3?`-$D_`#@6z;tB z*lT+~I2{P0@_APD4UYmj_$W82nW4cXw&~h}gU-Zx!iGc1vOhq_VVuYj$s4Eh5(k0i zcv{NAS?$-lRB383zEjIW@g}+IZ~Kp-L@e=j{iz?LB@*Kq2@6~%BRdrRKQ@Nsq!+Mz zu~gzH90Dt8Ce!-dPzm?u?t~ z)q8kcC51?P{^}N)IUy0}NupJtLC5i-%aG{WC$tRYbSCx~xUAx!0U%_ zFQ;&q8XmjWx7HSCHl66en@s;}ug4~>(JI!%BX17*mZ&<}o2MP6(p%j2|Tq%~_O zV`RdWshFj>#b6;xEN=}a0rr|OaWw_iGUjnHCS2?jLrf$nWSR2=Yt77e>WmBCByFv+ zAES*Ljh>X}|4D96w-t_%Z5_Ygcmv-sFlW=9n>F;%_eg3Dr7*MrPox@7L*w6fRo4W< zzDQrXO|_k*F2$7Wp~wuK4!L$JHJ-!?4%Wl}!x0w2hQ*Zri}M?BHmb%)UM>KpJS7*8 z^q$>9pAYlD)lxKduwTmU`{pIrk=@3dxXR!+$aZMZqIqka`YJnC%Yt?*<;{XSAGd@x%EztI9-^FV)3P|5k;IGLunt zdaMqubST_eq1J^abZ-+1x|E9)U-~et@2rAv8_P%r5*@@ve&z-%AWUYwlZy}U=ky@x0ZT!B?|2#7G z`51`1!$XWWk5{GmKH^;UFthI?o=cb1fz;U5t;nH?459~@W?et*+Gg9dg?(#EgD%<- zjJTsbOpEQJXSZWphI6<@LHGF!H*;n7S&!iF!|l%xFat^AKd4>~j!dyCibP00d>lxl z5gAEfaHl$mmRAVGnI=iF?uZ?z)`N@mW?63Tyf2TzePUIGqSY^qiQ^^BQpn)us=(xt zwj(1dB`d99Y6y{=EK@O{)BRuXW`kFns>-S_{Vt4ZhonfzEF<98sGzZ}KAXEzI@Cge zNBz-^(}4ZQZI&RWsd zF%7I?iBXDiT%#;gUI)+uL_-pT(#mTgkS(YHFD(7_(pyH}mM*!0z=be-8VVg^>**4O_ z<3iz}&{*uh+$z7#6(?>Pj6kBlu%FlC|0bGQr#IeMAH&cin-?<`C`td2w-kcu7Ls^) zhj5ggtisOvy`zyrWvdsPFiwHpA%b9HkYML{Yga(=*RRe&d?G(WQ$a;F8k&KZlz<$p z4!LNSsBzYo^T&ofBdv*#Y4DJmdPX<3rQ{d^*j14?({AQnOkZ~U_FQVjpE!lY`QN>~ zCwD?s#}vx+0l{BArWHUf6ijODrD7{GyW_3fveSm_onRFvGa8#SGAqol>wzOEUzBqa z012}W83>Luzbw$vfu>^}_xmJ{V{*;=fsiH_(OrbPa%Syu%(PlWm(bxI%g&;(uY-_fbl z>fw`V=hc5;(PwZFj1@2BlT(B$Gi3}dNPxfF})L-TXp8eQ-=i}CIE4sx_wtmp{kR>4gpRgRYpK7l_%)j^A z6xt=rqv_W&^rfUIWn-f*GgRO%;B7owHxPrI8LT+Mf0g~#txs)@>I>QgfbEOypx7h1 z{VZkc2^|WHtSZfudnuiy%35{cz&d7_xBQXL9si)c@Mw*yPi>pC_vvl+mrM^t&AYWm z-L{qmk*rL2OA$q}rKxk5_fan6r&P5ri~LWCE=AUcydRG0=Pe(Osp_({vi)lM%F+t= z#E6P7d0zZ3?pgkIYnm}91dI&=Kh!f5gx89g0@1(C$F?F(^*3sQ{rDdg!+wwS-|pm zaxY(MuYLJR$oaY}I!49nT|~(Cr)J!&MgIsBOPqu#a4Gq6Ao1C~TA5s23=&TY6M)*> zye}k!{cvxr6P0{a7m$N96?aIrl#fX!qwOx3)Ke1gxmTr`8pb*Kr32*S`Cl9`C zKg6j#JrR5K;w2WlW1Hp1%>JE_!~tb-8mRq;d6lcRtg?GHUKf`&R^ApeKS=>)?0O!8 zOeL#Wm7vg5V8wvI85k^M1@*No6{Vx|QUCkOA%AHvTnZk(( zC%ReIKWlbm9I|t!90*Dq9_VfR^Qv9Mj59#9B@$gM&IMJIEcWE7*{XXKXi`+z;Qppo z%JU0uMtoRL3es5wPWA}lMvIw3_FANPzW?cu@IXRH=BO0Lfgn9DBK+hKZ~OWG&8%%A zPN)jP=b4>O_&ix^D64(IpPsA){nveJIYHOe?#XO6@!RJJbdM9TUGHYm9~WF$FBzk}>5=R_rSE^dhO5)wgm&txy6A-fY`__jjn{DNN4 z79LhF4MA-5YIGc(An=DzP)UDsbrs);(~Ce7GFy?037j%WrFzAklJrkA_PeR179J9_ zD9DVLEC_#S%p2{Ep`EcogCG=d5|IqGEix<&>Npm_06}wvi!FoeSrrPEBMRBRj$`tLR)kc2Hw4v$14gEA{L2 zr^ubESy!*xq=aVj>NRI{)MnVSmY?&f(Yuy?dG^pyTa%JJmgm4$_pD{MhO^@MsExfP z<0Pv_(5{9$HDp#_UAej>b4p(?BUseEC3~{KBQm44ssRf}9)*FYNcdl<%&Rip)lb9p zN9if1F4&mLN?JsSf3amv;3r0PkBJiyLpz5ps;@Vhp?^9a(}Km?E{4A&ep2Q2A%hkg z!Kl~~ih-iFp+%Qa*a-uxNB5C58Up6fq8z@hE}m03jL7GNtau<;{aQDgnYjq)?I5t2nK^8v!V z^W1p@`*n*hip6z{@qz)K`aklm9kz93)qm~dTCq4oE+UJIBKUj{{9gu{b%amDCe^#2 zJvD#l$n$ggD!bfdGjW(zd=UeHO@)sc*))Ry4npFr*D&`nBfk=IL`GT%q4^?bOI8L8OX}dLT)-#{=!3#d!AWNVJu1lK zqhy8XV$We8qi|u&C~7UN-UJiNNA+Y3RtQNX3JXRLu!gY$*kR-VRu~kZXSNLkZVN1M z3qIxL3)p~p5u!@K;Fk!XlP-`-z21jA^t;k)4d($LI@f7yVw8Hnsn5TB@O_vYl``|o z-+23UfC>h{PeyqMpAUobx+=w;lIEnP);O3bVkiIx0SYw%GV$Xk^|h8O-B|-(@wm9c zoyK3reKqGQ1}ns&xd3V~tg>EAArZ=gKnP+;y2U_jUQm8UdJvmO+i9v}Wi~E^ zje7helkUX^NE-qicI5!3VC*tE>^yvSMK#~ny*pym81R3I60B$dlf(s5Gu!xAxA>^=psY}G`4mG{IG@s8m;Lu$*jdt zXlogadTU|Su5~EpH|qv$L+d6S6p|L_#QFpNAo5S79}-BSZ95TZZQGjWW}``-W=mE6 z+n7ECsN-?FjtR=e1f+;69StTCm^Y2mX2s;E-vs3-(WzUrRZ(!E_F!^Ix6Hj@391U8 zLl1+YVyoSvkW6m!_?^p{4`gzSbrbpbhza_FD;qIM)R72Pb*6H>xuRnO>uf{TNc=TM zM&QT)d_R6^JtJomllUF0@mcm+*H6nzsK@;|OV0%T?}ZnqhzpnHjh2HuG`Cn?F4+J_Y-~e9ioe z)_bwbs`jkipw8a;=as4I_G{-tJ(}!T?_I(+yQqmWBLFVv;EMi&1i(uqNM&xIeYcDVGEIreg&=Eg*uG3>j~6ve8c_9{X00 zTEVc^m$BV*sut7eA)8>aPmf~uyF=O7xSKFLC>&Tqnb5u(NkFVFI7lzwO&Jwo-0@|-Ox+NbY5ifFK}MC6>oLIy8JncUF%b=QND%;Ei&a-U!ydlKC~ zPaZ-5HXBG!jQZoBxDpf!$d2N{0025Z{q#|Lwr@qDBJq_H{m`poLzN1;@j;gw6oLkD zG-H4Kk86Dhn<&ea{S3}vn2oiNQ7?wt9j!%W+(fIND^ajj5wUtgl(t9cJ z4*R`(Z4h{SIQ29rK#3^S4O0=WC2Zv6njz1i(Ku)jPy;)UiBCEXBX${u7D+05w*iOb z^NQfH-~08ze)W?{FK`E1uoNaP=Uz|{^N;YfeY>`G**SP~+cmG|y{D8MzYk{8BRdN^ zL7j&@R2-;xE%CkxaM}uJv&g%aPnbS@%)9^H%c4qD8Qq_Ci>^i}eK!?ug@`f3SXCq@ z$+?CFK$V&aTyj_h5FSDfH@9>Mmwqd)t_dKP4{kpK1fsz)-0~~8-vE)4B6VX?ZSOm0 z{9cC3&C%g`;6+y4hS2bEHLZVBQ(rVPO;A&+jB-88tm6`S$rv-iDzWrAu&Bb{wgCty z|4K&b;ecKGNGovf!c6eRFHj7oO(%cTV@!ev&0HeWtnNVAamwf5XOZX!5$b ziRslGU5HxMi|BkO{LC0lm_B>{2c7r1F=fBKNw?OQqu7s^$q&Pq7<@eV(h2M&&YG@T z6%A5Abh1KDy`z=3R3J?gUNV}};8?7E9$Ykq2oVT9VsYJzZK zO-fimi_vkUUhg?g&qO7S9uLq!$QluY!AGm_vhp1*M;t$fvIfhEkYlF+fF(jFcnU6t z6r>p-#&R4Kt=X=p5&(>v1vo96^V;#@Ss#^{!)Yu7*E%j`mf%Y- z`1*~77H{@%J?~sEuFJEpQDW@6LO?zmLe$U%qIA2C3=^i?3;@$ebOYt*u_>@AXt_r% zUu7QQ8HG}{ntHTUUk?_YOVEsW7B}sW{0b_XA10Uig`_mJ9W8C> zmgW+S$Ez``xuPM3mIlmH38TGWNo6U#%{Ps$Et;6@dez7iEKdt1e-`Czg>S|%ChhRp`SZ=9{h zl=X{@lc;R;DvfFL9#(V01)k=t8Pji7JfW=WjB9r{zM*_mwV&pEy!OpxFHT=s+4;?m z6+!2H#Ga?oubzUX0#_rGHr;*Y7*e=4kz%j=5ez#9MoD65B2Ub4i>(@$V;|$>9OoYw zAMnMpy%hkg2Sh4OeXSV8Aju9|`4^wZcqypfF>A^Jw9pxWiV0(>pwt9i0E1A`K2(Sc zV%CS}vxUnTRTS0nDr1H-D!Ns~P_Q4SVc)31;o;@O!$v5AoVGow;-wt3*OfNY({dg| zci&3?%y28XtE>|U{eBl0dzxb!WuKk*YPq>ZOEgWtht%=@GMF^JidE;hSVpixD~Ijl zQI^84o{L6PzM*%Q22`YJL1^Dc;bHmRmE$H4etQ;P(B8t$jU9yMVSE`Yuqb32nSg3Z;peA;FrY@lE)@zSV)@TdJpwg}s3PFxS#lgrXp{9SO4L zIia5zpcV&?w7GL%DwDdu^DfdvvI;7ic{AoP_LZTY2T1; zQ*b3`%Q7e!aD8n6jT((c|8?_J+Zp_|Cn!vpIg7 zE{LIO`^(#aDJh;{<~-4y{^Abo{BM8RJT#X!_O$zxOQw@&Nt)@h+nSy0{coP~DX&IT zYWz40woEXLcX*?MV(r*BDN6!;fQ^gc{9EVm)_hbp!G&WGf#19X6ne-BF@&5Fe$fnJ zQh_~Ath!N95JfhY?uBc44QbLE6lXBc&x0*^qK2d`(UKQn--y>nz)^z3RqaM|wRBpkaN2Qx0vA8i!fcy z)P80v1pbjQiCd^z0dk7kb;Q=ILXeQmo4@?t$U*4WxP7-c_8S*ql==PLva#$&JA#Uw zlizIsQI&05B|Mr;s$Ra-VRT;z9=ePI&>P$Ooey{?c1IwnZb%?6A!dSaES3bqfwykZjzq zzgj|ZEK-En%kHzDkjx|u-Ki2kkPzbBm-LVJhL~hM zSe5WfuC;An9}90|75S6&%=7DVi|)tUlRZefQ{bR^#gnauN6_<#qciLMwix%p0vB}Q zFGKQ4^`N5(EL$ISNYq6Y>0F}#tgr{SzKnqtK@Ao3}Jq+`;j#jU>V#v2CHts*{87%9H!TPJ*q z=alIWI`ALwZFv9M2e5r08|wA6Yj{JhvBUZeT=NeaJmYtr+7wt#H5oMyMVbdA$t$Q)}<( zELc&DE)|{Uh|1pLOhr`9IgBrrmg~wj-cc^EBs7FrHFa@NFK&e5hT=OGtNP*qNGf6) zI0`R8XaY5`5d$f+sAkKKso%kCdZI@sacEm`B=~i0NyXGQ- z4o=+Mn>{o3ng_1#M?_CVk?w5~=~V8@xPDkpKi=DzI;G`ruZd@9$Uv7|# z&8o@1!H0b;nQS_pH1-C2>-H-`yrmA?RqZt|f3h5~_h!D7YWMwoCI6}Gpr6TN>c6~h zLg6d8=)Gkxj%`6n!z-B!pAM0-Dv>#E9n0CYL(a4xeJQeUCqC}Jd#VZFlb&nN$5_mf z%hlezwPht|izUZJMDhSgaIQ+gc+}OOsO(0HBNrOqMp9UoF#_)3fM`2n1H(RVqkY?jBh&wsX>-g5NVD%S+6}-%551Uq!_Xu$~8pOT&G@*pl&*^&kL>t-h zFQX*)ew8&52qF+``EVS(Oot7y#yzEA2R$c%YN)aUoB;W7-be}u%SsJ`o$j4Z zde89*&zT{SjB?dBht8zgxJ>I*m*|6kdEIV6F1Yc1-VZ9BEXA0^IChOp5+af~KX!(1 zNk)NMS((siy;{?igY}22J!#Y1XKih8iStLQBMN^Nshh}y_(XRY?jSn)&;k+~+(0$x z6k~{JQi*spBMqGz1yqQIQ{=g{s5_$$Q0^+O+V>q}&Me%fb@kf$Ezl1!>Y8nse%I-LqoV3`&`8kgq-?HsdshB&~@|OwN8QY2Z@aoiulzo5wXQUK;F*#&#t*Toi|g8-tQ<=XcT4z*kfAK21v5SS?GH~@VSd_c zMEC&k&a<_WaC4{D86`9*#ZBem^AHYJj(1$+){xxgaT)FnCm-xX}f zUp~PrM{iIAMbB}6xD=_-sqm}lu+s*WC2oCkd4wNQ03p@n=V)eB>0MMHP+GJMGwazY zv^%&*N-CC&SwL1)Sxe|dN5{|#Baz7UGYCZB=@TUK3|Uf9;|*ghvl+zLrX71{aHApt40d{wd$~YHMRMw#6u~CUM?M#R3N0i; zbd-$ezq&eaENIkbLms5pctS$gwj1{vdR!skfS8BUz0bU4LQ zNR@x)LgU*Vq6v!{5ub;E@kR@vW~zm`-$pm9U*{z8#u)^)nk!K;t_1p5kK4TMh**?j ziD=n1^_}_#F#S3YD1y~7q0%@|%>m!WyR){tT7d{NVcg=hVKaR5_vT<(NfZ2rTkr#h zuz$hk%OmQUH0r?3$b1@0l`uxtSzfu^IjtqBS;J(%rQ3!}j)fJUJ}T?WDZ>tNc_Qx* z?D3guf3estKg(A5(k_#!G(}jqkh;2sx_(V<&vg=777O<`eFo#XXp4azA3ly~M`r)s zwS0OsRR4tvNtwDOJM^+g*4(O@O$q3E^K_ZpHhmbs|9F^IBT|jcHRL+kU;{Kxdl0L(>(p={X1RjohI}_z`3?3b`lLUka;3`vBI?rVQ~5O z{8@ln%l)G{-FT-n2u|etzjePNffR-86H@%+?kc>aa%&O=^}9^QT3W6E^oE=Ebbk&+ zz9jEIEZ%_#8DX;RGSzCplNL*c=;=|m1^+?mxw4W-$M$)dDMv4NO6?q?Aef9A=OUE6 zLKGkoBWYA9FP#>MTQJ30#1$0f&*JLF#ePt6RT+E+OS~ost@i=V%0FxXkmeq14?mkG zN2-tEyT(n}!YD!Miv4h>SA za4~+k_Hr?_o|{{T1!177>zgmT{yV1gSkbF|-u6?^yc_5>tH9CCZoqxb7%P=*i^3t$N4Jx8Huj(%lO3u5tNQw}s&f7v8sg+#y2hfgI3ZG)*}g_v*pcj1KRK z-PJ1n`UfSgP&M7?fi%+8Qdhace0FWv=Yw55E7yM9O&IFzqod71_`8DIy#6+@ep}(E zCrxE9+BeHDFZ6i6j(wnqtpy>SO{E}|M14h&-`s48_5-A(pg4q}mOqTv0ww513X36? zZ;`yxCnfsq!zX7dL)_pT9pt{W$9Fmj${CzTgnFF~Jk3EyT$OcnJAR`Y;_^c8w+RM( zoV&6vKb}}2Rm7c&trqHc)5ClbFdQdaCKYFS(bZfnF_gTgc)fx!7fdx)7N)}?r6!NU z3u0R|ggzm?D*&la?vZ$*bTD!fAMf#KsGro z-$Dr&Ko|p11BPXb>5z+4)T^8Jvf(~7S+$l(X=>)2&ry|Kkc8RRUm1|$+>5E*%_Zi| zQK`x_*JH|NCB=D};m%<>kV>V6@9$W~3)(xYEte$>gfmZ~fu>lUVnedhK{!>U_gEss zo;E{2EZfyOv&c7#^#S8sp@$!T@v;(syw0{Vsy2^R)t~{wtI<41**q@!_D%)UTg2#~ zGaG>x&i1+59wR2^C$0b1?k_Ljf?aw(91zdJR6Qj{L{mMWrc^{+gZw0f_oH|W@XKVw zc?3xb#iHp?mWC`>mfLJ+QPx{p0*>#X_F4v574hEr#5$+=d<9aMJE4<k5Jblr*?*#rj zt?T@wt-^u1P9#5Vb&l+^IXtS@wJjlnKXFqzmg};**Sl(;q&RxBm#^V-Jau!O%Qn#E z>jUlZ3Zq%hs9G{;LMTbZ!DR;Ghczt@AAo|Ez@zZzlF~gu5dIE<4I|DL4`KDEgQF@^a#p}WjLW>6s^?!y#LwU|dgdkOjMz40|Lfm_P6 z{OMJs!b)I?%O9bcAf;jo+0odoU>v;R82+~nMie1_;`XUyAVP!4mCi^|bE~L-M2G02 zc{}QpK+65#3dczex*xwfxa`w$cVGl&ms>hy$K&c$N$u7tWw#rq8qmpm1zIAJDotrf1rYr@V6{aY8n0*|NShC4baOF)|2tc<9 z2b5PJ3xbI)u6$zgM~36o+Z2dRWM3bRCD=lFbwIajQ{hYB@nFfu6Dzq{g{Wd%%ca~` zx`-Q#z%!G@xOjX}?3U^m)-N*yQZDd2@w)N#z2$S`2B&F;L1bSH?^6ww!1VXb%RCVTDmz{HRGJ1SB!VWGJzCJf- zQhl~weHbwIMR>tCd!bw~K7~DtSi%&@!u;1Hx778%C4|a`PM5CneJ$oxkY{tI z!q1Rd%bY>(KD`eJ!lAnq{`&0#zhy^6SY9F$_Ih!y_Kbum1U*p67IYKL_y|ZATT@M# z2Wsa{T8`r2&>t^8s}Bwi(Jnhuw~f+!#4b=cFQ2Hhk8o~4HcgMb9plflMatvJ&c7D( z{TJ)L34~IV+@gWk!rL>abgn+ld-=$v!k-tv-ZW98H)Ij3>#Z^;$O6SewJWUOZ3`L?>dj*of!*H$3nA4J5ei-=>!yEDjF z&6D@ueOx$u=-QR58VenI_`_c5Mov1&KJ%^Urb7f$2w?J!FT^I5kup|cdYZ@i7(Qwn zt^~phCDvj<_DvK~G|S1TbF-vdT07n0=@)gFI1> z?+q;@A@{DcyHYahCf^uNj}}?gGp4e3H+BT9e5)QAK!XIM5Hn{e{_g#?okwnY$jjU7 z@27;$ij`S#2kp?@4A1p%cONjhPXQDHSUqwH8qRYC`plj*uH`FY)}B1;N_AJi$bl|_d|Nns{AL+<)O5dAZP z{Qy5(G}&r9tj4mE2`pMNKAHMiPE8^O+zif6EORPLHc5rMaE^OC7c`1aa6+rc>$`TB zOt;2Gi%zrcigojUQ{-wKeE%MFzIOoD{mCtx@~V5vweHR2Fq|Z$K5I^4_xDnSg2>Zf zEP89Dogq9-;vH_Aiq7q;k8ZopsVQnNeMqH$v0ivM+h{&8KEl3vpgU&=_;TuH#z_|& z>OhTSry@Y+MMZ?;a;Pv$426rM6v+1JyLImrZ4Uw;2S+84IyfXNf?YXAFdDoAfbAZoC#>0c7K(cV|T|HkFOU2Vs`QRil6IUz%-C%X5HK@FEdj3g7Esfa<8RuHgtX9z!a zpd(vSs4_xSbMj@=9LB*b_q=@mkamr&nLBwX;0^>+E4ScMILKzt)+s}rNdb<+?KLE) zaKS|5vmxy8^PB@s{pdls%Dp0rZ&$yqRsGKGvv7F)SYd5<`Sr}$_xGEv9lU|;lz@XE zR<7RkOds7WCYj(6RRMI0i2C*;O2>PkZDF;Prv@e@5qN3ij(ShFB}mSt=EYkzw|Rq~ zKZxA+)1L^>Dg+!#?sBH){sVcBX)!Y< z#tUd#uf_0G$Xd7v{P-)f>gvhKE6Gpl3iGkeG=@UYhmoQvWHTC6^>JYGWZ+zK0aJ&a zVP1LC^uAWRexfT*qV`_1TUfZXB*HQRyGPSvb56qjHsM|V^5B9i?Z1Dgo!^?G_U3MX zs#`wdjQlJS<>u^aq=CL5Zx0`+%ar=??YK(HfKmThnIC0J!|4~fFzpcz^ObYOv3HlU z&Un$~1RRM}GwCDkJVTew4eA3IPJ#f5Bbj(Dy|i=>Jfiei2Yg9KXc}PkKKpaM@82&NzgVTq02(8_^Ic&BLoiL z15;5S?t{!|;^(D4C713~2};WE0A5Ga>D$Q(nRkw&rLzv5J(mD%J2evM&m3$2<(Yh# ze|%OEuf@Xc%TBk+pTXFpl+nyV4ALNcM*^e0oKp^$vg%CJ`U?`eWAn`H*<`cLbcOjW zd#q<$rMa`Fva7pY^QjB_u|APzeD;cSmG#fbv=@dcY$_;e!$vvexk!p2rZK&TJ%!{^ zo`$7TZiOKS{_#Cg{Pn9ZxGcCczWe)v%pp+5(nT%fmZ?g!V5h?jtTg#_Aj6mTf)! zCw4bvmNgzL4J?*=c6N7$H1XzH(YH?~{m6&*$ED@zEsKmU{7sZFx&76(tmiuuqw(_o zxco7F>&bc3O5<&#FV_MO_sXJ`1zo|Roz6mvKSb!+js$7<)iJr=XrH}BG%i0g@M~#8 zvg=AU%zlFTIY1+dL=LQ8p=Kn@~WlpnZx-RTJcY#VPYjYvPSSq6O)8wlJ|56|Ort>}REgCsinyfSCW$%P3dcnPy} zj4Y$x73S8yEcs{i4xiin9lwpvsuj^HMan8&AtiIR`x-0dlAN(DfzOm`X$WW5Jw>J>lJ<^_5*ZOqPZCq8N^o9X_B;|8>T;(56S(hOd0?@N zQyKRN_6&3e&Rmo#tiWl<&QnsJJ@XY0!k|W^4DL+^iR>kjNqG9ZAV((xaKvqfXLWJn z3j4-?zQK1D?qF-5XJ5|OVW&y#5+vN@5d%r z%XrB8oE*k|@xI+3=p3U9oA(}OYZ|K@zTX@t%!H)JRtY8)4^}7X8k2sG7CXB?>q_L7)HsQPoZBD7EX2L9{b%zr zB<8Q#UEDZ4bdIIV{#1Q?mkSHoShmFOR-D(YiwNW|YY!0=BbmwA6=~dMuDFaeE=&q? z!V1$s6K2 zB~KTTrEEMw;cHvFWsj?y?kg=n@}ab+W8`q3`_XAYU8tBl=i%Yj-;4u$ z6I!qAVRag1QL%!1x%&2+1#kY@ybn3=w&=D`+N_Z>@Eq&(|J9|2+A*QiQx!$Jk7N^)}+D|x(Ji0 zg+@F-be7KlwYQ$R`!MUM>~=eTzoR1O@JDrArxN)Iu{QJh&FRHr{m#2qL9s#IA_$0u z{#YPYkae`oTg@)LV4jPIOepKYuOnUIbZxD;jlGG07s8(#_=Q-W{O226`~qA5{299c zIP+Y*pJpyi-WIi1v+Dl^0k5}Y`P0!2{Jav^1u1y@%nEMqmc4G<{F^oL&9(D1<)3Ef zgWNpr%Qru6TawRycbJYWG6XGfI|8)aJ!4I0Q-elcbokEF$wWrj;Gc`Wyq8YW=bU7T z+)7o|!S7&Mx~W?lY$fUab0Am6Ml$3&Tkm%iK7Bu-kg>tWKALkSQ7le8 zWZ>OAop)wd)}{J%EyPgS{N@;v0*-1Y90KMXq^jE0xOhL zs`W_Y;41eklQun-hmp>UO$=clqcv7gaEFg=bUOw8njCop(YnQz@RvnP7!_nX6iYdkbyvpe6A{69V6|NkcJIYa*e4i4Tm&fWN5 GUjIMKg6h=( literal 0 HcmV?d00001 diff --git a/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt b/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt new file mode 100644 index 0000000000000..e3dbef248d0b2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt @@ -0,0 +1 @@ + the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 65866fc9827a5..43dedbc394c38 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -314,12 +314,111 @@ stages: pushd /workspace/onnxruntime/python/tools/transformers/ ; \ python3 -m pip install --upgrade pip ; \ pushd models/llama ; \ - python3 -m pip install -r requirements-cuda.txt ; \ + python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip uninstall -y torch ; \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\ popd ; \ " displayName: 'Run Llama2 to Onnx F16 and parity Test' workingDirectory: $(Build.SourcesDirectory) + +- stage: Whisper_ONNX + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Whisper_ONNX + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: Onnxruntime-Linux-A10-24G + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/ort-artifact/' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu + Context: tools/ci_build/github/linux/docker/ + ScriptName: tools/ci_build/get_docker_image.py + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimepackagestest + UpdateDepsTxt: false + + - task: DownloadPackage@1 + # The model data in artifact is downloaded from openai/whisper-large-v3 in huggingface model hub + # In order to save size, removed .git directory and pickled files, and keep the safetensors model files + displayName: 'Download Whisper Model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: 'b583ce7c-1a8f-4099-ae28-5d5f56c478b1' + downloadPath: $(Agent.TempDirectory)/whisper_large_v3 + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/whisper_large_v3:/whisper_large_v3 \ + onnxruntimepackagestest \ + bash -c ' + set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ + python3 -m pip install --upgrade pip ; \ + pushd models/whisper ; \ + python3 -m pip install -r requirements.txt ; \ + popd ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip uninstall -y torch ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + python3 -m models.whisper.convert_to_onnx -m /whisper_large_v3 --output whisperlargev3 --use_external_data_format ; \ + popd ; \ + ' + displayName: 'Convert Whisper Model' + workingDirectory: $(Build.SourcesDirectory) + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/whisper_large_v3:/whisper_large_v3 \ + onnxruntimepackagestest \ + bash -c ' + set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ + python3 -m pip install --upgrade pip ; \ + pushd models/whisper ; \ + python3 -m pip install -r requirements.txt ; \ + popd ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip uninstall -y torch ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + ls whisperlargev3; \ + python3 -m models.whisper.benchmark \ + --benchmark-type ort \ + --audio-path models/whisper/test/1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v3 \ + --ort-model-path /workspace/onnxruntime/python/tools/transformers/whisperlargev3/whisper_large_v3_beamsearch.onnx \ + --precision fp32 \ + --device cuda > ort_output.txt ; \ + cat ort_output.txt ; \ + diff ort_output.txt /workspace/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt && exit 0 || exit 1 + popd ; \ + ' + displayName: 'Test Whisper ONNX Model' + workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 9b9dc9ecae822..c9038afc0954c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -16,15 +16,18 @@ ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} RUN apt-get update &&\ - apt-get install -y git bash wget + apt-get install -y git bash wget diffutils # Install python3 RUN apt-get install -y --no-install-recommends \ python3 \ python3-pip \ python3-dev \ - python3-wheel - + python3-wheel + +# Install ffmpeg, which couldn't be installed in UBI8 +# https://stackoverflow.com/questions/73597789/how-to-install-ffmpeg-on-ubi-docker-images +RUN apt-get install -y --no-install-recommends ffmpeg RUN pip install --upgrade pip From 430a086f22684ad0020819dc3e7712f36fe9f016 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Sun, 25 Feb 2024 08:50:45 -0800 Subject: [PATCH 060/237] fix memory mapping on Windows (#19623) ### Description Windows memory map casts mapped_offset to DWORD directly. It will be truncated if it is larger than 2^32-1. We need to set high dwFileOffsetHigh for this case. ### Motivation and Context The bug was found from #19450 --- onnxruntime/core/platform/windows/env.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 0eb34cbfbc9eb..983cc6089bb4c 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -459,8 +459,8 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), FILE_MAP_READ, - 0, - static_cast(mapped_offset), + static_cast((mapped_offset >> 32) & 0xFFFFFFFF), + static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); GSL_SUPPRESS(r.11) mapped_memory = From a9568935a52b3d51ec802a4ab89ab3852129fc1e Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Mon, 26 Feb 2024 11:35:13 -0800 Subject: [PATCH 061/237] [DML EP] Enable DML Graph Serialization (#19505) ### Description This PR adds a feature to serialize all DML EP partitions into DML currency individually for a given a model. This feature can be dynamically turned on by using DML EP option `ep.dml.enable_graph_serialization`. ### Motivation and Context - Why is this change required? What problem does it solve? Useful when user want to capture the DML EP specific partition into DML currency to mitigate the dependency on the framework. --- .../inc/IWinmlExecutionProvider.h | 7 +- .../DmlExecutionProvider/src/ApiTraits.cpp | 570 +++++++ .../src/DmlGraphDeserialization.cpp | 554 +++++++ .../src/DmlGraphFusionHelper.cpp | 247 ++- .../src/DmlGraphFusionHelper.h | 19 +- .../src/DmlGraphFusionTransformer.cpp | 41 +- .../src/DmlGraphFusionTransformer.h | 4 +- .../src/DmlGraphSerialization.cpp | 580 ++++++++ .../src/DmlRuntimeFusedGraphKernel.cpp | 30 +- .../src/External/DirectMLHelpers/ApiTraits.h | 453 +++++- .../External/DirectMLHelpers/DirectMLSchema.h | 112 +- .../DirectMLHelpers/DmlGraphDesc_generated.h | 788 ++++++++++ .../DirectMLHelpers/DmlGraphDeserialization.h | 14 + .../DirectMLHelpers/DmlGraphSerialization.h | 8 + .../DirectMLHelpers/DmlSerializedGraphDesc.h | 73 + .../DirectMLHelpers/GeneratedSchemaHelpers.h | 92 +- .../DirectMLHelpers/GeneratedSchemaTypes.h | 32 +- .../OperatorFieldTypes_generated.h | 1318 +++++++++++++++++ .../External/DirectMLHelpers/SchemaHelpers.h | 54 +- .../src/GraphDescBuilder.cpp | 404 ++--- .../src/GraphDescBuilder.h | 21 +- .../src/MLOperatorAuthorImpl.cpp | 30 +- .../src/Operators/DmlOperator.cpp | 4 +- .../src/Operators/DmlOperatorAttention.cpp | 2 +- .../src/Operators/DmlOperatorBiasAdd.cpp | 2 +- .../Operators/DmlOperatorBiasSplitGelu.cpp | 2 +- .../DmlOperatorEmbedLayerNormalization.cpp | 2 +- .../src/Operators/DmlOperatorGroupNorm.cpp | 2 +- .../DmlOperatorLayerNormalization.cpp | 2 +- .../Operators/DmlOperatorQLinearConcat.cpp | 2 +- .../Operators/DmlOperatorQLinearSigmoid.cpp | 2 +- .../src/Operators/DmlOperatorQuickGelu.cpp | 2 +- .../Operators/DmlOperatorRotaryEmbedding.cpp | 2 +- .../DmlOperatorSkipLayerNormalization.cpp | 2 +- .../dml/DmlExecutionProvider/src/Utility.h | 141 ++ .../dml/DmlExecutionProvider/src/precomp.h | 7 + .../MLOperatorAuthorPrivate.h | 11 +- .../dml/dml_session_options_config_keys.h | 1 + onnxruntime/core/session/inference_session.cc | 9 +- .../test/perftest/command_args_parser.cc | 1 + onnxruntime/test/perftest/ort_test_session.cc | 10 + 41 files changed, 5203 insertions(+), 454 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index f29cc3afc3cda..88e3dd487d427 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -80,15 +80,10 @@ namespace Windows::AI::MachineLearning::Adapter }; // This is the counterpart to the MLOperatorGraphDesc ABI struct which owns its memory and uses containers. - // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { uint32_t nodeCount = 0; - std::vector> nodesAsOperatorDesc; - - // TODO (jeffbloo): Remove this - std::vector> nodesAsIDMLOperator; - + std::vector> nodes; std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp new file mode 100644 index 0000000000000..bf9800458102b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp @@ -0,0 +1,570 @@ +//--------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// This file is automatically generated. Please do not edit it directly. +// To modify this file, edit the schema: dml/Tools/DirectMLSchema.json +// And run this script to regenerate: dml/Tools/GenerateSchema.ps1 +// +// #dml-new-operator-location +//--------------------------------------------------------------------------- + +#pragma once + +#include "precomp.h" + +template +T ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_DATA_TYPE_UNKNOWN", DML_TENSOR_DATA_TYPE_UNKNOWN}, + {"DML_TENSOR_DATA_TYPE_FLOAT32", DML_TENSOR_DATA_TYPE_FLOAT32}, + {"DML_TENSOR_DATA_TYPE_FLOAT16", DML_TENSOR_DATA_TYPE_FLOAT16}, + {"DML_TENSOR_DATA_TYPE_UINT32", DML_TENSOR_DATA_TYPE_UINT32}, + {"DML_TENSOR_DATA_TYPE_UINT16", DML_TENSOR_DATA_TYPE_UINT16}, + {"DML_TENSOR_DATA_TYPE_UINT8", DML_TENSOR_DATA_TYPE_UINT8}, + {"DML_TENSOR_DATA_TYPE_INT32", DML_TENSOR_DATA_TYPE_INT32}, + {"DML_TENSOR_DATA_TYPE_INT16", DML_TENSOR_DATA_TYPE_INT16}, + {"DML_TENSOR_DATA_TYPE_INT8", DML_TENSOR_DATA_TYPE_INT8}, + {"DML_TENSOR_DATA_TYPE_FLOAT64", DML_TENSOR_DATA_TYPE_FLOAT64}, + {"DML_TENSOR_DATA_TYPE_UINT64", DML_TENSOR_DATA_TYPE_UINT64}, + {"DML_TENSOR_DATA_TYPE_INT64", DML_TENSOR_DATA_TYPE_INT64}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_TENSOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_TYPE_INVALID", DML_TENSOR_TYPE_INVALID}, + {"DML_TENSOR_TYPE_BUFFER", DML_TENSOR_TYPE_BUFFER}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_OPERATOR_INVALID", DML_OPERATOR_INVALID}, + {"DML_OPERATOR_ELEMENT_WISE_IDENTITY", DML_OPERATOR_ELEMENT_WISE_IDENTITY}, + {"DML_OPERATOR_ELEMENT_WISE_ABS", DML_OPERATOR_ELEMENT_WISE_ABS}, + {"DML_OPERATOR_ELEMENT_WISE_ACOS", DML_OPERATOR_ELEMENT_WISE_ACOS}, + {"DML_OPERATOR_ELEMENT_WISE_ADD", DML_OPERATOR_ELEMENT_WISE_ADD}, + {"DML_OPERATOR_ELEMENT_WISE_ASIN", DML_OPERATOR_ELEMENT_WISE_ASIN}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN", DML_OPERATOR_ELEMENT_WISE_ATAN}, + {"DML_OPERATOR_ELEMENT_WISE_CEIL", DML_OPERATOR_ELEMENT_WISE_CEIL}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP", DML_OPERATOR_ELEMENT_WISE_CLIP}, + {"DML_OPERATOR_ELEMENT_WISE_COS", DML_OPERATOR_ELEMENT_WISE_COS}, + {"DML_OPERATOR_ELEMENT_WISE_DIVIDE", DML_OPERATOR_ELEMENT_WISE_DIVIDE}, + {"DML_OPERATOR_ELEMENT_WISE_EXP", DML_OPERATOR_ELEMENT_WISE_EXP}, + {"DML_OPERATOR_ELEMENT_WISE_FLOOR", DML_OPERATOR_ELEMENT_WISE_FLOOR}, + {"DML_OPERATOR_ELEMENT_WISE_LOG", DML_OPERATOR_ELEMENT_WISE_LOG}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_MAX", DML_OPERATOR_ELEMENT_WISE_MAX}, + {"DML_OPERATOR_ELEMENT_WISE_MEAN", DML_OPERATOR_ELEMENT_WISE_MEAN}, + {"DML_OPERATOR_ELEMENT_WISE_MIN", DML_OPERATOR_ELEMENT_WISE_MIN}, + {"DML_OPERATOR_ELEMENT_WISE_MULTIPLY", DML_OPERATOR_ELEMENT_WISE_MULTIPLY}, + {"DML_OPERATOR_ELEMENT_WISE_POW", DML_OPERATOR_ELEMENT_WISE_POW}, + {"DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW}, + {"DML_OPERATOR_ELEMENT_WISE_RECIP", DML_OPERATOR_ELEMENT_WISE_RECIP}, + {"DML_OPERATOR_ELEMENT_WISE_SIN", DML_OPERATOR_ELEMENT_WISE_SIN}, + {"DML_OPERATOR_ELEMENT_WISE_SQRT", DML_OPERATOR_ELEMENT_WISE_SQRT}, + {"DML_OPERATOR_ELEMENT_WISE_SUBTRACT", DML_OPERATOR_ELEMENT_WISE_SUBTRACT}, + {"DML_OPERATOR_ELEMENT_WISE_TAN", DML_OPERATOR_ELEMENT_WISE_TAN}, + {"DML_OPERATOR_ELEMENT_WISE_THRESHOLD", DML_OPERATOR_ELEMENT_WISE_THRESHOLD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR}, + {"DML_OPERATOR_ACTIVATION_ELU", DML_OPERATOR_ACTIVATION_ELU}, + {"DML_OPERATOR_ACTIVATION_CELU", DML_OPERATOR_ACTIVATION_CELU}, + {"DML_OPERATOR_ACTIVATION_HARDMAX", DML_OPERATOR_ACTIVATION_HARDMAX}, + {"DML_OPERATOR_ACTIVATION_HARDMAX1", DML_OPERATOR_ACTIVATION_HARDMAX1}, + {"DML_OPERATOR_ACTIVATION_HARD_SIGMOID", DML_OPERATOR_ACTIVATION_HARD_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_IDENTITY", DML_OPERATOR_ACTIVATION_IDENTITY}, + {"DML_OPERATOR_ACTIVATION_LEAKY_RELU", DML_OPERATOR_ACTIVATION_LEAKY_RELU}, + {"DML_OPERATOR_ACTIVATION_LINEAR", DML_OPERATOR_ACTIVATION_LINEAR}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU}, + {"DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_RELU", DML_OPERATOR_ACTIVATION_RELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_ELU", DML_OPERATOR_ACTIVATION_SCALED_ELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_TANH", DML_OPERATOR_ACTIVATION_SCALED_TANH}, + {"DML_OPERATOR_ACTIVATION_SIGMOID", DML_OPERATOR_ACTIVATION_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX", DML_OPERATOR_ACTIVATION_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX1", DML_OPERATOR_ACTIVATION_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_SOFTPLUS", DML_OPERATOR_ACTIVATION_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_SOFTSIGN", DML_OPERATOR_ACTIVATION_SOFTSIGN}, + {"DML_OPERATOR_ACTIVATION_TANH", DML_OPERATOR_ACTIVATION_TANH}, + {"DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU}, + {"DML_OPERATOR_CONVOLUTION", DML_OPERATOR_CONVOLUTION}, + {"DML_OPERATOR_GEMM", DML_OPERATOR_GEMM}, + {"DML_OPERATOR_REDUCE", DML_OPERATOR_REDUCE}, + {"DML_OPERATOR_AVERAGE_POOLING", DML_OPERATOR_AVERAGE_POOLING}, + {"DML_OPERATOR_AVERAGE_POOLING1", DML_OPERATOR_AVERAGE_POOLING1}, + {"DML_OPERATOR_LP_POOLING", DML_OPERATOR_LP_POOLING}, + {"DML_OPERATOR_LP_POOLING1", DML_OPERATOR_LP_POOLING1}, + {"DML_OPERATOR_MAX_POOLING", DML_OPERATOR_MAX_POOLING}, + {"DML_OPERATOR_ROI_POOLING", DML_OPERATOR_ROI_POOLING}, + {"DML_OPERATOR_SLICE", DML_OPERATOR_SLICE}, + {"DML_OPERATOR_CAST", DML_OPERATOR_CAST}, + {"DML_OPERATOR_SPLIT", DML_OPERATOR_SPLIT}, + {"DML_OPERATOR_JOIN", DML_OPERATOR_JOIN}, + {"DML_OPERATOR_PADDING", DML_OPERATOR_PADDING}, + {"DML_OPERATOR_PADDING1", DML_OPERATOR_PADDING1}, + {"DML_OPERATOR_VALUE_SCALE_2D", DML_OPERATOR_VALUE_SCALE_2D}, + {"DML_OPERATOR_UPSAMPLE_2D", DML_OPERATOR_UPSAMPLE_2D}, + {"DML_OPERATOR_GATHER", DML_OPERATOR_GATHER}, + {"DML_OPERATOR_SPACE_TO_DEPTH", DML_OPERATOR_SPACE_TO_DEPTH}, + {"DML_OPERATOR_DEPTH_TO_SPACE", DML_OPERATOR_DEPTH_TO_SPACE}, + {"DML_OPERATOR_TILE", DML_OPERATOR_TILE}, + {"DML_OPERATOR_TOP_K", DML_OPERATOR_TOP_K}, + {"DML_OPERATOR_BATCH_NORMALIZATION", DML_OPERATOR_BATCH_NORMALIZATION}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION}, + {"DML_OPERATOR_LP_NORMALIZATION", DML_OPERATOR_LP_NORMALIZATION}, + {"DML_OPERATOR_RNN", DML_OPERATOR_RNN}, + {"DML_OPERATOR_LSTM", DML_OPERATOR_LSTM}, + {"DML_OPERATOR_GRU", DML_OPERATOR_GRU}, + {"DML_OPERATOR_ELEMENT_WISE_SIGN", DML_OPERATOR_ELEMENT_WISE_SIGN}, + {"DML_OPERATOR_ELEMENT_WISE_IS_NAN", DML_OPERATOR_ELEMENT_WISE_IS_NAN}, + {"DML_OPERATOR_ELEMENT_WISE_ERF", DML_OPERATOR_ELEMENT_WISE_ERF}, + {"DML_OPERATOR_ELEMENT_WISE_SINH", DML_OPERATOR_ELEMENT_WISE_SINH}, + {"DML_OPERATOR_ELEMENT_WISE_COSH", DML_OPERATOR_ELEMENT_WISE_COSH}, + {"DML_OPERATOR_ELEMENT_WISE_TANH", DML_OPERATOR_ELEMENT_WISE_TANH}, + {"DML_OPERATOR_ELEMENT_WISE_ASINH", DML_OPERATOR_ELEMENT_WISE_ASINH}, + {"DML_OPERATOR_ELEMENT_WISE_ACOSH", DML_OPERATOR_ELEMENT_WISE_ACOSH}, + {"DML_OPERATOR_ELEMENT_WISE_ATANH", DML_OPERATOR_ELEMENT_WISE_ATANH}, + {"DML_OPERATOR_ELEMENT_WISE_IF", DML_OPERATOR_ELEMENT_WISE_IF}, + {"DML_OPERATOR_ELEMENT_WISE_ADD1", DML_OPERATOR_ELEMENT_WISE_ADD1}, + {"DML_OPERATOR_ACTIVATION_SHRINK", DML_OPERATOR_ACTIVATION_SHRINK}, + {"DML_OPERATOR_MAX_POOLING1", DML_OPERATOR_MAX_POOLING1}, + {"DML_OPERATOR_MAX_UNPOOLING", DML_OPERATOR_MAX_UNPOOLING}, + {"DML_OPERATOR_DIAGONAL_MATRIX", DML_OPERATOR_DIAGONAL_MATRIX}, + {"DML_OPERATOR_SCATTER", DML_OPERATOR_SCATTER}, + {"DML_OPERATOR_ONE_HOT", DML_OPERATOR_ONE_HOT}, + {"DML_OPERATOR_RESAMPLE", DML_OPERATOR_RESAMPLE}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT}, + {"DML_OPERATOR_ELEMENT_WISE_ROUND", DML_OPERATOR_ELEMENT_WISE_ROUND}, + {"DML_OPERATOR_ELEMENT_WISE_IS_INFINITY", DML_OPERATOR_ELEMENT_WISE_IS_INFINITY}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE", DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR", DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR}, + {"DML_OPERATOR_FILL_VALUE_SEQUENCE", DML_OPERATOR_FILL_VALUE_SEQUENCE}, + {"DML_OPERATOR_FILL_VALUE_CONSTANT", DML_OPERATOR_FILL_VALUE_CONSTANT}, + {"DML_OPERATOR_CUMULATIVE_SUMMATION", DML_OPERATOR_CUMULATIVE_SUMMATION}, + {"DML_OPERATOR_REVERSE_SUBSEQUENCES", DML_OPERATOR_REVERSE_SUBSEQUENCES}, + {"DML_OPERATOR_GATHER_ELEMENTS", DML_OPERATOR_GATHER_ELEMENTS}, + {"DML_OPERATOR_GATHER_ND", DML_OPERATOR_GATHER_ND}, + {"DML_OPERATOR_SCATTER_ND", DML_OPERATOR_SCATTER_ND}, + {"DML_OPERATOR_MAX_POOLING2", DML_OPERATOR_MAX_POOLING2}, + {"DML_OPERATOR_SLICE1", DML_OPERATOR_SLICE1}, + {"DML_OPERATOR_TOP_K1", DML_OPERATOR_TOP_K1}, + {"DML_OPERATOR_DEPTH_TO_SPACE1", DML_OPERATOR_DEPTH_TO_SPACE1}, + {"DML_OPERATOR_SPACE_TO_DEPTH1", DML_OPERATOR_SPACE_TO_DEPTH1}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1}, + {"DML_OPERATOR_RESAMPLE1", DML_OPERATOR_RESAMPLE1}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY}, + {"DML_OPERATOR_CONVOLUTION_INTEGER", DML_OPERATOR_CONVOLUTION_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_AND", DML_OPERATOR_ELEMENT_WISE_BIT_AND}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_OR", DML_OPERATOR_ELEMENT_WISE_BIT_OR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_XOR", DML_OPERATOR_ELEMENT_WISE_BIT_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_NOT", DML_OPERATOR_ELEMENT_WISE_BIT_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT", DML_OPERATOR_ELEMENT_WISE_BIT_COUNT}, + {"DML_OPERATOR_ACTIVATION_RELU_GRAD", DML_OPERATOR_ACTIVATION_RELU_GRAD}, + {"DML_OPERATOR_AVERAGE_POOLING_GRAD", DML_OPERATOR_AVERAGE_POOLING_GRAD}, + {"DML_OPERATOR_MAX_POOLING_GRAD", DML_OPERATOR_MAX_POOLING_GRAD}, + {"DML_OPERATOR_RANDOM_GENERATOR", DML_OPERATOR_RANDOM_GENERATOR}, + {"DML_OPERATOR_NONZERO_COORDINATES", DML_OPERATOR_NONZERO_COORDINATES}, + {"DML_OPERATOR_RESAMPLE_GRAD", DML_OPERATOR_RESAMPLE_GRAD}, + {"DML_OPERATOR_SLICE_GRAD", DML_OPERATOR_SLICE_GRAD}, + {"DML_OPERATOR_ADAM_OPTIMIZER", DML_OPERATOR_ADAM_OPTIMIZER}, + {"DML_OPERATOR_ARGMIN", DML_OPERATOR_ARGMIN}, + {"DML_OPERATOR_ARGMAX", DML_OPERATOR_ARGMAX}, + {"DML_OPERATOR_ROI_ALIGN", DML_OPERATOR_ROI_ALIGN}, + {"DML_OPERATOR_GATHER_ND1", DML_OPERATOR_GATHER_ND1}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN_YX", DML_OPERATOR_ELEMENT_WISE_ATAN_YX}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE", DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD}, + {"DML_OPERATOR_CUMULATIVE_PRODUCT", DML_OPERATOR_CUMULATIVE_PRODUCT}, + {"DML_OPERATOR_BATCH_NORMALIZATION_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_GRAD}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD}, + {"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ROI_ALIGN1", DML_OPERATOR_ROI_ALIGN1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP1", DML_OPERATOR_ELEMENT_WISE_CLIP1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1}, + {"DML_OPERATOR_ELEMENT_WISE_NEGATE", DML_OPERATOR_ELEMENT_WISE_NEGATE}, + {"DML_OPERATOR_ACTIVATION_GELU", DML_OPERATOR_ACTIVATION_GELU}, + {"DML_OPERATOR_ACTIVATION_SWISH", DML_OPERATOR_ACTIVATION_SWISH}, + {"DML_OPERATOR_ACTIVATION_HARD_SWISH", DML_OPERATOR_ACTIVATION_HARD_SWISH}, + {"DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2}, + {"DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1}, + {"DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1}, + {"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION}, + {"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_BINDING_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_BINDING_TYPE_NONE", DML_BINDING_TYPE_NONE}, + {"DML_BINDING_TYPE_BUFFER", DML_BINDING_TYPE_BUFFER}, + {"DML_BINDING_TYPE_BUFFER_ARRAY", DML_BINDING_TYPE_BUFFER_ARRAY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_REDUCE_FUNCTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_REDUCE_FUNCTION_ARGMAX", DML_REDUCE_FUNCTION_ARGMAX}, + {"DML_REDUCE_FUNCTION_ARGMIN", DML_REDUCE_FUNCTION_ARGMIN}, + {"DML_REDUCE_FUNCTION_AVERAGE", DML_REDUCE_FUNCTION_AVERAGE}, + {"DML_REDUCE_FUNCTION_L1", DML_REDUCE_FUNCTION_L1}, + {"DML_REDUCE_FUNCTION_L2", DML_REDUCE_FUNCTION_L2}, + {"DML_REDUCE_FUNCTION_LOG_SUM", DML_REDUCE_FUNCTION_LOG_SUM}, + {"DML_REDUCE_FUNCTION_LOG_SUM_EXP", DML_REDUCE_FUNCTION_LOG_SUM_EXP}, + {"DML_REDUCE_FUNCTION_MAX", DML_REDUCE_FUNCTION_MAX}, + {"DML_REDUCE_FUNCTION_MIN", DML_REDUCE_FUNCTION_MIN}, + {"DML_REDUCE_FUNCTION_MULTIPLY", DML_REDUCE_FUNCTION_MULTIPLY}, + {"DML_REDUCE_FUNCTION_SUM", DML_REDUCE_FUNCTION_SUM}, + {"DML_REDUCE_FUNCTION_SUM_SQUARE", DML_REDUCE_FUNCTION_SUM_SQUARE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_MATRIX_TRANSFORM ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MATRIX_TRANSFORM_NONE", DML_MATRIX_TRANSFORM_NONE}, + {"DML_MATRIX_TRANSFORM_TRANSPOSE", DML_MATRIX_TRANSFORM_TRANSPOSE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_MODE_CONVOLUTION", DML_CONVOLUTION_MODE_CONVOLUTION}, + {"DML_CONVOLUTION_MODE_CROSS_CORRELATION", DML_CONVOLUTION_MODE_CROSS_CORRELATION}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_DIRECTION_FORWARD", DML_CONVOLUTION_DIRECTION_FORWARD}, + {"DML_CONVOLUTION_DIRECTION_BACKWARD", DML_CONVOLUTION_DIRECTION_BACKWARD}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_PADDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_PADDING_MODE_CONSTANT", DML_PADDING_MODE_CONSTANT}, + {"DML_PADDING_MODE_EDGE", DML_PADDING_MODE_EDGE}, + {"DML_PADDING_MODE_REFLECTION", DML_PADDING_MODE_REFLECTION}, + {"DML_PADDING_MODE_SYMMETRIC", DML_PADDING_MODE_SYMMETRIC}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_INTERPOLATION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"DML_INTERPOLATION_MODE_LINEAR", DML_INTERPOLATION_MODE_LINEAR}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RECURRENT_NETWORK_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RECURRENT_NETWORK_DIRECTION_FORWARD", DML_RECURRENT_NETWORK_DIRECTION_FORWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BACKWARD", DML_RECURRENT_NETWORK_DIRECTION_BACKWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL", DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT", DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT}, + {"DML_FEATURE_FEATURE_LEVELS", DML_FEATURE_FEATURE_LEVELS}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_LEVEL_1_0", DML_FEATURE_LEVEL_1_0}, + {"DML_FEATURE_LEVEL_2_0", DML_FEATURE_LEVEL_2_0}, + {"DML_FEATURE_LEVEL_2_1", DML_FEATURE_LEVEL_2_1}, + {"DML_FEATURE_LEVEL_3_0", DML_FEATURE_LEVEL_3_0}, + {"DML_FEATURE_LEVEL_3_1", DML_FEATURE_LEVEL_3_1}, + {"DML_FEATURE_LEVEL_4_0", DML_FEATURE_LEVEL_4_0}, + {"DML_FEATURE_LEVEL_4_1", DML_FEATURE_LEVEL_4_1}, + {"DML_FEATURE_LEVEL_5_0", DML_FEATURE_LEVEL_5_0}, + {"DML_FEATURE_LEVEL_5_1", DML_FEATURE_LEVEL_5_1}, + {"DML_FEATURE_LEVEL_5_2", DML_FEATURE_LEVEL_5_2}, + {"DML_FEATURE_LEVEL_6_0", DML_FEATURE_LEVEL_6_0}, + {"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1}, + {"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_IS_INFINITY_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_IS_INFINITY_MODE_EITHER", DML_IS_INFINITY_MODE_EITHER}, + {"DML_IS_INFINITY_MODE_POSITIVE", DML_IS_INFINITY_MODE_POSITIVE}, + {"DML_IS_INFINITY_MODE_NEGATIVE", DML_IS_INFINITY_MODE_NEGATIVE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_DEPTH_SPACE_ORDER ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW}, + {"DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_AXIS_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_AXIS_DIRECTION_INCREASING", DML_AXIS_DIRECTION_INCREASING}, + {"DML_AXIS_DIRECTION_DECREASING", DML_AXIS_DIRECTION_DECREASING}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_ROUNDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN", DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN}, + {"DML_ROUNDING_MODE_TOWARD_ZERO", DML_ROUNDING_MODE_TOWARD_ZERO}, + {"DML_ROUNDING_MODE_TOWARD_INFINITY", DML_ROUNDING_MODE_TOWARD_INFINITY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RANDOM_GENERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10", DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_MULTIHEAD_ATTENTION_MASK_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE", DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN", DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp new file mode 100644 index 0000000000000..7d8ed17e7d925 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp @@ -0,0 +1,554 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc); + +OperatorFieldVariant CreateActivation( + const dml::ir::operatorFieldTypes::Activation* activationDesc) +{ + DML_OPERATOR_TYPE activationOperatorType = ApiTraits::StringifyHelpers::FromString(activationDesc->type()->c_str()); + const DML_OPERATOR_SCHEMA& activationSchema = SchemaHelpers::GetSchema(activationOperatorType); + std::vector activationOperatorFields(activationSchema.FieldCount); + uint32_t attributeIndex = 0; + + for (uint32_t fieldIndex = 0; fieldIndex < activationSchema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &activationSchema.Fields[fieldIndex]; + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + field = OperatorFieldTypes::TensorDesc(); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + field = OperatorFieldTypes::TensorDescArray(); + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= activationDesc->attributes()->size() ? + nullptr : + activationDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + activationOperatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&activationSchema, std::move(activationOperatorFields)); +} + +OperatorFieldVariant CreateActivations( + const dml::ir::operatorFieldTypes::ActivationArray* activationDescs) +{ + std::vector activations; + for (uint32_t index = 0; index < static_cast(activationDescs->data()->size()); index++) + { + OperatorFieldVariant activation = CreateActivation(activationDescs->data()->Get(index)); + activations.push_back(std::get(activation).value()); + } + return activations; +} + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc) +{ + switch (schemaField->Type) + { + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC: + { + return attributeDesc != nullptr && attributeDesc->val_as_Activation() != nullptr ? + CreateActivation(attributeDesc->val_as_Activation()) : + OperatorFieldTypes::FusedActivationOperatorDesc(); + } + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY: + { + return attributeDesc != nullptr && attributeDesc->val_as_ActivationArray() != nullptr ? + CreateActivations(attributeDesc->val_as_ActivationArray()) : + OperatorFieldTypes::FusedActivationOperatorDescArray(); + } + case DML_SCHEMA_FIELD_TYPE_UINT: + { + OperatorFieldTypes::UInt data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT64: + { + OperatorFieldTypes::UInt64 data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt64()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT: + { + OperatorFieldTypes::Int data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Int32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT: + { + OperatorFieldTypes::Float data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Float32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: + { + OperatorFieldTypes::UIntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_UIntArray()->data()->begin(), attributeDesc->val_as_UIntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT_ARRAY: + { + OperatorFieldTypes::IntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_IntArray()->data()->begin(), attributeDesc->val_as_IntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: + { + OperatorFieldTypes::FloatArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_FloatArray()->data()->begin(), attributeDesc->val_as_FloatArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS: + { + OperatorFieldTypes::ScaleBias scaleBias; + const dml::ir::operatorFieldTypes::ScaleBias* scaleBiasAttribute = attributeDesc->val_as_ScaleBias(); + if (scaleBiasAttribute != nullptr) + { + scaleBias = {scaleBiasAttribute->scale(), scaleBiasAttribute->bias()}; + } + return scaleBias; + } + case DML_SCHEMA_FIELD_TYPE_SIZE_2D: + { + OperatorFieldTypes::Size2D size2d = {}; + if (attributeDesc != nullptr) + { + size2d.Height = attributeDesc->val_as_Size2D()->height(); + size2d.Width = attributeDesc->val_as_Size2D()->width(); + } + return size2d; + } + case DML_SCHEMA_FIELD_TYPE_SCALAR_UNION: + { + DML_SCALAR_UNION scalarUnion; + if (attributeDesc != nullptr) + { + const dml::ir::operatorFieldTypes::ByteArray* byteArr = attributeDesc->val_as_ScalarUnionData()->data_as_ByteArray(); + std::copy(byteArr->data()->begin(), byteArr->data()->end(), scalarUnion.Bytes); + } + return scalarUnion; + } + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + OperatorFieldTypes::Bool data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Bool()->data(); + } + return data; + } + default: + { + throw std::invalid_argument("Invalid attribute type."); + } + } +} + +OperatorFieldTypes::TensorDesc CreateBufferTensorDesc( + const dml::ir::DmlBufferTensorDesc* tensorDesc, + const bool isConstantTensor = false) +{ + DmlBufferTensorDesc bufferTensorDesc = {}; + bufferTensorDesc.dataType = ApiTraits::StringifyHelpers::FromString(tensorDesc->dataType()->c_str()); + if (isConstantTensor) + { + bufferTensorDesc.flags = DML_TENSOR_FLAG_OWNED_BY_DML; + } + bufferTensorDesc.sizes.assign(tensorDesc->sizes()->begin(), tensorDesc->sizes()->end()); + if (flatbuffers::IsFieldPresent(tensorDesc, dml::ir::DmlBufferTensorDesc::VT_STRIDES)) + { + bufferTensorDesc.strides.emplace(tensorDesc->strides()->begin(), tensorDesc->strides()->end()); + } + bufferTensorDesc.totalTensorSizeInBytes = tensorDesc->totalTensorSizeInBytes(); + return bufferTensorDesc; +} + +AbstractOperatorDesc CreateAbstractOperatorDesc( + uint32_t nodeIndex, + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeInputNames, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeOutputNames, + const std::unordered_set& constantInputs) +{ + DML_OPERATOR_TYPE type = ApiTraits::StringifyHelpers::FromString(flatbufferOperatorNodeDesc->type()->c_str()); + if (type == DML_OPERATOR_INVALID) + { + throw std::invalid_argument("Graph operator node at index:" + std::to_string(nodeIndex) + + " either has empty or invalid operator type."); + } + const DML_OPERATOR_SCHEMA& schema = SchemaHelpers::GetSchema(type); + std::vector operatorFields(schema.FieldCount); + + auto inputNameItr = nodeInputNames->begin(); + uint32_t inputTensorDescIndex = 0; + + uint32_t outputTensorDescIndex = 0; + auto outputNameItr = nodeOutputNames->begin(); + + uint32_t attributeIndex = 0; + + + for (uint32_t fieldIndex = 0; fieldIndex < schema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &schema.Fields[fieldIndex]; + + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + { + if (inputNameItr == nodeInputNames->end()) + { + throw std::invalid_argument("Missing input names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + if (inputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc, isConstantTensor); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (inputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->inputs()->size())) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc, isConstantTensor).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (outputNameItr == nodeOutputNames->end()) + { + throw std::invalid_argument("Missing output names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* outputName = *outputNameItr; + outputNameItr++; + + if (outputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (outputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->outputs()->size())) + { + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + if (flatbufferOperatorNodeDesc->attributes()->size() <= attributeIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(attributeIndex + 1) + + "attributes for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= flatbufferOperatorNodeDesc->attributes()->size() ? + nullptr : + flatbufferOperatorNodeDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + operatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&schema, std::move(operatorFields)); +} + +std::unordered_map ConvertToEdgeNameToIndexMap( + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* list) +{ + std::unordered_map nameToIndexMap; + for (uint32_t index = 0; index < list->size(); index++) + { + const flatbuffers::String* name = list->GetAsString(index); + if (name->size() == 0) + { + continue; + } + nameToIndexMap[name->string_view()] = index; + } + return nameToIndexMap; // NRVO will automatically move it. no need to use std::move +} + +template void PopulateEdges( + const uint32_t nodeIndex, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* edgeNames, + const std::unordered_map& edgeNameToIndexMap, + /*out*/ std::vector& edges, + /*out*/ std::vector& intermediateEdges, + /*out*/ std::unordered_map& edgeToOutgoingNodeIndexMap) +{ + for (flatbuffers::uoffset_t edgeIndex = 0; edgeIndex < edgeNames->size(); edgeIndex++) + { + const flatbuffers::String* edgeName = edgeNames->Get(edgeIndex); + if (edgeName->size() == 0) + { + // This must be optional input/output + continue; + } + // edge can be graphInput or graphOutput + if (edgeNameToIndexMap.find(edgeName->string_view()) != edgeNameToIndexMap.end()) + { + EdgeType edge = {}; + edge.Name = edgeName->str(); + + if constexpr (std::is_same_v) + { + edge.GraphInputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.ToNodeIndex = nodeIndex; + edge.ToNodeInputIndex = edgeIndex; + } + else if constexpr (std::is_same_v) + { + edge.GraphOutputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.FromNodeIndex = nodeIndex; + edge.FromNodeOutputIndex = edgeIndex; + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + + edges.push_back(edge); + } + // edge is intermediate edge + else + { + if constexpr (std::is_same_v) + { + if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end()) + { + throw std::range_error("Neither there is any graph input with name " + edgeName->str() + + "nor there is any node which has " + edgeName->str() + " as one of the output."); + } + auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()]; + DmlIntermediateSerializedGraphEdge intermediateEdge = {}; + intermediateEdge.Name = edgeName->str(); + intermediateEdge.FromNodeIndex = intermediateEdgeNodeIndex.nodeIndex; + intermediateEdge.FromNodeOutputIndex = intermediateEdgeNodeIndex.nodeOutputIndex; + intermediateEdge.ToNodeIndex = nodeIndex; + intermediateEdge.ToNodeInputIndex = edgeIndex; + intermediateEdges.push_back(std::move(intermediateEdge)); + } + else if constexpr (std::is_same_v) + { + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + } + } +} + +/* +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData) +{ + if (flatbufferGraphDescBlob == nullptr) + { + throw std::invalid_argument("Given pointer to flatbuffer blob is null"); + } + const dml::ir::DmlGraphDesc* flatbufferGraphDesc = dml::ir::GetDmlGraphDesc(flatbufferGraphDescBlob); + + std::unordered_map graphInputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphInputNames()); + std::unordered_map graphOutputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphOutputNames()); + + std::unordered_map edgeToOutgoingNodeIndexMap; + std::unordered_set constantInputs; + + std::vector nodes(flatbufferGraphDesc->nodes()->size()); + std::vector inputEdges; + std::vector outputEdges; + std::vector intermediateEdges; + + for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++) + { + const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex); + + PopulateEdges( + nodeIndex, + flatbufferNode->inputNames(), + graphInputEdgeToIndexMap, + inputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + PopulateEdges( + nodeIndex, + flatbufferNode->outputNames(), + graphOutputEdgeToIndexMap, + outputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + + DmlSerializedGraphNode node = {}; + if (flatbufferNode->name()->size() == 0) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + " doesn't have any name"); + } + node.Name = flatbufferNode->name()->c_str(); + + if (flatbufferNode->desc_type() == dml::ir::NodeDesc_ConstantNodeDesc) + { + const dml::ir::ConstantNodeDesc* flatbufferConstantNode = flatbufferNode->desc_as_ConstantNodeDesc(); + if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantName) + { + if (flatbufferConstantNode->data_as_ConstantName()->name()->size() == 0) + { + throw std::invalid_argument("Constant node at index:" + std::to_string(nodeIndex) + + " doesn't have constant data name."); + } + + ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()}; + node.Desc = constantNode; + // output of this node will part of constantInputs list + for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++) + { + constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str()); + } + } + else if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData) + { + + uint32_t rawDataSize = flatbufferConstantNode->data_as_ConstantRawData()->data()->size(); + rawData.push_back(std::make_unique(rawDataSize)); + std::transform( + flatbufferConstantNode->data_as_ConstantRawData()->data()->begin(), + flatbufferConstantNode->data_as_ConstantRawData()->data()->end(), + rawData.back().get(), + [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {}; + constantData.dataSize = rawDataSize; + constantData.data = rawData.back().get(); + node.Desc = constantData; + } + + + } + else if (flatbufferNode->desc_type() == dml::ir::NodeDesc::NodeDesc_OperatorNodeDesc) + { + // convert dml::ir::OperatorNodeDesc to AbstractOperatorDesc + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc = flatbufferNode->desc_as_OperatorNodeDesc(); + node.Desc = CreateAbstractOperatorDesc( + nodeIndex, + flatbufferOperatorNodeDesc, + flatbufferNode->inputNames(), + flatbufferNode->outputNames(), + constantInputs); + } + + nodes[nodeIndex] = node; + } + + DmlSerializedGraphDesc graphDesc; + graphDesc.InputCount = flatbufferGraphDesc->graphInputNames()->size(); + graphDesc.OutputCount = flatbufferGraphDesc->graphOutputNames()->size(); + graphDesc.InputEdges = std::move(inputEdges); + graphDesc.IntermediateEdges = std::move(intermediateEdges); + graphDesc.OutputEdges = std::move(outputEdges); + graphDesc.Nodes = std::move(nodes); + return graphDesc; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 642d9aa03eeef..202b762d99e01 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -135,8 +135,10 @@ namespace DmlGraphFusionHelper void ProcessInputData( const ExecutionProviderImpl* providerImpl, + const bool graphSerializationEnabled, const std::vector& isInputsUploadedByDmlEP, - const std::vector& inputEdges, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, const gsl::span subGraphInputArgNames, const std::unordered_map>& initializerNameToInitializerMap, onnxruntime::Graph& graph, @@ -162,8 +164,17 @@ namespace DmlGraphFusionHelper // Walk through each graph edge and mark used inputs inputsUsed.assign(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : inputEdges) { - inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex->begin(); it != serializedGraphInputIndexToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex->begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + + std::wstring modelName; + if (graphSerializationEnabled) + { + modelName = GetModelName(graph.ModelPath()); } for (uint32_t i = 0; i < initInputBindings.size(); i++) @@ -209,6 +220,10 @@ namespace DmlGraphFusionHelper // Tensor sizes in DML must be a multiple of 4 bytes large. tensorByteSize = AlignToPow2(tensorByteSize, 4); + if(graphSerializationEnabled) + { + WriteToFile(modelName, ConvertToWString(iter->first) + L".bin", reinterpret_cast(tensorPtr), tensorByteSize); + } if (inputRawData) { @@ -287,55 +302,158 @@ namespace DmlGraphFusionHelper return initializerPartitionMap; } + inline uint32_t GetConstantNodeGraphInputIndex( + const std::string& constantName, + const std::unordered_map* serializedGraphConstantNameToMainGraphInputIndex, + uint32_t& graphMaxInputIndex, + std::unordered_map& localConstantNameToIndexMap) + { + if (serializedGraphConstantNameToMainGraphInputIndex == nullptr) + { + if (localConstantNameToIndexMap.find(constantName) == localConstantNameToIndexMap.end()) + { + localConstantNameToIndexMap[constantName] = ++graphMaxInputIndex; + } + return localConstantNameToIndexMap[constantName]; + } + else + { + graphMaxInputIndex = std::max(graphMaxInputIndex, serializedGraphConstantNameToMainGraphInputIndex->at(constantName)); + return serializedGraphConstantNameToMainGraphInputIndex->at(constantName); + } + } + + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, - _Inout_ std::vector& dmlConstantGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges) { - for (size_t i = 0; i < graphDesc.nodes.size(); ++i) + std::unordered_map oldNodeIndexToNewNodeIndexMap; + for (uint32_t index = 0; index < static_cast(graphDesc.Nodes.size()); index++) { - auto& nodeInfo = graphDesc.nodes[i]; - - if (std::holds_alternative>(nodeInfo.nodeDef)) + const DmlSerializedGraphNode& node = graphDesc.Nodes[index]; + if (std::holds_alternative(node.Desc)) { - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(std::get(node.Desc), &allocator); + ComPtr op; + ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); + dmlOperators.push_back(op); + DML_OPERATOR_GRAPH_NODE_DESC* dmlOperatorGraphNode = allocator.template Allocate(); + dmlOperatorGraphNode->Name = node.Name.data(); + dmlOperatorGraphNode->Operator = op.Get(); + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, dmlOperatorGraphNode}); } else { - auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); - dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ - nodeDefinitionData.data(), - nodeDefinitionData.size(), - nodeInfo.name.data() - }; - - // TODO: Change as new header is ingested - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + auto& constantNodeVariant = std::get(node.Desc); + if (std::holds_alternative(constantNodeVariant)) + { + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + + auto& constantData = std::get(constantNodeVariant); + + DML_CONSTANT_DATA_GRAPH_NODE_DESC* constantNode = allocator.template Allocate(); + constantNode->Name = node.Name.data(); + constantNode->DataSize = constantData.dataSize; + constantNode->Data = constantData.data; + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_CONSTANT, constantNode}); + } } } - for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) + uint32_t graphMaxInputIndex = 0; + + for (size_t i = 0; i < graphDesc.InputEdges.size(); ++i) { - dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + // 1. If serializedGraphInputIndexToMainGraphInputIndex is not null: + // then use the corresponding main graph input index, because the caller will use corresponding + // main graph input index for extracting the actual input tensor from the main graph and + // the caller does not own the creation of dml bindings directly. + // Use Case: When the caller is ORT (DML EP) or DmlEngine. + // + // 2. If serializedGraphInputIndexToMainGraphInputIndex is null: + // then assign the sequential graph input index, because it owns the creation of dml bindings + // directly. + edge->GraphInputIndex = serializedGraphInputIndexToSubgraphInputIndex == nullptr ? + graphDesc.InputEdges[i].GraphInputIndex : + serializedGraphInputIndexToSubgraphInputIndex->at(graphDesc.InputEdges[i].GraphInputIndex); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.InputEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.InputEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.InputEdges[i].Name.data(); + + graphMaxInputIndex = std::max(graphMaxInputIndex, edge->GraphInputIndex); + dmlInputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, edge}); } - for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) + for (size_t i = 0; i < graphDesc.OutputEdges.size(); ++i) { - dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; + DML_OUTPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphOutputIndex = graphDesc.OutputEdges[i].GraphOutputIndex; + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.OutputEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.OutputEdges[i].FromNodeOutputIndex; + edge->Name = graphDesc.OutputEdges[i].Name.data(); + + dmlOutputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, edge}); } - for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) + std::unordered_map localConstantNameToIndexMap; + for (uint32_t i = 0; i < static_cast(graphDesc.IntermediateEdges.size()); ++i) { - dmlIntermediateEdges[i] = - DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; + DmlSerializedGraphNodeDescVariant descVariant = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Desc; + bool isConstantEdge = std::holds_alternative(descVariant); + if (isConstantEdge) + { + auto& constantNodeVariant = std::get(descVariant); + if (std::holds_alternative(constantNodeVariant)) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } + else + { + const std::string& constantName = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Name; + + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphInputIndex = GetConstantNodeGraphInputIndex( + constantName, + serializedGraphLargeConstantNameToSubgraphInputIndex, + graphMaxInputIndex, + localConstantNameToIndexMap); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + + dmlInputEdges.push_back({DML_GRAPH_EDGE_TYPE_INPUT, edge}); + } + } + else + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } } dmlGraphDesc.InputCount = inputCount; @@ -400,27 +518,34 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl) + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator - DML_GRAPH_DESC dmlGraphDesc = {}; - std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); - std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - std::vector dmlGraphNodes(graphDesc.nodes.size()); - std::vector dmlInputEdges(graphDesc.inputEdges.size()); - std::vector dmlOutputEdges(graphDesc.outputEdges.size()); - std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); + StackAllocator<1024> allocator; + DML_GRAPH_DESC dmlGraphDesc = {}; + std::vector> dmlOperators; + std::vector dmlGraphNodes; + std::vector dmlInputEdges; + std::vector dmlOutputEdges; + std::vector dmlIntermediateEdges; ConvertGraphDesc( graphDesc, - dmlGraphDesc, fusedNodeInputCount, fusedNodeOutputCount, - dmlOperatorGraphNodes, - dmlConstantGraphNodes, + device.Get(), + allocator, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + dmlGraphDesc, + dmlOperators, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, @@ -438,8 +563,6 @@ namespace DmlGraphFusionHelper executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; } - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); ComPtr device1; ORT_THROW_IF_FAILED(device.As(&device1)); @@ -460,6 +583,7 @@ namespace DmlGraphFusionHelper } void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -467,8 +591,43 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator) + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { + if (graphSerializationEnabled) + { + + const std::wstring modelName = GetModelName(graph.ModelPath()); + auto buffer = SerializeDmlGraph(graphDesc); + + const std::wstring partitionName = + L"Partition_" + + std::to_wstring(partitionIndex) + + L".bin"; + WriteToFile(modelName, partitionName, buffer.data(), buffer.size()); + + std::vector> rawData; + DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData); + GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {}; + deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount; + deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges); + deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges); + deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes); + deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount; + deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges); + deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList; + deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes; + + compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + deserializedDmlGraphDesc, + indexedSubGraph, + providerImpl, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex); + } + auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); @@ -482,8 +641,10 @@ namespace DmlGraphFusionHelper std::vector inputsUsed; ProcessInputData( providerImpl, + graphSerializationEnabled, isInputsUploadedByDmlEP, - graphDesc.inputEdges, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, indexedSubGraph.GetMetaDef()->inputs, initializerNameToInitializerMap, graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index f8f6162aaa1e0..f1e9654021196 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -45,12 +45,17 @@ namespace DmlGraphFusionHelper gsl::span> partitions ); + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -69,9 +74,12 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl); + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex); void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -79,7 +87,10 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex = nullptr, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex = nullptr); void RegisterDynamicKernel( onnxruntime::Graph& graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 679738b639ec9..35a2c451a49a5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -24,15 +24,20 @@ namespace Dml std::vector isInputsUploadedByDmlEP; GraphDescBuilder::GraphDesc graphDesc; std::unordered_map> isInitializerTransferable; + std::vector> smallConstantData; // Need to keep it alive for maintaining lifetime + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; }; } DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ) :onnxruntime::GraphTransformer(name), - m_providerImpl(static_cast(provider)->GetImpl()) + m_providerImpl(static_cast(provider)->GetImpl()), + graphSerializationEnabled(graphSerializationEnabled) { } @@ -227,23 +232,39 @@ namespace Dml ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, partitionNodePropsMap, - device.Get(), m_providerImpl, modelPath, subgraphNodes, subgraphInputs, - subgraphOutputs); + subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, indexedSubGraph, - m_providerImpl); + m_providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); if (!compiledPartition) { @@ -264,6 +285,9 @@ namespace Dml compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP); compiledPartitionInfo->graphDesc = std::move(graphDesc); compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable); + compiledPartitionInfo->smallConstantData = std::move(smallConstantData); + compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex = std::move(serializedGraphInputIndexToSubgraphInputIndex); + compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex = std::move(serializedGraphLargeConstantNameToSubgraphInputIndex); compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo); } } @@ -271,12 +295,14 @@ namespace Dml } while (!additionalSplittingNodes.empty()); + uint32_t partitionIndex = 0; for (auto&& compiledPartitionInfo : compiledPartitionInfos) { // Null compiled operators were not DML partitions if (compiledPartitionInfo) { DmlGraphFusionHelper::FusePartitionAndRegisterKernel( + partitionIndex++, graph, m_providerImpl->GetKernelRegistry().get(), compiledPartitionInfo->isInitializerTransferable, @@ -284,7 +310,10 @@ namespace Dml compiledPartitionInfo->indexedSubGraph, std::move(compiledPartitionInfo->isInputsUploadedByDmlEP), compiledPartitionInfo->graphDesc, - compiledPartitionInfo->compiledOperator); + compiledPartitionInfo->compiledOperator, + graphSerializationEnabled, + &compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex, + &compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index 19dab0c89943c..b370f3ef9043c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -16,7 +16,8 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer public: DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ); public: @@ -38,5 +39,6 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer private: const ExecutionProviderImpl* m_providerImpl = nullptr; + const bool graphSerializationEnabled = false; }; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp new file mode 100644 index 0000000000000..5355964e8db74 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp @@ -0,0 +1,580 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +template +T* ReadAs(uint8_t* base, size_t byteOffset) +{ + return reinterpret_cast(base + byteOffset); +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs); + +flatbuffers::Offset serializeActivation( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& activationOperatorDesc) +{ + std::vector> attributeDescs; + SerializeAttributeDescs(builder, activationOperatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::operatorFieldTypes::CreateActivationDirect( + builder, + activationOperatorDesc.schema->OperatorName, + &attributeDescs); + return offset; +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs) +{ + for (const OperatorField& field : operatorDesc.fields) + { + if (field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_INPUT_TENSOR || + field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR) + { + continue; + } + + flatbuffers::Offset offset; + + if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDesc& fusedActivation = field.AsFusedActivationOperatorDesc(); + if (!fusedActivation.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation); + } + else + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation, + serializeActivation(builder, fusedActivation.value()).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDescArray& fusedActivations = + field.AsFusedActivationOperatorDescArray(); + if (!fusedActivations.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray); + } + else + { + std::vector> fbActivations; + + for (AbstractOperatorDesc activationOpDesc : fusedActivations.value()) + { + flatbuffers::Offset fbActivation = + serializeActivation(builder, activationOpDesc); + fbActivations.push_back(fbActivation); + } + + flatbuffers::Offset activationOffset = + dml::ir::operatorFieldTypes::CreateActivationArrayDirect(builder, &fbActivations); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray, + activationOffset.Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt32(field.AsUInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt64(field.AsUInt64())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Int32(field.AsInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Float32(field.AsFloat())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray, + dml::ir::operatorFieldTypes::CreateUIntArray(builder, builder.CreateVector(field.AsUIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray, + dml::ir::operatorFieldTypes::CreateIntArray(builder, builder.CreateVector(field.AsIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray, + dml::ir::operatorFieldTypes::CreateFloatArray(builder, builder.CreateVector(field.AsFloatArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::ScaleBias& scaleBias = field.AsScaleBias(); + if (!scaleBias.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias); + } + else + { + dml::ir::operatorFieldTypes::ScaleBias fbScaleBias(scaleBias.value().Scale, scaleBias.value().Bias); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias, + builder.CreateStruct(fbScaleBias).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const DML_SIZE_2D size2d = field.AsSize2D(); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D, + builder.CreateStruct(dml::ir::operatorFieldTypes::Size2D(size2d.Width, size2d.Height)).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + OperatorFieldTypes::ScalarUnion scalarUnion = field.AsScalarUnion(); + dml::ir::operatorFieldTypes::ByteArray byteArr; + for (uint32_t index = 0; index < static_cast(sizeof(scalarUnion.Bytes)); index++) + { + byteArr.mutable_data()->Mutate(index, scalarUnion.Bytes[index]); + } + + flatbuffers::Offset scalarUnionOffset = + dml::ir::operatorFieldTypes::CreateScalarUnionData( + builder, + dml::ir::operatorFieldTypes::ScalarVariant_ByteArray, + builder.CreateStruct(byteArr).Union()); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData, + scalarUnionOffset.Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool, + builder.CreateStruct(dml::ir::operatorFieldTypes::Bool(field.AsBool())).Union()); + } + else + { + continue; + } + + attributeDescs.push_back(offset); + } +} + +flatbuffers::Offset SerializeDmlTensorDesc( + flatbuffers::FlatBufferBuilder& builder, + const DmlBufferTensorDesc* tensorDesc) +{ + const std::vector *strides = nullptr; + if (tensorDesc->strides.has_value()) + { + strides = &tensorDesc->strides.value(); + } + + flatbuffers::Offset offset = dml::ir::CreateDmlBufferTensorDescDirect( + builder, + ApiTraits::StringifyHelpers::ToString(tensorDesc->dataType), + &tensorDesc->sizes, + strides, + tensorDesc->totalTensorSizeInBytes); + return offset; +} + +flatbuffers::Offset SerializeOperatorNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc) +{ + const DML_OPERATOR_SCHEMA* operatorSchema = operatorDesc.schema; + + std::vector> inputTensorDescs; + std::vector> outputTensorDescs; + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetInputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + inputTensorDescs.push_back(serializedDmlTensorDesc); + } + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetOutputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + outputTensorDescs.push_back(serializedDmlTensorDesc); + } + + std::vector> attributeDescs; + SerializeAttributeDescs(builder, operatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::CreateOperatorNodeDesc( + builder, + builder.CreateString(operatorSchema->OperatorName), + builder.CreateVector(inputTensorDescs), + builder.CreateVector(outputTensorDescs), + builder.CreateVector(attributeDescs)); + return offset.Union(); +} + +flatbuffers::Offset SerializeConstantNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + uint32_t nodeIndex, + const DmlSerializedGraphNodeConstantVariant& constantNodeDesc) +{ + flatbuffers::Offset offset; + + if (std::holds_alternative(constantNodeDesc)) + { + auto& constantName = std::get(constantNodeDesc); + if (constantName.name.empty()) + { + throw std::invalid_argument("Graph constant node at index:" + std::to_string(nodeIndex) + + " doesn't have the constant data name."); + } + + flatbuffers::Offset constantNameOffset = dml::ir::CreateConstantName( + builder, + builder.CreateString(constantName.name)); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantName, + constantNameOffset.Union()); + } + else + { + auto& constantData = std::get(constantNodeDesc); + std::vector rawBytes; + std::transform(constantData.data, constantData.data + constantData.dataSize, + std::back_inserter(rawBytes), [](std::byte b) {return static_cast(b); }); + flatbuffers::Offset constantDataOffset = dml::ir::CreateConstantRawDataDirect( + builder, + &rawBytes); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantRawData, + constantDataOffset.Union()); + } + + return offset.Union(); +} + +flatbuffers::Offset SerializeNode( + flatbuffers::FlatBufferBuilder& builder, + const uint32_t nodeIndex, + const DmlSerializedGraphNode& graphNode, + const std::vector>& nodeInputNames, + const std::vector>& nodeOutputNames) +{ + if (graphNode.Name.empty()) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + + " does not have any name."); + } + + flatbuffers::Offset offset; + if (std::holds_alternative(graphNode.Desc)) + { + auto& operatorNode = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_OperatorNodeDesc, + SerializeOperatorNodeDesc(builder, operatorNode), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + else + { + auto& constantNodeVariant = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_ConstantNodeDesc, + SerializeConstantNodeDesc(builder, nodeIndex, constantNodeVariant), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + return offset; +} + +/* +* validates input/output edges and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +template +std::unordered_map> ConvertToEdgeIndexToNameMap( + const std::vector& edges, + flatbuffers::FlatBufferBuilder& builder) +{ + std::unordered_map> edgeIndexToNameMap; + for (auto& edge : edges) + { + uint32_t index; + if constexpr (std::is_same_v) + { + index = edge.GraphInputIndex; + } + else if constexpr (std::is_same_v) + { + index = edge.GraphOutputIndex; + } + + if (edge.Name.empty()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " does not have name."); + } + + if (edgeIndexToNameMap.find(index) != edgeIndexToNameMap.end()) + { + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeIndexToNameMap[index].o); + if (edge.Name != edgeName->str()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " has more than 1 names."); + } + } + + edgeIndexToNameMap[index] = builder.CreateString(edge.Name); + } + return edgeIndexToNameMap; // NRVO will automatically move it. no need to use std::move +} + +void PopulateNonConstantNodeInputOutputCount( + const std::vector& nodes, + /*out*/ std::vector& nodeInputCounts, + /*out*/ std::vector& nodeOutputCounts) +{ + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(nodes.size()); nodeIndex++) + { + auto& node = nodes[nodeIndex]; + if (std::holds_alternative(node.Desc)) + { + auto& operatorNode = std::get(node.Desc); + nodeInputCounts[nodeIndex] = std::max( + nodeInputCounts[nodeIndex], + static_cast(operatorNode.GetInputTensors().size())); + + nodeOutputCounts[nodeIndex] = std::max( + nodeOutputCounts[nodeIndex], + static_cast(operatorNode.GetOutputTensors().size())); + } + } +} + +void PopulateConstantNodeInputOutputCount( + const std::vector& edges, + /*out*/std::vector& maxInputIndexForNodes, + /*out*/std::vector& maxOutputIndexForNodes) +{ + for (auto& edge : edges) + { + maxInputIndexForNodes[edge.ToNodeIndex] = std::max(maxInputIndexForNodes[edge.ToNodeIndex], edge.ToNodeInputIndex + 1); + maxOutputIndexForNodes[edge.FromNodeIndex] = std::max(maxOutputIndexForNodes[edge.FromNodeIndex], edge.FromNodeOutputIndex + 1); + } +} + +/* +* validates intermediate edge and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +void PopulateNodeInputOutputNames( + flatbuffers::FlatBufferBuilder& builder, + const DmlSerializedGraphDesc& graphDesc, + const std::unordered_map>& graphInputIndexToNameMap, + const std::unordered_map>& graphOutputIndexToNameMap, + /*out*/std::vector>>& nodeToInputNames, + /*out*/std::vector>>& nodeToOutputNames) +{ + for (auto& edge : graphDesc.InputEdges) + { + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = graphInputIndexToNameMap.at(edge.GraphInputIndex); + } + + for (auto& edge : graphDesc.OutputEdges) + { + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = graphOutputIndexToNameMap.at(edge.GraphOutputIndex); + } + + std::unordered_map>> intermediateEdgeNames; + for (uint32_t edgeIndex = 0; edgeIndex < static_cast(graphDesc.IntermediateEdges.size()); edgeIndex++) + { + auto& edge = graphDesc.IntermediateEdges[edgeIndex]; + if (edge.Name.empty()) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " doesn't have name."); + } + + if (intermediateEdgeNames.find(edge.FromNodeIndex) != intermediateEdgeNames.end() && + intermediateEdgeNames[edge.FromNodeIndex].find(edge.FromNodeOutputIndex) != intermediateEdgeNames[edge.FromNodeIndex].end()) + { + flatbuffers::Offset edgeNameOffset = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeNameOffset.o); + + if (edgeName->str() != edge.Name) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " has more than 1 names."); + } + } + else + { + intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = builder.CreateString(edge.Name.c_str()); + } + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + } +} + + +/* +* - If an edge is connected to multiple nodes, then there will be multiple instances +* of input or intermediate edges, all with the same name. +* - The input will be validated incrementally throughout the execution +* of the method. +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc) +{ + + flatbuffers::FlatBufferBuilder builder(1024); + if (graphDesc.Nodes.empty()) + { + return builder.Release(); + } + + // create input/output edge index to name map + std::unordered_map> graphInputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.InputEdges, builder); + std::unordered_map> graphOutputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.OutputEdges, builder); + + /* + * - Calculate number of input/output for each operator to allocate + * appropriate amount of memory for each node to store input/output names. + * - Non-constant node's input/output count can be determined by the + * AbstractOperatorDesc. + * - Constant node will only have outgoing edges and those outgoing edges + * will be intermediate edges. + */ + std::vector nodeInputCounts(graphDesc.Nodes.size(), 0); + std::vector nodeOutputCounts(graphDesc.Nodes.size(), 0); + PopulateNonConstantNodeInputOutputCount(graphDesc.Nodes, nodeInputCounts, nodeOutputCounts); + PopulateConstantNodeInputOutputCount(graphDesc.IntermediateEdges, nodeInputCounts, nodeOutputCounts); + + // populate node input/output names. + std::vector>> nodeToInputNames(graphDesc.Nodes.size()); + std::vector>> nodeToOutputNames(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodeToInputNames[nodeIndex].assign(nodeInputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + nodeToOutputNames[nodeIndex].assign(nodeOutputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + } + PopulateNodeInputOutputNames(builder, graphDesc, graphInputIndexToNameMap, graphOutputIndexToNameMap, nodeToInputNames, nodeToOutputNames); + + // Create flatbuffer node objects + std::vector> nodes(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodes[nodeIndex] = SerializeNode( + builder, + nodeIndex, + graphDesc.Nodes[nodeIndex], + nodeToInputNames[nodeIndex], + nodeToOutputNames[nodeIndex]); + } + + // Convert to std::vector to create the object. + std::vector> graphInputNames(graphDesc.InputCount, builder.CreateString(nullptr, 0)); + std::vector> graphOutputNames(graphDesc.OutputCount, builder.CreateString(nullptr, 0)); + for (const auto& [key, value] : graphInputIndexToNameMap) + { + graphInputNames[key] = value; + } + for (const auto& [key, value] : graphOutputIndexToNameMap) + { + graphOutputNames[key] = value; + } + + flatbuffers::Offset dmlGraphDescOffset = dml::ir::CreateDmlGraphDescDirect( + builder, + &nodes, + &graphInputNames, + &graphOutputNames); + builder.Finish(dmlGraphDescOffset); + return builder.Release(); +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 5c7b7bff1e370..0f0d445a95bae 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -180,32 +180,50 @@ namespace Dml // Convert partitionONNXGraph into DML EP GraphDesc ComPtr device; ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), m_isInitializerTransferable, m_partitionNodePropsMap, - device.Get(), providerImpl, m_modelPath, m_subgraphNodePointers, m_subgraphInputs, - m_subgraphOutputs); + m_subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); m_outputShapes = graphDesc.outputShapes; // Walk through each graph edge and mark used inputs m_inputsUsed.resize(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) - { - m_inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; } // Compile the operator m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, *m_indexedSubGraph, - providerImpl); + providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index a5415ba85f3d3..e1e7eacfbd85d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,8 +24,8 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 161; - static constexpr size_t ActivationFunctionCount = 24; + static constexpr auto ValueCount = 168; + static constexpr size_t ActivationFunctionCount = 26; }; template <> @@ -62,7 +62,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 4; + static constexpr auto ValueCount = 5; }; template <> @@ -86,7 +86,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 8; + static constexpr auto ValueCount = 13; }; template <> @@ -119,6 +119,12 @@ struct EnumTraits static constexpr auto ValueCount = 1; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 5; +}; + template constexpr auto EnumValueCount = EnumTraits::ValueCount; @@ -495,12 +501,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; -}; - template <> struct OperatorDescTraits { @@ -1029,6 +1029,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1174,9 +1192,15 @@ struct OperatorDescTraits }; template <> -struct OperatorDescTraits +struct OperatorDescTraits { - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SWISH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SWISH; }; template @@ -1502,12 +1526,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> -{ - using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; -}; - template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2036,6 +2054,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX1> using DescType = DML_DIAGONAL_MATRIX1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +{ + using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT> +{ + using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -2181,14 +2217,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU> }; template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SWISH> { - using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; + using DescType = DML_ACTIVATION_SWISH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH> +{ + using DescType = DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC; }; // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as // the first argument. -// +// // For example: // Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { // using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs @@ -2485,6 +2527,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MULTIHEAD_ATTENTION: return std::invoke(std::forward(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return std::invoke(std::forward(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2533,13 +2579,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); - -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: - return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); -#pragma warning(pop) - + case DML_OPERATOR_ACTIVATION_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SWISH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC{}, std::forward(args)...); default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); @@ -2547,7 +2590,55 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args } #pragma warning(pop) +namespace StringifyHelpers +{ +template +inline gsl::czstring ToString(T value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_DATA_TYPE value) +{ + switch (value) + { + case DML_TENSOR_DATA_TYPE_UNKNOWN: return "DML_TENSOR_DATA_TYPE_UNKNOWN"; + case DML_TENSOR_DATA_TYPE_FLOAT32: return "DML_TENSOR_DATA_TYPE_FLOAT32"; + case DML_TENSOR_DATA_TYPE_FLOAT16: return "DML_TENSOR_DATA_TYPE_FLOAT16"; + case DML_TENSOR_DATA_TYPE_UINT32: return "DML_TENSOR_DATA_TYPE_UINT32"; + case DML_TENSOR_DATA_TYPE_UINT16: return "DML_TENSOR_DATA_TYPE_UINT16"; + case DML_TENSOR_DATA_TYPE_UINT8: return "DML_TENSOR_DATA_TYPE_UINT8"; + case DML_TENSOR_DATA_TYPE_INT32: return "DML_TENSOR_DATA_TYPE_INT32"; + case DML_TENSOR_DATA_TYPE_INT16: return "DML_TENSOR_DATA_TYPE_INT16"; + case DML_TENSOR_DATA_TYPE_INT8: return "DML_TENSOR_DATA_TYPE_INT8"; + case DML_TENSOR_DATA_TYPE_FLOAT64: return "DML_TENSOR_DATA_TYPE_FLOAT64"; + case DML_TENSOR_DATA_TYPE_UINT64: return "DML_TENSOR_DATA_TYPE_UINT64"; + case DML_TENSOR_DATA_TYPE_INT64: return "DML_TENSOR_DATA_TYPE_INT64"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_TYPE value) +{ + switch (value) + { + case DML_TENSOR_TYPE_INVALID: return "DML_TENSOR_TYPE_INVALID"; + case DML_TENSOR_TYPE_BUFFER: return "DML_TENSOR_TYPE_BUFFER"; + default: + assert(false); + return ""; + } +} +template <> inline gsl::czstring ToString(DML_OPERATOR_TYPE value) { switch (value) @@ -2561,9 +2652,6 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; - case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; @@ -2587,24 +2675,41 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; - case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; - case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR"; case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR"; + case DML_OPERATOR_ACTIVATION_ELU: return "DML_OPERATOR_ACTIVATION_ELU"; + case DML_OPERATOR_ACTIVATION_CELU: return "DML_OPERATOR_ACTIVATION_CELU"; + case DML_OPERATOR_ACTIVATION_HARDMAX: return "DML_OPERATOR_ACTIVATION_HARDMAX"; + case DML_OPERATOR_ACTIVATION_HARDMAX1: return "DML_OPERATOR_ACTIVATION_HARDMAX1"; + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return "DML_OPERATOR_ACTIVATION_HARD_SIGMOID"; + case DML_OPERATOR_ACTIVATION_IDENTITY: return "DML_OPERATOR_ACTIVATION_IDENTITY"; + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return "DML_OPERATOR_ACTIVATION_LEAKY_RELU"; + case DML_OPERATOR_ACTIVATION_LINEAR: return "DML_OPERATOR_ACTIVATION_LINEAR"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU"; + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_RELU: return "DML_OPERATOR_ACTIVATION_RELU"; + case DML_OPERATOR_ACTIVATION_SCALED_ELU: return "DML_OPERATOR_ACTIVATION_SCALED_ELU"; + case DML_OPERATOR_ACTIVATION_SCALED_TANH: return "DML_OPERATOR_ACTIVATION_SCALED_TANH"; + case DML_OPERATOR_ACTIVATION_SIGMOID: return "DML_OPERATOR_ACTIVATION_SIGMOID"; + case DML_OPERATOR_ACTIVATION_SOFTMAX: return "DML_OPERATOR_ACTIVATION_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_SOFTSIGN: return "DML_OPERATOR_ACTIVATION_SOFTSIGN"; + case DML_OPERATOR_ACTIVATION_TANH: return "DML_OPERATOR_ACTIVATION_TANH"; + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU"; case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; - case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; - case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; - case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE"; case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST"; @@ -2620,18 +2725,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; - case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; - case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; @@ -2641,6 +2743,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH"; case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF"; case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1"; + case DML_OPERATOR_ACTIVATION_SHRINK: return "DML_OPERATOR_ACTIVATION_SHRINK"; + case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING"; case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX"; case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; @@ -2652,10 +2756,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; - case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; - case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; @@ -2684,20 +2787,278 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD"; case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD"; case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; + case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; + case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; - case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; - case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; + case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD"; - case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; + case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; + case DML_OPERATOR_ACTIVATION_GELU: return "DML_OPERATOR_ACTIVATION_GELU"; + case DML_OPERATOR_ACTIVATION_SWISH: return "DML_OPERATOR_ACTIVATION_SWISH"; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return "DML_OPERATOR_ACTIVATION_HARD_SWISH"; case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING"; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_BINDING_TYPE value) +{ + switch (value) + { + case DML_BINDING_TYPE_NONE: return "DML_BINDING_TYPE_NONE"; + case DML_BINDING_TYPE_BUFFER: return "DML_BINDING_TYPE_BUFFER"; + case DML_BINDING_TYPE_BUFFER_ARRAY: return "DML_BINDING_TYPE_BUFFER_ARRAY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_REDUCE_FUNCTION value) +{ + switch (value) + { + case DML_REDUCE_FUNCTION_ARGMAX: return "DML_REDUCE_FUNCTION_ARGMAX"; + case DML_REDUCE_FUNCTION_ARGMIN: return "DML_REDUCE_FUNCTION_ARGMIN"; + case DML_REDUCE_FUNCTION_AVERAGE: return "DML_REDUCE_FUNCTION_AVERAGE"; + case DML_REDUCE_FUNCTION_L1: return "DML_REDUCE_FUNCTION_L1"; + case DML_REDUCE_FUNCTION_L2: return "DML_REDUCE_FUNCTION_L2"; + case DML_REDUCE_FUNCTION_LOG_SUM: return "DML_REDUCE_FUNCTION_LOG_SUM"; + case DML_REDUCE_FUNCTION_LOG_SUM_EXP: return "DML_REDUCE_FUNCTION_LOG_SUM_EXP"; + case DML_REDUCE_FUNCTION_MAX: return "DML_REDUCE_FUNCTION_MAX"; + case DML_REDUCE_FUNCTION_MIN: return "DML_REDUCE_FUNCTION_MIN"; + case DML_REDUCE_FUNCTION_MULTIPLY: return "DML_REDUCE_FUNCTION_MULTIPLY"; + case DML_REDUCE_FUNCTION_SUM: return "DML_REDUCE_FUNCTION_SUM"; + case DML_REDUCE_FUNCTION_SUM_SQUARE: return "DML_REDUCE_FUNCTION_SUM_SQUARE"; default: assert(false); return ""; } } + +template <> +inline gsl::czstring ToString(DML_MATRIX_TRANSFORM value) +{ + switch (value) + { + case DML_MATRIX_TRANSFORM_NONE: return "DML_MATRIX_TRANSFORM_NONE"; + case DML_MATRIX_TRANSFORM_TRANSPOSE: return "DML_MATRIX_TRANSFORM_TRANSPOSE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_MODE value) +{ + switch (value) + { + case DML_CONVOLUTION_MODE_CONVOLUTION: return "DML_CONVOLUTION_MODE_CONVOLUTION"; + case DML_CONVOLUTION_MODE_CROSS_CORRELATION: return "DML_CONVOLUTION_MODE_CROSS_CORRELATION"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_DIRECTION value) +{ + switch (value) + { + case DML_CONVOLUTION_DIRECTION_FORWARD: return "DML_CONVOLUTION_DIRECTION_FORWARD"; + case DML_CONVOLUTION_DIRECTION_BACKWARD: return "DML_CONVOLUTION_DIRECTION_BACKWARD"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_PADDING_MODE value) +{ + switch (value) + { + case DML_PADDING_MODE_CONSTANT: return "DML_PADDING_MODE_CONSTANT"; + case DML_PADDING_MODE_EDGE: return "DML_PADDING_MODE_EDGE"; + case DML_PADDING_MODE_REFLECTION: return "DML_PADDING_MODE_REFLECTION"; + case DML_PADDING_MODE_SYMMETRIC: return "DML_PADDING_MODE_SYMMETRIC"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_INTERPOLATION_MODE value) +{ + switch (value) + { + case DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR: return "DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR"; + case DML_INTERPOLATION_MODE_LINEAR: return "DML_INTERPOLATION_MODE_LINEAR"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RECURRENT_NETWORK_DIRECTION value) +{ + switch (value) + { + case DML_RECURRENT_NETWORK_DIRECTION_FORWARD: return "DML_RECURRENT_NETWORK_DIRECTION_FORWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BACKWARD: return "DML_RECURRENT_NETWORK_DIRECTION_BACKWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL: return "DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE value) +{ + switch (value) + { + case DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT: return "DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT"; + case DML_FEATURE_FEATURE_LEVELS: return "DML_FEATURE_FEATURE_LEVELS"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE_LEVEL value) +{ + switch (value) + { + case DML_FEATURE_LEVEL_1_0: return "DML_FEATURE_LEVEL_1_0"; + case DML_FEATURE_LEVEL_2_0: return "DML_FEATURE_LEVEL_2_0"; + case DML_FEATURE_LEVEL_2_1: return "DML_FEATURE_LEVEL_2_1"; + case DML_FEATURE_LEVEL_3_0: return "DML_FEATURE_LEVEL_3_0"; + case DML_FEATURE_LEVEL_3_1: return "DML_FEATURE_LEVEL_3_1"; + case DML_FEATURE_LEVEL_4_0: return "DML_FEATURE_LEVEL_4_0"; + case DML_FEATURE_LEVEL_4_1: return "DML_FEATURE_LEVEL_4_1"; + case DML_FEATURE_LEVEL_5_0: return "DML_FEATURE_LEVEL_5_0"; + case DML_FEATURE_LEVEL_5_1: return "DML_FEATURE_LEVEL_5_1"; + case DML_FEATURE_LEVEL_5_2: return "DML_FEATURE_LEVEL_5_2"; + case DML_FEATURE_LEVEL_6_0: return "DML_FEATURE_LEVEL_6_0"; + case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1"; + case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_IS_INFINITY_MODE value) +{ + switch (value) + { + case DML_IS_INFINITY_MODE_EITHER: return "DML_IS_INFINITY_MODE_EITHER"; + case DML_IS_INFINITY_MODE_POSITIVE: return "DML_IS_INFINITY_MODE_POSITIVE"; + case DML_IS_INFINITY_MODE_NEGATIVE: return "DML_IS_INFINITY_MODE_NEGATIVE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_DEPTH_SPACE_ORDER value) +{ + switch (value) + { + case DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW: return "DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW"; + case DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH: return "DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_AXIS_DIRECTION value) +{ + switch (value) + { + case DML_AXIS_DIRECTION_INCREASING: return "DML_AXIS_DIRECTION_INCREASING"; + case DML_AXIS_DIRECTION_DECREASING: return "DML_AXIS_DIRECTION_DECREASING"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_ROUNDING_MODE value) +{ + switch (value) + { + case DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN: return "DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN"; + case DML_ROUNDING_MODE_TOWARD_ZERO: return "DML_ROUNDING_MODE_TOWARD_ZERO"; + case DML_ROUNDING_MODE_TOWARD_INFINITY: return "DML_ROUNDING_MODE_TOWARD_INFINITY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RANDOM_GENERATOR_TYPE value) +{ + switch (value) + { + case DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10: return "DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_MULTIHEAD_ATTENTION_MASK_TYPE value) +{ + switch (value) + { + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN"; + default: + assert(false); + return ""; + } +} + + +template +T FromString(std::string_view value); + +} } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 2a82c12872a72..5fe6603c2a0bf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -618,7 +618,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -633,7 +633,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -869,31 +869,6 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; - -constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", - static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 13, - DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1146,7 +1121,7 @@ constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, }; constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA { @@ -2312,7 +2287,7 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA { DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ +constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, @@ -2323,7 +2298,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2342,7 +2317,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2350,7 +2325,7 @@ constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ +constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false }, @@ -2359,7 +2334,7 @@ constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "DiagonalFillEnd", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA { "DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2396,6 +2371,48 @@ constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA { DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2732,6 +2749,35 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_GELU_OPERATOR_SCHEMA { DML_ACTIVATION_GELU_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SigmoidInputScale", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SWISH", + DML_OPERATOR_ACTIVATION_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARD_SWISH", + DML_OPERATOR_ACTIVATION_HARD_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_RNN_ZERO_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h new file mode 100644 index 0000000000000..72059b9a3f911 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h @@ -0,0 +1,788 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ +#define FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ + +#include "flatbuffers/flatbuffers.h" + +#include "OperatorFieldTypes_generated.h" + +namespace dml { +namespace ir { + +struct ConstantRawData; +struct ConstantRawDataBuilder; + +struct ConstantName; +struct ConstantNameBuilder; + +struct ConstantNodeDesc; +struct ConstantNodeDescBuilder; + +struct DmlBufferTensorDesc; +struct DmlBufferTensorDescBuilder; + +struct OperatorNodeDesc; +struct OperatorNodeDescBuilder; + +struct DmlGraphNode; +struct DmlGraphNodeBuilder; + +struct DmlGraphDesc; +struct DmlGraphDescBuilder; + +enum ConstantNodeDescDetail { + ConstantNodeDescDetail_NONE = 0, + ConstantNodeDescDetail_ConstantName = 1, + ConstantNodeDescDetail_ConstantRawData = 2, + ConstantNodeDescDetail_MIN = ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_MAX = ConstantNodeDescDetail_ConstantRawData +}; + +inline const ConstantNodeDescDetail (&EnumValuesConstantNodeDescDetail())[3] { + static const ConstantNodeDescDetail values[] = { + ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_ConstantName, + ConstantNodeDescDetail_ConstantRawData + }; + return values; +} + +inline const char * const *EnumNamesConstantNodeDescDetail() { + static const char * const names[4] = { + "NONE", + "ConstantName", + "ConstantRawData", + nullptr + }; + return names; +} + +inline const char *EnumNameConstantNodeDescDetail(ConstantNodeDescDetail e) { + if (flatbuffers::IsOutRange(e, ConstantNodeDescDetail_NONE, ConstantNodeDescDetail_ConstantRawData)) return ""; + const size_t index = static_cast(e); + return EnumNamesConstantNodeDescDetail()[index]; +} + +template struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_NONE; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantName; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantRawData; +}; + +bool VerifyConstantNodeDescDetail(flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type); +bool VerifyConstantNodeDescDetailVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum NodeDesc { + NodeDesc_NONE = 0, + NodeDesc_OperatorNodeDesc = 1, + NodeDesc_ConstantNodeDesc = 2, + NodeDesc_MIN = NodeDesc_NONE, + NodeDesc_MAX = NodeDesc_ConstantNodeDesc +}; + +inline const NodeDesc (&EnumValuesNodeDesc())[3] { + static const NodeDesc values[] = { + NodeDesc_NONE, + NodeDesc_OperatorNodeDesc, + NodeDesc_ConstantNodeDesc + }; + return values; +} + +inline const char * const *EnumNamesNodeDesc() { + static const char * const names[4] = { + "NONE", + "OperatorNodeDesc", + "ConstantNodeDesc", + nullptr + }; + return names; +} + +inline const char *EnumNameNodeDesc(NodeDesc e) { + if (flatbuffers::IsOutRange(e, NodeDesc_NONE, NodeDesc_ConstantNodeDesc)) return ""; + const size_t index = static_cast(e); + return EnumNamesNodeDesc()[index]; +} + +template struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_NONE; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_OperatorNodeDesc; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_ConstantNodeDesc; +}; + +bool VerifyNodeDesc(flatbuffers::Verifier &verifier, const void *obj, NodeDesc type); +bool VerifyNodeDescVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +struct ConstantRawData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantRawDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct ConstantRawDataBuilder { + typedef ConstantRawData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(ConstantRawData::VT_DATA, data); + } + explicit ConstantRawDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantRawDataBuilder &operator=(const ConstantRawDataBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantRawData( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + ConstantRawDataBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConstantRawDataDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::CreateConstantRawData( + _fbb, + data__); +} + +struct ConstantName FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantNameBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + verifier.EndTable(); + } +}; + +struct ConstantNameBuilder { + typedef ConstantName Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(ConstantName::VT_NAME, name); + } + explicit ConstantNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantNameBuilder &operator=(const ConstantNameBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantName( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0) { + ConstantNameBuilder builder_(_fbb); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConstantNameDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::CreateConstantName( + _fbb, + name__); +} + +struct ConstantNodeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::ConstantNodeDescDetail data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::ConstantName *data_as_ConstantName() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantName ? static_cast(data()) : nullptr; + } + const dml::ir::ConstantRawData *data_as_ConstantRawData() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData ? static_cast(data()) : nullptr; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE) && + VerifyOffset(verifier, VT_DATA) && + VerifyConstantNodeDescDetail(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::ConstantName *ConstantNodeDesc::data_as() const { + return data_as_ConstantName(); +} + +template<> inline const dml::ir::ConstantRawData *ConstantNodeDesc::data_as() const { + return data_as_ConstantRawData(); +} + +struct ConstantNodeDescBuilder { + typedef ConstantNodeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::ConstantNodeDescDetail data_type) { + fbb_.AddElement(ConstantNodeDesc::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(ConstantNodeDesc::VT_DATA, data); + } + explicit ConstantNodeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantNodeDescBuilder &operator=(const ConstantNodeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantNodeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::ConstantNodeDescDetail data_type = dml::ir::ConstantNodeDescDetail_NONE, + flatbuffers::Offset data = 0) { + ConstantNodeDescBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +struct DmlBufferTensorDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlBufferTensorDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATATYPE = 4, + VT_SIZES = 6, + VT_STRIDES = 8, + VT_TOTALTENSORSIZEINBYTES = 10 + }; + const flatbuffers::String *dataType() const { + return GetPointer(VT_DATATYPE); + } + const flatbuffers::Vector *sizes() const { + return GetPointer *>(VT_SIZES); + } + const flatbuffers::Vector *strides() const { + return GetPointer *>(VT_STRIDES); + } + uint64_t totalTensorSizeInBytes() const { + return GetField(VT_TOTALTENSORSIZEINBYTES, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATATYPE) && + verifier.VerifyString(dataType()) && + VerifyOffset(verifier, VT_SIZES) && + verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_TOTALTENSORSIZEINBYTES) && + verifier.EndTable(); + } +}; + +struct DmlBufferTensorDescBuilder { + typedef DmlBufferTensorDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_dataType(flatbuffers::Offset dataType) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_DATATYPE, dataType); + } + void add_sizes(flatbuffers::Offset> sizes) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_SIZES, sizes); + } + void add_strides(flatbuffers::Offset> strides) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_STRIDES, strides); + } + void add_totalTensorSizeInBytes(uint64_t totalTensorSizeInBytes) { + fbb_.AddElement(DmlBufferTensorDesc::VT_TOTALTENSORSIZEINBYTES, totalTensorSizeInBytes, 0); + } + explicit DmlBufferTensorDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlBufferTensorDescBuilder &operator=(const DmlBufferTensorDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlBufferTensorDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset dataType = 0, + flatbuffers::Offset> sizes = 0, + flatbuffers::Offset> strides = 0, + uint64_t totalTensorSizeInBytes = 0) { + DmlBufferTensorDescBuilder builder_(_fbb); + builder_.add_totalTensorSizeInBytes(totalTensorSizeInBytes); + builder_.add_strides(strides); + builder_.add_sizes(sizes); + builder_.add_dataType(dataType); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlBufferTensorDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *dataType = nullptr, + const std::vector *sizes = nullptr, + const std::vector *strides = nullptr, + uint64_t totalTensorSizeInBytes = 0) { + auto dataType__ = dataType ? _fbb.CreateString(dataType) : 0; + auto sizes__ = sizes ? _fbb.CreateVector(*sizes) : 0; + auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; + return dml::ir::CreateDmlBufferTensorDesc( + _fbb, + dataType__, + sizes__, + strides__, + totalTensorSizeInBytes); +} + +struct OperatorNodeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_ATTRIBUTES = 10 + }; + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + const flatbuffers::Vector> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const flatbuffers::Vector> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + const flatbuffers::Vector> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct OperatorNodeDescBuilder { + typedef OperatorNodeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(OperatorNodeDesc::VT_TYPE, type); + } + void add_inputs(flatbuffers::Offset>> inputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset>> outputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_OUTPUTS, outputs); + } + void add_attributes(flatbuffers::Offset>> attributes) { + fbb_.AddOffset(OperatorNodeDesc::VT_ATTRIBUTES, attributes); + } + explicit OperatorNodeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorNodeDescBuilder &operator=(const OperatorNodeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperatorNodeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type = 0, + flatbuffers::Offset>> inputs = 0, + flatbuffers::Offset>> outputs = 0, + flatbuffers::Offset>> attributes = 0) { + OperatorNodeDescBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorNodeDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector> *inputs = nullptr, + const std::vector> *outputs = nullptr, + const std::vector> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + auto attributes__ = attributes ? _fbb.CreateVector>(*attributes) : 0; + return dml::ir::CreateOperatorNodeDesc( + _fbb, + type__, + inputs__, + outputs__, + attributes__); +} + +struct DmlGraphNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlGraphNodeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DESC_TYPE = 4, + VT_DESC = 6, + VT_NAME = 8, + VT_INPUTNAMES = 10, + VT_OUTPUTNAMES = 12 + }; + dml::ir::NodeDesc desc_type() const { + return static_cast(GetField(VT_DESC_TYPE, 0)); + } + const void *desc() const { + return GetPointer(VT_DESC); + } + template const T *desc_as() const; + const dml::ir::OperatorNodeDesc *desc_as_OperatorNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_OperatorNodeDesc ? static_cast(desc()) : nullptr; + } + const dml::ir::ConstantNodeDesc *desc_as_ConstantNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_ConstantNodeDesc ? static_cast(desc()) : nullptr; + } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector> *inputNames() const { + return GetPointer> *>(VT_INPUTNAMES); + } + const flatbuffers::Vector> *outputNames() const { + return GetPointer> *>(VT_OUTPUTNAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DESC_TYPE) && + VerifyOffset(verifier, VT_DESC) && + VerifyNodeDesc(verifier, desc(), desc_type()) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_INPUTNAMES) && + verifier.VerifyVector(inputNames()) && + verifier.VerifyVectorOfStrings(inputNames()) && + VerifyOffset(verifier, VT_OUTPUTNAMES) && + verifier.VerifyVector(outputNames()) && + verifier.VerifyVectorOfStrings(outputNames()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::OperatorNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_OperatorNodeDesc(); +} + +template<> inline const dml::ir::ConstantNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_ConstantNodeDesc(); +} + +struct DmlGraphNodeBuilder { + typedef DmlGraphNode Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_desc_type(dml::ir::NodeDesc desc_type) { + fbb_.AddElement(DmlGraphNode::VT_DESC_TYPE, static_cast(desc_type), 0); + } + void add_desc(flatbuffers::Offset desc) { + fbb_.AddOffset(DmlGraphNode::VT_DESC, desc); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(DmlGraphNode::VT_NAME, name); + } + void add_inputNames(flatbuffers::Offset>> inputNames) { + fbb_.AddOffset(DmlGraphNode::VT_INPUTNAMES, inputNames); + } + void add_outputNames(flatbuffers::Offset>> outputNames) { + fbb_.AddOffset(DmlGraphNode::VT_OUTPUTNAMES, outputNames); + } + explicit DmlGraphNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlGraphNodeBuilder &operator=(const DmlGraphNodeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlGraphNode( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + flatbuffers::Offset desc = 0, + flatbuffers::Offset name = 0, + flatbuffers::Offset>> inputNames = 0, + flatbuffers::Offset>> outputNames = 0) { + DmlGraphNodeBuilder builder_(_fbb); + builder_.add_outputNames(outputNames); + builder_.add_inputNames(inputNames); + builder_.add_name(name); + builder_.add_desc(desc); + builder_.add_desc_type(desc_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlGraphNodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + flatbuffers::Offset desc = 0, + const char *name = nullptr, + const std::vector> *inputNames = nullptr, + const std::vector> *outputNames = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto inputNames__ = inputNames ? _fbb.CreateVector>(*inputNames) : 0; + auto outputNames__ = outputNames ? _fbb.CreateVector>(*outputNames) : 0; + return dml::ir::CreateDmlGraphNode( + _fbb, + desc_type, + desc, + name__, + inputNames__, + outputNames__); +} + +struct DmlGraphDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlGraphDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NODES = 4, + VT_GRAPHINPUTNAMES = 6, + VT_GRAPHOUTPUTNAMES = 8 + }; + const flatbuffers::Vector> *nodes() const { + return GetPointer> *>(VT_NODES); + } + const flatbuffers::Vector> *graphInputNames() const { + return GetPointer> *>(VT_GRAPHINPUTNAMES); + } + const flatbuffers::Vector> *graphOutputNames() const { + return GetPointer> *>(VT_GRAPHOUTPUTNAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NODES) && + verifier.VerifyVector(nodes()) && + verifier.VerifyVectorOfTables(nodes()) && + VerifyOffset(verifier, VT_GRAPHINPUTNAMES) && + verifier.VerifyVector(graphInputNames()) && + verifier.VerifyVectorOfStrings(graphInputNames()) && + VerifyOffset(verifier, VT_GRAPHOUTPUTNAMES) && + verifier.VerifyVector(graphOutputNames()) && + verifier.VerifyVectorOfStrings(graphOutputNames()) && + verifier.EndTable(); + } +}; + +struct DmlGraphDescBuilder { + typedef DmlGraphDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_nodes(flatbuffers::Offset>> nodes) { + fbb_.AddOffset(DmlGraphDesc::VT_NODES, nodes); + } + void add_graphInputNames(flatbuffers::Offset>> graphInputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHINPUTNAMES, graphInputNames); + } + void add_graphOutputNames(flatbuffers::Offset>> graphOutputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHOUTPUTNAMES, graphOutputNames); + } + explicit DmlGraphDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlGraphDescBuilder &operator=(const DmlGraphDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlGraphDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> nodes = 0, + flatbuffers::Offset>> graphInputNames = 0, + flatbuffers::Offset>> graphOutputNames = 0) { + DmlGraphDescBuilder builder_(_fbb); + builder_.add_graphOutputNames(graphOutputNames); + builder_.add_graphInputNames(graphInputNames); + builder_.add_nodes(nodes); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlGraphDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *nodes = nullptr, + const std::vector> *graphInputNames = nullptr, + const std::vector> *graphOutputNames = nullptr) { + auto nodes__ = nodes ? _fbb.CreateVector>(*nodes) : 0; + auto graphInputNames__ = graphInputNames ? _fbb.CreateVector>(*graphInputNames) : 0; + auto graphOutputNames__ = graphOutputNames ? _fbb.CreateVector>(*graphOutputNames) : 0; + return dml::ir::CreateDmlGraphDesc( + _fbb, + nodes__, + graphInputNames__, + graphOutputNames__); +} + +inline bool VerifyConstantNodeDescDetail(flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type) { + switch (type) { + case ConstantNodeDescDetail_NONE: { + return true; + } + case ConstantNodeDescDetail_ConstantName: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case ConstantNodeDescDetail_ConstantRawData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyConstantNodeDescDetailVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyConstantNodeDescDetail( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyNodeDesc(flatbuffers::Verifier &verifier, const void *obj, NodeDesc type) { + switch (type) { + case NodeDesc_NONE: { + return true; + } + case NodeDesc_OperatorNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case NodeDesc_ConstantNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyNodeDescVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyNodeDesc( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline const dml::ir::DmlGraphDesc *GetDmlGraphDesc(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const dml::ir::DmlGraphDesc *GetSizePrefixedDmlGraphDesc(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline bool VerifyDmlGraphDescBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedDmlGraphDescBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishDmlGraphDescBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedDmlGraphDescBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h new file mode 100644 index 0000000000000..9decf0dce1bb2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlSerializedGraphDesc.h" + +struct NodeIndex +{ + uint32_t nodeIndex; + uint32_t nodeOutputIndex; +}; + +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData); \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h new file mode 100644 index 0000000000000..d8d069da906b7 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlGraphDesc_generated.h" + +struct DmlSerializedGraphDesc; + +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h new file mode 100644 index 0000000000000..51c3d6c81244b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h @@ -0,0 +1,73 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------------- + +#pragma once + +struct ConstantName +{ + std::string name; +}; + +struct ConstantData +{ + std::byte* data; + uint64_t dataSize; +}; + +using DmlSerializedGraphNodeConstantVariant = std::variant< + ConstantName, + ConstantData +>; + +using DmlSerializedGraphNodeDescVariant = std::variant< + AbstractOperatorDesc, + DmlSerializedGraphNodeConstantVariant +>; + +struct DmlSerializedGraphNode +{ + DmlSerializedGraphNodeDescVariant Desc; + std::string Name; +}; + +struct DmlInputSerializedGraphEdge +{ + uint32_t GraphInputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlOutputSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t GraphOutputIndex; + std::string Name; +}; + +struct DmlIntermediateSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlSerializedGraphDesc +{ + uint32_t InputCount; + uint32_t OutputCount; + // nodes must be present in topological order for deserialization to work + // because while creating a intermediate edge during deserialization, node (from + // which given intermediate edge is outputting) must be visited before than the node + // (to which given intermediate edge is inputting) + std::vector Nodes; + std::vector InputEdges; + std::vector OutputEdges; + std::vector IntermediateEdges; +}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 99218c135f058..4be41ad3924a2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,7 +425,6 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } - inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) { return { @@ -502,24 +501,6 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } -inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), - }; -} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -1488,6 +1469,37 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast(desc.MaskType))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1680,6 +1692,23 @@ inline std::vector GetFields(const DML_ACTIVATION_GELU_OPERATOR_D OperatorField(&DML_ACTIVATION_GELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.SigmoidInputScale))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) { switch (operatorType) @@ -1826,6 +1855,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA; case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -1850,6 +1881,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_GELU: return DML_ACTIVATION_GELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SWISH: return DML_ACTIVATION_SWISH_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA; default: ORT_THROW_HR(E_INVALIDARG); @@ -2431,6 +2464,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, @@ -2527,13 +2568,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + case DML_OPERATOR_ACTIVATION_SWISH: return AbstractOperatorDesc( - &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); -#pragma warning(pop) + &DML_ACTIVATION_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); default: ORT_THROW_HR(E_INVALIDARG); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h index 25f0dd26c6067..a94bb67b68d36 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h @@ -15,32 +15,34 @@ using ApiAttributeVariant = std::variant< const FLOAT*, const DML_SCALE_BIAS*, DML_SIZE_2D, - DML_SCALAR_UNION + DML_SCALAR_UNION, + BOOL >; namespace OperatorFieldTypes { using TensorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC using TensorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY - using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC - using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY + using FusedActivationOperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC + using FusedActivationOperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT using UInt64 = uint64_t; // DML_SCHEMA_FIELD_TYPE_UINT64 using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT - using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY - using IntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY - using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY + using UIntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY + using IntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY + using FloatArray = std::vector; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D using ScalarUnion = DML_SCALAR_UNION; // DML_SCHEMA_FIELD_TYPE_SCALAR_UNION + using Bool = bool; // DML_SCHEMA_FIELD_TYPE_BOOL } using OperatorFieldVariant = std::variant< OperatorFieldTypes::TensorDesc, OperatorFieldTypes::TensorDescArray, - OperatorFieldTypes::OperatorDesc, - OperatorFieldTypes::OperatorDescArray, + OperatorFieldTypes::FusedActivationOperatorDesc, + OperatorFieldTypes::FusedActivationOperatorDescArray, OperatorFieldTypes::UInt, OperatorFieldTypes::UInt64, OperatorFieldTypes::Int, @@ -50,7 +52,8 @@ using OperatorFieldVariant = std::variant< OperatorFieldTypes::FloatArray, OperatorFieldTypes::ScaleBias, OperatorFieldTypes::Size2D, - OperatorFieldTypes::ScalarUnion + OperatorFieldTypes::ScalarUnion, + OperatorFieldTypes::Bool >; class OperatorField @@ -80,11 +83,11 @@ class OperatorField const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get(m_data); } OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() { return std::get(m_data); } const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } @@ -116,6 +119,9 @@ class OperatorField const OperatorFieldTypes::ScalarUnion& AsScalarUnion() const { return std::get(m_data); } OperatorFieldTypes::ScalarUnion& AsScalarUnion() { return std::get(m_data); } + const OperatorFieldTypes::Bool& AsBool() const { return std::get(m_data); } + OperatorFieldTypes::Bool& AsBool() { return std::get(m_data); } + private: const DML_SCHEMA_FIELD* m_schema; OperatorFieldVariant m_data; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h new file mode 100644 index 0000000000000..167a913bb0132 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h @@ -0,0 +1,1318 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ +#define FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace dml { +namespace ir { +namespace operatorFieldTypes { + +struct AttributeDesc; +struct AttributeDescBuilder; + +struct Activation; +struct ActivationBuilder; + +struct ActivationArray; +struct ActivationArrayBuilder; + +struct UInt8; + +struct UInt16; + +struct UInt32; + +struct UInt64; + +struct Int8; + +struct Int16; + +struct Int32; + +struct Int64; + +struct Float32; + +struct Float64; + +struct UIntArray; +struct UIntArrayBuilder; + +struct IntArray; +struct IntArrayBuilder; + +struct FloatArray; +struct FloatArrayBuilder; + +struct ScaleBias; + +struct Size2D; + +struct ByteArray; + +struct ScalarUnionData; +struct ScalarUnionDataBuilder; + +struct Bool; + +enum AttributeFieldVariant { + AttributeFieldVariant_NONE = 0, + AttributeFieldVariant_Activation = 1, + AttributeFieldVariant_ActivationArray = 2, + AttributeFieldVariant_UInt32 = 3, + AttributeFieldVariant_UInt64 = 4, + AttributeFieldVariant_Int32 = 5, + AttributeFieldVariant_Float32 = 6, + AttributeFieldVariant_UIntArray = 7, + AttributeFieldVariant_IntArray = 8, + AttributeFieldVariant_FloatArray = 9, + AttributeFieldVariant_ScaleBias = 10, + AttributeFieldVariant_Size2D = 11, + AttributeFieldVariant_ScalarUnionData = 12, + AttributeFieldVariant_Bool = 13, + AttributeFieldVariant_MIN = AttributeFieldVariant_NONE, + AttributeFieldVariant_MAX = AttributeFieldVariant_Bool +}; + +inline const AttributeFieldVariant (&EnumValuesAttributeFieldVariant())[14] { + static const AttributeFieldVariant values[] = { + AttributeFieldVariant_NONE, + AttributeFieldVariant_Activation, + AttributeFieldVariant_ActivationArray, + AttributeFieldVariant_UInt32, + AttributeFieldVariant_UInt64, + AttributeFieldVariant_Int32, + AttributeFieldVariant_Float32, + AttributeFieldVariant_UIntArray, + AttributeFieldVariant_IntArray, + AttributeFieldVariant_FloatArray, + AttributeFieldVariant_ScaleBias, + AttributeFieldVariant_Size2D, + AttributeFieldVariant_ScalarUnionData, + AttributeFieldVariant_Bool + }; + return values; +} + +inline const char * const *EnumNamesAttributeFieldVariant() { + static const char * const names[15] = { + "NONE", + "Activation", + "ActivationArray", + "UInt32", + "UInt64", + "Int32", + "Float32", + "UIntArray", + "IntArray", + "FloatArray", + "ScaleBias", + "Size2D", + "ScalarUnionData", + "Bool", + nullptr + }; + return names; +} + +inline const char *EnumNameAttributeFieldVariant(AttributeFieldVariant e) { + if (flatbuffers::IsOutRange(e, AttributeFieldVariant_NONE, AttributeFieldVariant_Bool)) return ""; + const size_t index = static_cast(e); + return EnumNamesAttributeFieldVariant()[index]; +} + +template struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_NONE; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Activation; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ActivationArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt64; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Int32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Float32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UIntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_IntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_FloatArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScaleBias; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Size2D; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScalarUnionData; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Bool; +}; + +bool VerifyAttributeFieldVariant(flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type); +bool VerifyAttributeFieldVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum ScalarVariant { + ScalarVariant_NONE = 0, + ScalarVariant_ByteArray = 1, + ScalarVariant_Int8 = 2, + ScalarVariant_UInt8 = 3, + ScalarVariant_Int16 = 4, + ScalarVariant_UInt16 = 5, + ScalarVariant_Int32 = 6, + ScalarVariant_UInt32 = 7, + ScalarVariant_Int64 = 8, + ScalarVariant_UInt64 = 9, + ScalarVariant_Float32 = 10, + ScalarVariant_Float64 = 11, + ScalarVariant_MIN = ScalarVariant_NONE, + ScalarVariant_MAX = ScalarVariant_Float64 +}; + +inline const ScalarVariant (&EnumValuesScalarVariant())[12] { + static const ScalarVariant values[] = { + ScalarVariant_NONE, + ScalarVariant_ByteArray, + ScalarVariant_Int8, + ScalarVariant_UInt8, + ScalarVariant_Int16, + ScalarVariant_UInt16, + ScalarVariant_Int32, + ScalarVariant_UInt32, + ScalarVariant_Int64, + ScalarVariant_UInt64, + ScalarVariant_Float32, + ScalarVariant_Float64 + }; + return values; +} + +inline const char * const *EnumNamesScalarVariant() { + static const char * const names[13] = { + "NONE", + "ByteArray", + "Int8", + "UInt8", + "Int16", + "UInt16", + "Int32", + "UInt32", + "Int64", + "UInt64", + "Float32", + "Float64", + nullptr + }; + return names; +} + +inline const char *EnumNameScalarVariant(ScalarVariant e) { + if (flatbuffers::IsOutRange(e, ScalarVariant_NONE, ScalarVariant_Float64)) return ""; + const size_t index = static_cast(e); + return EnumNamesScalarVariant()[index]; +} + +template struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_NONE; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_ByteArray; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float64; +}; + +bool VerifyScalarVariant(flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type); +bool VerifyScalarVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) UInt8 FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + UInt8() { + memset(static_cast(this), 0, sizeof(UInt8)); + } + UInt8(uint8_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint8_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint8_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) UInt16 FLATBUFFERS_FINAL_CLASS { + private: + uint16_t data_; + + public: + UInt16() { + memset(static_cast(this), 0, sizeof(UInt16)); + } + UInt16(uint16_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint16_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint16_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) UInt32 FLATBUFFERS_FINAL_CLASS { + private: + uint32_t data_; + + public: + UInt32() { + memset(static_cast(this), 0, sizeof(UInt32)); + } + UInt32(uint32_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint32_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint32_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) UInt64 FLATBUFFERS_FINAL_CLASS { + private: + uint64_t data_; + + public: + UInt64() { + memset(static_cast(this), 0, sizeof(UInt64)); + } + UInt64(uint64_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint64_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint64_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Int8 FLATBUFFERS_FINAL_CLASS { + private: + int8_t data_; + + public: + Int8() { + memset(static_cast(this), 0, sizeof(Int8)); + } + Int8(int8_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int8_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int8_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) Int16 FLATBUFFERS_FINAL_CLASS { + private: + int16_t data_; + + public: + Int16() { + memset(static_cast(this), 0, sizeof(Int16)); + } + Int16(int16_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int16_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int16_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Int32 FLATBUFFERS_FINAL_CLASS { + private: + int32_t data_; + + public: + Int32() { + memset(static_cast(this), 0, sizeof(Int32)); + } + Int32(int32_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int32_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int32_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Int64 FLATBUFFERS_FINAL_CLASS { + private: + int64_t data_; + + public: + Int64() { + memset(static_cast(this), 0, sizeof(Int64)); + } + Int64(int64_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int64_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int64_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Float32 FLATBUFFERS_FINAL_CLASS { + private: + float data_; + + public: + Float32() { + memset(static_cast(this), 0, sizeof(Float32)); + } + Float32(float _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + float data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(float _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Float64 FLATBUFFERS_FINAL_CLASS { + private: + double data_; + + public: + Float64() { + memset(static_cast(this), 0, sizeof(Float64)); + } + Float64(double _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + double data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(double _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) ScaleBias FLATBUFFERS_FINAL_CLASS { + private: + float scale_; + float bias_; + + public: + ScaleBias() { + memset(static_cast(this), 0, sizeof(ScaleBias)); + } + ScaleBias(float _scale, float _bias) + : scale_(flatbuffers::EndianScalar(_scale)), + bias_(flatbuffers::EndianScalar(_bias)) { + } + float scale() const { + return flatbuffers::EndianScalar(scale_); + } + void mutate_scale(float _scale) { + flatbuffers::WriteScalar(&scale_, _scale); + } + float bias() const { + return flatbuffers::EndianScalar(bias_); + } + void mutate_bias(float _bias) { + flatbuffers::WriteScalar(&bias_, _bias); + } +}; +FLATBUFFERS_STRUCT_END(ScaleBias, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Size2D FLATBUFFERS_FINAL_CLASS { + private: + uint32_t width_; + uint32_t height_; + + public: + Size2D() { + memset(static_cast(this), 0, sizeof(Size2D)); + } + Size2D(uint32_t _width, uint32_t _height) + : width_(flatbuffers::EndianScalar(_width)), + height_(flatbuffers::EndianScalar(_height)) { + } + uint32_t width() const { + return flatbuffers::EndianScalar(width_); + } + void mutate_width(uint32_t _width) { + flatbuffers::WriteScalar(&width_, _width); + } + uint32_t height() const { + return flatbuffers::EndianScalar(height_); + } + void mutate_height(uint32_t _height) { + flatbuffers::WriteScalar(&height_, _height); + } +}; +FLATBUFFERS_STRUCT_END(Size2D, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) ByteArray FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_[8]; + + public: + ByteArray() { + memset(static_cast(this), 0, sizeof(ByteArray)); + } + const flatbuffers::Array *data() const { + return reinterpret_cast *>(data_); + } + flatbuffers::Array *mutable_data() { + return reinterpret_cast *>(data_); + } +}; +FLATBUFFERS_STRUCT_END(ByteArray, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Bool FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + Bool() { + memset(static_cast(this), 0, sizeof(Bool)); + } + Bool(bool _data) + : data_(flatbuffers::EndianScalar(static_cast(_data))) { + } + bool data() const { + return flatbuffers::EndianScalar(data_) != 0; + } + void mutate_data(bool _data) { + flatbuffers::WriteScalar(&data_, static_cast(_data)); + } +}; +FLATBUFFERS_STRUCT_END(Bool, 1); + +struct AttributeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AttributeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_VAL_TYPE = 6, + VT_VAL = 8 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + flatbuffers::String *mutable_name() { + return GetPointer(VT_NAME); + } + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type() const { + return static_cast(GetField(VT_VAL_TYPE, 0)); + } + const void *val() const { + return GetPointer(VT_VAL); + } + template const T *val_as() const; + const dml::ir::operatorFieldTypes::Activation *val_as_Activation() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ActivationArray *val_as_ActivationArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *val_as_UInt32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *val_as_UInt64() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *val_as_Int32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *val_as_Float32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UIntArray *val_as_UIntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::IntArray *val_as_IntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::FloatArray *val_as_FloatArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScaleBias *val_as_ScaleBias() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Size2D *val_as_Size2D() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScalarUnionData *val_as_ScalarUnionData() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Bool *val_as_Bool() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool ? static_cast(val()) : nullptr; + } + void *mutable_val() { + return GetPointer(VT_VAL); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_VAL_TYPE) && + VerifyOffset(verifier, VT_VAL) && + VerifyAttributeFieldVariant(verifier, val(), val_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::Activation *AttributeDesc::val_as() const { + return val_as_Activation(); +} + +template<> inline const dml::ir::operatorFieldTypes::ActivationArray *AttributeDesc::val_as() const { + return val_as_ActivationArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *AttributeDesc::val_as() const { + return val_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *AttributeDesc::val_as() const { + return val_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *AttributeDesc::val_as() const { + return val_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *AttributeDesc::val_as() const { + return val_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UIntArray *AttributeDesc::val_as() const { + return val_as_UIntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::IntArray *AttributeDesc::val_as() const { + return val_as_IntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::FloatArray *AttributeDesc::val_as() const { + return val_as_FloatArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScaleBias *AttributeDesc::val_as() const { + return val_as_ScaleBias(); +} + +template<> inline const dml::ir::operatorFieldTypes::Size2D *AttributeDesc::val_as() const { + return val_as_Size2D(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScalarUnionData *AttributeDesc::val_as() const { + return val_as_ScalarUnionData(); +} + +template<> inline const dml::ir::operatorFieldTypes::Bool *AttributeDesc::val_as() const { + return val_as_Bool(); +} + +struct AttributeDescBuilder { + typedef AttributeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(AttributeDesc::VT_NAME, name); + } + void add_val_type(dml::ir::operatorFieldTypes::AttributeFieldVariant val_type) { + fbb_.AddElement(AttributeDesc::VT_VAL_TYPE, static_cast(val_type), 0); + } + void add_val(flatbuffers::Offset val) { + fbb_.AddOffset(AttributeDesc::VT_VAL, val); + } + explicit AttributeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AttributeDescBuilder &operator=(const AttributeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAttributeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + flatbuffers::Offset val = 0) { + AttributeDescBuilder builder_(_fbb); + builder_.add_val(val); + builder_.add_name(name); + builder_.add_val_type(val_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateAttributeDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + flatbuffers::Offset val = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::operatorFieldTypes::CreateAttributeDesc( + _fbb, + name__, + val_type, + val); +} + +struct Activation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ActivationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_ATTRIBUTES = 6 + }; + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + flatbuffers::String *mutable_type() { + return GetPointer(VT_TYPE); + } + const flatbuffers::Vector> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + flatbuffers::Vector> *mutable_attributes() { + return GetPointer> *>(VT_ATTRIBUTES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct ActivationBuilder { + typedef Activation Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(Activation::VT_TYPE, type); + } + void add_attributes(flatbuffers::Offset>> attributes) { + fbb_.AddOffset(Activation::VT_ATTRIBUTES, attributes); + } + explicit ActivationBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ActivationBuilder &operator=(const ActivationBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateActivation( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type = 0, + flatbuffers::Offset>> attributes = 0) { + ActivationBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateActivationDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto attributes__ = attributes ? _fbb.CreateVector>(*attributes) : 0; + return dml::ir::operatorFieldTypes::CreateActivation( + _fbb, + type__, + attributes__); +} + +struct ActivationArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ActivationArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector> *data() const { + return GetPointer> *>(VT_DATA); + } + flatbuffers::Vector> *mutable_data() { + return GetPointer> *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.VerifyVectorOfTables(data()) && + verifier.EndTable(); + } +}; + +struct ActivationArrayBuilder { + typedef ActivationArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset>> data) { + fbb_.AddOffset(ActivationArray::VT_DATA, data); + } + explicit ActivationArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ActivationArrayBuilder &operator=(const ActivationArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateActivationArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> data = 0) { + ActivationArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateActivationArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *data = nullptr) { + auto data__ = data ? _fbb.CreateVector>(*data) : 0; + return dml::ir::operatorFieldTypes::CreateActivationArray( + _fbb, + data__); +} + +struct UIntArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UIntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct UIntArrayBuilder { + typedef UIntArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(UIntArray::VT_DATA, data); + } + explicit UIntArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UIntArrayBuilder &operator=(const UIntArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUIntArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + UIntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateUIntArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateUIntArray( + _fbb, + data__); +} + +struct IntArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef IntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct IntArrayBuilder { + typedef IntArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(IntArray::VT_DATA, data); + } + explicit IntArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + IntArrayBuilder &operator=(const IntArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateIntArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + IntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateIntArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateIntArray( + _fbb, + data__); +} + +struct FloatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FloatArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct FloatArrayBuilder { + typedef FloatArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(FloatArray::VT_DATA, data); + } + explicit FloatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FloatArrayBuilder &operator=(const FloatArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFloatArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + FloatArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateFloatArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateFloatArray( + _fbb, + data__); +} + +struct ScalarUnionData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ScalarUnionDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::operatorFieldTypes::ScalarVariant data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::operatorFieldTypes::ByteArray *data_as_ByteArray() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_ByteArray ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int8 *data_as_Int8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt8 *data_as_UInt8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int16 *data_as_Int16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt16 *data_as_UInt16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *data_as_Int32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *data_as_UInt32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int64 *data_as_Int64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *data_as_UInt64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *data_as_Float32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float64 *data_as_Float64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float64 ? static_cast(data()) : nullptr; + } + void *mutable_data() { + return GetPointer(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE) && + VerifyOffset(verifier, VT_DATA) && + VerifyScalarVariant(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::ByteArray *ScalarUnionData::data_as() const { + return data_as_ByteArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int8 *ScalarUnionData::data_as() const { + return data_as_Int8(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt8 *ScalarUnionData::data_as() const { + return data_as_UInt8(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int16 *ScalarUnionData::data_as() const { + return data_as_Int16(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt16 *ScalarUnionData::data_as() const { + return data_as_UInt16(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *ScalarUnionData::data_as() const { + return data_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *ScalarUnionData::data_as() const { + return data_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int64 *ScalarUnionData::data_as() const { + return data_as_Int64(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *ScalarUnionData::data_as() const { + return data_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *ScalarUnionData::data_as() const { + return data_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float64 *ScalarUnionData::data_as() const { + return data_as_Float64(); +} + +struct ScalarUnionDataBuilder { + typedef ScalarUnionData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::operatorFieldTypes::ScalarVariant data_type) { + fbb_.AddElement(ScalarUnionData::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(ScalarUnionData::VT_DATA, data); + } + explicit ScalarUnionDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ScalarUnionDataBuilder &operator=(const ScalarUnionDataBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateScalarUnionData( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::operatorFieldTypes::ScalarVariant data_type = dml::ir::operatorFieldTypes::ScalarVariant_NONE, + flatbuffers::Offset data = 0) { + ScalarUnionDataBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +inline bool VerifyAttributeFieldVariant(flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type) { + switch (type) { + case AttributeFieldVariant_NONE: { + return true; + } + case AttributeFieldVariant_Activation: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ActivationArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_UInt32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_UInt64: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Int32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Float32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_UIntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_IntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_FloatArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ScaleBias: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Size2D: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_ScalarUnionData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_Bool: { + return verifier.Verify(static_cast(obj), 0); + } + default: return true; + } +} + +inline bool VerifyAttributeFieldVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyAttributeFieldVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyScalarVariant(flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type) { + switch (type) { + case ScalarVariant_NONE: { + return true; + } + case ScalarVariant_ByteArray: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int8: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt8: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int16: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt16: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int64: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt64: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Float32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Float64: { + return verifier.Verify(static_cast(obj), 0); + } + default: return true; + } +} + +inline bool VerifyScalarVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyScalarVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +} // namespace operatorFieldTypes +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h index 5285481485184..1bc694dfe90c2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h @@ -26,14 +26,14 @@ namespace SchemaHelpers return field; } - inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) + inline OperatorFieldTypes::FusedActivationOperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) { - return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; + return value ? OperatorFieldTypes::FusedActivationOperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; } - inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) + inline OperatorFieldTypes::FusedActivationOperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) { - OperatorFieldTypes::OperatorDescArray field; + OperatorFieldTypes::FusedActivationOperatorDescArray field; if (values && count != 0) { field.emplace(count); @@ -65,13 +65,17 @@ namespace SchemaHelpers return value; } + inline OperatorFieldTypes::Bool ToOperatorFieldType(bool value) + { + return value; + } + inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count) { OperatorFieldTypes::UIntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -81,8 +85,7 @@ namespace SchemaHelpers OperatorFieldTypes::IntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -92,8 +95,7 @@ namespace SchemaHelpers OperatorFieldTypes::FloatArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -237,7 +239,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* desc = nullptr; - const auto& value = field.AsOperatorDesc(); + const auto& value = field.AsFusedActivationOperatorDesc(); if (value) { desc = allocator->template Allocate(); @@ -251,7 +253,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* descs = nullptr; - const auto& values = field.AsOperatorDescArray(); + const auto& values = field.AsFusedActivationOperatorDescArray(); if (values) { descs = allocator->template Allocate(values->size()); @@ -288,16 +290,20 @@ namespace SchemaHelpers dst->Write(value); } break; + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + // OperatorFieldTypes::Bool is a 'bool' (1 byte) but written as 'BOOL' in op descs (4 bytes). + BOOL value = static_cast(field.AsBool()); + dst->Write(value); + } break; + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: { uint32_t* arrayPtr = nullptr; const auto& values = field.AsUIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -307,11 +313,8 @@ namespace SchemaHelpers int32_t* arrayPtr = nullptr; const auto& values = field.AsIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -321,11 +324,8 @@ namespace SchemaHelpers float* arrayPtr = nullptr; const auto& values = field.AsFloatArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 2456b396de3f6..e6f008af5c23f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -33,10 +33,10 @@ namespace Dml::GraphDescBuilder #pragma warning(pop) static void RemoveUnconnectedNodes( - std::vector& graphNodes, - std::vector& graphInputEdges, - std::vector& graphIntermediateEdges, - std::vector& graphOutputEdges) + std::vector& graphNodes, + std::vector& graphInputEdges, + std::vector& graphIntermediateEdges, + std::vector& graphOutputEdges) { enum class NodeState { @@ -52,7 +52,7 @@ namespace Dml::GraphDescBuilder }; std::vector nodesData(graphNodes.size()); - for (const DML_INTERMEDIATE_GRAPH_EDGE_DESC& intermediateEdge : graphIntermediateEdges) + for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphIntermediateEdges) { nodesData[intermediateEdge.ToNodeIndex].predecessorIndices.push_back(intermediateEdge.FromNodeIndex); } @@ -60,7 +60,7 @@ namespace Dml::GraphDescBuilder std::stack nodeIndicesToVisit; // Start from the outputs of the graph and traverse upwards - for (const DML_OUTPUT_GRAPH_EDGE_DESC& outputEdge : graphOutputEdges) + for (const DmlOutputSerializedGraphEdge& outputEdge : graphOutputEdges) { nodeIndicesToVisit.push(outputEdge.FromNodeIndex); } @@ -143,17 +143,44 @@ namespace Dml::GraphDescBuilder } } + + uint32_t SetAndGetDmlGraphNodeIndex( + const uint32_t operatorDmlGraphNodeIndex, + const std::string& nodeNamePrefix, + AbstractOperatorDesc& operatorDesc, + /*in_out*/std::unordered_map& operatorDmlGraphToDmlGraphNodeIndexMap, + /*in_out*/std::vector& dmlGraphNodes) + { + auto iter = operatorDmlGraphToDmlGraphNodeIndexMap.find(operatorDmlGraphNodeIndex); + if (iter != operatorDmlGraphToDmlGraphNodeIndexMap.end()) + { + return iter->second; + } + operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex] = static_cast(dmlGraphNodes.size()); + dmlGraphNodes.push_back({operatorDesc, nodeNamePrefix + std::to_string(operatorDmlGraphNodeIndex)}); + return operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex]; + } + + // Terminology: + // Subgraph: partitioned ONNX graph from the original (main) ONNX graph + // DmlGraph: a graph in DML currency converted from subgraph. + // operatorDmlGraph: a graph in DML currency for a given node or operator + // Main Points to note: + // - GraphDesc will always has sequential indices for input and intermediate edges. + // - 1 onnx node can be converted to one or more dml nodes. GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs) + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData) { struct NodeAndIndex { @@ -161,19 +188,34 @@ namespace Dml::GraphDescBuilder uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node) }; - // Map from Lotus node argument names to the new node and index where it will be produced - std::unordered_map nameToNodeAndIndexMap; - std::unordered_map nodeOutputShapes; - // Map from Lotus node argument names to input indices of the fused kernel node. - std::unordered_map nameToDmlFusedNodeInputIndex; + // Map from ORT subgraph input names to indices + std::unordered_map subgraphInputNameToIndexMap; + + // - Map from ORT node's output names to DmlGraph . + // - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph, + // then ORT node's output names will become output edges for the operatorDmlGraph. + // - This map will be populated for those output edges. + std::unordered_map dmlGraphNodeOutputNameToNodeAndIndexMap; + + // This map will be used to re-index an subGraphInputIndex to sequential input index + // for DmlGraph + std::unordered_map subGraphInputIndexToDmlGraphInputIndex; + + // Iterate through each node and create a corresponding node in the new graph + // We can iterate the nodes in any order because the edge connectivity will take care of the topological order + std::unordered_map> inferredOutputShapes; + + std::vector dmlGraphNodes; + std::vector dmlGraphInputEdges; + std::vector dmlGraphIntermediateEdges; + std::vector dmlGraphOutputEdges; for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; - - if (!graphInput) + const onnxruntime::NodeArg* subgraphInput = subgraphInputs[inputIndex]; + if (!subgraphInput) { // This is a workaround for when node inputs get manipulated by transformers outside of our control, // which then causes them to have a different name. If that happens we can't figure out how to @@ -181,45 +223,21 @@ namespace Dml::GraphDescBuilder // just bail early. ORT_THROW_HR(E_UNEXPECTED); } - - nameToDmlFusedNodeInputIndex.emplace(graphInput->Name(), gsl::narrow_cast(inputIndex)); - } - - StackAllocator<1024> allocator; // Used for converting abstract operator descs into DML_OPERATOR_DESC - - std::vector graphNodes; - std::vector graphInputEdges; - std::vector graphIntermediateEdges; - std::vector graphOutputEdges; - - // Avoid using separate command lists for small graphs. This value can be reduced by tuning the - // flushing behavior of DmlCommandRecorder. Its current behavior is to assume that graphs contain - // enough GPU work to be worth flushing immediately. - const uint32_t minNodeCountToReuseCommandList = 5; - bool reuseCommandList = false; - - if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) - { - reuseCommandList = true; + subgraphInputNameToIndexMap.emplace(subgraphInput->Name(), gsl::narrow_cast(inputIndex)); } auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; - auto iter = isInitializerTransferable.find(argName); if (iter != isInitializerTransferable.end()) { // Using const_cast here is simpler than making surrounding code const correct. tensorWrapper = wil::MakeOrThrow(const_cast(iter->second.first), modelPath); } - return tensorWrapper; }; - // Iterate through each node and create a corresponding node in the new graph - // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - std::unordered_map> inferredOutputShapes; for (const onnxruntime::Node* subgraphNode : subgraphNodes) { @@ -277,195 +295,206 @@ namespace Dml::GraphDescBuilder } EdgeShapes outputShapes; - DmlGraphNodeCreateInfo graphNodeCreateInfo; + DmlGraphNodeCreateInfo operatorDmlGraphCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, &inputShapesOverrides, /*out*/ &outputShapes, - /*out*/ &graphNodeCreateInfo + /*out*/ &operatorDmlGraphCreateInfo ); ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); for (int i = 0; i < node.OutputDefs().size(); ++i) { inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); - } - - // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. - std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; - uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); - const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; - size_t firstOpDescGraphNodeIndex = graphNodes.size(); - - if (isNodeAsOpDesc) + } + + // Algorithm: + // 1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it, + // because there would be an intermediate edge from the constantNode and source of the intermediate edge + // should come before the destination. + // 2. Again iterate through operatorDmlGraph's input edges to create mainGraph's input and intermediate edges. + // 3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges. + // 4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex + // 5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list. + + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) { - // Can't populate graphNodes vector at this point, because operatorDesc may get modified later. - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; + if (arg->Exists()) { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - } + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end() && + iter->second < isConstGpuGraphInputCount && + isConstGpuGraphInput[iter->second]) + { + DmlSerializedGraphNode constantNode = {}; + constantNode.Name = arg->Name(); + + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; + + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + { + constantInput = constantCpuGraphInputGetter(arg->Name()); + } - graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); - } - else - { - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) - { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); - graphNodes.push_back(std::move(nodeInfo)); + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + minimumConstantSize); + + smallConstantData.push_back(std::make_unique(tensorData.size())); + std::transform(tensorData.begin(), tensorData.end(), smallConstantData.back().get(), [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {smallConstantData.back().get(), tensorData.size()}; + constantNode.Desc = constantData; + } + else + { + ConstantName constantFileName = {GetSanitizedFileName(arg->Name())}; + constantNode.Desc = constantFileName; + } + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {static_cast(dmlGraphNodes.size()), 0}; + dmlGraphNodes.push_back(constantNode); + } } } - // map operatorGraphInputEdge as either mainGraphInputEdge or mainGraphIntermediateEdge - for (auto& operatorGraphInputEdge : graphNodeCreateInfo.inputEdges) - { - // operatorGraphInputEdge.GraphInputIndex will be the ONNX input index. - const onnxruntime::NodeArg* arg = node.InputDefs()[operatorGraphInputEdge.GraphInputIndex]; + // Create a map between operatorGraphNodeIndex to dmlGraphNodeIndex. + std::unordered_map operatorDmlGraphToDmlGraphNodeIndexMap; + // map operatorDmlGraphInputEdge as either mainDmlGraphInputEdge or mainDmlGraphIntermediateEdge + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) + { + // operatorDmlGraphInputEdge.GraphInputIndex will be the ONNX input index. + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; if (arg->Exists()) { - auto iter = nameToDmlFusedNodeInputIndex.find(arg->Name()); - uint32_t mainGraphNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphInputEdge.ToNodeIndex]; - - if (iter != nameToDmlFusedNodeInputIndex.end()) + uint32_t dmlGraphNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorDmlGraphInputEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end()) { - // This is a graph input - - const uint32_t dmlFusedNodeInputIndex = iter->second; - - // If this is a constant input, set the appropriate flags on the desc - if (isNodeAsOpDesc && - dmlFusedNodeInputIndex < isConstGpuGraphInputCount && - isConstGpuGraphInput[dmlFusedNodeInputIndex]) + const uint32_t subgraphInputIndex = iter->second; + + // Either this edge will be + // a constant input, then it will be an intermediate edge and + // set the OWNED_BY_DML flag if it is large constant + // or, + // a non-constant input, then it will be a mainDmlGraphInputEdge. + if (subgraphInputIndex < isConstGpuGraphInputCount && + isConstGpuGraphInput[subgraphInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently - // only used for small inputs. - uint32_t c_maxConstNodeDataSize = 8; - - - auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - ComPtr constantInput; - - if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) - { - constantInput = constantCpuGraphInputGetter(arg->Name()); - } - - if (constantInput) - { - // The tensor description's size should be no larger than the constant input unless it was rounded to - // the required alignment. - assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); - size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); - auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + minimumConstantSize); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(tensorData); - graphNodes.push_back(std::move(nodeInfo)); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = static_cast(graphNodes.size() - 1); - edge.FromNodeOutputIndex = 0; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); - } - else + const auto& constantNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); + auto& constantNodeVariant = std::get(dmlGraphNodes[constantNodeAndIndex.nodeIndex].Desc); + if (std::holds_alternative(constantNodeVariant)) { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); - + auto& mainDmlGraphNode = dmlGraphNodes[dmlGraphNodeIndex]; + AbstractOperatorDesc& abstractOperatorDesc = std::get(mainDmlGraphNode.Desc); + std::vector toNodeInputTensorDescs = abstractOperatorDesc.GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + serializedGraphLargeConstantNameToSubgraphInputIndex[arg->Name()] = subgraphInputIndex; } + + DmlIntermediateSerializedGraphEdge edge = {}; + edge.FromNodeIndex = constantNodeAndIndex.nodeIndex; + edge.FromNodeOutputIndex = constantNodeAndIndex.targetIndex; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name() + "-nodeIdx:" + std::to_string(edge.FromNodeIndex) + "-outputIdx:" + std::to_string(edge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } else { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); + DmlInputSerializedGraphEdge edge = {}; + if (subGraphInputIndexToDmlGraphInputIndex.find(subgraphInputIndex) == subGraphInputIndexToDmlGraphInputIndex.end()) + { + subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex] = static_cast(subGraphInputIndexToDmlGraphInputIndex.size()); + } + + edge.GraphInputIndex = subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex]; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex + edge.Name = arg->Name(); + + serializedGraphInputIndexToSubgraphInputIndex[edge.GraphInputIndex] = subgraphInputIndex; + dmlGraphInputEdges.push_back(edge); } } else { - const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name()); + const auto& inputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + DmlIntermediateSerializedGraphEdge edge = {}; edge.FromNodeIndex = inputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = inputNodeAndIndex.targetIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name(); + dmlGraphIntermediateEdges.push_back(edge); } } } // map operatorGraphIntermediateEdges as mainGraphIntermediateEdge - for (auto& operatorGraphIntermediateEdge : graphNodeCreateInfo.intermediateEdges) + for (auto& operatorGraphIntermediateEdge : operatorDmlGraphCreateInfo.intermediateEdges) { - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.FromNodeIndex]; + DmlIntermediateSerializedGraphEdge edge = {}; + uint32_t shiftedFromNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + uint32_t shiftedToNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + edge.FromNodeIndex = shiftedFromNodeIndex; edge.FromNodeOutputIndex = operatorGraphIntermediateEdge.FromNodeOutputIndex; - edge.ToNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.ToNodeIndex]; + edge.ToNodeIndex = shiftedToNodeIndex; edge.ToNodeInputIndex = operatorGraphIntermediateEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } - + // populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges - for (auto& operatorGraphOutputEdge : graphNodeCreateInfo.outputEdges) + for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges) { const onnxruntime::NodeArg* arg = node.OutputDefs()[operatorGraphOutputEdge.GraphOutputIndex]; if (arg->Exists()) { - nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex { - operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], - operatorGraphOutputEdge.FromNodeOutputIndex - }; - + uint32_t shiftedNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphOutputEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphOutputEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex}; nodeOutputShapes[arg->Name()] = outputShapes; } } - - if (isNodeAsOpDesc) - { - for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) - { - auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; - - DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) - dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) - dmlDesc.Type = (DML_OPERATOR_TYPE) 170; - - ComPtr op; - ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); - allocator.Reset(); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(op); - nodeInfo.name = node.Name(); - graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); - } - } } EdgeShapes graphOutputShapes(subgraphOutputs.size()); @@ -476,24 +505,27 @@ namespace Dml::GraphDescBuilder const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); - const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); + const auto& outputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(graphOutput->Name()); - DML_OUTPUT_GRAPH_EDGE_DESC edge = {}; + DmlOutputSerializedGraphEdge edge = {}; edge.FromNodeIndex = outputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); - graphOutputEdges.push_back(edge); + edge.Name = graphOutput->Name(); + dmlGraphOutputEdges.push_back(edge); graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } - RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); + RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges); GraphDesc graphDesc{}; - graphDesc.nodes = std::move(graphNodes); - graphDesc.inputEdges = std::move(graphInputEdges); - graphDesc.outputEdges = std::move(graphOutputEdges); - graphDesc.intermediateEdges = std::move(graphIntermediateEdges); - graphDesc.reuseCommandList = reuseCommandList; + graphDesc.InputCount = static_cast(dmlGraphInputEdges.size()); + graphDesc.OutputCount = static_cast(subgraphOutputs.size()); + graphDesc.Nodes = std::move(dmlGraphNodes); + graphDesc.InputEdges = std::move(dmlGraphInputEdges); + graphDesc.OutputEdges = std::move(dmlGraphOutputEdges); + graphDesc.IntermediateEdges = std::move(dmlGraphIntermediateEdges); + graphDesc.reuseCommandList = (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()); graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index c95e89b45541b..4055984b40405 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -22,22 +22,15 @@ namespace Dml namespace GraphDescBuilder { + constexpr uint32_t minNodeCountToReuseCommandList = 5; + constexpr uint32_t c_maxConstNodeDataSize = 8; + // Gets a unique name for the node which survives recreation and graph manipulations between the point // that graph partitioning occurs and kernel creation happens const std::string& GetUniqueNodeName(const onnxruntime::Node& node); - struct NodeInfo - { - std::variant, std::vector> nodeDef; - std::string name; - }; - - struct GraphDesc + struct GraphDesc : DmlSerializedGraphDesc { - std::vector nodes; - std::vector inputEdges; - std::vector outputEdges; - std::vector intermediateEdges; bool reuseCommandList; Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; @@ -47,11 +40,13 @@ namespace Dml const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs); + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index d524780de71b8..f29fbc7a1a65b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1508,31 +1508,17 @@ namespace Windows::AI::MachineLearning::Adapter ORT_TRY { assert(operatorGraphDesc != nullptr); - // Either nodesAsOpDesc or nodesIDMLOperator can be present. - assert(operatorGraphDesc->nodeCount == 0 || (!operatorGraphDesc->nodesAsOpDesc ^ !operatorGraphDesc->nodesAsIDMLOperator)); + assert(operatorGraphDesc->nodeCount == 0 || operatorGraphDesc->nodes); - if (operatorGraphDesc->nodesAsOpDesc) + m_graphNodeCreateInfo->nodes = std::vector>(); + for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) { - m_graphNodeCreateInfo->nodesAsOperatorDesc = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsOpDesc[nodeIndex]; - assert(node != nullptr); - AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); - m_graphNodeCreateInfo->nodesAsOperatorDesc.push_back(std::make_unique(std::move(abstractDesc))); - } - } - else - { - m_graphNodeCreateInfo->nodesAsIDMLOperator = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsIDMLOperator[nodeIndex]; - assert(node != nullptr); - m_graphNodeCreateInfo->nodesAsIDMLOperator.push_back(node); - } + auto* node = operatorGraphDesc->nodes[nodeIndex]; + assert(node != nullptr); + AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); + m_graphNodeCreateInfo->nodes.push_back(std::make_unique(std::move(abstractDesc))); } - + // There can be operators (or kernels) which don't require any input. assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr); m_graphNodeCreateInfo->inputEdges.insert( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index c3bb1a52210f5..287f1e5b6dfe7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -53,7 +53,7 @@ namespace Dml MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 1; const DML_OPERATOR_DESC* opDescs{&operatorDesc}; - operatorGraphDesc.nodesAsOpDesc = &opDescs; + operatorGraphDesc.nodes = &opDescs; std::vector inputEdges; for (uint32_t inputIndex = 0; inputIndex < m_kernelInputIndices.size(); inputIndex++) @@ -796,7 +796,7 @@ namespace Dml for (size_t i = 0; i < graphDesc.NodeCount; ++i) { // Create the operator. - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i]))); + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodes[i], IID_PPV_ARGS(&dmlOperators[i]))); dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()}; dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index c8ca6806e75f7..73c2d57e984af 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -531,7 +531,7 @@ class DmlOperatorAttention : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp index 1c851c94c4ddc..5aceebbdabfe3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp @@ -103,7 +103,7 @@ class DmlOperatorBiasAdd : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp index 501ce14f1fc08..1e10214ffd463 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp @@ -137,7 +137,7 @@ class DmlOperatorBiasSplitGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp index 6a8333cd72561..3c9458658c4d0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp @@ -484,7 +484,7 @@ class DmlOperatorEmbedLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp index fed0e4645ffd8..8b275fc550f3e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp @@ -287,7 +287,7 @@ class DmlOperatorGroupNorm : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 5c64059f7caa9..80e6fefc2fb80 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -247,7 +247,7 @@ class DmlOperatorLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index c97b03dc36b62..8727610ff3112 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -166,7 +166,7 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = static_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 35f926d62c92a..f658e7c7da323 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -113,7 +113,7 @@ class DmlOperatorQLinearSigmoid : public DmlOperator MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 3; std::vector opDescs{&opDesc1, &opDesc2, &opDesc3}; - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); // set input edges std::pair nodeToNodeInputIndex[5] {{0, 0}, {0, 1}, {0, 2}, {2, 1}, {2, 2}}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp index 3683ab7b0b0b3..e62b7d707ba78 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp @@ -123,7 +123,7 @@ class DmlOperatorQuickGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 44004b5d77f70..0f15ebf342b3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -441,7 +441,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp index 4dafd78f21ea8..094c45a0e38e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp @@ -198,7 +198,7 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h new file mode 100644 index 0000000000000..02166f992449e --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include + + +namespace Dml +{ + static inline std::wstring ConvertToWString(std::string_view str) + { + std::wstring_convert,wchar_t> g_converterToUtf16; + return g_converterToUtf16.from_bytes(str.data()); + } + + static inline std::wstring GetModelName(const onnxruntime::Path& modelPath) + { + if (modelPath.GetComponents().empty()) + { + return L""; + } + + const onnxruntime::PathString& pathString = modelPath.GetComponents().back(); + size_t dotPosition = pathString.find_last_of('.'); + if (dotPosition == std::string::npos) + { + return L""; + } + + return pathString.substr(0, dotPosition); + } + + static inline std::wstring GetSanitizedFileName(std::wstring_view name) + { + std::wstring newName(name); + for (wchar_t& c : newName) + { + switch (c) + { + case '\\': + case '/': + case '\"': + case '|': + case '<': + case '>': + case ':': + case '?': + case '*': + c = '_'; + break; + } + } + return newName; + } + + static inline std::string GetSanitizedFileName(std::string_view name) + { + std::string newName(name); + for (char& c : newName) + { + switch (c) + { + case '\\': + case '/': + case '\"': + case '|': + case '<': + case '>': + case ':': + case '?': + case '*': + c = '_'; + break; + } + } + return newName; + } + + static inline void WriteToFile(std::wstring_view directoryName, std::wstring_view fileName, std::uint8_t* data, size_t dataSize) + { + std::wstring sanitizedFileName = GetSanitizedFileName(fileName); + std::filesystem::create_directory(directoryName); + std::wstring fullSanitizedFileName = std::wstring(directoryName) + + (directoryName.empty() ? L"" : L"/") + + sanitizedFileName; + std::ofstream file(fullSanitizedFileName, std::ios::binary); + if (!file.is_open()) + { + std::wstring_convert,wchar_t> g_converterToUtf16; + std::stringstream errorMessage; + errorMessage << "File named: " << g_converterToUtf16.to_bytes(fileName.data()) << " could not be opened\n"; + throw std::ios::failure(errorMessage.str()); + } + file.write(reinterpret_cast(data), dataSize); + } + +} + +namespace StringUtil +{ + struct NameAndIndex + { + const char* name; // Null terminated. + uint32_t index; + }; + + struct WideNameAndIndex + { + const wchar_t* name; // Null terminated. + uint32_t index; + }; + + inline std::optional MapToIndex(std::string_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + return {}; + } + + inline std::optional MapToIndex(std::wstring_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (wcsncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + return {}; + } +} \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h index 83737d2ba4848..332bf86685e8a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include @@ -37,6 +39,7 @@ #include #include "External/D3DX12/d3dx12.h" #endif +#include "flatbuffers/flatbuffers.h" #include "GraphicsUnknownHelper.h" @@ -53,6 +56,9 @@ #include "External/DirectMLHelpers/SchemaHelpers.h" #include "External/DirectMLHelpers/GeneratedSchemaHelpers.h" #include "External/DirectMLHelpers/DirectMLX.h" +#include "External/DirectMLHelpers/DmlSerializedGraphDesc.h" +#include "External/DirectMLHelpers/DmlGraphSerialization.h" +#include "External/DirectMLHelpers/DmlGraphDeserialization.h" using Microsoft::WRL::ComPtr; @@ -67,3 +73,4 @@ using Microsoft::WRL::ComPtr; #include "TensorDesc.h" #include "DescriptorPool.h" #include "IExecutionProvider.h" +#include "Utility.h" \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 3bec8d3864cba..ac3a3eb1268b8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -10,18 +10,11 @@ struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; -// Either nodesAsOpDesc or nodesAsIDMLOperator is present. -// 1) Operator kernels which implement operators using only a single DML operator will pass a DML_OPERATOR_DESC. -// These kernels pass DML_OPERATOR_DESC, because while building Dml graph (inside FusedGraphKernel.cpp) we can change the -// the flag of constant inputs to DML_TENSOR_FLAG_OWNED_BY_DML. -// 2) Operator kernels which implement operators using DMLX graph, they will pass IDMLOperator and won't be able -// to use DML_TENSOR_FLAG_OWNED_BY_DML. struct MLOperatorGraphDesc { uint32_t nodeCount; - _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodesAsOpDesc; - _Field_size_opt_(nodeCount) IDMLOperator** nodesAsIDMLOperator; - + _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodes; + uint32_t inputEdgeCount; _Field_size_(inputEdgeCount) const DML_INPUT_GRAPH_EDGE_DESC* inputEdges; diff --git a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h index d11fa7516e713..5b5f371f51616 100644 --- a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h +++ b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h @@ -21,3 +21,4 @@ // "1": disabled (disallowed). Graph fusion will never be used. // The default value is "0" static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml.disable_graph_fusion"; +static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization"; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index efd7db4ea7629..5fd66c459d382 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1725,10 +1725,17 @@ common::Status InferenceSession::Initialize() { // graph optimization level and is generally always applied. bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() && session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableDmlGraphFusion, "0") == "0"; + std::string dml_graph_serialization_enabled_config_val = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphSerialization, "0"); + std::transform(dml_graph_serialization_enabled_config_val.begin(), + dml_graph_serialization_enabled_config_val.end(), + dml_graph_serialization_enabled_config_val.begin(), + [](char ch) { return std::tolower(ch); }); + bool dml_graph_serialization_enabled = dml_graph_serialization_enabled_config_val == "true"; if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - dmlExecutionProvider); + dmlExecutionProvider, + dml_graph_serialization_enabled); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 3874901f86387..7d4111e3b9c39 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -68,6 +68,7 @@ namespace perftest { "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" + "\t [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 87506c7240578..1934314b8ce43 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -18,6 +18,7 @@ #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" +#include "core/providers/dml/dml_session_options_config_keys.h" #endif #ifdef _WIN32 @@ -542,6 +543,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. " "Select from 'true' or 'false' \n"); } + } else if (key == "enable_graph_serialization") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + session_options.AddConfigEntry(kOrtSessionOptionsConfigEnableGraphSerialization, value.data()); + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'enable_graph_serialization'. " + "Select from 'true' or 'false' \n"); + } } } session_options.AppendExecutionProvider("DML", dml_options); From 8bd943be39301639e3f50f524f8fd71c7f2b2a34 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 27 Feb 2024 09:31:32 +1000 Subject: [PATCH 062/237] Retry flaky XCode iOS UI tests if we get a known error (#19639) ### Description Xcode UI tests seem to be flaky: https://github.com/orgs/community/discussions/68807 Add a couple of retries if we get a "Timed out while loading Accessibility." error which is transient. ### Motivation and Context --- .../github/apple/test_apple_packages.py | 61 ++++++++++++++----- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/tools/ci_build/github/apple/test_apple_packages.py b/tools/ci_build/github/apple/test_apple_packages.py index cd360a63a3a0f..3c0df994ffd3d 100644 --- a/tools/ci_build/github/apple/test_apple_packages.py +++ b/tools/ci_build/github/apple/test_apple_packages.py @@ -130,22 +130,51 @@ def _test_apple_packages(args): simulator_device_info = json.loads(simulator_device_info) - subprocess.run( - [ - "xcrun", - "xcodebuild", - "test", - "-workspace", - "./apple_package_test.xcworkspace", - "-scheme", - "ios_package_test", - "-destination", - f"platform=iOS Simulator,id={simulator_device_info['device_udid']}", - ], - shell=False, - check=True, - cwd=target_proj_path, - ) + # Xcode UI tests seem to be flaky: https://github.com/orgs/community/discussions/68807 + # Add a couple of retries if we get this error: + # ios_package_testUITests-Runner Failed to initialize for UI testing: + # Error Domain=com.apple.dt.XCTest.XCTFuture Code=1000 "Timed out while loading Accessibility." + attempts = 0 + cmd = [ + "xcrun", + "xcodebuild", + "test", + "-workspace", + "./apple_package_test.xcworkspace", + "-scheme", + "ios_package_test", + "-destination", + f"platform=iOS Simulator,id={simulator_device_info['device_udid']}", + ] + + while True: + attempts += 1 + completed_process = subprocess.run( + cmd, + shell=False, + capture_output=True, + check=False, + text=True, + cwd=target_proj_path, + ) + + # print so it's in CI output + print(completed_process.stdout) + + if completed_process.returncode != 0: + print(f"Running ios_package_test failed. Return code was {completed_process.returncode}") + print("xcrun xcodebuild test stderr:") + print(completed_process.stderr) + print("---") + + if "Timed out while loading Accessibility" in completed_process.stderr and attempts < 3: + continue + + raise subprocess.CalledProcessError( + completed_process.returncode, " ".join(cmd), completed_process.stdout, completed_process.stderr + ) + + break if PackageVariant[args.variant] != PackageVariant.Mobile and not args.skip_macos_test: subprocess.run( From 18c8fab1ae03e68a906fe42698ac322e9e49e218 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 26 Feb 2024 15:58:09 -0800 Subject: [PATCH 063/237] Fix a bug in build.py (#19652) ### Description Fix a bug in build.py that accidentally disabled C# tests for most builds when "--build_nuget" is specified. ### Motivation and Context The bug was introduced in PR #8892 . --- tools/ci_build/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 5b715bb29e5a1..74c473d34f548 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2592,7 +2592,7 @@ def main(): raise BuildError("Using --get-api-doc requires a single build config") # Disabling unit tests for GPU on nuget creation - if args.use_openvino != "CPU_FP32" and args.build_nuget: + if args.use_openvino and args.use_openvino != "CPU_FP32" and args.build_nuget: args.test = False # GDK builds don't support testing From 8a71b657654d63437267014b324bf124a80de347 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 27 Feb 2024 11:35:27 +1000 Subject: [PATCH 064/237] Remove skipping of Reshape from NNAPI EP (#19618) ### Description A number of Qualcomm Snapdragon chipsets do not produce correct output if we skip the Reshape, which ironically was a performance optimization for Snapdragon chips. Perf testing showed that Squeeze also seems to execute on CPU so there's no benefit to using that as an alternative where possible e.g. Global*Pool -> Reshape to 2D -> Gemm could be potentially be replaced with Global*Pool -> Squeeze dims 2 and 3 -> Gemm if that offered better performance. ### Motivation and Context #19518 --- .../builders/op_builder_helpers.cc | 30 ++++++++++++++----- .../builders/op_builder_helpers.h | 3 -- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index a066c64dac67d..466865f23f49a 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -965,6 +965,18 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, return Status::OK(); } +// NOTE: Skipping Reshape results in invalid output on some SnapDragon chipsets. Whilst the NNAPI spec says the input +// to FullyConnnected can be > 2D, those chipsets don't handle this correctly. +// +// CanSkipReshape could potentially be re-enabled in the future if we no longer want to support those old chipsets. +// However, the Reshape of newer chipsets may not run on CPU so there may not be a performance issue to try and avoid, +// so CanSkipReshape could be redundant anyway. +// +// Known bad chipsets: Qualcomm Snapdragon 850, 855, 865, 870. +// +// See https://github.com/microsoft/onnxruntime/issues/19518 + +/* // We can skip the Reshape if all the output edges satisfies both the following conditions // 1. The output of the reshape/flatten is not an output of the graph // 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators, @@ -977,7 +989,7 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, // between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution // If we are going to skip the reshape, we will still add correct shape and operand type for the output in // onnxruntime::nnapi::Model. -bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, +static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_rank, size_t output_rank) { // Since we know this is a Reshape NodeUnit, so we can safely assume there is only 1 output // and the node_unit has only one output node. @@ -1039,33 +1051,37 @@ bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit << node_unit.Name() << "] with output, " << output_name; return true; } +*/ Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, const std::string& input, const std::vector& shape) { auto& shaper(model_builder.GetShaper()); - const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); const auto& output = node_unit.Outputs()[0].node_arg.Name(); const auto input_shape = shaper[input]; const auto output_shape = shaper[output]; - const auto input_rank = input_shape.size(); - const auto output_rank = output_shape.size(); // For reshape, the output type should be the same as the input type except the shape is different auto output_operand_type = operand_types.at(input); output_operand_type.SetDimensions(output_shape); + /* See CanSkipReshape definition above for explanation of why this is disabled. // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now) // We will try to see if we the skip the Reshape to prevent context switching between // NNAPI CPU impl and NNAPI hardware accelerator impl if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) { - // Since reshape can be skipped, only register the dimension and type, with same index and new name + const auto& operand_indices(model_builder.GetOperandIndices()); + const auto input_rank = input_shape.size(); + const auto output_rank = output_shape.size(); + // Since reshape can be skipped, only register the dimension and type, with same index and new name. + // This essentially redirects the downstream operator builders to the input of the skipped Reshape node, + // but with the output shape of the Reshape node. model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); - } else { - // We still need to perform a reshape here + } else */ + { std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape"); ORT_RETURN_IF_ERROR(op_builder_helpers::AddNnapiReshape(model_builder, input, shape_name, shape, output)); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 7ccf4c1ef7555..61a16ceff752f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -181,9 +181,6 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, const std::string& input, const std::vector& shape); -bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, - size_t input_rank, size_t output_rank); - Status GetAxesForSqueezeAndUnSqueeze(ModelBuilder& model_builder, const NodeUnit& node_unit, std::vector& axes); From 6f566562cedff9996e55dbf623b1f0141733d52c Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:31:03 +0800 Subject: [PATCH 065/237] support user_compute_stream for rocm ep (#19619) ### Description According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP. And when we testing this feature with torch, we found torch use stream 0 for the default stream, and `torch.cuda.current_stream()` returns `0` for current stream, but ort treat `0` or `nullptr` as invalid, and reset has_user_compute_stream to false. Will remove has_user_compute_stream option in the future. ### Motivation and Context The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream. --- .../rocm/rocm_execution_provider_info.cc | 20 +++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index b557f92287f2b..3cb826437a54f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -13,6 +13,8 @@ namespace onnxruntime { namespace rocm { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search"; @@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P void* alloc = nullptr; void* free = nullptr; void* empty_cache = nullptr; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) + .AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + rocm::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddValueParser( rocm::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; } ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.arena_extend_strategy))}, {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 91b6c71e735a8..ab56f3fa0f37f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -559,6 +559,16 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) + # test for user_compute_stream + option = options["ROCMExecutionProvider"] + option["user_compute_stream"] = "1" + sess.set_providers(["ROCMExecutionProvider"], [option]) + new_options = sess.get_provider_options() + new_option = new_options["ROCMExecutionProvider"] + self.assertEqual(new_option["user_compute_stream"], "1") + # set user_compute_stream will set has_user_compute_stream to 1 too + self.assertEqual(new_option["has_user_compute_stream"], "1") + run_rocm_options_test() def test_invalid_set_providers(self): From 5bb58a10e739f8720e9867d19c4313081b12d948 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:00:14 -0800 Subject: [PATCH 066/237] Enable the most verbose logging level in detox E2E React Native CI (#19659) ### Description The RN CI has intermittent failure error with "app seems to idle". enable the most verbose logging level (and can add steps to dump device.log from the detox folder/artifacts if necessary) to at least get more information. ### Motivation and Context --------- Co-authored-by: rachguo --- .../github/azure-pipelines/templates/react-native-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 47cd72f412c67..1b7962059e301 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -279,7 +279,7 @@ stages: - script: | JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/android-test-results.xml \ - detox test --record-logs all --configuration android.emu.release + detox test --record-logs all --configuration android.emu.release --loglevel trace workingDirectory: '$(Build.SourcesDirectory)/js/react_native/e2e' displayName: Run React Native Detox Android e2e Tests @@ -329,7 +329,7 @@ stages: - script: | JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/ios-test-results.xml \ - detox test --record-logs all --configuration ios.sim.release + detox test --record-logs all --configuration ios.sim.release --loglevel trace workingDirectory: '$(Build.SourcesDirectory)/js/react_native/e2e' displayName: Run React Native Detox iOS e2e Tests From 9e19684944adfda4a414fc91a67259894fce2898 Mon Sep 17 00:00:00 2001 From: duanshengliu <44742794+duanshengliu@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:56:32 +0800 Subject: [PATCH 067/237] Fix the TypeError issue in quantize.py (#19459) ### Description Fix related bug as described in https://github.com/microsoft/onnxruntime/issues/19430 --- onnxruntime/python/tools/quantization/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 1bd2ef42151d0..05d3ac248c92c 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -479,7 +479,7 @@ def inc_dataloader(): del dataloader model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)) sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.") - model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix() + model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix() model.save(model_input) nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes]) model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration From 1e69b612382205b0588f08d2b808b12e32a50a51 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Tue, 27 Feb 2024 16:06:06 +0800 Subject: [PATCH 068/237] Make version string detection more robust (#19615) `/opt/rocm/.info/version-dev` is only available if the `rocm-dev` metapackage is installed. This will bring a lot of unused packages which are not needed by the users, they may opt for fine grained control. Fallback to `rocm_version.h` in case `rocm-dev` is not installed. --- cmake/CMakeLists.txt | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ed9043f2adc4a..1376c90fbcefe 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -324,15 +324,27 @@ if (onnxruntime_USE_ROCM) endif() # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173 - file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - if (ROCM_VERSION_DEV_MATCH) + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version-dev") + file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") endif() message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") From 4838cb6b3e98273fcdd6a3e54da74cd584167780 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 27 Feb 2024 02:27:35 -0800 Subject: [PATCH 069/237] [QNN Quantization] Ensure fused nodes have names (#19650) ### Description - Updates the `qnn_preprocess_model()` method to set a name for any new nodes added to the graph (due to fusion). - Updates the `qnn_preprocess_model()` method to set a name for any unnamed nodes that previously existed in the original graph. - Adds unit tests for fusions (previously missing) - Checks that fused node names exist and are unique - Checks that fused graph is equivalent to original graph ### Motivation and Context Nodes are not strictly required to have names. However, a planned/upcoming feature to support mixed-precision (integer) quantized models needs nodes to have names. --- .../execution_providers/qnn/fusion_lpnorm.py | 7 +- .../execution_providers/qnn/preprocess.py | 11 + .../tools/quantization/fusions/fusion.py | 15 + .../tools/quantization/fusions/fusion_gelu.py | 25 +- .../quantization/fusions/fusion_layernorm.py | 1 + .../python/tools/quantization/onnx_model.py | 17 + .../test/python/quantization/test_fusions.py | 401 ++++++++++++++++++ 7 files changed, 465 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_fusions.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py index 9ebf400498e0e..fbf954febdda4 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py @@ -122,6 +122,11 @@ def fuse( self.nodes_to_remove.extend(subgraph_nodes) fused_node = onnx.helper.make_node( - self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1 + self.fused_op_type, + name=self.create_unique_node_name(), + inputs=[subgraph_input], + outputs=[subgraph_output], + p=2, + axis=-1, ) self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index becbaceab184e..b1c114fe1f9fd 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if fusion_layernorm.apply(): modified = True + # Make sure all nodes have a name. + unnamed_node_prefix = "qnn_preproc_node_" + available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1 + for node in onnx_model.model.graph.node: + if node.op_type != "Constant" and not node.name: + new_node_name = f"{unnamed_node_prefix}{available_suffix!s}" + available_suffix += 1 + node.name = new_node_name + modified = True + logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.") + if modified: onnx_model.topological_sort() onnx.save_model(model, model_output) diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py index b54b421226f1a..4bdc5c26cc946 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str): self.nodes_to_remove: list = [] self.nodes_to_add: list = [] + self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_" + self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops. + def fuse( self, node: onnx.NodeProto, @@ -57,6 +60,18 @@ def apply(self) -> bool: return graph_updated + def create_unique_node_name(self): + prefix = self._new_node_name_prefix + + if self._new_node_name_suffix is None: + largest_suffix: int = self.model.get_largest_node_name_suffix(prefix) + self._new_node_name_suffix = largest_suffix + 1 + + new_name = f"{prefix}{self._new_node_name_suffix!s}" + self._new_node_name_suffix += 1 + + return new_name + @staticmethod def is_safe_to_fuse_nodes( nodes_to_remove: list[onnx.NodeProto], diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py index a20d6dbffd7a7..42c4a11833641 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py @@ -112,7 +112,9 @@ def fuse_1( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -173,11 +175,9 @@ def fuse_2( if not self.has_constant_input(sqrt_node, 2.0): return False - root_node = self.model.get_parent(div, 0, output_name_to_node) - if root_node is None: - return False + subgraph_input = div.input[0] - if root_node.output[0] not in mul.input: + if subgraph_input not in mul.input: return False subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul] @@ -188,7 +188,9 @@ def fuse_2( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -239,9 +241,8 @@ def fuse_3( if i < 0: return False - root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node) - if root_node is None: - return False + root_input_index = 1 - i + subgraph_input = first_mul.input[root_input_index] if mul_half.output[0] not in input_name_to_nodes: return False @@ -250,7 +251,7 @@ def fuse_3( return False last_mul = children[0] - if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]): + if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input): return False subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] @@ -263,7 +264,9 @@ def fuse_3( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py index d7fb89236d3d2..7d58c1c180822 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -127,6 +127,7 @@ def fuse( normalize_node = onnx.helper.make_node( "LayerNormalization", + name=self.create_unique_node_name(), inputs=[reduce_mean_node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]], ) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 4591c9c950e6e..46d245d353a07 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -283,6 +283,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph): node = find_by_name(node_name, graph_nodes_list) return node + def get_largest_node_name_suffix(self, node_name_prefix): + """ + Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`. + Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3. + """ + suffix = -1 + + for node in self.model.graph.node: + if node.name and node.name.startswith(node_name_prefix): + try: + index = int(node.name[len(node_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py new file mode 100644 index 0000000000000..bea110e566fb9 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_fusions.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest + +import numpy as np +import onnx + +import onnxruntime +from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization +from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization +from onnxruntime.quantization.onnx_model import ONNXModel + + +class TestFusions(unittest.TestCase): + def check_fused_model_correctness(self, orig_model, fused_model, inputs, rtol=1e-7, atol=0): + """ + Checks that the output of the fused model matches the output of the original model. + """ + orig_session = onnxruntime.InferenceSession(orig_model.SerializeToString(), providers=["CPUExecutionProvider"]) + orig_results = orig_session.run(None, inputs) + + fused_session = onnxruntime.InferenceSession( + fused_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + fused_results = fused_session.run([], inputs) + + self.assertEqual(len(orig_results), len(fused_results), "Number of outputs for fused model differs") + for idx, expected_output in enumerate(orig_results): + actual_output = fused_results[idx] + np.testing.assert_allclose( + expected_output, + actual_output, + rtol=rtol, + atol=atol, + err_msg=f"Fused model output {idx} differs", + ) + + def build_erf_sequence_1_model(self, shape): + """ + Erf sequence that fuses into Gelu: + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) + + This method builds 2 of these Erf sequences: + + [root] -> ERF_SEQUENCE1 -> ERF_SEQUENCE2 -> output + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + # First Erf sequence + mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["mul0_out"]) + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "mul0_out"], ["seq1_output"]) + + # Second Erf sequence + mul0_node_dup = onnx.helper.make_node("Mul", ["seq1_output", "half_const"], ["mul0_out_dup"]) + div_node_dup = onnx.helper.make_node("Div", ["seq1_output", "root2_const"], ["div_out_dup"]) + erf_node_dup = onnx.helper.make_node("Erf", ["div_out_dup"], ["erf_out_dup"]) + add_node_dup = onnx.helper.make_node("Add", ["erf_out_dup", "one_const"], ["add_out_dup"]) + mul1_node_dup = onnx.helper.make_node("Mul", ["add_out_dup", "mul0_out_dup"], ["output"]) + + graph = onnx.helper.make_graph( + [ + mul0_node, + div_node, + erf_node, + add_node, + mul1_node, + mul0_node_dup, + div_node_dup, + erf_node_dup, + add_node_dup, + mul1_node_dup, + ], + "two_erf_sequences", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_2_model(self, shape): + """ + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "root"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "half_const"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_2", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_3_model(self, shape): + """ + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_3", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_4_model(self, shape): + """ + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + frac_const = onnx.numpy_helper.from_array(np.array(0.7071067690849304, dtype=np.float32), "frac_const") + + mul0_node = onnx.helper.make_node("Mul", ["root", "frac_const"], ["mul0_out"]) + erf_node = onnx.helper.make_node("Erf", ["mul0_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul1_out"]) + mul2_node = onnx.helper.make_node("Mul", ["mul1_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [mul0_node, erf_node, add_node, mul1_node, mul2_node], + "erf_sequence_4", + [root_inp], + [output], + initializer=[one_const, half_const, frac_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1): + """ + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ ^ ^ + | | | | + +-------------------------------------------------+ [Scale] [Bias] + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const") + + rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"]) + sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"]) + pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"]) + rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"]) + add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"]) + sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"]) + div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"]) + mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"]) + add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node], + "reduce_mean_sequence", + [root_inp], + [output], + initializer=[scale_const, bias_const, axes_const, two_const, eps_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1): + """ + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1) + clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"]) + expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"]) + div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"]) + + graph = onnx.helper.make_graph( + [rl2_node, clip_node, expand_node, div_node], + "reducel2_sequence", + [root_inp], + [output], + initializer=[axes_const, eps_const, shape_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def test_fuse_erf_to_gelu_1(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_1_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 2 Gelu nodes. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 2) + + gelu_node_0 = model.model.graph.node[0] + gelu_node_1 = model.model.graph.node[1] + self.assertEqual(gelu_node_0.op_type, "Gelu") + self.assertEqual(gelu_node_1.op_type, "Gelu") + + self.assertTrue(gelu_node_0.name) + self.assertTrue(gelu_node_1.name) + self.assertNotEqual(gelu_node_0.name, gelu_node_1.name) # Generated names should not be equal + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_2(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_2_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_3(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_3_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_4(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_4_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_reduce_l2_to_lpnorm(self): + shape = (1, 2, 3) + model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LpNormalization node. + modified = FusionLpNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + lpnorm_node = model.model.graph.node[0] + self.assertEqual(lpnorm_node.op_type, "LpNormalization") + self.assertTrue(lpnorm_node.name) + + # LpNorm's p attribute should be set to 2 + p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p") + self.assertEqual(p_attr.i, 2) + + def test_fuse_reduce_mean_to_layer_norm(self): + shape = (1, 2, 3) + model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LayerNormalization node. + modified = FusionLayerNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + layer_norm_node = model.model.graph.node[0] + self.assertEqual(layer_norm_node.op_type, "LayerNormalization") + self.assertTrue(layer_norm_node.name) + + # Check that fused model is equivalent to original model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + +if __name__ == "__main__": + unittest.main() From 3b46ab643944a3bcc9e4d9eb2c155ead0bad5cdb Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 28 Feb 2024 00:46:29 +0800 Subject: [PATCH 070/237] Re-add testing removed by mistake. (#19647) --- .../azure-pipelines/linux-ci-pipeline.yml | 42 ++++++++++++++++++- .../docker/scripts/manylinux/requirements.txt | 1 + 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index a4bd24b4dd18b..02147c321fab3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -115,6 +115,7 @@ stages: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() + - job: Linux_Release timeoutInMinutes: 180 workspace: @@ -243,7 +244,46 @@ stages: ln -s /data/models $(Build.BinariesDirectory)/models displayName: link model dir - + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + pushd /onnxruntime_src/csharp; \ + dotnet restore /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln; \ + dotnet build /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release; \ + dotnet test /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release -f net6.0 --no-build -l \"console;verbosity=normal\"; \ + popd + " + displayName: 'Dotnet build C# sln and Test' + + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + /bin/bash /onnxruntime_src/tools/scripts/python_test.sh /onnxruntime_src /build Release && \ + /bin/bash /onnxruntime_src/tools/scripts/symbolic_shape_infer_test.sh /build + " + displayName: 'Run Release tests and symbolic shape infer test' - task: PublishTestResults@2 displayName: 'Publish unit test results' diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 94f52f476579b..886f19388d01e 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -10,3 +10,4 @@ protobuf==4.21.12 sympy==1.12 flatbuffers neural-compressor>=2.2.1 +triton From 580ee20dfce2849029229eb213dc8c7c87a89483 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 28 Feb 2024 02:56:16 +1000 Subject: [PATCH 071/237] Tweak Windows build parallelization settings (#19664) ### Description Use UseMultiToolTask and limit the number of cl.exe instances running. MultiToolTask info: https://devblogs.microsoft.com/cppblog/improved-parallelism-in-msbuild/ Info on why limiting CL_MPCount can help: https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows The current CIs have 4 cores (both physical and logical). Hardcoded the GPU build in win-ci.yml to use CL_MPCount of 2 as that seems to work fine. Can adjust if needed to base it on the actual number of cores or to use build.py to build. Caveat: I've run about 16 builds and haven't seen a slow build yet, but as the root cause of the slow builds isn't really known this isn't guaranteed to be a fix. ### Motivation and Context Try and prevent super slow GPU builds by reducing number of tasks potentially running in parallel. --- tools/ci_build/build.py | 15 ++++++++++++++- .../github/azure-pipelines/templates/win-ci.yml | 3 ++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 74c473d34f548..1056c4ed84510 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1451,6 +1451,13 @@ def generate_build_tree( # tools need to use the symbols. add_default_definition(cmake_extra_defines, "CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "ProgramDatabase") + if number_of_parallel_jobs(args) > 0: + # https://devblogs.microsoft.com/cppblog/improved-parallelism-in-msbuild/ + # NOTE: this disables /MP if set (according to comments on blog post). + # By default, MultiProcMaxCount and CL_MPCount value are equal to the number of CPU logical processors. + # See logic around setting CL_MPCount below + cmake_args += ["-DCMAKE_VS_GLOBALS=UseMultiToolTask=true;EnforceProcessCountAcrossBuilds=true"] + cmake_args += [f"-D{define}" for define in cmake_extra_defines] cmake_args += cmake_extra_args @@ -1662,11 +1669,17 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe build_tool_args = [] if num_parallel_jobs != 1: if is_windows() and args.cmake_generator != "Ninja" and not args.build_wasm: + # https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows suggests + # not maxing out CL_MPCount + # Start by having one less than num_parallel_jobs (default is num logical cores), + # limited to a range of 1..3 + # that gives maxcpucount projects building using up to 3 cl.exe instances each build_tool_args += [ f"/maxcpucount:{num_parallel_jobs}", + # one less than num_parallel_jobs, at least 1, up to 3 + f"/p:CL_MPCount={min(max(num_parallel_jobs - 1, 1), 3)}", # if nodeReuse is true, msbuild processes will stay around for a bit after the build completes "/nodeReuse:False", - f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": build_tool_args += [ diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 8ed22153fd947..e32956d6eb913 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -162,10 +162,11 @@ stages: platform: ${{ parameters.msbuildPlatform }} configuration: RelWithDebInfo msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true + maximumCpuCount: true # default is num logical cores worth of projects building concurrently logProjectEvents: true workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' createLogFile: true + msbuildArgs: "/p:CL_MPCount=2" # 2x cl.exe per project building. - task: PythonScript@0 displayName: 'test' From 1c468a03b90aa8122d49b3148152a67b0519d36e Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 28 Feb 2024 03:27:43 +1000 Subject: [PATCH 072/237] Improve Nuget-CUDA-Packaging-Pipeline (#19668) ### Description * Publish the artifacts as late as possible * once published the artifacts are immutable, and any retry will fail if they exist * if any step fails after publishing the stage cannot be retried * use powershell to cleanup * DeleteFiles is taking >30 mins and causing the stage to timeout * powershell took < 1s ### Motivation and Context Make pipeline more robust --- .../stages/nuget-combine-cuda-stage.yml | 13 ++++++------- ...mponent-governance-component-detection-steps.yml | 7 ++----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 8ca3d9148b514..064e2ea91d194 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -213,13 +213,6 @@ stages: PlatformsSupported: 'linux-x64' VerifyNugetSigning: false - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-GPU' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 displayName: 'Clean C#' inputs: @@ -241,6 +234,12 @@ stages: parameters: condition: 'succeeded' + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-GPU' + targetPath: '$(Build.ArtifactStagingDirectory)' + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index f1418e75bffa2..3d128fdb78eee 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -6,11 +6,8 @@ parameters: steps: - ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - - task: DeleteFiles@1 - inputs: - SourceFolder: '$(Build.BinariesDirectory)' - contents: | - **/* + - powershell: | + Remove-Item $(Build.BinariesDirectory)/* -Recurse -Force displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 From 2e4d1b8f1ba928fe5879077eced9cd5191760cfb Mon Sep 17 00:00:00 2001 From: zesongw Date: Wed, 28 Feb 2024 02:01:12 +0800 Subject: [PATCH 073/237] [WebNN EP] Add support for Op MatMul of WebNN CPU backend (#19413) Enable MatMul support for WebNN CPU backend to support more models. --- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/gemm_op_builder.cc | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d94729e60d029..d7892fe02c1ba 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -195,7 +195,7 @@ static const InlinedHashMap op_map = { {"LessOrEqual", {"lesserOrEqual", false}}, {"Log", {"log", false}}, {"LpPool", {"l2Pool2d", false}}, - {"MatMul", {"matmul", false}}, + {"MatMul", {"matmul", true}}, {"MatMulInteger", {"matmulInteger", false}}, {"Max", {"max", true}}, {"MaxPool", {"maxPool2d", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 4bf991a1b0105..d5f84f853f7de 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder { // Add operator related. Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C @@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); if (op_type == "MatMul") { - output = model_builder.GetBuilder().call("matmul", a, b); + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A."); + } + // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case. + // TODO: Remove this workaround when it is fixed in Chromium. + if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) { + output = model_builder.GetBuilder().call("gemm", a, b); + } else { + output = model_builder.GetBuilder().call("matmul", a, b); + } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); emscripten::val b_zero_point = emscripten::val::null(); From 3cb81cdde25d059af5674506f6a5b899c9c0f5ee Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:07:15 -0800 Subject: [PATCH 074/237] [js/common] move 'env.wasm.trace' to 'env.trace' (#19617) ### Description Try to move 'env.wasm.trace' to 'env.trace' to make it less confusing, because it also works in webgpu. Marked 'env.wasm.trace' as deprecated. --- js/common/lib/env.ts | 9 +++++++++ js/common/lib/trace.ts | 6 +++--- js/web/lib/wasm/jsep/backend-webgpu.ts | 3 ++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 6299c26159400..73a47d1a4f937 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -36,6 +36,7 @@ export declare namespace Env { /** * set or get a boolean value indicating whether to enable trace. * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` */ trace?: boolean; @@ -167,6 +168,7 @@ export interface Env { * @defaultValue `'warning'` */ logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + /** * Indicate whether run in debug mode. * @@ -174,6 +176,13 @@ export interface Env { */ debug?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @defaultValue `false` + */ + trace?: boolean; + /** * Get version of the current package. */ diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 404f7ef8089af..7e0487b350198 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -4,7 +4,7 @@ import {env} from './env-impl.js'; export const TRACE = (deviceType: string, label: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } // eslint-disable-next-line no-console @@ -30,14 +30,14 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => { }; export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('BEGIN', extraMsg); }; export const TRACE_FUNC_END = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('END', extraMsg); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 3e3a191ec3ead..27c5566ab9fed 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -710,7 +710,8 @@ export class WebGpuBackend { } setQueryType(): void { this.queryType = 'none'; - if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) { + if (this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { this.queryType = 'inside-passes'; } else if (this.device.features.has('timestamp-query')) { From c20ced4132d111e3e63844e292f0d8e318cffab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Tue, 27 Feb 2024 20:26:48 +0100 Subject: [PATCH 075/237] Use CMake's find package for CUDA libs (#19673) ### Description Answers issue #19640 More details are in the issue, basically I am changing all the include directory and link directory usage to CMake's `CUDA::*` targets --- cmake/CMakeLists.txt | 4 ++++ cmake/adjust_global_compile_flags.cmake | 2 +- .../external/onnxruntime_external_deps.cmake | 3 +-- cmake/onnxruntime_providers_cuda.cmake | 20 +++++++++---------- cmake/onnxruntime_providers_tensorrt.cmake | 11 +++++----- cmake/onnxruntime_python.cmake | 5 +---- cmake/onnxruntime_unittests.cmake | 4 ++-- .../core/providers/cuda/nvtx_profile.cc | 5 ----- 8 files changed, 25 insertions(+), 29 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1376c90fbcefe..8453da19ce3a6 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1412,6 +1412,10 @@ endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) + if(onnxruntime_CUDA_HOME) + file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) + endif() + find_package(CUDAToolkit REQUIRED) if(onnxruntime_CUDNN_HOME) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 8161ea574b8cc..d3f9256105127 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -205,7 +205,7 @@ endif() macro(check_nvcc_compiler_flag _FLAG _RESULT) - execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) + execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) message("NVCC_ERROR = ${NVCC_ERROR}") message("NVCC_OUT = ${NVCC_OUT}") if ("${NVCC_OUT}" MATCHES "0") diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 22d12b128dc1f..09d57164b4ee1 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -556,16 +556,15 @@ message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) if (onnxruntime_USE_CUDA) #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same + find_package(CUDAToolkit REQUIRED) if (WIN32) if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64) else() if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64) endif() endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 9887d615c92d7..0f6d48bdb6ec8 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -178,15 +178,16 @@ add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(onnxruntime_CUDA_MINIMAL) target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) - target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) else() - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) if(onnxruntime_CUDNN_HOME) target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) endif() endif() - + if (onnxruntime_USE_TRITON_KERNEL) # compile triton kernel, generate .a and .h files include(onnxruntime_compile_triton_kernel.cmake) @@ -196,25 +197,24 @@ target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE cuda) + target_link_libraries(${target} PRIVATE CUDA::cuda_driver) endif() include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling - target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) - target_link_libraries(${target} PRIVATE cupti) + target_link_libraries(${target} PRIVATE CUDA::cupti) endif() - if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) - target_link_libraries(${target} PRIVATE nvToolsExt) + if (onnxruntime_ENABLE_NVTX_PROFILE) + target_link_libraries(${target} PRIVATE CUDA::nvtx3) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 686a993de3a4a..15ffc29e79ff4 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -8,7 +8,7 @@ set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") - set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) @@ -58,7 +58,7 @@ URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} ) if (NOT CUDA_INCLUDE_DIR) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build endif() # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. @@ -102,11 +102,12 @@ onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) else() - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 3f20787e87425..23c6e5e430875 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -282,10 +282,7 @@ if (WIN32) get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE) string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}") if(NOT onnxruntime_CUDA_VERSION) - message("Reading json file ${onnxruntime_CUDA_HOME}/version.json") - set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json") - file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT) - string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version") + set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION}) message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}") endif() file(APPEND "${VERSION_INFO_FILE}" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 3ed695327c183..88f662075e177 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -1268,7 +1268,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart) endif() if (onnxruntime_USE_ROCM) list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) diff --git a/onnxruntime/core/providers/cuda/nvtx_profile.cc b/onnxruntime/core/providers/cuda/nvtx_profile.cc index 6c7c594066b86..867e7c1f24584 100644 --- a/onnxruntime/core/providers/cuda/nvtx_profile.cc +++ b/onnxruntime/core/providers/cuda/nvtx_profile.cc @@ -4,13 +4,8 @@ #ifdef ENABLE_NVTX_PROFILE #include "nvtx_profile.h" #include "core/common/common.h" -#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__) #include #include -#else -#include -#include -#endif namespace onnxruntime { namespace profile { From f95c0773a129a4605b2161f5f9fddb8116c948d0 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 28 Feb 2024 10:40:40 +0800 Subject: [PATCH 076/237] Add share memory Flag in docker (#19672) ### Description ### Motivation and Context Ref: https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem Co-authored-by: Your Name --- tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 822bc559d992d..165bd804a8ad5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -241,7 +241,7 @@ stages: script: | set -e -x mkdir -p $HOME/.onnx - docker run --gpus all --rm \ + docker run --gpus all --shm-size=1g --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --rm \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory)/Release:/build/Release \ --volume /data/models:/build/models:ro \ From 026e3178ae71cfcc5cfa2decde9a7d64b935d255 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 28 Feb 2024 15:57:05 +0800 Subject: [PATCH 077/237] Improve memory matrix for ORTModule (#19620) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Memory matrix for ORTModule Collect parameter/gradient/buffers sizes also. Exposed as a function, can be used externally for debugging purpose. ``` 2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 [2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1 0%|▎ | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 5/3200 [01:28<6:53:04, 7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 6/3200 [01:28<4:36:19, 5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 7/3200 [01:28<3:09:33, 3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 8/3200 [01:28<2:13:48, 2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▋ | 9/3200 [01:29<1:36:03, 1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 ``` --- .../training/ortmodule/_runtime_inspector.py | 37 +++------ .../python/training/utils/__init__.py | 2 + .../training/utils/torch_profile_utils.py | 76 +++++++++++++++++++ 3 files changed, 88 insertions(+), 27 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27cd6f..772b9bd9e31ae 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -14,7 +14,7 @@ from sympy import Symbol, simplify from sympy.parsing.sympy_parser import parse_expr -from onnxruntime.training.utils import PTable +from onnxruntime.training.utils import PTable, log_memory_usage from ._execution_agent import TrainingAgent from .options import _MemoryOptimizationLevel, _RuntimeOptions @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._is_first_inspect = True + self._m = m + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -621,29 +623,13 @@ def inspect_memory(self, cur_phase: Phase): need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) if need_print: - cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) - max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) - cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) - max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) - torch_mem_stat = torch.cuda.memory_stats() - cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) - max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) - - mem_stats = [ - ["phase", _convert_phase_to_string(cur_phase)], - ["allocated", cur_mem_allocated], # current memory allocated for tensors - ["max allocated", max_mem_allocated], # peak memory allocated for tensors - ["cached", cur_mem_cached], # current memory cached for the caching allocator - ["max cached", max_mem_cached], # peak memory cached for caching allocator. - ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory - ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ] - - summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" - for stat in mem_stats: - summ += f" | {stat[0]}: {stat[1]}" - - self._logger.info(summ) + log_memory_usage( + _convert_phase_to_string(cur_phase), + rank_0_only=True, + step_info=f"step {self._current_step}", + logger=self._logger, + module=self._m, + ) if cur_phase == self._last_phase: self._increase_step() @@ -655,9 +641,6 @@ def inspect_memory(self, cur_phase: Phase): def _increase_step(self): self._current_step += 1 - def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: - return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index b4a518d573998..ecfb7d7907f3c 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -12,6 +12,7 @@ unflatten_data_using_schema, ) from onnxruntime.training.utils.torch_profile_utils import ( + log_memory_usage, nvtx_function_decorator, torch_nvtx_range_pop, torch_nvtx_range_push, @@ -31,6 +32,7 @@ "torch_nvtx_range_push", "torch_nvtx_range_pop", "nvtx_function_decorator", + "log_memory_usage", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py index 382d7dac142fe..9e8a41e0dc7c8 100644 --- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import torch @@ -26,3 +28,77 @@ def wrapped_fn(*args, **kwargs): return ret_val return wrapped_fn + + +def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None): + """Log memory usage for the current phase. + Args: + cur_phase (str): The current phase. + rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True. + step_info (str, optional): The step information. Defaults to "". + logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout. + module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None. + """ + rank = 0 + if rank_0_only is True: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + if rank != 0: + return + + _normalizer_factor = float(1024 * 1024) + _normalizer_unit = "MiB" + + def _normalize(mem_size_in_bytes: float | int) -> str: + return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}" + + cur_mem_allocated = _normalize(torch.cuda.memory_allocated()) + max_mem_allocated = _normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = _normalize(torch.cuda.memory_reserved()) + max_mem_cached = _normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + mem_stats = [ + ["phase", cur_phase], + ["allocated", cur_mem_allocated], # current memory allocated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for the caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + # Calculate the total size of parameters and gradients in the model + if module: + param_total_size = 0 + grad_total_size = 0 + for p in module.parameters(): + if p.is_cuda: + param_total_size += p.numel() * p.element_size() + if p.grad is not None and p.grad.is_cuda: + grad_total_size += p.grad.numel() * p.grad.element_size() + + # Calculate the total size of buffers in the model + buffer_total_size = 0 + for b in module.buffers(): + if b.is_cuda: + buffer_total_size += b.numel() * b.element_size() + + mem_stats.extend( + [ + ["param", _normalize(param_total_size)], + ["grad", _normalize(grad_total_size)], + ["buffer", _normalize(buffer_total_size)], + ] + ) + + summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + + if logger is None: + print(summ) + else: + logger.info(summ) From 7a147fc6f76a30b8d5875352afced515431ec7e5 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 02:20:53 -0800 Subject: [PATCH 078/237] Remove a bash task from webgpu CI pipeline (#19682) ### Description It is a "Bash" task that requires running bash on Windows. Most Windows operating systems do not have Bash installed. Given this task is only debugging purposes, we can remove it for now. ### Motivation and Context I am making this change because I am regenerating the VM image in a different manner, and the new image does not contain bash. Once this PR is in, I can switch the images. --- .../github/azure-pipelines/templates/win-web-ci.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 8ba3517530edd..043da233cc674 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -155,12 +155,7 @@ jobs: path: $(Build.SourcesDirectory)/js/test/ cacheHitVar: CACHE_RESTORED displayName: 'Cache ONNX node test data' - - task: Bash@3 - inputs: - targetType: 'inline' - script: find "$(Build.SourcesDirectory)/js/test/" -type f - condition: and(not(canceled()), eq(variables.CACHE_RESTORED, 'true')) - displayName: 'List ONNX node test data' + - task: PowerShell@2 inputs: filePath: '$(Build.SourcesDirectory)\tools\ci_build\github\js\pack-npm-packages.ps1' From 913bdc7306e11b65644f733861684a3a460e8db0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 28 Feb 2024 08:30:12 -0800 Subject: [PATCH 079/237] [QNN Quant] Handle external data for QNN preprocessing/quant (#19670) ### Description - Adds parameters to `qnn_preprocess_model()` to allow saving the new model with external data. - Updates `get_qnn_qdq_config()` to: - Load model without external data (it is not needed) - Return a quantization configuration with `use_external_data_format` set to `True` if the model has external data or if the model is >= 2GB. ### Motivation and Context Update QNN quantization to better handle large models that use external data. --- .../execution_providers/qnn/preprocess.py | 51 +++++- .../execution_providers/qnn/quant_config.py | 15 +- .../quantization/test_qnn_preprocess_model.py | 170 ++++++++++++++++++ .../test_tensor_quant_overrides_option.py | 30 ++++ 4 files changed, 261 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_qnn_preprocess_model.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index b1c114fe1f9fd..b0dab81830c8b 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import logging from pathlib import Path @@ -13,7 +15,44 @@ from .fusion_lpnorm import FusionLpNormalization -def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool: +def qnn_preprocess_model( + model_input: Path, + model_output: Path, + fuse_layernorm: bool = False, + save_as_external_data: bool = False, + all_tensors_to_one_file: bool = False, + external_data_location: str | None = None, + external_data_size_threshold: int = 1024, + external_data_convert_attribute: bool = False, +) -> bool: + """ + If necessary, this method creates a new "pre-processed" model in preparation for + quantization of a model to be used in QNN EP. Returns true if a new model was created. + + This method perfoms the following operations: + - Fuse Erf sequence into a single Gelu node. + - Fuse ReduceL2 sequence into a single LpNormalization node (p == 2). + - (Optional) Fuse ReduceMean sequence into a single LayerNormalization node. + + Args: + model_input: Path to the input model file. + model_output: Path the output model file, which is only created if this method returns True. + fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes. + Defaults to False. + save_as_external_data: True if output model should be saved with external data. Defaults to false. + all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false. + If true, save all tensors to one external file specified by external_data_location. + If false, save each tensor to a file named with the tensor name. + external_data_location: Effective only if save_as_external_data is true. Defaults to None. + Specify the external file to which all tensors are saved. Path is relative + to the model path. If not specified, the model's name is used. + external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024. + Tensors with a data size >= external_data_size_threshold are converted to external data. + To convert every tensor with raw data to external data, set to 0. + external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false. + If true, convert all tensors to external data. + If false, convert only non-attribute tensors to external data. + """ modified = False model = onnx.load_model(model_input) onnx_model = ONNXModel(model) @@ -57,6 +96,14 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if modified: onnx_model.topological_sort() - onnx.save_model(model, model_output) + onnx.save_model( + model, + model_output, + save_as_external_data=save_as_external_data, + all_tensors_to_one_file=all_tensors_to_one_file, + location=external_data_location, + size_threshold=external_data_size_threshold, + convert_attribute=external_data_convert_attribute, + ) return modified diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 7c2fa4f65ae1b..e9affae7ac263 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -15,6 +15,7 @@ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} OP_TYPES_TO_EXCLUDE = {"Cast"} +MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB def get_qnn_qdq_config( @@ -28,14 +29,21 @@ def get_qnn_qdq_config( if per_channel: raise ValueError("QNN EP does not yet support per-channel quantization.") - # Process model nodes to setup overrides. - model = onnx.load_model(model_input) + model = onnx.load_model(model_input, load_external_data=False) op_types = set() tensor_quant_overrides = {} + model_has_external_data = False + name_to_initializer = {} - name_to_initializer = {initializer.name: initializer for initializer in model.graph.initializer} + # Build map of initializers (name -> initializer) and + # check if the model has external data. + for initializer in model.graph.initializer: + name_to_initializer[initializer.name] = initializer + if onnx.external_data_helper.uses_external_data(initializer): + model_has_external_data = True + # Setup quantization overrides for specific operator types for node in model.graph.node: op_types.add(node.op_type) @@ -89,5 +97,6 @@ def get_qnn_qdq_config( activation_type=activation_type, weight_type=weight_type, op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), extra_options=extra_options, ) diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py new file mode 100644 index 0000000000000..9b67fd41caac3 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest +from pathlib import Path + +import numpy as np +import onnx + +from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model +from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain + + +class TestQnnPreprocessModel(unittest.TestCase): + def build_model(self, shape, scale_val, bias_val): + """ + Build a model that supports 3 kinds of fusions: + - Erf sequence to Gelu + - ReduceL2 sequence to LpNormalization + - ReduceMean sequence to LayerNormalization + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + + # Erf sequence + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + e_mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["e_mul0_out"]) + e_div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["e_div_out"]) + e_erf_node = onnx.helper.make_node("Erf", ["e_div_out"], ["e_erf_out"]) + e_add_node = onnx.helper.make_node("Add", ["e_erf_out", "one_const"], ["e_add_out"]) + e_mul1_node = onnx.helper.make_node("Mul", ["e_add_out", "e_mul0_out"], ["erf_seq_output"]) + + # ReduceL2 sequence + axes_const = onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(1e-12, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + l2_rl2_node = onnx.helper.make_node("ReduceL2", ["erf_seq_output", "axes_const"], ["l2_rl2_out"], keepdims=1) + l2_clip_node = onnx.helper.make_node("Clip", ["l2_rl2_out", "eps_const"], ["l2_clip_out"]) + l2_expand_node = onnx.helper.make_node("Expand", ["l2_clip_out", "shape_const"], ["l2_expand_out"]) + l2_div_node = onnx.helper.make_node("Div", ["erf_seq_output", "l2_expand_out"], ["l2_seq_output"]) + + # ReduceMean sequence + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + + m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"]) + m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"]) + m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"]) + m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"]) + m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"]) + m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"]) + m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"]) + m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"]) + m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [ + e_mul0_node, + e_div_node, + e_erf_node, + e_add_node, + e_mul1_node, + l2_rl2_node, + l2_clip_node, + l2_expand_node, + l2_div_node, + m_rm0_node, + m_sub_node, + m_pow_node, + m_rm1_node, + m_add0_node, + m_sqrt_node, + m_div_node, + m_mul_node, + m_add1_node, + ], + "qnn_f32_model", + [root_inp], + [output], + initializer=[ + one_const, + half_const, + root2_const, + axes_const, + eps_const, + shape_const, + scale_const, + bias_const, + two_const, + ], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_all_fusions(self): + """ + Test calling qnn_preprocess_model() with a model that supports all 3 fusions. + """ + model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]) + onnx.save_model(model, "model.onnx") + modified = qnn_preprocess_model("model.onnx", "model.qnn_pp.onnx", fuse_layernorm=True) + + self.assertTrue(modified) + + fused_model = onnx.load_model("model.qnn_pp.onnx") + + # 3 fused Ops: Gelu, LpNorm, LayerNorm + self.assertEqual(len(fused_model.graph.node), 3) + expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} + for node in fused_model.graph.node: + self.assertIn(node.op_type, expected_op_types) + + # Should have added "com.microsoft" opset import because we added a Gelu. + ms_domain_opset = next((opset for opset in fused_model.opset_import if opset.domain == ms_domain), None) + self.assertIsNotNone(ms_domain_opset) + self.assertEqual(ms_domain_opset.version, 1) + + def test_external_data(self): + """ + Test calling qnn_preprocess_model() with a model that uses external data. + The new preprocessed model should also have external data. + """ + model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]) + onnx.save_model( + model, + "model.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.bin", + size_threshold=0, + ) + modified = qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + fuse_layernorm=True, + save_as_external_data=True, + all_tensors_to_one_file=True, + external_data_location="weights2.bin", + external_data_size_threshold=0, + ) + + self.assertTrue(modified) + + # Model should still have external data. + self.assertTrue(model_has_external_data(Path("model.qnn_pp.onnx"))) + + fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False) + + # 3 fused Ops: Gelu, LpNorm, LayerNorm + self.assertEqual(len(fused_model.graph.node), 3) + expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} + for node in fused_model.graph.node: + self.assertIn(node.op_type, expected_op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 0470953e385b6..cbb6b3ae2e776 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -555,6 +555,36 @@ def test_get_qnn_qdq_config(self): self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0)) + def test_get_qnn_qdq_config_ext_data(self): + """ + Test that get_qnn_qdq_config() returns a config that enables external data + if the input model has external data. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + large_weight = onnx.numpy_helper.from_array(np.random.random((1, 32, 32)).astype(np.float32), "weight") + graph = onnx.helper.make_graph( + [onnx.helper.make_node("Add", ["input", "weight"], ["output"])], + "add_ext_data", + [onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 32, 32))], + [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1, 32, 32))], + initializer=[large_weight], + ) + model = onnx.helper.make_model( + graph, + opset_imports=[onnx.helper.make_opsetid("", 18)], + ) + onnx.save_model( + model, + "add_ext_data.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location="add_ext_data.bin", + ) + + qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations)) + self.assertTrue(qnn_config.use_external_data_format) + if __name__ == "__main__": t = TestTensorQuantOverridesOption() From a93c31e3c9971063d8dfe45a627a80cbdcf99ed9 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 12:03:17 -0800 Subject: [PATCH 080/237] Update dml-vs-2022.yml (#19687) ### Description Fix a build error in "Zip-Nuget-Java-Nodejs Packaging Pipeline" which deletes files too early. --- .../nuget/templates/dml-vs-2022.yml | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 9393fb07d718a..d6bb415a68ee6 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -55,6 +55,9 @@ stages: - checkout: self clean: true submodules: recursive + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - powershell: | if($env:TELEMETRYGUID) @@ -231,14 +234,7 @@ stages: searchPattern: '**/*.pdb' symbolServerType: teamServices - - ${{ if eq(parameters['DoCompliance'], 'true') }}: - - template: ../../templates/compliance.yml - parameters : - msbuildPlatform: ${{ parameters.sln_platform }} - - template: ../../templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' # Node.js Publish - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: @@ -294,6 +290,12 @@ stages: targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.sln_platform }}-dml' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() + + - ${{ if eq(parameters['DoCompliance'], 'true') }}: + - template: ../../templates/compliance.yml + parameters : + msbuildPlatform: ${{ parameters.sln_platform }} + + - template: ../../templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' From e30618d05535d3fe0fdc34d350d78e8ad01b64d5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:05:08 -0800 Subject: [PATCH 081/237] [js/webgpu] use Headless for webgpu test by default (#19702) ### Description use Chromium Headless for webgpu test by default. Still use normal Chromium with window when debug=true or perfMode=true. Use the [`--headless=new`](https://developer.chrome.com/docs/chromium/new-headless) mode. ### Motivation and Context try to use a more stable way to launch npm tests to avoid a "chrome not found" issue in pipeline, which may potentially caused by windowed application. --- js/web/karma.conf.js | 4 ++-- js/web/script/test-runner-cli.ts | 29 +++++++---------------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 8fce79843f617..9e44d9c0d9652 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -86,11 +86,11 @@ module.exports = function(config) { hostname, listenAddress, customLaunchers: { - // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. + // Chromium-based browsers EdgeTest: {base: 'Edge', flags: chromiumFlags}, ChromeTest: {base: 'Chrome', flags: chromiumFlags}, - ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags}, ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags}, + // // ==== BrowserStack browsers ==== // diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 9105c02412e34..59bd0d5f6313a 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -542,14 +542,13 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; const webnn = args.backends.indexOf('webnn') > -1; - const browser = getBrowserNameFromEnv( - args.env, - args.bundleMode === 'perf' ? 'perf' : - args.debug ? 'debug' : - 'test', - webgpu); + const browser = getBrowserNameFromEnv(args.env); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags]; + if (args.bundleMode === 'dev' && !args.debug) { + // use headless for 'test' mode (when 'perf' and 'debug' are OFF) + chromiumFlags.push('--headless=new'); + } if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); chromiumFlags.push('--remote-debugging-port=9333'); @@ -662,10 +661,10 @@ async function main() { fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config); } - function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean) { + function getBrowserNameFromEnv(env: TestRunnerCliArgs['env']) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu); + return 'ChromeTest'; case 'edge': return 'EdgeTest'; case 'firefox': @@ -680,20 +679,6 @@ async function main() { throw new Error(`env "${env}" not supported.`); } } - - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean) { - if (webgpu) { - return 'ChromeTest'; - } else { - switch (mode) { - case 'debug': - case 'perf': - return 'ChromeTest'; - default: - return 'ChromeTestHeadless'; - } - } - } } void main(); From 250779474de0ce50f0ef4b39f7b050755e1019ba Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 19:36:26 -0800 Subject: [PATCH 082/237] Change "onnxruntime-Linux-CPU-For-Android-CI" machine pool to "onnxruntime-Ubuntu2204-AMD-CPU" (#19698) ### Description The original one reports "out of disk space", which needs to be investigated. --- .../android-x86_64-crosscompile-ci-pipeline.yml | 6 +++--- .../azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml | 2 +- .../github/azure-pipelines/mac-react-native-ci-pipeline.yml | 2 +- .../templates/android-binary-size-check-stage.yml | 3 ++- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 9136b21aec626..d0a22aae07741 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -53,7 +53,7 @@ stages: Codeql.Enabled: false jobs: - job: Build_CPU_EP - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU workspace: clean: all timeoutInMinutes: 30 @@ -140,7 +140,7 @@ stages: jobs: - job: Build_NNAPI_EP - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: ${{ variables.JobsTimeout }} workspace: clean: all @@ -456,7 +456,7 @@ stages: variables: - name: skipComponentGovernanceDetection value: true - pool: 'onnxruntime-Linux-CPU-For-Android-CI' + pool: 'onnxruntime-Ubuntu2204-AMD-CPU' condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) dependsOn: - NNAPI_EP_MASTER diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml index 1053a2518125f..bbea7a0d114e8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml @@ -59,7 +59,7 @@ jobs: timeoutInMinutes: 120 workspace: clean: all - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU variables: ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] diff --git a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml index e8f4931d5ad9f..886bacf5aac4d 100644 --- a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml @@ -61,4 +61,4 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Linux-CPU-For-Android-CI' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml index 733cafdeeb8c0..9822950127112 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml @@ -31,7 +31,7 @@ stages: timeoutInMinutes: 60 workspace: clean: all - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - checkout: self clean: true @@ -49,6 +49,7 @@ stages: - task: PythonScript@0 displayName: 'Set variables from config file "${{ parameters.BuildConfigFile }}"' inputs: + pythonInterpreter: /usr/bin/python3 scriptSource: inline script: | import json From 7455dd1f32af760984f42e8e6d752b675a4a0852 Mon Sep 17 00:00:00 2001 From: Sophie Schoenmeyer <107952697+sophies927@users.noreply.github.com> Date: Wed, 28 Feb 2024 21:10:25 -0800 Subject: [PATCH 083/237] Update labeler.yml to change permissions (#19709) ### Description Updated github/issue-labeler permissions to give write access for issues. Tried to submit the same PR last week, but the checks kept failing, so I couldn't merge. ### Motivation and Context Enables issue labeling again, which has been broken since GitHub Actions permissions were changed a couple weeks ago. --- .github/workflows/labeler.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 936ab0de899a2..a196226a4b836 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -3,6 +3,9 @@ on: issues: types: [opened, edited] +permissions: + issues: write + jobs: triage: runs-on: ubuntu-latest From d2e6dd25ea8bd528f614250ba0165a535734305e Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 29 Feb 2024 13:45:58 +0800 Subject: [PATCH 084/237] Merge GatherToSplitFusion and #19218 to a General Fusion (#19600) #19218 tried to fuse Gather/Slice to Split, but the logic has problem. Scalar value or 1-dim value of indices in Gather node will produce different result, scalar value will produce a result tensor by removing the axis dim, will 1-dim indices value will keep that dim, even when the dim value is 1. For example, Node |-> Gather(indices=[0], axis=axis) |-> Gather(indices=[1], axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) But Node |-> Gather(indices=0, axis=axis) |-> Gather(indices=1, axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) ||-> Squeeze(axis=axis) ||-> Squeeze(axis=axis) ||-> Previous PR doesn't take such case related to Squeeze/Unsqueeze into account. This PR merges #19218 and GatherToSplitFusion to a general fusion, which relaxes the limit the number of Gather and Slice node number, check all Gather and Slice consumers, if the indices of Gather and start/end of Slice can cover the specific dim of the input tensor, then we can fuse them to a Split, and adding Squeeze if necessary according to the dim count of the indices tensor in Gather. @rui-ren, please check if the fix can still be applied to your model. --- onnxruntime/core/optimizer/gather_fusion.cc | 318 ++++++---- onnxruntime/core/optimizer/gather_fusion.h | 16 +- .../core/optimizer/gather_slice_fusion.cc | 344 ----------- .../core/optimizer/gather_slice_fusion.h | 32 - .../core/optimizer/graph_transformer_utils.cc | 4 +- .../test/optimizer/graph_transform_test.cc | 550 +++++------------- .../core/optimizer/graph_transformer_utils.cc | 4 +- 7 files changed, 352 insertions(+), 916 deletions(-) delete mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.cc delete mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.h diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 4903bc1d6b961..90cabff88122c 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,55 +9,144 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || +namespace { +static int64_t GetGatherAxis(const Node& node, int64_t rank) { + int64_t axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) { + axis = axis_attr.i(); + if (axis < 0) axis += rank; + } + } + return axis; +} + +static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_arg, int64_t& value, int64_t& rank) { + if (!optimizer_utils::IsScalar(node_arg)) return false; + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); + if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + Initializer init_const{*tensor_proto, graph.ModelPath()}; + value = *(init_const.data()); + rank = tensor_proto->dims_size(); + return true; +} + +static bool GetSliceAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.InputDefs().size() < 4) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[3], axis, unused)) return false; + if (axis < 0) axis += rank; + return true; +} + +static bool GetAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.OpType() == "Gather") { + axis = GetGatherAxis(node, rank); + return true; + } + if (node.OpType() == "Slice") { + return GetSliceAxis(graph, node, rank, axis); + } + return false; +} + +} // namespace + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, + int64_t target_axis, int64_t dim_size, InlinedVector& consumed, + int64_t& start, bool& need_squeeze) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; } - const NodeArg& input_arg = *(node.InputDefs()[1]); - if (!optimizer_utils::IsScalar(input_arg)) return false; - const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - if (!tensor_proto) return false; - if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; - index = *(init_const.data()); - axis = 0; // Default value. - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + if (GetGatherAxis(node, rank) != target_axis) return false; + // Require the indices input to be a scalar tensor for now. Normally if not, the exporter will choose Slice. + // We can relax this later if needed. + int64_t indices_n_dims = 0; + if (!GetScalarInt64Initializer(graph, *(node.InputDefs()[1]), start, indices_n_dims)) return false; + if (start < 0) start += dim_size; + if (start < 0 || start >= dim_size || consumed[static_cast(start)]) return false; + consumed[static_cast(start)] = true; + need_squeeze = indices_n_dims == 0; + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, + int64_t dim_size, InlinedVector& consumed, int64_t& start, + int64_t& end) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + int64_t axis = 0; + if (!GetSliceAxis(graph, node, rank, axis) || axis != target_axis) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[1], start, unused) || + !GetScalarInt64Initializer(graph, *node.InputDefs()[2], end, unused)) { + return false; + } + // Handling start and end according to schema definition. + if (start < 0) start += dim_size; + if (end < 0) end += dim_size; + if (start < 0) + start = 0; + else if (start > dim_size) + start = dim_size; + if (end < 0) + end = 0; + else if (end > dim_size) + end = dim_size; + if (start >= end) return false; + if (node.InputDefs().size() >= 5) { + int64_t step = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[4], step, unused) || step != 1) return false; + } + for (int64_t i = start; i < end; ++i) { + if (consumed[static_cast(i)]) return false; + consumed[static_cast(i)] = true; } - indices_n_dims = tensor_proto->dims_size(); return true; } /* -GatherToSplitFusion is to fuse: -Node -> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Gather(index=2, axis=axis) +GatherSliceToSplitFusion is to fuse: +Node -> Gather(indices=0, axis=axis) + |-> Gather(indices=[1], axis=axis) + |-> Slice(start=2, end=3, axes=[axis]) |... To Node -> Split -> Squeeze(axis=axis) - |-> Squeeze(axis=axis) - |-> Squeeze(axis=axis) + |-> + |-> |... So that we can use one kernel to finish the job. +The fusion requires that the indices of Gather nodes and start/end of Slice nodes are not overlapping and cover +all the elements in the target axis. Step of Slice node should be 1. */ -Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + // Squeeze, Gather, Slice and Split have different schemas before and after OpSet 13. + // To make code simple, support OpSet >= 13 only. + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + if (onnx_opset_version < 13) return Status::OK(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedVector node_args; + InlinedVector candidate_args; for (auto node_arg : graph.GetInputs()) { if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } @@ -65,7 +154,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (graph.GetConsumerNodes(entry.first).size() > 1) { auto node_arg = graph.GetNodeArg(entry.first); if (node_arg) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } } @@ -90,129 +179,108 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - node_args.push_back(node.OutputDefs()[0]); + candidate_args.push_back(node.OutputDefs()[0]); } - for (const NodeArg* node_arg : node_args) { + for (const NodeArg* node_arg : candidate_args) { auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - InlinedVector gather_outputs(consumer_count, nullptr); - InlinedVector> nodes_to_fuse; + InlinedVector condidate_consumers; for (auto consumer : consumers) { - int64_t index, axis, dims; - if (!consumer || consumer->InputDefs()[0] != node_arg || - !IsSupportedGather(graph, *consumer, index, axis, dims)) { - can_fuse = false; - break; - } - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; + if (consumer && consumer->InputDefs()[0] == node_arg && + (consumer->OpType() == "Gather" || consumer->OpType() == "Slice")) { + condidate_consumers.emplace_back(consumer); } - if (axis < 0) axis += rank; - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { + } + if (condidate_consumers.size() < 2) continue; + int64_t axis = 0; + if (!GetAxis(graph, *condidate_consumers[0], rank, axis)) continue; + auto dim = shape->dim(static_cast(axis)); + if (!utils::HasDimValue(dim)) continue; + int64_t dim_size = dim.dim_value(); + InlinedVector consumed(static_cast(dim_size), false); + bool can_fuse = true; + InlinedVector> nodes_to_fuse; + InlinedVector starts; + InlinedHashMap> output_info_map; + for (auto consumer : condidate_consumers) { + if (!consumer || consumer->InputDefs()[0] != node_arg) { can_fuse = false; break; } - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { + int64_t start = 0, end = 0; + bool need_squeeze = false; + if (IsSupportedGather(graph, *consumer, rank, axis, dim_size, consumed, start, need_squeeze)) { + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(gather_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(gather_node.MutableOutputDefs()[0], 1, need_squeeze); + } else if (IsSupportedSlice(graph, *consumer, rank, axis, dim_size, consumed, start, end)) { + Node& slice_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(slice_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(slice_node.MutableOutputDefs()[0], end - start, false); + } else { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.emplace_back(gather_node); - gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; - } - - if (!can_fuse) continue; - - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = - static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - for (int64_t i = 0; i < rank; ++i) { - if (i == split_axis) { - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - } else { - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } } + if (!can_fuse || std::find(consumed.begin(), consumed.end(), false) != consumed.end()) continue; + std::sort(starts.begin(), starts.end()); InlinedVector split_outputs; - bool add_squeeze_node = indices_n_dims == 0; - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - split_outputs.emplace_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); - } - } - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); - split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version < 13) { - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); - squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + InlinedVector split_values; + for (int64_t start : starts) { + auto& output_info = output_info_map[start]; + NodeArg* original_output_arg = std::get<0>(output_info); + int64_t split_value = std::get<1>(output_info); + split_values.emplace_back(split_value); + if (std::get<2>(output_info)) { + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t i = 0; i < rank; ++i) { + if (i == axis) { + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(split_value); + } else { + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } } - } - } else { - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - if (add_squeeze_node) { + NodeArg* split_output_arg = + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split_output"), &split_output_type); ONNX_NAMESPACE::TensorProto axes_initializer_proto; - axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer")); + axes_initializer_proto.set_name(graph.GenerateNodeName("squeeze_axes")); axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector axes_value{split_axis}; - axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); + axes_initializer_proto.add_int64_data(axis); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = - graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - } + Node& squeeze_node = + graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", + {split_output_arg, axes_arg}, {original_output_arg}); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + split_outputs.emplace_back(split_output_arg); + } else { + split_outputs.emplace_back(original_output_arg); } } - for (Node& n : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, n); - graph.RemoveNode(n.Index()); + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("splits")); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_initializer_proto.add_dims(static_cast(split_values.size())); + split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); + NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); + split_node.AddAttribute("axis", axis); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + for (Node& node : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); } modified = true; diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index 44c235915b6cc..098278a77dafe 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -8,19 +8,23 @@ namespace onnxruntime { /** -@Class GatherToSplitFusion +@Class GatherSliceToSplitFusion -Fuse multiple Gather nodes that comsuming one output to one Split node. +Fuse multiple Gather/Slice nodes that comsuming one output to one Split node. */ -class GatherToSplitFusion : public GraphTransformer { +class GatherSliceToSplitFusion : public GraphTransformer { public: - GatherToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {} + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const; + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, bool& need_squeeze) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, int64_t& end) const; }; /** diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc deleted file mode 100644 index 21266d356a020..0000000000000 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/gather_slice_fusion.h" -#include "core/graph/graph_utils.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/utils.h" - -namespace onnxruntime { - -bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, - int64_t& axis, int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } - - const NodeArg& input_arg = *(node.InputDefs()[1]); - - if (!optimizer_utils::IsScalar(input_arg)) return false; - - const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - - if (!indices_init) return false; - - if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; - - // get the index value - Initializer init_const(*indices_init, graph.ModelPath()); - index = *(init_const.data()); - - // get attributes value - axis = 0; - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); - } - - indices_n_dims = indices_init->dims_size(); - return true; -} - -bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const { - // check the version of Slice ops - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } - - // get the opset version - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - // If Slice op of opset version 1 - if (onnx_opset_version == 1) { - if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || - !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || - starts.size() != ends.size()) { - return false; - } - - if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { - return false; - } - } - - // If Slice op of opset version >= 10 - if (onnx_opset_version >= 10) { - // node inputs include: starts - ends - axes - steps - - // return a pointer to the corresponding NodeArg if input of the node at the index exists - auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { - const auto& input_defs = node.InputDefs(); - const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; - return (input == nullptr || !input->Exists()) ? nullptr : input; - }; - - // return a pointer to the initializer if it is constant; otherwise, a nullptr - auto get_initializer_if_constant = - [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { - const NodeArg* input = get_input_if_exists(input_index); - return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; - }; - - // return the initialization data if it is constant - auto get_initializer_data = - [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { - Initializer init(*slice_initializer, graph.ModelPath()); - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { - int32_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { - int64_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - return {}; - }; - - // starts and ends inputs have to exist, be constants and be of the same size. - const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); - const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); - const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); - const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); - - if (!starts_init || !ends_init || !axes_init || !steps_init) { - return false; - } - - starts = get_initializer_data(starts_init); - ends = get_initializer_data(ends_init); - axes = get_initializer_data(axes_init); - steps = get_initializer_data(steps_init); - - if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { - return false; - } - - if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { - return false; - } - - // if steps exists, it should be constant and all value should be 1 - if (steps.size() != starts.size()) { - return false; - } - - for (int64_t step : steps) { - if (step != 1) { - return false; - } - } - } - - return true; -} - -/* -GatherToSplitFusion is to fuse: - Node - |-> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Slice(index=2, axis=axis) -To - Node - |-> Split(index=0) -So that we can use one kernel to finish the job. -*/ - -Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - InlinedVector output_args; - - // Iterate the topological order and get Reshape ops - for (auto node_index : node_topology_list) { - auto* p_node = graph.GetNode(node_index); - - if (p_node == nullptr) continue; - - Node& node = *p_node; - - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - - // Currently only catch after Reshape ops, optimize in the future - if (node.OpType() != "Reshape") continue; - - size_t output_count = node.GetOutputEdgesCount(); - - // We only catch 1 scenario for Multi Query Attention for now. - // |---> Gather - // Reshape |---> Gather - // |---> Slice - // |... or (other ops) - - // Get the output into node args - if (output_count < 3) continue; - - output_args.push_back(node.OutputDefs()[0]); - } - - // iterate the children of Reshape node - for (const NodeArg* node_arg : output_args) { - auto shape = node_arg->Shape(); - if (!shape) continue; - - auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - - // get the tensor rank - int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; - - // Fuse 2 Gathers and 1 slice to Split - // Get those outputs as Split outputs - InlinedVector split_outputs(3); - - InlinedVector> nodes_to_fuse; - size_t gather_node_count = 2, slice_node_count = 0; - - // find the nodes to be merged - for (auto consumer : consumers) { - int64_t index, axis, dims; - InlinedVector starts, ends, axes, steps; - - bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); - bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); - - if ((!consumer || consumer->InputDefs()[0] != node_arg) || - (!IsSupportedGatherOps && !IsSupportedSliceOps)) { - break; - } - - if (IsSupportedGatherOps) { - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; - } - - if (axis < 0) axis += rank; - - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - // dim.dim_value() = 73 - if (!utils::HasDimValue(dim)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { - can_fuse = false; - break; - } - - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count)) { - can_fuse = false; - break; - } - - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.push_back(gather_node); - NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs[gather_node_count--] = gather_output_args; - } - - // check the Slice Ops - if (IsSupportedSliceOps) { - if (axes[0] != axis && !first_edge) { - can_fuse = false; - break; - } - - Node& slice_node = *graph.GetNode(consumer->Index()); - NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; - nodes_to_fuse.push_back(slice_node); - split_outputs[slice_node_count++] = slice_output_args; - } - } - - // condition check - if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; - - // generate the split node and merge the kernel - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node_arg->TypeAsProto()->tensor_type().elem_type()); - - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - - for (int64_t i = 0; i < rank; i++) { - if (i == split_axis) - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - else - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } - - InlinedVector split_output_types; - - for (size_t i = 0; i < consumer_count; ++i) { - split_output_types.push_back( - &graph.GetOrCreateNodeArg( - graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); - } - - // Generate the Split Node - ONNX_NAMESPACE::TensorProto split_initializer_proto; - split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); - split_initializer_proto.add_dims(static_cast(3)); - split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - - auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); - // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 - int64_t slice_dim = static_cast(dim_value - 2); - InlinedVector split_value{{slice_dim, 1, 1}}; - split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); - NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); - - split_node.AddAttribute("axis", split_axis); - - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - for (Node& node_to_fuse : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); - graph.RemoveNode(node_to_fuse.Index()); - } - modified = true; - } - - return Status::OK(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h deleted file mode 100644 index 1c5c307efed7f..0000000000000 --- a/onnxruntime/core/optimizer/gather_slice_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -/** -@class GatherSliceToSplitFusion -Fuse (2 Gather nodes + 1 Slice) to 1 split node. -*/ - -class GatherSliceToSplitFusion : public GraphTransformer { - private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const; - - bool IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const; - - public: - GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4e939fe3c7b6b..8376b87aee6b2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,7 +37,6 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -307,9 +306,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e1fcf835c6043..16f38bac62713 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,7 +42,6 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7059,13 +7058,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { } } -TEST_F(GraphTransformationTests, GatherToSplitFusion) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); auto* gather_out_1 = builder.MakeIntermediate(); auto* gather_out_2 = builder.MakeIntermediate(); @@ -7082,7 +7081,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; @@ -7091,27 +7091,16 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { return Status::OK(); }; - // OpSet-12 + // OpSet-12, not support { auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); - } - } + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } @@ -7121,7 +7110,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 2); for (auto& node : graph.Nodes()) { if (node.OpType() == "Split") { auto& attrs = node.GetAttributes(); @@ -7140,249 +7129,140 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } } -TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllSlice_GraphInput) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({1}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(2)}); - auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); - auto* gather_out_3 = builder.MakeIntermediate(); + auto* data_arg = builder.MakeInput({{2, 3, 8, 3}}); + auto* starts_1 = builder.MakeInitializer({1}, {0}); + auto* ends_1 = builder.MakeInitializer({1}, {2}); + auto* axes_1 = builder.MakeInitializer({1}, {2}); + auto* steps_1 = builder.MakeInitializer({1}, {1}); + auto* starts_2 = builder.MakeInitializer({1}, {2}); + auto* ends_2 = builder.MakeInitializer({1}, {-2}); + auto* axes_2 = builder.MakeInitializer({1}, {-2}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* starts_3 = builder.MakeInitializer({1}, {-2}); + auto* ends_3 = builder.MakeInitializer({1}, {16}); + auto* axes_3 = builder.MakeInitializer({1}, {2}); + auto* slice_out_1 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); + auto* slice_out_3 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(-2)); - builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1}); + builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3}); + builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 3); return Status::OK(); }; - // OpSet-12 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-14 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); } - return Status::OK(); - }; + } + return Status::OK(); + }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* data_arg = builder.MakeInput({{144}}); + auto* shape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 8, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(5)}); + auto* starts_2 = builder.MakeInitializer({1}, {6}); + auto* ends_2 = builder.MakeInitializer({1}, {8}); + auto* axes_2 = builder.MakeInitializer({1}, {-3}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(4)}); + auto* starts_4 = builder.MakeInitializer({1}, {-16}); + auto* ends_4 = builder.MakeInitializer({1}, {4}); + auto* axes_4 = builder.MakeInitializer({1}, {1}); auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); auto* gather_out_3 = builder.MakeIntermediate(); + auto* slice_out_4 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); + auto* transpose_out_4 = builder.MakeOutput(); - builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(-2)); - builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Slice", {reshape_out, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) + .AddAttribute("axis", static_cast(-3)); + builder.AddNode("Slice", {reshape_out, starts_4, ends_4, axes_4}, {slice_out_4}); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_4}, {transpose_out_4}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 2); return Status::OK(); }; - // OpSet-12 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-14 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(1 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(1 == static_cast(*(init_const.data()))); } - return Status::OK(); - }; + } + return Status::OK(); + }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); @@ -7430,31 +7310,31 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0); return Status::OK(); }; auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); return Status::OK(); }; - // Invalid shape. + // Not cover all elements of specific dimension. { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{72}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 4, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); auto* gather_out_1 = builder.MakeIntermediate(); auto* gather_out_2 = builder.MakeIntermediate(); @@ -7467,63 +7347,65 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); + .AddAttribute("axis", static_cast(-2)); builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) .AddAttribute("perm", std::vector{0, 2, 1}); builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) .AddAttribute("perm", std::vector{0, 2, 1}); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - // Invalid Gather indices. + // Has overlap. { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); - auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); - auto* gather_out_3 = builder.MakeIntermediate(); + auto* data_arg = builder.MakeInput({{2, 3, 8, 3}}); + auto* starts_1 = builder.MakeInitializer({1}, {0}); + auto* ends_1 = builder.MakeInitializer({1}, {3}); + auto* axes_1 = builder.MakeInitializer({1}, {2}); + auto* steps_1 = builder.MakeInitializer({1}, {1}); + auto* starts_2 = builder.MakeInitializer({1}, {2}); + auto* ends_2 = builder.MakeInitializer({1}, {-2}); + auto* axes_2 = builder.MakeInitializer({1}, {-2}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* starts_3 = builder.MakeInitializer({1}, {-2}); + auto* ends_3 = builder.MakeInitializer({1}, {16}); + auto* axes_3 = builder.MakeInitializer({1}, {2}); + auto* slice_out_1 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); + auto* slice_out_3 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1}); + builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3}); + builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - // Invalid Gather axis. + // Invalid axis. { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); @@ -7550,7 +7432,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { .AddAttribute("perm", std::vector{0, 2, 1}); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } @@ -7643,143 +7525,5 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } -TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { - { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* reshape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - - // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); - auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose 1-Ops - auto* transpose_out_1 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Gather-2 Ops - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); - auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose-2 Ops - auto* transpose_out_2 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Slice Ops - auto* slice_output = builder.MakeIntermediate(); - auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); - builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); - - // Create Shape-1 Ops - auto* shape_output_1 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_1}); - - // Create Shape-2 Ops - auto* shape_output_2 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_2}); - - // Create Transpose-3 Ops - auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } -} - -TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { - { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* reshape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - - // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); - auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose 1-Ops - auto* transpose_out_1 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Slice Ops - auto* slice_output = builder.MakeIntermediate(); - auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); - builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); - - // Create Shape-1 Ops - auto* shape_output_1 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_1}); - - // Create Shape-2 Ops - auto* shape_output_2 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_2}); - - // Create Transpose-3 Ops - auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } -} - } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 0b68dc65e41cd..5d527369a1b75 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,7 +24,6 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -139,9 +138,8 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); From c1bf7fcd2fb105e067dc1f2edd408c399a61a1fe Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 29 Feb 2024 01:19:25 -0800 Subject: [PATCH 085/237] [QNN Quant] Ensure 16bit tensor quant overrides set MS domain (#19684) ### Description Ensures that DQ and Q ops use the msft domain if tensor quantization overrides specify 16-bit integer types. ### Motivation and Context ONNX does not yet support 16bit integer types for QuantizeLinear and DequantizeLinear ops (coming soon). For now, DQ/Q ops must use the MSFT domain. We have to also check if tensor quantization overrides force the use of 16-bit quantization types. If so, we must correctly set the domain for Q/DQ ops. --- .../tools/quantization/onnx_quantizer.py | 11 ++++--- .../tools/quantization/qdq_quantizer.py | 5 ++- .../test_tensor_quant_overrides_option.py | 32 ++++++++++++++++++- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 9450426f12444..19a72e38dea33 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -154,7 +154,7 @@ def __init__( if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") - self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides() + self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides() self.quantization_params = self.calculate_quantization_params() # QuantizeRange tensor name and zero tensor name for scale and zero point calculation. @@ -177,8 +177,10 @@ def __init__( def _get_and_check_tensor_quant_overrides(self): """ Get tensor quantization overrides and check correctness. + Also returns a set of quantization types (as TensorProto) specified across all overrides. """ tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {}) + tensor_quant_override_types = set() # Validate that compatible/valid overrides are provided. if tensor_quant_overrides: @@ -211,6 +213,8 @@ def _get_and_check_tensor_quant_overrides(self): # other channels. if index == 0: quant_type = quant_overrides.get("quant_type") + if quant_type is not None: + tensor_quant_override_types.add(quant_type.tensor_type) elif quant_type != quant_overrides.get("quant_type"): raise ValueError( "Channel quantization types for tensor '{tensor_name}' do not match at index {index}." @@ -231,7 +235,7 @@ def _get_and_check_tensor_quant_overrides(self): f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'" ) - return tensor_quant_overrides + return tensor_quant_overrides, tensor_quant_override_types def get_per_tensor_quant_overrides(self, tensor_name): quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}]) @@ -747,8 +751,7 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") scale_values = np.array([params["scale"]]) assert scale_values.dtype != np.float64 - # zero_point_type = params["quant_type"] - assert zero_point_type == params["quant_type"] + zero_point_type = params["quant_type"] else: zero_point_values = np.array([use_zeropoint]) scale_values = np.array([use_scale]) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 775a3e8b8b588..76cd0d21fca37 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -116,7 +116,10 @@ def __init__( # if the activation or weight types are 16-bit integers. # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. int16_types = (TensorProto.UINT16, TensorProto.INT16) - if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types): + overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types) + if not self.qdq_op_domain and ( + self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 + ): logging.warning( "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. " f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index cbb6b3ae2e776..9ea4719f3c595 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -13,7 +13,7 @@ from onnxruntime import quantization from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config -from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType +from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain class DummyDataReader(quantization.CalibrationDataReader): @@ -423,6 +423,36 @@ def test_qdq_overrides_per_channel2(self): self.assertEqual(zp, expected_zp) self.assertEqual(scale, np.float32(expected_scale)) + def test_16bit_overrides_set_ms_domain(self): + """ + Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft' + domain on DQ and Q ops. + """ + qdq_model_name = "model_quant_overrides_to_16bit.onnx" + inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization( + qdq_model_name, + activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations + extra_options={ + "TensorQuantOverrides": { + "INP": [{"quant_type": quantization.QuantType.QUInt16}], + "SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}], + } + }, + ) + + # Input and Sigmoid's output should be overridden to 16bit + self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16) + self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) + + # Output should the default uint8 type + self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8) + + # Q/DQ ops should all have the 'com.microsoft' domain + qdq_model = onnx.load_model(qdq_model_name) + for node in qdq_model.graph.node: + if node.op_type in {"QuantizeLinear", "DequantizeLinear"}: + self.assertEqual(node.domain, ms_domain) + def test_override_validation_nonexisting_tensor(self): """ Test that specifying a non-existing tensor should fail. From c311d1faf50167e38613927e44c8a430ffcc8e89 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:51:29 +0800 Subject: [PATCH 086/237] [ROCm] Update dockerfile (#19661) Update dockerfile to ROCm6.0 --- dockerfiles/Dockerfile.migraphx | 43 +++------------------------------ dockerfiles/Dockerfile.rocm | 4 +-- dockerfiles/README.md | 4 +-- 3 files changed, 8 insertions(+), 43 deletions(-) diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index bc513a8e8ba6d..c3541a8bd3425 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,57 +5,22 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM ubuntu:20.04 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ARG ROCM_VERSION=5.4 -# MIGraphX version should be the same as ROCm version -ARG MIGRAPHX_VERSION=rocm-5.4.0 -ENV DEBIAN_FRONTEND noninteractive -ENV MIGRAPHX_DISABLE_FAST_GELU=1 -RUN apt-get clean && apt-get update && apt-get install -y locales -RUN locale-gen en_US.UTF-8 -RUN update-locale LANG=en_US.UTF-8 -ENV LC_ALL C.UTF-8 -ENV LANG C.UTF-8 +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} -# Install rocm -RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ - curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION}/ ubuntu main > /etc/apt/sources.list.d/rocm.list' - -RUN apt-get update &&\ - apt-get install -y sudo git bash build-essential rocm-dev python3-dev python3-pip miopen-hip \ - rocblas half aria2 libnuma-dev pkg-config - -RUN aria2c -q -d /tmp -o cmake-3.27.3-linux-x86_64.tar.gz \ -https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.tar.gz &&\ -tar -zxf /tmp/cmake-3.27.3-linux-x86_64.tar.gz --strip=1 -C /usr - -# Install rbuild -RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz numpy yapf==0.28.0 - -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Install MIGraphX from source -RUN mkdir -p /migraphx -RUN cd /migraphx && git clone --depth=1 --branch ${MIGRAPHX_VERSION} https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src -RUN cd /migraphx && rbuild package --cxx /opt/rocm/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 -RUN dpkg -i /migraphx/build/*.deb -RUN rm -rf /migraphx - -# Install rocm ep dependencies RUN apt-get update &&\ - apt-get install -y rocrand rccl hipsparse hipfft hipcub hipblas rocthrust + apt-get install -y migraphx WORKDIR /code # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ + cd onnxruntime && pip install --upgrade pip &&\ /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \ --skip_tests --build_wheel --use_rocm --rocm_version=${ROCM_VERSION} --rocm_home /opt/rocm --use_migraphx &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index 35a676383337b..c242933f677f0 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -5,14 +5,14 @@ # Dockerfile to run ONNXRuntime with ROCm integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm5.4_ubuntu20.04_py3.7_pytorch_1.12.1 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/README.md b/dockerfiles/README.md index f226ebfe8b193..a2e99d66d4654 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -277,7 +277,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropiate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime). ## MIGraphX -**Ubuntu 20.04, ROCm5.4, AMDMIGraphX v1.2** +**Ubuntu 20.04, ROCm6.0, MIGraphX** 1. Build the docker image from the Dockerfile in this repository. ``` @@ -291,7 +291,7 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` ## ROCm -**Ubuntu 20.04, ROCm5.4** +**Ubuntu 20.04, ROCm6.0** 1. Build the docker image from the Dockerfile in this repository. ``` From 937cdd651e4f656e65053d027c71b51f1e1411ec Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 29 Feb 2024 23:03:57 +0800 Subject: [PATCH 087/237] [ORTMODULE] Support Register Custom Triton Kernel (#19690) Add support for registering custom Triton kernel function. --- .../python/training/ort_triton/__init__.py | 1 + .../python/training/ort_triton/triton_op_executor.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ort_triton/__init__.py b/orttraining/orttraining/python/training/ort_triton/__init__.py index fbb59d1354ae7..5f2d0c62ffa50 100644 --- a/orttraining/orttraining/python/training/ort_triton/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/__init__.py @@ -9,6 +9,7 @@ from onnxruntime.capi import _pybind_state as _C from .kernel import * # noqa: F403 +from .triton_op_executor import register_triton_kernel # noqa: F401 from .triton_op_executor import call_triton_by_name, call_triton_by_onnx, get_config diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index f16abc71251ed..e104ea13c59a3 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -23,6 +23,8 @@ _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 +_CUSTOM_KERNELS = dict() + @functools.lru_cache(None) def _gen_module_internal(sorted_graph: SortedGraph) -> Tuple[str, str, ModuleType]: @@ -113,7 +115,10 @@ def call_triton_by_name(func_name: str, *tensors, **kwargs): """ torch_tensors = [_from_dlpack(tensor) if tensor is not None else None for tensor in tensors] - func = getattr(sys.modules[".".join(__name__.split(".")[:-1])], func_name) + func = getattr(sys.modules[".".join(__name__.split(".")[:-1])], func_name, None) + if func is None: + func = _CUSTOM_KERNELS.get(func_name) + assert func is not None, f"Function {func_name} is not found in the registered kernels." output = func(*torch_tensors, **kwargs) if output is not None: if isinstance(output, tuple): @@ -138,3 +143,8 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): if isinstance(output, tuple): return tuple([to_dlpack(tensor) for tensor in output]) return to_dlpack(output) + + +def register_triton_kernel(fn): + _CUSTOM_KERNELS[fn.__name__] = fn + return fn From ec0e4d3b6572c18a3462eb6efb3bb007ec3a2962 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Thu, 29 Feb 2024 10:31:57 -0800 Subject: [PATCH 088/237] Parallel Transpose_BSNH_to_BNSH (#19406) Achieved a speedup of 1.098 in MultiHeadAttention and an end-to-end speedup of 1.021 in the OCR model through parallelization of the Transpose_BSNH_to_BNSH operation. --- onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index eb25d0fd7cc1e..c4e4b4ec707fb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv, // Transpose Q/K/V from BxSxNxH to BxNxSxH Status Transpose_BSNH_to_BNSH(const Tensor* qkv, - OrtValue& qkv_transposed) { + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp = nullptr) { std::vector permutations({0, 2, 1, 3}); gsl::span permutations_span{permutations}; size_t from = 2, to = 1; - SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to); + SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); return Status::OK(); } @@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); // Transpose Q from BxSxNxH to BxNxSxH - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed)); + auto tp = context->GetOperatorThreadPool(); + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); return Status::OK(); } From d5606cd7ee394ba9444ef509021720ebe63c9856 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Thu, 29 Feb 2024 13:40:56 -0800 Subject: [PATCH 089/237] Introducing customizable input names for loss in generate_artifacts. (#19705) # loss function extra inputs. Currently, the loss functions in onnxblock expect exactly two inputs in their build method. Occasionally, models may pass additional inputs, causing the build function to fail. To solve this issue, we can let users pass a list of loss input names to be used in the loss function. --- .../orttraining/python/training/artifacts.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 7a4eb251bc5bc..4e76174d8255e 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -48,6 +48,7 @@ def generate_artifacts( custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, additional_output_names: Optional[List[str]] = None, nominal_checkpoint: bool = False, + loss_input_names: Optional[List[str]] = None, ) -> None: """Generates artifacts required for training with ORT training api. @@ -77,7 +78,9 @@ def generate_artifacts( Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model parameters. It can be used on the device to reduce overhead while constructing the training model as well as to reduce the size of the checkpoint packaged with the on-device application. - + loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided, + only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to + the loss function. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` RuntimeError: If the optimizer provided is not one of the supported optimizers. @@ -111,11 +114,16 @@ def generate_artifacts( logging.info("Custom loss block provided: %s", loss.__class__.__name__) class _TrainingBlock(onnxblock.TrainingBlock): - def __init__(self, _loss): + def __init__(self, _loss, _loss_input_names=None): super().__init__() self._loss = _loss + self._loss_input_names = _loss_input_names def build(self, *inputs_to_loss): + # If loss_input_names is passed, only pass the specified input names to the loss function. + if self._loss_input_names: + inputs_to_loss = self._loss_input_names + if additional_output_names: # If additional output names is not a list, raise an error if not isinstance(additional_output_names, list): @@ -132,7 +140,7 @@ def build(self, *inputs_to_loss): return self._loss(*inputs_to_loss) - training_block = _TrainingBlock(loss_block) + training_block = _TrainingBlock(loss_block, loss_input_names) if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)): raise RuntimeError( @@ -157,9 +165,11 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library_path - ) if custom_op_library is not None else contextlib.nullcontext(): + with onnxblock.base(model), ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() From 5ee62a6bcc228e63704f64f2de46d61d2c57a281 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 29 Feb 2024 14:46:42 -0800 Subject: [PATCH 090/237] CUDA Resize-18 implementation (#19595) ### Description Implement Resize-18 on CUDA. ### Motivation and Context Performance --- docs/OperatorKernels.md | 3 +- .../providers/cpu/cpu_execution_provider.cc | 6 +- .../core/providers/cpu/cpu_provider_shared.cc | 8 + .../core/providers/cpu/cpu_provider_shared.h | 5 + .../core/providers/cpu/tensor/upsample.cc | 79 +- .../core/providers/cpu/tensor/upsample.h | 14 +- .../providers/cpu/tensor/upsample_antialias.h | 95 +- .../core/providers/cpu/tensor/upsamplebase.h | 191 ++- .../core/providers/cuda/cu_inc/common.cuh | 12 +- .../providers/cuda/cuda_execution_provider.cc | 30 +- .../core/providers/cuda/tensor/resize.cc | 14 +- .../cuda/tensor/resize_antialias_impl.cu | 1179 +++++++++++++++++ .../core/providers/cuda/tensor/resize_impl.cu | 254 ++-- .../core/providers/cuda/tensor/resize_impl.h | 111 ++ .../core/providers/cuda/tensor/upsample.cc | 254 +++- .../core/providers/cuda/tensor/upsample.h | 10 +- .../providers/rocm/rocm_execution_provider.cc | 40 +- .../provider_bridge_provider.cc | 7 +- .../core/providers/xnnpack/tensor/resize.cc | 2 +- .../providers/cpu/tensor/resize_op_test.cc | 171 ++- 20 files changed, 2090 insertions(+), 395 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b0ed68d595c42..1eaf0fb6dad76 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -734,7 +734,8 @@ Do not modify directly.* |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 48e4617b33b4d..37e7e42150413 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2008,8 +2008,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Greater)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo namespace onnxruntime { // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor." @@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } + #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index f33eec4b93e98..c0e674827e4d1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +class UpsampleBase; using PadsVector = InlinedVector; @@ -202,6 +203,10 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; + #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index fa69e144be554..babbac0b7be17 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -1,10 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cpu/tensor/upsample.h" + +#include + +#include "core/common/inlined_containers.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/tensor/upsample.h" #include "core/providers/cpu/tensor/upsample_antialias.h" + using namespace onnxruntime::common; using namespace std; using onnxruntime::narrow; @@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + // AspectRatioPolicy::STRETCH is default policy when opset < 18 + if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) { + return; + } + + InlinedHashSet axes_set(axes_.begin(), axes_.end()); + + float scale_in_policy = 0.0f; + if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { + scale_in_policy = std::numeric_limits::max(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::min(scale_in_policy, scales[i]); + } + } + } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { + scale_in_policy = std::numeric_limits::min(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::max(scale_in_policy, scales[i]); + } + } + } + + for (size_t i = 0; i < scales.size(); i++) { + // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes + if (axes_set.empty() || axes_set.count(i) > 0) { + scales[i] = scale_in_policy; + output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); + } else { + scales[i] = 1.0f; + output_dims[i] = input_dims[i]; + } + } +} + template void UpsampleNearest2x(int64_t batch_size, int64_t num_channels, @@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim, const TensorShape& input_shape, const TensorShape& output_shape, const std::vector& input_dim_factor, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const GetOriginalCoordinateFunc& get_original_coordinate, const GetNearestPixelFunc& get_nearest_pixel) { @@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const T extrapolation_value, const GetOriginalCoordinateFunc& get_original_coordinate, @@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool is_resize, bool extrapolation_enabled, T extrapolation_value, @@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate) { TrilinearParams p; @@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, const T* XdataBase, @@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const T* Xdata, T* Ydata, const GetOriginalCoordinateFunc& get_original_coordinate) { @@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size, template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const auto* X = context->Input(0); auto dims = X->Shape().GetDims(); ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); @@ -1327,7 +1372,7 @@ Status Upsample::Compute(OpKernelContext* context) const { // Initialize the roi array to all zeros as this will be the most common case // Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize // for all other cases we need a 0 initialized roi array - std::vector roi_array(roi_); + InlinedVector roi_array(roi_); if (!roi_cached_) { bool use_default_roi = true; @@ -1353,7 +1398,7 @@ Status Upsample::Compute(OpKernelContext* context) const { ComputeROIWithAxes(roi_array, input_dims.size()); // Get scales data - std::vector scales_array(input_dims.size()); + InlinedVector scales_array(input_dims.size()); if (OpKernel::Node().InputDefs().size() == 1) { // Compute output shape from scales and input dims diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h index 3046ee4b8260d..8ff04781f6ad0 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample.h @@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel { Status Compute(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; }; BilinearParams SetupUpsampleBilinear(const int32_t input_height, @@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, const T* const XdataBase, @@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, @@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index e1dcaf500a325..1e32b7e874b1a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -21,32 +21,6 @@ namespace onnxruntime { -namespace ConstValue { -constexpr int32_t mag_factor = 1 << (22 - 1); -} - -namespace { -const uint8_t* GetLookupTableShared() { - // initialized once - static const auto* lookup_table = []() { - // if we have already initialized the lookup table, just return - // ideally we could have a global lookup table, but that account for too much space. - /* Handles values form -640 to 639. */ - static uint8_t table[1280] = {0}; - - // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 - // we need to handle negative values - // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] - // we will accept a negative x for (&table[640])[x] means table +640 -x - for (int i = 0; i < 1280; ++i) { - table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); - } - return table; - }(); - return lookup_table; -} -} // namespace - template struct FilterParamsBaseAntiAlias { std::vector bound; @@ -57,15 +31,15 @@ struct FilterParamsBaseAntiAlias { template struct FilterParamsAntiAlias { - float support_size = 2.0f; - float cubic_coeff_a = -0.75f; + float support_size = antialias_constants::kSupportSize; + float cubic_coeff_a = antialias_constants::kCubicCoeffA; FilterParamsBaseAntiAlias dim_x; FilterParamsBaseAntiAlias dim_y; FilterParamsBaseAntiAlias dim_z; const uint8_t* GetClip8LookupTable() const { - return GetLookupTableShared(); + return UpsampleBase::GetLookupTableShared(); } virtual ~FilterParamsAntiAlias() = default; virtual float Filter(float x) const = 0; @@ -89,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias { template struct BiCubicParamsAntiAlias : FilterParamsAntiAlias { BiCubicParamsAntiAlias() { - this->support_size = 4.0f; + this->support_size = antialias_constants::kBiCubicSupportSize; } // taken from @@ -124,27 +98,6 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias { } }; -template -struct AccumulateType { - using type = int32_t; - using Dtype = T; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = double; -}; - // The following method supports a 3/4/5-D input in 'Linear mode, cubic mode' // that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes // A N-D tensor has @@ -156,19 +109,20 @@ struct AccumulateType { // - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0] template void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, - const gsl::span input_h_w_c, - const gsl::span output_h_w_c, - const gsl::span scale_h_w_c, - const std::vector& roi, + gsl::span input_h_w_c, + gsl::span output_h_w_c, + gsl::span scale_h_w_c, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, bool exclude_outside, const bool is_nchw) { - auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias& p, - const int64_t input_size, - const int64_t output_size, - size_t rindex, - FilterParamsBaseAntiAlias& param_base, - const float rscale) -> int64_t { + auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside]( + const FilterParamsAntiAlias& p, + const int64_t input_size, + const int64_t output_size, + size_t rindex, + FilterParamsBaseAntiAlias& param_base, + const float rscale) -> int64_t { param_base.bound.reserve(static_cast(output_size) * 2); param_base.out_of_bound_idx.reserve(static_cast(output_size)); @@ -245,13 +199,14 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, // normalize the scale to 1 << 22 for int8/uint8 if constexpr (std::is_same::value) { - scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f)); + scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2)); } } /*for (; x < window_size; x++) { scale_buffer[x] = 0; }*/ } + return window_size; }; @@ -269,9 +224,6 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, } } -template -inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; - /** * @brief To compute interpolation along with the last axis. * For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim. @@ -398,6 +350,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -444,6 +397,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -515,6 +469,7 @@ void UpsampleBaseAntiAlias(FilterParamsAntiAlias& p, narrow(input_height * num_channels * input_width)); auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width)); + // This computes only the width direction.Thus height keeps unchanged. ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width, xdata_span, ydata_span, p, p.dim_x, tp); } @@ -546,7 +501,7 @@ void UpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -575,7 +530,7 @@ void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -608,7 +563,7 @@ void NhwcResizeBiCubicAntiAlias(const int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -688,7 +643,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -719,7 +674,7 @@ void UpsampleTrilinearAntiAlias(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, bool exclude_outside, diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index a0e7ca1084fef..b768fedd8513a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -3,11 +3,13 @@ #pragma once +#include #include #include #include #include -#include + +#include #include "core/common/status.h" #include #include @@ -58,7 +60,73 @@ enum class AspectRatioPolicy { NOT_SMALLER, }; +// Antialias types +template +struct AccumulateType { + using type = int32_t; + using Dtype = T; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = double; +}; + +namespace antialias_constants { +constexpr float kCubicCoeffA = -0.75f; +constexpr float kSupportSize = 2.0f; +constexpr float kBiCubicSupportSize = 4.0f; +} // namespace antialias_constants + +namespace ConstValue { +constexpr int32_t mag_factor = 1 << (22 - 1); +// We use to multiply by 2, let's make a constant which is twice as big +constexpr int32_t mag_factor_x_2 = 1 << 22; +} // namespace ConstValue + +template +inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; + +template +void PrintAntiAliasBuffers(std::ostream& os, gsl::span bounds, gsl::span out_of_bounds, + gsl::span weight_coefficients) { + os << "#### Bounds: "; + std::copy(bounds.begin(), bounds.end(), std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Out of Bounds: "; + std::copy(out_of_bounds.begin(), out_of_bounds.end(), + std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Scale Buffer: "; + std::copy(weight_coefficients.begin(), weight_coefficients.end(), + std::ostream_iterator(os, " ")); + os << std::endl; +} + class UpsampleBase { + public: + // Make this available in other EP via provider bridge + // it works iff output_shape is specified + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const; + protected: explicit UpsampleBase(const OpKernelInfo& info) : scales_cached_(false), roi_cached_(false), use_extrapolation_(false) { @@ -69,23 +137,32 @@ class UpsampleBase { std::string mode; ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); mode_ = StringToUpsampleMode(mode); - antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; - if (antialias_) { - ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), - "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); - } auto input_count = info.GetInputCount(); if (input_count == 1) { // opset < 10 - ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales_)); - ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_)); + std::vector scales; + ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales)); + ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_)); + scales_.assign(scales.cbegin(), scales.cend()); scales_cached_ = true; } - std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); - keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + if (opset >= 18) { + antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; + + if (antialias_) { + ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), + "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); + } - axes_ = info.GetAttrsOrDefault("axes"); + // The attribute is absent in opset < 18, but the default value as if stretch. + std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); + keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + + // guard against unit tests that can add an attribute + auto axes = info.GetAttrsOrDefault("axes"); + axes_.assign(axes.cbegin(), axes.cend()); + } extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f); @@ -112,7 +189,7 @@ class UpsampleBase { nearest_mode_ = StringToNearestMode(nearest_mode_name); get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_); - cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", -0.75f); + cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA); exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) { @@ -166,7 +243,7 @@ class UpsampleBase { ResizeCoordinateTransformationMode coordinate_transform_mode_; GetOriginalCoordinateFunc get_original_coordinate_; ResizeNearestMode nearest_mode_; - AspectRatioPolicy keep_aspect_ratio_policy_; + AspectRatioPolicy keep_aspect_ratio_policy_{AspectRatioPolicy::STRETCH}; GetNearestPixelFunc get_nearest_pixel_; float cubic_coeff_a_; bool exclude_outside_; @@ -174,9 +251,9 @@ class UpsampleBase { float extrapolation_value_; bool use_nearest2x_optimization_ = false; - std::vector scales_; - std::vector roi_; - std::vector axes_; + InlinedVector scales_; + InlinedVector roi_; + TensorShapeVector axes_; bool scales_cached_; bool roi_cached_; @@ -335,7 +412,7 @@ class UpsampleBase { } } - [[nodiscard]] Status ScalesValidation(const std::vector& scales, const UpsampleMode mode) const { + [[nodiscard]] Status ScalesValidation(gsl::span scales, const UpsampleMode mode) const { if (!is_resize_) { for (auto& scale : scales) { ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1."); @@ -372,7 +449,7 @@ class UpsampleBase { } [[nodiscard]] Status - ParseScalesData(const Tensor* scale, std::vector& scales, int64_t rank) const { + ParseScalesData(const Tensor* scale, InlinedVector& scales, int64_t rank) const { const auto* scale_data = scale->Data(); int64_t scales_size = scale->Shape().Size(); ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0."); @@ -387,19 +464,19 @@ class UpsampleBase { // in which case the other axes is ignored and use default scale of 1 // scales_size == axes_.size() should be guaranteed if axes is not empty if (rank > 0 && (scales_size != rank || axes_.size())) { - std::vector new_scales(size_t(rank), 1.0f); + InlinedVector new_scales(size_t(rank), 1.0f); ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size), "all values in axes should be less than rank of the data"); for (size_t i = 0; i < axes_.size(); i++) { new_scales[static_cast(axes_[i])] = scales[i]; } - scales = new_scales; + scales.swap(new_scales); } return ScalesValidation(scales, mode_); } - void ParseRoiData(const Tensor* roi, std::vector& roi_array) const { + void ParseRoiData(const Tensor* roi, InlinedVector& roi_array) const { int64_t roi_size = roi->Shape().Size(); if (roi_size > 0) { roi_array.resize(onnxruntime::narrow(roi_size)); @@ -429,52 +506,11 @@ class UpsampleBase { return Status::OK(); } - // it works iff output_shape is specified - void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { - std::unordered_set axes_set(axes_.begin(), axes_.end()); - - // AspectRatioPolicy::STRETCH is default policy when opset < 18 - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) { - return; - } - - float scale_in_policy = 0.0f; - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { - scale_in_policy = std::numeric_limits::max(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::min(scale_in_policy, scales[i]); - } - } - } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { - scale_in_policy = std::numeric_limits::min(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::max(scale_in_policy, scales[i]); - } - } - } - - for (size_t i = 0; i < scales.size(); i++) { - // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes - if (axes_set.empty() || axes_set.count(i) > 0) { - scales[i] = scale_in_policy; - output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); - } else { - scales[i] = 1.0f; - output_dims[i] = input_dims[i]; - } - } - } - // It's different in Opset 18 and before. // we will modify output_shape by sorts of policy even if it's specified [[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { + InlinedVector& scales) const { for (size_t i = 0, end = input_dims.size(); i < end; ++i) { // Handle corner case to avoid dividing by zero in the next step if (input_dims[i] == 0) { @@ -507,9 +543,9 @@ class UpsampleBase { // Roi is redefined in Opset-18, we have a concept of axes. // So we need to update it accordingly. - void ComputeROIWithAxes(std::vector& roi_array, size_t rank) const { + void ComputeROIWithAxes(InlinedVector& roi_array, size_t rank) const { if (axes_.size()) { - std::vector roi_tmp(rank * 2, 0); + InlinedVector roi_tmp(rank * 2, 0); for (size_t i = rank; i < rank * 2; ++i) { roi_tmp[i] = 1; } @@ -518,9 +554,32 @@ class UpsampleBase { roi_tmp[v_in_axes] = (roi_array[i]); roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]); } - roi_array = roi_tmp; + roi_array.swap(roi_tmp); } } + + public: + static constexpr size_t kLookupTableSize = 1280; + + static const uint8_t* GetLookupTableShared() { + // initialized once + static const auto* lookup_table = []() { + // if we have already initialized the lookup table, just return + // ideally we could have a global lookup table, but that account for too much space. + /* Handles values form -640 to 639. */ + static uint8_t table[kLookupTableSize] = {0}; + + // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 + // we need to handle negative values + // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] + // we will accept a negative x for (&table[640])[x] means table +640 -x + for (int i = 0; i < static_cast(kLookupTableSize); ++i) { + table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); + } + return table; + }(); + return lookup_table; + } }; // UpsampleBase } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 0d9928baa86e0..66794f88d8670 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -194,13 +194,13 @@ template <> __device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); } template -__device__ __inline__ T _Floor(T a); +__device__ __host__ __inline__ T _Floor(T a); template <> -__device__ __inline__ float _Floor(float a) { return floorf(a); } +__device__ __host__ __inline__ float _Floor(float a) { return floorf(a); } template <> -__device__ __inline__ double _Floor(double a) { return floor(a); } +__device__ __host__ __inline__ double _Floor(double a) { return floor(a); } template <> __device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); } @@ -230,13 +230,13 @@ template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } template -__device__ __inline__ T _Round(T a); +__device__ __host__ __inline__ T _Round(T a); template <> -__device__ __inline__ float _Round(float a) { return rintf(a); } +__device__ __host__ __inline__ float _Round(float a) { return rintf(a); } template <> -__device__ __inline__ double _Round(double a) { return rint(a); } +__device__ __host__ __inline__ double _Round(double a) { return rint(a); } template <> __device__ __inline__ half _Round(half a) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 00783bcbc2665..1ce089fd93044 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1109,11 +1109,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1277,6 +1277,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2009,11 +2014,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2176,6 +2181,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc index 764172a8d1fac..97d4eb71e970a 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize.cc +++ b/onnxruntime/core/providers/cuda/tensor/resize.cc @@ -28,10 +28,22 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ Resize); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Resize, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + Resize); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Resize, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu new file mode 100644 index 0000000000000..56b7c3f499303 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -0,0 +1,1179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/tensor/resize_impl.h" + +#define FUNC_DEF __device__ + +namespace onnxruntime { +namespace cuda { + +using onnxruntime::ResizeCoordinateTransformationMode; +using onnxruntime::UpsampleMode; + +/// +/// Compute a buffer for bilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeBilinearScaleBufferSize( + int64_t output_height, int64_t output_width, + float height_rscale, float width_rscale, + float support_value, + float& scaled_support_height, float& scaled_support_width, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale); + scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale); + window_size_height = ComputeWindowSize(scaled_support_height); + window_size_width = ComputeWindowSize(scaled_support_width); + + auto height_buffer_size = ComputeWeightedCoeffBufferSize(output_height, window_size_height); + auto width_buffer_size = ComputeWeightedCoeffBufferSize(output_width, window_size_width); + + return std::make_tuple(height_buffer_size, width_buffer_size); +} + +/// +/// Compute a buffer for btrilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeTrilinearScaleBufferSize( + int64_t output_depth, int64_t output_height, int64_t output_width, + float depth_rscale, float height_rscale, float width_rscale, + float support_value, + float& scaled_support_depth, float& scaled_support_height, + float& scaled_support_width, int32_t& window_size_depth, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_depth = ComputeScaledSupportValue(support_value, depth_rscale); + window_size_depth = ComputeWindowSize(scaled_support_depth); + auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth); + + const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height, + output_width, height_rscale, + width_rscale, support_value, + scaled_support_height, + scaled_support_width, + window_size_height, window_size_width); + return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size); +} + +// Antialiasing filters +struct BilinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +struct BiCubicFilter { + __device__ __host__ float operator()(float x, float cubic_coeff_a) const { + /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + */ + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return ((cubic_coeff_a + 2.0f) * x - (cubic_coeff_a + 3.0f)) * x * x + 1; + } + if (x < 2.0f) { + return (((x - 5.0f) * x + 8.f) * x - 4.f) * cubic_coeff_a; + } + return 0.0f; + } +}; + +struct TriLinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +template +struct AccumTypeCaster { + static __device__ __host__ AccumType* cast(AccumType* p) { + return p; + } +}; + +template <> +struct AccumTypeCaster { + static __device__ __host__ float* cast(int32_t* p) { + return reinterpret_cast(p); + } +}; + +template +__global__ void _ComputeInterpolationAtLevel1( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, + const int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_width == input_width) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_y, output_x; + div_output_width.divmod(output_image_index, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_x; + int64_t xmin = bound[static_cast(output_x) * 2]; + int64_t xmax = bound[static_cast(output_x) * 2 + 1]; + + // Input window + const auto* Xdata_offset = Xdata + input_index + input_width * output_y + xmin; + + for (; xmin < xmax; ++xmin) { + if constexpr (std::is_same::value) { + // This cast is needed when we deal with half + output += static_cast((*Xdata_offset++)) * (*weight_coeff++); + } else { + output += (*Xdata_offset++) * (*weight_coeff++); + } + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = static_cast(output); + } +} + +template +__global__ void _ComputeInterpolationAtLevel2( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_height == input_height) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width + + output_z * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width + + output_z * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_y; + int64_t ymin = bound[static_cast(output_y) * 2]; + int64_t ymax = bound[static_cast(output_y) * 2 + 1]; + + const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x; + + for (; ymin < ymax; ++ymin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += input_width; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +template +__global__ void _ComputeInterpolationAtLevel3( + int64_t input_depth, + int64_t input_height, int64_t input_width, + int64_t output_depth, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (input_depth == output_depth) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * input_depth * input_height * input_width); + + auto* Ydata_offset = Ydata + id; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<2>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<1>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the z dimension + const int64_t* z_outof_bounds = std::get<0>(outof_bounds_buffers); + if (z_outof_bounds != nullptr && z_outof_bounds[static_cast(output_z)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_z; + int64_t zmin = bound[static_cast(output_z) * 2]; + int64_t zmax = bound[static_cast(output_z) * 2 + 1]; + + const auto z_step = input_height * input_width; + const auto* Xdata_offset = Xdata + input_index + zmin * z_step + output_y * output_width + output_x; + + for (; zmin < zmax; ++zmin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += z_step; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +/// +/// This function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] +/// 2. out_of_bounds: int64_t[output_size] +/// 3. scale_data: T[output_size * window_size] +/// +/// Template parameter AccumType +/// +template +FUNC_DEF void SetupUpsampleFilterAnitAliasImpl( + int64_t i, + int64_t input_size, int64_t output_size, + float rscale, + float roi_start, float roi_end, + float scaled_support, int32_t window_size, bool exclude_outside, + float cubic_coeff_a, + int64_t* bounds, + int64_t* out_of_bounds, + AccumType* scale_data) { + Filter filter{}; + CudaFunctionOriginalCoordinate get_original_coordinate{}; + + const auto scale = 1.f / rscale; + const float inv_scale = (scale >= 1.0f) ? 1.0f / scale : 1.0f; + + const float id = static_cast(i); + float center = 0.5f; + if (scale == 1.0f) { + center += id; + } else { + center += get_original_coordinate(id, rscale, + static_cast(output_size), + static_cast(input_size), + roi_start, roi_end); + } + + if (center - 0.5f < 0 || center - 0.5f > static_cast(input_size - 1)) { + out_of_bounds[i] = i; + } else { + out_of_bounds[i] = -1; + } + + float total_weight{0}; + + auto fmin = _Floor(center - scaled_support + 0.5f); + auto fmax = _Floor(center + scaled_support + 0.5f); + + int64_t min_real = static_cast(fmin); + int64_t max_real = static_cast(fmax); + int64_t min_cut = std::max(min_real, 0); + int64_t max_cut = std::min(max_real, input_size); + + int64_t min_val = exclude_outside ? min_cut : min_real; + int64_t max_val = exclude_outside ? max_cut : max_real; + bounds[i * 2] = min_cut; + bounds[i * 2 + 1] = max_cut; + + // This is done for int32_t case, when the final result is in int32_t, but + // we perform calculations in float. All other types as is. + auto* scale_buffer = AccumTypeCaster::cast(&scale_data[i * window_size]); + + max_val -= min_val; + for (int64_t x = 0; x < max_val; x++) { + const float arg = (x + min_val - center + 0.5f) * inv_scale; + const auto w = filter(arg, cubic_coeff_a); + scale_buffer[x] = w; + total_weight += w; + } + + if (!exclude_outside) { + int64_t neg_xsize = min_val < 0 ? -min_val : 0; + for (int64_t x = 0; x < neg_xsize; x++) { + scale_buffer[neg_xsize] += scale_buffer[x]; + } + + int64_t bound_size = + max_val + min_val > input_size ? max_val + min_val - input_size : 0; + for (int64_t x = max_val - bound_size; x < max_val; x++) { + scale_buffer[max_val - bound_size - 1] += + scale_buffer[x]; + } + + for (int64_t x = 0; (neg_xsize | bound_size) > 0 && x < max_cut - min_cut; x++) { + scale_buffer[x] = scale_buffer[x + neg_xsize]; + } + } + + const float total_weight_inv = (total_weight == 0) ? 1.f : (1.f / total_weight); + if constexpr (std::is_same::value) { + auto* scale_buffer_int = reinterpret_cast(scale_buffer); + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + // normalize the scale to 1 << 22 for int8/uint8 + scale_buffer_int[x] = static_cast(_Round(scale_buffer[x] * ConstValue::mag_factor_x_2)); + } + } else { + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + } + } +} + +/// This kernel computes antialias filter for bilinear or bicubic upsampling. +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the two dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the two dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the two dimensions +/// Buffers layout [h_data, w_data] +template +__global__ void _SetupBilinearUpsampleFilterAntiAlias( + std::tuple input_dims, // h, w + std::tuple output_dims, // h, w + std::tuple inv_scale_vals, // h, w + std::tuple roi_start_vals, // h, w + std::tuple roi_end_vals, // h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values h, w + std::tuple dim_window_size, // Pre-computed windows sizes h, w + float cubic_coeff_a, + bool exclude_outside, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients // y, h buffers +) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for y + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else { + // Setup for w + // w = id - output_height + + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto y_output_size = std::get<0>(output_dims); + + auto i = id - y_output_size; + bounds += (y_output_size * 2); + out_of_bounds += y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } +} + +/// +/// Compute AntiAlias filter for trilinear upsampling, all in one go +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the three dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the three dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the three dimensions +/// Each kind of buffer contains data for all 3 dims. +/// Buffers layout [d_data, h_data, w_data] +/// +template +__global__ void _SetupTrilinerarUpsampleFilterAntiAlias( + std::tuple input_dims, // d, h, w + std::tuple output_dims, // d, h, w + std::tuple inv_scale_vals, // d, h, w + std::tuple roi_start_vals, // d, h, w + std::tuple roi_end_vals, // d, h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values d, h, w + std::tuple dim_window_size, // Pre-computed windows sizes d, h, w + bool exclude_outisde, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims) + std::get<2>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for d by default (id < output_depth) + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else if (id >= std::get<0>(output_dims) && id < (std::get<0>(output_dims) + std::get<1>(output_dims))) { + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto d_output_size = std::get<0>(output_dims); + + auto i = id - d_output_size; + bounds += d_output_size * 2; + out_of_bounds += d_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } else { + int64_t input_size = std::get<2>(input_dims); + int64_t output_size = std::get<2>(output_dims); + float inv_scale = std::get<2>(inv_scale_vals); + float roi_start = std::get<2>(roi_start_vals); + float roi_end = std::get<2>(roi_end_vals); + float scaled_support = std::get<2>(dim_scaled_support); + int32_t window_size = std::get<2>(dim_window_size); + + // Adjust buffer positions + const auto d_y_output_size = std::get<0>(output_dims) + std::get<1>(output_dims); + + auto i = id - d_y_output_size; + bounds += (d_y_output_size * 2); + out_of_bounds += d_y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<2>(weighted_coefficients)); + } +} + +#define CASEA_COORD_ANTIALIAS(coordinate_mode, TransformCoordType, ...) \ + case coordinate_mode: { \ + using coord_t = TransformCoordType; \ + return __VA_ARGS__(); \ + break; \ + } + +#define DISPATCH_ANTIALIAS_FILTER_SETUP(coord_enum, ...) \ + [&] { \ + const auto the_type = coord_enum; \ + switch (the_type) { \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL, \ + TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ASYMMETRIC, \ + TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ + }() + +namespace { +template +IAllocatorUniquePtr AllocateTyped( + const TempSpaceAllocateFunc& alloc, + size_t elements) { + return alloc(elements * sizeof(T)); +} + +template +T* GetTyped(IAllocatorUniquePtr& bytes) { + return reinterpret_cast(bytes.get()); +} +} // namespace + +template +void ResizeTrilinearUpsample( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + static_cast(ceil((output_depth + output_height + output_width) / 32.0)); + + int blocksPerGrid = static_cast(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + + constexpr float support_value = antialias_constants::kSupportSize; + float z_scale, h_scale, w_scale; + std::tie(z_scale, h_scale, w_scale) = inferred_dim_rscales; + + const auto& div_output_width = output_div_pitches[rank - 2]; + + SafeInt bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* z_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* y_bounds_buffer = z_bounds_buffer + output_depth * 2; + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* z_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* y_outof_bounds_buffer = z_outof_bounds_buffer + output_depth; + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + float z_scaled_support, h_scaled_support, w_scaled_support; + int32_t z_window_size, h_window_size, w_window_size; + const auto [z_buffer_size, y_buffer_size, w_buffer_size] = ComputeTrilinearScaleBufferSize( + output_depth, output_height, output_width, + z_scale, h_scale, w_scale, support_value, + z_scaled_support, h_scaled_support, w_scaled_support, + z_window_size, h_window_size, w_window_size); + + const int64_t weighted_buffer_size = SafeInt(z_buffer_size) + y_buffer_size + w_buffer_size; + + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + AccumType* z_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* y_weighted_buffer = z_weighted_buffer + z_buffer_size; + AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size; + + const auto h_w_interpolate_temp_buf_size = SafeInt(batch_size) * num_channels * + input_depth * input_height * output_width; + auto h_w_interpolate_temp_buffer_ptr = AllocateTyped(allocate_temp_space, + narrow(h_w_interpolate_temp_buf_size)); + + const auto h_w_interpolate_result_buffer_size = SafeInt(batch_size) * num_channels * + input_depth * output_height * output_width; + auto h_w_interpolate_result_buffer_ptr = AllocateTyped(allocate_temp_space, h_w_interpolate_result_buffer_size); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupTrilinerarUpsampleFilterAntiAlias<<>>( + inferred_input_dims, + inferred_output_dims, + inferred_dim_rscales, + std::make_tuple(roi_vals[rank - 3], roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts d, h, w + std::make_tuple(roi_vals[rank - 3 + rank], roi_vals[rank - 2 + rank], // roi ends d, h, w + roi_vals[rank - 1 + rank]), + std::make_tuple(z_scaled_support, h_scaled_support, w_scaled_support), + std::make_tuple(z_window_size, h_window_size, w_window_size), + exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(z_weighted_buffer, y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_w_image(narrow(num_channels * input_depth * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels * input_depth, input_height, input_width, input_height, output_width, + div_output_width, + div_w_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, + GetTyped(h_w_interpolate_temp_buffer_ptr), + narrow(h_w_interpolate_temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + const fast_divmod div_h_w_image(narrow(num_channels * input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels * input_depth, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_h_w_image, + h_window_size, + false, 0.f, // No extrapolation + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(h_w_interpolate_temp_buffer_ptr), + GetTyped(h_w_interpolate_result_buffer_ptr), + narrow(h_w_interpolate_result_buffer_size)); + + // clang-format on + const fast_divmod div_z_h_w_image(narrow(input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel3<<>>( + input_depth, output_height, output_width, + output_depth, output_height, output_width, + div_output_height, + div_output_width, + div_z_h_w_image, + z_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + z_bounds_buffer, + std::make_tuple(z_outof_bounds_buffer, y_outof_bounds_buffer, w_outof_bounds_buffer), + z_weighted_buffer, GetTyped(h_w_interpolate_result_buffer_ptr), + output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeBiLinearUpsample(cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + // rank 2 or 4 + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kSupportSize; + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, narrow(weighted_buffer_size)); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + // Data is d, h, w in tuples + + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_step_image{narrow(num_channels * input_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + + // clang-format on +} + +template +void ResizeBicubicUpsample(cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + // const TArray& input_strides, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kBiCubicSupportSize; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = SafeInt(batch_size) * num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + // clang-format on + const fast_divmod div_step_image(narrow(num_channels * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + // clang-format on + + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + // We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_2D = (rank == 2 || rank == 4); + + // We support a special case of trilinear or tricubic if the input data is 5D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_3D = (rank == 3 || rank == 5); + + // Should not hit this as we have already validated input rank/scales and we provide verbose error messages + // to the user. + ORT_ENFORCE(is_2D || is_3D, "Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + + switch (upsample_mode) { + case UpsampleMode::LINEAR: { + if (is_2D) { + ResizeBiLinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else if (is_3D) { + ResizeTrilinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D or 3-D in LINEAR mode."); + } + } break; + case CUBIC: { + if (is_2D) { + ResizeBicubicUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D in CUBIC mode."); + } + } break; + default: + ORT_NOT_IMPLEMENTED("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + break; + } +} + +#define SPECIALIZED_ANTIALIAS_IMPL(T) \ + template void ResizeAntiAliasImpl( \ + cudaStream_t stream, \ + int rank, \ + const UpsampleMode upsample_mode, \ + ResizeCoordinateTransformationMode coordinate_transform_mode, \ + gsl::span input_shape, \ + gsl::span output_shape, \ + int64_t batch_size, int64_t num_channels, \ + std::tuple inferred_input_dims, \ + std::tuple inferred_output_dims, \ + std::tuple inferred_dim_rscales, \ + const TArray& output_div_pitches, \ + gsl::span roi_vals, \ + const std::optional& extrapolation_value, \ + bool exclude_outside, \ + TempSpaceAllocateFunc allocate_temp_space, \ + const uint8_t* clip8_lookups, \ + const T* input_data, \ + T* output_data, \ + const size_t N); + +SPECIALIZED_ANTIALIAS_IMPL(float) +SPECIALIZED_ANTIALIAS_IMPL(double) +SPECIALIZED_ANTIALIAS_IMPL(half) +SPECIALIZED_ANTIALIAS_IMPL(int32_t) +SPECIALIZED_ANTIALIAS_IMPL(uint8_t) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 1a94c7705e913..0cde0ed8e8681 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -12,7 +12,7 @@ using onnxruntime::ResizeNearestMode; using onnxruntime::UpsampleMode; struct NearestPixel_SIMPLE { - __device__ __forceinline__ int operator() (float x_original, bool is_down_sampling) const { + __device__ __forceinline__ int operator()(float x_original, bool is_down_sampling) const { if (is_down_sampling) { return static_cast(_Ceil(x_original)); } @@ -21,7 +21,7 @@ struct NearestPixel_SIMPLE { }; struct NearestPixel_ROUND_PREFER_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { if (x_original == static_cast(x_original) + 0.5f) { return static_cast(_Floor(x_original)); } @@ -30,62 +30,23 @@ struct NearestPixel_ROUND_PREFER_FLOOR { }; struct NearestPixel_ROUND_PREFER_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(roundf(x_original)); } }; struct NearestPixel_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Floor(x_original)); } }; struct NearestPixel_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Ceil(x_original)); } }; -struct TransformCoordinate_ASYMMETRIC { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return x_resized / x_scale; - } -}; - -struct TransformCoordinate_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return ((x_resized + 0.5f) / x_scale) - 0.5f; - } -}; - -struct TransformCoordinate_PYTORCH_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float length_resized, float, float, float) const { - return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; - } -}; - -struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return (x_resized + 0.5f) / x_scale; - } -}; - -struct TransformCoordinate_ALIGN_CORNERS { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float, float) const { - return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); - } -}; - -struct TransformCoordinate_TF_CROP_AND_RESIZE { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float roi_start, float roi_end) const { - auto orig = length_resized > 1 - ? roi_start * (length_original - 1) + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) - : 0.5 * (roi_start + roi_end) * (length_original - 1); - return static_cast(orig); - } -}; - #define CASE_TYPE_USING_HINT(enum_type, type, HINT, ...) \ case enum_type: { \ using HINT = type; \ @@ -95,20 +56,24 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { #define CASE_TYPE_COORD(enum_type, type, ...) \ CASE_TYPE_USING_HINT(enum_type, type, coord_t, __VA_ARGS__) -#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - switch (the_type) { \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ - default: \ - ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ - } \ +#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + switch (the_type) { \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ }() #define CASE_TYPE_NEAREST(enum_type, type, ...) \ @@ -119,11 +84,11 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ switch (the_type) { \ - CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_FLOOR, NearestPixel_ROUND_PREFER_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ default: \ ORT_THROW("unknown ResizeNearestMode"); \ } \ @@ -151,10 +116,12 @@ __global__ void _ResizeNearestMappingKernel2D( // only apply co-ordinate transformation if scale != 1.0 if (scales_height == 1.0f) { - dims_mapping[id].extrapolate_ = 0; + dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales_height, static_cast(output_height), - static_cast(input_height), roi_start_height, roi_end_height); + float orig_coord = transform_coordinate(static_cast(dim), scales_height, + static_cast(output_height), + static_cast(input_height), + roi_start_height, roi_end_height); dims_mapping[id].extrapolate_ = static_cast( extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_height - 1))); dim = calc_nearest_pixel(orig_coord, scales_height < 1); @@ -210,9 +177,12 @@ __global__ void _ResizeNearestMappingKernel( if (scales[axis] == 1.0f) { dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales[axis], static_cast(output_shape[axis]), + float orig_coord = transform_coordinate(static_cast(dim), scales[axis], + static_cast(output_shape[axis]), static_cast(input_shape[axis]), roi[axis], roi[axis + rank]); - dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_shape[axis] - 1))); + dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && + (orig_coord < 0.f || + orig_coord > static_cast(input_shape[axis] - 1))); dim = calc_nearest_pixel(orig_coord, scales[axis] < 1); if (dim >= input_shape[axis]) dim = input_shape[axis] - 1; if (dim < 0) dim = 0; @@ -293,21 +263,27 @@ __global__ void _ResizeBilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumHW); if (id < output_height) { // y = id - float input_y = scale_height == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_height, + static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_height - float input_x = scale_width == 1 ? static_cast(id - output_height) : - transform_coordinate(static_cast(id - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_height + float input_x = scale_width == 1 ? static_cast(id - output_height) + : transform_coordinate(static_cast(id - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), roi_width_start, + roi_width_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_x < 0 || + input_x > static_cast(input_width - 1)))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -371,32 +347,40 @@ __global__ void _ResizeTrilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumDHW); if (id < output_depth) { // z = id - float input_z = scale_depth == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_depth, - static_cast(output_depth), static_cast(input_depth), - roi_depth_start, roi_depth_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_z < 0 || input_z > static_cast(input_depth - 1))); + float input_z = scale_depth == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_depth, + static_cast(output_depth), + static_cast(input_depth), + roi_depth_start, roi_depth_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_z < 0 || + input_z > static_cast(input_depth - 1)))); input_z = max(0.0f, min(input_z, static_cast(input_depth - 1))); int z_int = static_cast(input_z); dims_mapping[id].origin_ = z_int; dims_mapping[id].weight_ = (z_int >= input_depth - 1) ? 0.5f : input_z - z_int; } else if (id >= output_depth && id < (output_depth + output_height)) { // y = id - output_depth - float input_y = scale_height == 1 ? static_cast(id - output_depth) : - transform_coordinate(static_cast(id - output_depth), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id - output_depth) + : transform_coordinate(static_cast(id - output_depth), + scale_height, static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_depth - output_height - float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) : - transform_coordinate(static_cast(id - output_depth - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_depth - output_height + float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) + : transform_coordinate(static_cast(id - output_depth - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), + roi_width_start, roi_width_end); + dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || + input_x > static_cast(input_width - 1))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -513,21 +497,33 @@ __global__ void _ResizeCubicCoordinateMapping( int max_input_coord = static_cast(is_y_axis ? input_height : input_width); float scale = is_y_axis ? scale_height : scale_width; - float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) : - transform_coordinate( - static_cast(is_y_axis ? id : id - output_height), - scale, - static_cast(is_y_axis ? output_height : output_width), - static_cast(max_input_coord), - (is_y_axis ? roi_height_start : roi_width_start), - (is_y_axis ? roi_height_end : roi_width_end)); + float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) + : transform_coordinate( + static_cast(is_y_axis ? id : id - output_height), + scale, + static_cast(is_y_axis ? output_height : output_width), + static_cast(max_input_coord), + (is_y_axis ? roi_height_start : roi_width_start), + (is_y_axis ? roi_height_end : roi_width_end)); int coord_int = static_cast(_Floor(input_coordinat)); float s_coord = abs(input_coordinat - coord_int); float coeff_sum = 1.0f; - float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * (s_coord + 1) + 8 * cubic_coeff_a) * (s_coord + 1) - 4 * cubic_coeff_a); - float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * s_coord * s_coord + 1); - float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * (1 - s_coord) * (1 - s_coord) + 1); - float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * (2 - s_coord) + 8 * cubic_coeff_a) * (2 - s_coord) - 4 * cubic_coeff_a); + float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * + (s_coord + 1) + + 8 * cubic_coeff_a) * + (s_coord + 1) - + 4 * cubic_coeff_a); + float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * + s_coord * s_coord + + 1); + float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * + (1 - s_coord) * (1 - s_coord) + + 1); + float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * + (2 - s_coord) + + 8 * cubic_coeff_a) * + (2 - s_coord) - + 4 * cubic_coeff_a); if (exclude_outside) { coeff_0 = (coord_int - 1 < 0 || coord_int - 1 >= max_input_coord) ? 0.0 : coeff_0; coeff_1 = (coord_int + 0 < 0 || coord_int + 0 >= max_input_coord) ? 0.0 : coeff_1; @@ -540,7 +536,8 @@ __global__ void _ResizeCubicCoordinateMapping( dm.coeff1_ = coeff_1 / coeff_sum; dm.coeff2_ = coeff_2 / coeff_sum; dm.coeff3_ = coeff_3 / coeff_sum; - dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || input_coordinat > static_cast(max_input_coord - 1))); + dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || + input_coordinat > static_cast(max_input_coord - 1))); } template @@ -569,21 +566,30 @@ __global__ void _ResizeBiCubicKernel( int x_int = x_info.origin_; int y_int = y_info.origin_; const T* image = input_data + input_index; - output_data[id] = y_info.coeff0_ * CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff1_ * CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff2_ * CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff3_ * CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); + output_data[id] = y_info.coeff0_ * + CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff1_ * + CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff2_ * + CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff3_ * + CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); } size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims) { switch (upsample_mode) { case UpsampleMode::NN: - return sizeof(int64_t) * output_dims.size() + sizeof(NearestMappingInfo) * static_cast(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0)); + return sizeof(int64_t) * output_dims.size() + + sizeof(NearestMappingInfo) * + static_cast(std::accumulate(output_dims.begin(), + output_dims.end(), (int64_t)0)); case UpsampleMode::LINEAR: - return sizeof(LinearMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(LinearMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); case UpsampleMode::CUBIC: - return sizeof(CubicMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(CubicMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); } return 0; } @@ -616,7 +622,8 @@ void ResizeNearestImpl( if (could2d) { int64_t output_height = output_shape[rank - 2]; int64_t output_width = output_shape[rank - 1]; - fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(static_cast(output_height * output_width)); + fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] + : fast_divmod(static_cast(output_height * output_width)); int blocksPerDimsMappingGrid = static_cast(ceil((output_height + output_width) / 32.0)); DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() { @@ -694,13 +701,6 @@ void ResizeImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, ResizeNearestMode nearest_mode, void* dims_mapping) { - bool isSame = std::all_of(scales_vals.Data(), scales_vals.Data() + rank, [](float v) { return v == 1.0f; }) && - (coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); - if (isSame) { - CUDA_CALL_THROW(cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - return; - } - if (upsample_mode == UpsampleMode::NN) { ResizeNearestImpl( stream, rank, input_shape, output_shape, input_strides, output_div_pitches, @@ -761,7 +761,7 @@ void ResizeImpl( } else if (is_3D) { DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(coordinate_transform_mode, [&]() { _ResizeTrilinearCoordinateMapping<<>>( - input_shape[rank - 3] , input_shape[rank - 2], input_shape[rank - 1], + input_shape[rank - 3], input_shape[rank - 2], input_shape[rank - 1], output_depth, output_height, output_width, scales_vals[rank - 3], scales_vals[rank - 2], scales_vals[rank - 1], roi_vals[rank - 3], roi_vals[rank - 3 + rank], @@ -778,7 +778,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize support 2-D and 3-D dimensions in LINEAR mode."); break; case UpsampleMode::CUBIC: if (is_2D) { @@ -801,7 +801,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize supports only 2-D in CUBIC mode."); case UpsampleMode::NN: ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); } @@ -809,7 +809,7 @@ void ResizeImpl( #define SPECIALIZED_IMPL(T) \ template void ResizeImpl( \ - cudaStream_t stream, \ + cudaStream_t stream, \ const UpsampleMode upsample_mode, \ const int rank, \ TArray& input_shape, \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h index d459dbff18d3e..ad06eebb9efb1 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h @@ -2,15 +2,69 @@ // Licensed under the MIT License. #pragma once + #include + +#include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/common/common.h" #include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { +template <> +struct AccumulateType { + using type = float; +}; namespace cuda { +struct TransformCoordinate_ASYMMETRIC { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return x_resized / x_scale; + } +}; + +struct TransformCoordinate_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return ((x_resized + 0.5f) / x_scale) - 0.5f; + } +}; + +struct TransformCoordinate_PYTORCH_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized, float, + float, float) const { + return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; + } +}; + +struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return (x_resized + 0.5f) / x_scale; + } +}; + +struct TransformCoordinate_ALIGN_CORNERS { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float, float) const { + return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); + } +}; + +struct TransformCoordinate_TF_CROP_AND_RESIZE { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float roi_start, float roi_end) const { + auto orig = length_resized > 1 + ? roi_start * (length_original - 1) + + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) + : 0.5 * (roi_start + roi_end) * (length_original - 1); + return static_cast(orig); + } +}; + size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims); @@ -36,5 +90,62 @@ void ResizeImpl( onnxruntime::ResizeNearestMode nearest_mode, void* dims_mapping); +using TempSpaceAllocateFunc = std::function(size_t buffer_size)>; + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, // CPU + const std::optional& extrapolation_value, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N); + +/// +/// Compute scaled support value for a given dimension inverse scale +/// +/// Support value from parameters +/// inverse scale value comes from input/attr for +/// +inline float ComputeScaledSupportValue(float support_value, float rscale) { + const float scale = 1.0f / rscale; + float scaled_support = (scale >= 1.0f) ? (support_value * 0.5f) * scale : support_value * 0.5f; + return scaled_support; +} + +/// +/// Compute window size for a given dimension scaled support value. +/// +/// +/// +inline int32_t ComputeWindowSize(float scaled_support) { + SafeInt window_size(ceilf(scaled_support)); + return window_size * 2 + 1; +} + +/// +/// Computes scale buffer size in number of elements for allocation purposes. +/// +/// +/// +/// Number of elements to fit in the buffer +inline SafeInt ComputeWeightedCoeffBufferSize(int64_t output_size, int32_t window_size) { + SafeInt buffer_size(output_size); + return buffer_size * window_size; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index ae12ca328bc7c..17533eb3d9a72 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include "upsample.h" + +#include + #include "upsample_impl.h" #include "core/providers/cuda/tensor/resize_impl.h" #include "core/providers/cpu/tensor/utils.h" @@ -37,11 +40,23 @@ REGISTER_VERSIONED_TYPED_KERNEL(MLFloat16, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +template +Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + if (UpsampleBase::antialias_) { + // Copy the table on DEVICE + const uint8_t* lookup_table = GetLookupTableShared(); + auto alloc = info.GetAllocator(OrtMemTypeDefault); + shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, nullptr)); + } +} + template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const Tensor* X = context->Input(0); auto X_dims = X->Shape().GetDims(); int32_t rank = static_cast(X_dims.size()); @@ -52,7 +67,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, is_resize_ ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar."); if (rank != static_cast(scales.size())) return Status(ONNXRUNTIME, INVALID_ARGUMENT, - is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales."); + is_resize_ ? "Resize: input tensor's dimension does not match the scales." + : "Upsample: input tensor's dimension does not match the scales."); if (roi.size() != 2 * X_dims.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); @@ -79,22 +95,194 @@ Status Upsample::BaseCompute(OpKernelContext* context, size_t output_count = Y->Shape().Size(); if (is_resize_) { - TArray input_shape(X_dims); - TArray output_shape(output_dims); - TArray roi_vals(roi); - TArray scales_vals(scales); - - size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); - auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); - void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); - ResizeImpl(Stream(context), mode_, (int)rank, input_shape, output_shape, - input_strides, output_div_pitches, scales_vals, roi_vals, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), - cubic_coeff_a_, exclude_outside_, - coordinate_transform_mode_, nearest_mode_, - dims_mapping); + const bool is_same = std::all_of(scales.begin(), scales.end(), [](float v) { return v == 1.0f; }) && + (coordinate_transform_mode_ != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); + if (is_same) { + CUDA_CALL_THROW(cudaMemcpyAsync(Y->MutableData(), X->Data(), + output_count * sizeof(T), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + if (antialias_) { + TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { + return GetScratchBuffer(bytes_size, context->GetComputeStream()); + }; + + std::optional extrapolation_value; + if (use_extrapolation_) + extrapolation_value.emplace(extrapolation_value_); + + switch (mode_) { + case UpsampleMode::LINEAR: { + if (X_dims.size() == 2 || X_dims.size() == 4) { + const bool is_2D = X_dims.size() == 2; + + int64_t batch_size = 1; + int64_t num_channels = 1; + + int64_t input_height; + int64_t input_width; + + int64_t output_height; + int64_t output_width; + + float height_scale; + float width_scale; + + if (is_2D) { + input_height = X_dims[0]; + input_width = X_dims[1]; + + output_height = output_dims[0]; + output_width = output_dims[1]; + + height_scale = scales[0]; + width_scale = scales[1]; + } else { + if (scales[0] == 1.0f && scales[1] == 1.0f) { + batch_size = X_dims[Channels::N]; + num_channels = X_dims[Channels::C]; + input_height = X_dims[Channels::H]; + input_width = X_dims[Channels::W]; + + output_height = output_dims[Channels::H]; + output_width = output_dims[Channels::W]; + + height_scale = scales[2]; + width_scale = scales[3]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NHWC is not supported yet"); + } + } + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + + } else if (X_dims.size() == 3 || X_dims.size() == 5) { + const bool is_3D = X_dims.size() == 3; + + if (!is_3D) { + if (!(scales[0] == 1.0f && scales[1] == 1.0f)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NDHWC is not supported yet"); + } + } + + const int64_t batch_size = is_3D ? 1 : X_dims[0]; + const int64_t num_channels = is_3D ? 1 : X_dims[1]; + const int64_t input_depth = is_3D ? X_dims[0] : X_dims[2]; + const int64_t input_height = is_3D ? X_dims[1] : X_dims[3]; + const int64_t input_width = is_3D ? X_dims[2] : X_dims[4]; + + const int64_t output_depth = is_3D ? output_dims[0] : output_dims[2]; + const int64_t output_height = is_3D ? output_dims[1] : output_dims[3]; + const int64_t output_width = is_3D ? output_dims[2] : output_dims[4]; + + const float depth_scale = is_3D ? scales[0] : scales[2]; + const float height_scale = is_3D ? scales[1] : scales[3]; + const float width_scale = is_3D ? scales[2] : scales[4]; + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(input_depth, input_height, input_width), + std::make_tuple(output_depth, output_height, output_width), + std::make_tuple(depth_scale, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Linear' mode only support 2-D inputs or 3-D inputs ('Bilinear', 'Trilinear') " + "or 4-D inputs or 5-D inputs with the corresponding outermost 2 scale values " + "being 1."); + } + } break; + case UpsampleMode::CUBIC: { + if (X_dims.size() != 2 && X_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1."); + } + + const bool is_2D = X_dims.size() == 2; + const bool is_nchw = is_2D ? true : (scales[1] == 1.0f && scales[1] == 1.0f); + + ORT_RETURN_IF_NOT(is_nchw, + "Resize 'Cubic' mode only supports NCWH layout " + " with 2-D or 4-D with leading dims equal to 1"); + + const int64_t batch_size = is_2D ? 1 : X_dims[Channels::N]; + const int64_t num_channels = is_2D ? 1 : X_dims[Channels::C]; + const int64_t input_height = is_2D ? X_dims[0] : X_dims[Channels::H]; + const int64_t input_width = is_2D ? X_dims[1] : X_dims[Channels::W]; + + const int64_t output_height = is_2D ? output_dims[0] : output_dims[Channels::H]; + const int64_t output_width = is_2D ? output_dims[1] : output_dims[Channels::W]; + const float height_scale = is_2D ? scales[0] : scales[2]; + const float width_scale = is_2D ? scales[1] : scales[3]; + + ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } break; + default: + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: unexpected mode"); + } + } else { + TArray input_shape(X_dims); + TArray output_shape(output_dims); + TArray roi_vals(roi); + TArray scales_vals(scales); + + size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); + auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); + void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); + ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape, + input_strides, output_div_pitches, scales_vals, roi_vals, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), + cubic_coeff_a_, exclude_outside_, + coordinate_transform_mode_, nearest_mode_, + dims_mapping); + } } else { TArray scales_div(rank); @@ -124,7 +312,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { auto input_dims = X->Shape().GetDims(); TensorShapeVector output_dims(input_dims.size()); - std::vector roi_array(input_dims.size() * 2, 0.0f); + InlinedVector roi_array(input_dims.size() * 2, 0.0f); if (!roi_cached_) { bool use_default_roi = true; if (need_roi_input_) { @@ -147,29 +335,37 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { } } - const std::vector& roi = roi_cached_ ? roi_ : roi_array; - std::vector scales_array = scales_; + ComputeROIWithAxes(roi_array, input_dims.size()); + InlinedVector scales_array(input_dims.size()); + // opset < 10 if (OpKernel::Node().InputDefs().size() == 1) { - // Compute output shape from scales and input dims + // Compute output shape from scales attributes and input dims + scales_array = scales_; + ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_, output_dims); } const Tensor* scales = context->Input(scales_input_idx_); const Tensor* sizes = context->Input(sizes_input_idx_); + // This is when scales are obtained and cached from a constant initializer if (scales_cached_) { - ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + scales_array = scales_; + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } - scales_array.resize((input_dims.size())); + // Scales and sizes are input to the node if (scales != nullptr && scales->Shape().Size() != 0) { // use scales input data ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); + + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); } else { // When sizes input is available directly populate it into the output_dims array. @@ -179,7 +375,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); } - return BaseCompute(context, roi, scales_array, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index 7bf2a23ede399..50597e0fba1b9 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -13,12 +13,14 @@ namespace cuda { template class Upsample : public UpsampleBase, public CudaKernel { public: - Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { - } + explicit Upsample(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; + + private: + IAllocatorUniquePtr shared_lookup_table_ondevice_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 3fd5423681b81..0265c06b9a938 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1145,11 +1145,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1304,6 +1304,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2081,11 +2086,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2240,6 +2250,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index da17135878fe5..7b73ab36b3742 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -24,6 +24,7 @@ #include "core/providers/cpu/tensor/size.h" #include "core/providers/cpu/tensor/scatter_nd.h" #include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cpu/tensor/tile.h" #ifndef DISABLE_CONTRIB_OPS @@ -572,6 +573,11 @@ std::unique_ptr> EinsumTypedComputeProcessor template <> std::unique_ptr> EinsumTypedComputeProcessor::Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) { return g_host_cpu.EinsumTypedComputeProcessor_MLFloat16__Create(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + g_host_cpu.UpsampleBase__AdjustOutputSizeAsPolicy(this, output_dims, input_dims, scales); +} + #ifndef DISABLE_CONTRIB_OPS namespace contrib { Status embed_layer_norm::CheckInputs(const OpKernelContext* context, bool quantizedVersion) { @@ -648,7 +654,6 @@ Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, c const SessionState& subgraph_session_state) { return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } - } // namespace transformers #ifdef ENABLE_ATEN diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 0c9e2e9fc17a2..09666c8039402 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -288,7 +288,7 @@ Status Resize::Compute(OpKernelContext* ctx) const { // Get scales data const auto* scales = ctx->Input(scales_input_idx_); - std::vector scales_array(X->Shape().GetDims().size()); + InlinedVector scales_array(X->Shape().GetDims().size()); if (scales != nullptr && scales->Shape().Size() != 0) { ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, output_shape.size())); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 10f02349a24d5..1d31f3fdb4eb4 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -11,7 +11,8 @@ namespace test { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.20000028610229492, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] " + << "is 0.20000028610229492, which exceeds threshold"; } OpTester test("Resize", 13); @@ -32,7 +33,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { test.AddInput("X", {H, W}, X); test.AddInput("roi", {4}, roi); - test.AddInput("", {0}, scales); // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them + // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them + test.AddInput("", {0}, scales); test.AddInput("sizes", {2}, sizes); std::vector Y = {7.600004f, 7.9f, 8.2f, @@ -188,7 +190,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -317,7 +321,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { // The output size is [1,1,2,4].*[1,1,0.6,0.6]=[1,1,1,2] // NNAPI will recaluclate the scales as the output size divided by input size // scales = [1,1,1,2]./[1,1,2,4] = [1,1,0.5,0.5] -// See, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h +// See:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h // So the result of the above example will be different than CPU EP // Add the following 2 tests to test with scales valid to NNAPI TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { @@ -475,7 +479,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_pytorch_half_pixel) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << " The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; } OpTester test("Resize", 13); @@ -533,7 +538,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -721,7 +727,8 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_align_corners) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_3DTrilinear_pytorch_half_pixel) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; } OpTester test("Resize", 13); @@ -1088,7 +1095,8 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 3, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 3, which exceeds threshold"; } OpTester test("Resize", 12); // tf_half_pixel_for_nn is deprecated since opset 13 @@ -1480,7 +1488,8 @@ TEST(ResizeOpTest, ResizeOpCubicUpSampleTest_tf_half_pixel_for_nn) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1505,7 +1514,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold "; } OpTester test("Resize", 10); @@ -1530,7 +1540,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1565,7 +1576,8 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1676,7 +1688,8 @@ TEST(UpsampleOpTest, ResizeOpNearestNoScaleTest_Ver10) { TEST(ResizeOpTest, ResizeOp_MissingRoiAndMissingScalesOptionalInputs) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect."; + GTEST_SKIP() << "Skipping because of the following error: " + << "MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect."; } OpTester test("Resize", 13); @@ -1827,7 +1840,8 @@ template void TestAntialiasing(std::map attributes, std::vector input_shape, std::vector input_data, - std::vector output_shape_or_scale, std::vector output_data) { + std::vector output_shape_or_scale, std::vector output_data, + gsl::span excluded_ep = {}) { auto parse_attr = [](const std::string& str, auto typed_v) { using Tdata = decltype(typed_v); std::vector vect; @@ -1891,13 +1905,22 @@ void TestAntialiasing(std::map attributes, } test.AddOutput("Y", output_shape, output_data); - // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + + std::unordered_set excluded_eps; + std::transform(excluded_ep.begin(), excluded_ep.end(), + std::inserter(excluded_eps, excluded_eps.end()), [](std::string_view ep) { + return std::string(ep); + }); + // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accuracy issue. + excluded_eps.insert(kTensorrtExecutionProvider); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_eps); } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias " + << "is slightly different and doesn't match in all cases."; } std::vector X(16); std::iota(X.begin(), X.end(), 1.f); @@ -1939,7 +1962,8 @@ TEST(ResizeOpTest, Antialias_Bilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + InlinedVector excluded_eps = {kCudaExecutionProvider}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y, excluded_eps); } { std::vector X(16); @@ -1982,17 +2006,21 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear) { 33.5f, 73.5f, 113.5f, 35.074074f, 75.07407f, 115.07407f, 36.590908f, 76.59091f, 116.59091f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y); + + // Nchw is not supported by CUDA Resize implementation + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; { std::vector X(16); std::iota(X.begin(), X.end(), uint8_t(0)); std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } { std::vector X(16); @@ -2000,7 +2028,7 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } { std::vector X(16); @@ -2008,13 +2036,14 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } } TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of " + << "antialias is slightly different and doesn't match in all cases."; } std::vector X(16 * 4); std::iota(X.begin(), X.end(), 0.f); @@ -2038,13 +2067,17 @@ TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) { TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias" + << " is slightly different and doesn't match in all cases."; } + + InlinedVector excluded_eps = {kCudaExecutionProvider}; std::vector X(16 * 4 * 4); std::iota(X.begin(), X.end(), 0.f); { std::vector Y = X; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y, + excluded_eps); } { std::vector Y = {0.625f, 2.375f, 4.625f, 6.375f, 8.625f, 10.375f, 12.625f, @@ -2066,7 +2099,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { 224.625f, 226.375f, 228.625f, 230.375f, 232.625f, 234.375f, 236.625f, 238.375f, 240.625f, 242.375f, 244.625f, 246.375f, 248.625f, 250.375f, 252.625f, 254.375f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y, + excluded_eps); } { std::vector Y = {2.5f, 3.5f, 4.5f, 5.5f, 9.5f, 10.5f, 11.5f, 12.5f, 18.5f, @@ -2084,7 +2118,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { 217.5f, 218.5f, 219.5f, 220.5f, 226.5f, 227.5f, 228.5f, 229.5f, 233.5f, 234.5f, 235.5f, 236.5f, 242.5f, 243.5f, 244.5f, 245.5f, 249.5f, 250.5f, 251.5f, 252.5f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y, + excluded_eps); } } @@ -2124,12 +2159,15 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { 19.576872f, 43.57687f, 21.126253f, 45.126255f, 22.606192f, 46.606194f, 19.878183f, 43.87818f, 21.358122f, 45.35812f, 22.907503f, 46.907505f, 24.387442f, 48.387444f}; - TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y); + + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias" + << "is slightly different and doesn't match in all cases."; } std::vector X(256); std::iota(X.begin(), X.end(), 0.0f); @@ -2145,9 +2183,40 @@ TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { 187.08333f, 195.91667f, 198.41667f, 205.91667f, 208.41667f, 217.25f, 219.75f, 227.25f, 229.75f, 238.58333f, 241.08333f, 248.58333f, 251.08333f}; + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; TestAntialiasing( {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, - {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y); + {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y, excluded_eps); +} + +TEST(ResizeOpTest, Antialias_Linear_AlignCorners_3D) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly " + << "different and doesn't match in all cases."; + } + std::vector X(256); + std::iota(X.begin(), X.end(), 0.0f); + std::vector Y{ + 1.25f, 3.75f, 11.25f, 13.75f, + 17.25f, 19.75f, 27.25f, 29.75f, + 33.25f, 35.75f, 43.25f, 45.75f, + 49.25f, 51.75f, 59.25f, 61.75f, + 65.25f, 67.75f, 75.25f, 77.75f, + 81.25f, 83.75f, 91.25f, 93.75f, + 97.25f, 99.75f, 107.25f, 109.75f, + 113.25f, 115.75f, 123.25f, 125.75f, + 129.25f, 131.75f, 139.25f, 141.75f, + 145.25f, 147.75f, 155.25f, 157.75f, + 161.25f, 163.75f, 171.25f, 173.75f, + 177.25f, 179.75f, 187.25f, 189.75f, + 193.25f, 195.75f, 203.25f, 205.75f, + 209.25f, 211.75f, 219.25f, 221.75f, + 225.25f, 227.75f, 235.25f, 237.75f, + 241.25f, 243.75f, 251.25f, 253.75f}; + + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, + {16, 4, 4}, X, {16, 2, 2}, Y); } TEST(ResizeOpTest, Antialias_Bicubic_ExcludeOutside) { @@ -2166,19 +2235,23 @@ TEST(ResizeOpTest, Antialias_Bicubic_Dtype) { std::vector X(36); std::iota(X.begin(), X.end(), uint8_t(0)); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y); } { std::vector X(36); std::iota(X.begin(), X.end(), int8_t(0)); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + InlinedVector excluded_eps = {kCudaExecutionProvider}; + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y, excluded_eps); } { std::vector X(36); std::iota(X.begin(), X.end(), 0); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y); } } @@ -2189,8 +2262,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Scale) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, - std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, + {1, 1, 4, 4, 4}, X, + std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); } TEST(ResizeOpTest, Antialias_Axes_and_Size) { @@ -2199,8 +2274,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Size) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, - {3, 3, 3}, Y); + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, + {1, 1, 4, 4, 4}, X, + {3, 3, 3}, Y); } TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { @@ -2209,9 +2286,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_larger"}}, - {1, 1, 4, 4, 4}, X, - {3, 4, 5}, Y); + // clang-format off + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, + {"policy", "not_larger"}}, + {1, 1, 4, 4, 4}, X, + {3, 4, 5}, Y); + // clang-format on } TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { @@ -2220,9 +2301,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_smaller"}}, - {1, 1, 4, 4, 4}, X, - {1, 2, 3}, Y); + // clang-format off + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, + {"policy", "not_smaller"}}, + {1, 1, 4, 4, 4}, X, + {1, 2, 3}, Y); + // clang-format on } TEST(ResizeOpTest, Antialias_Use_Extrapolation) { From 2a857d9a86ca3049829256df3347521069ccd6b4 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 1 Mar 2024 10:23:29 +1000 Subject: [PATCH 091/237] Add ML Program support for more operators (#19527) ### Description Add support for: - Clip/Relu/Relu6 - Add/Mul/Div/Sub/Pow - GlobalAveragePool/GlobalMaxPool/AveragePool/MaxPool - Reshape - Gemm/MatMul Fix some build issues/warnings from changes. Fix a couple of potential issues with the Resize op as well (noticed due to change to reject inputs with empty data at a higher level). ### Motivation and Context Enable mobilenetv2 with ML Program --- cmake/onnxruntime_providers_coreml.cmake | 2 +- .../providers/coreml/builders/coreml_spec.h | 7 +- .../core/providers/coreml/builders/helper.cc | 14 +- .../coreml/builders/impl/base_op_builder.cc | 13 +- .../coreml/builders/impl/base_op_builder.h | 6 +- .../coreml/builders/impl/binary_op_builder.cc | 113 +++--- .../coreml/builders/impl/builder_utils.cc | 68 ++++ .../coreml/builders/impl/builder_utils.h | 17 +- .../coreml/builders/impl/clip_op_builder.cc | 187 ++++++--- .../coreml/builders/impl/conv_op_builder.cc | 94 +---- .../coreml/builders/impl/gemm_op_builder.cc | 332 +++++++++++----- .../coreml/builders/impl/pool_op_builder.cc | 218 +++++++---- .../builders/impl/reshape_op_builder.cc | 70 ++-- .../coreml/builders/impl/resize_op_builder.cc | 16 +- .../coreml/builders/impl/slice_op_builder.cc | 2 +- .../builders/impl/softmax_op_builder.cc | 4 +- .../coreml/builders/model_builder.cc | 366 +++++++++++++----- .../providers/coreml/builders/model_builder.h | 63 +-- .../coreml/coreml_execution_provider.cc | 82 ++-- .../providers/coreml/dump_mlprogram_model.py | 27 ++ .../core/providers/coreml/model/host_utils.h | 6 + .../core/providers/coreml/model/host_utils.mm | 10 + .../core/providers/coreml/model/model.h | 19 +- .../core/providers/coreml/model/model.mm | 13 + .../core/providers/coreml/model/model_stub.cc | 4 + .../providers/cpu/tensor/reshape_helper.h | 6 +- .../test/perftest/command_args_parser.cc | 25 +- onnxruntime/test/perftest/ort_test_session.cc | 30 +- .../providers/coreml/coreml_basic_test.cc | 20 + .../test/providers/cpu/math/clip_test.cc | 27 +- .../test/providers/cpu/math/gemm_test.cc | 37 +- .../providers/cpu/nn/batch_norm_op_test.cc | 37 ++ .../providers/cpu/tensor/resize_op_test.cc | 4 +- 33 files changed, 1344 insertions(+), 595 deletions(-) create mode 100644 onnxruntime/core/providers/coreml/dump_mlprogram_model.py diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index c9f35e5337f9b..8f3b1828e1c61 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -111,7 +111,7 @@ if(_enable_ML_PROGRAM) file(GLOB onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" - "${coremltools_SOURCE_DIR}/modelpackage/src/Utils/JsonMap.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp" ) set(coremltools_srcs diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index c9adba9e579d0..9448f1167990e 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -17,14 +17,19 @@ #ifdef HAS_SHORTEN_64_TO_32 #pragma GCC diagnostic ignored "-Wshorten-64-to-32" #endif +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from long to int #endif // Model.pb.h is generated in the build output directory from the CoreML protobuf files in -// onnxruntime/core/providers/coreml/coremltools/mlmodel/format +// /_deps/coremltools-src/mlmodel/format #include "coreml_proto/Model.pb.h" #if defined(__GNUC__) #pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) #endif namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index bc3ba4432e66d..b8ebbd05a2a20 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -85,9 +85,15 @@ bool IsInputSupported(const Node& node, const NodeArg& input, } if (dim == 0) { - LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name - << ", shape: " << Shape2String(shape); - return false; + if (node.OpType() == "Resize" && &input == node.InputDefs()[1]) { + // one special case. Resize 'roi' input was originally a required input but is rarely used. + // ROI is not supported in the CoreML implementation so we will ignore the value, but is often added + // (at least in the unit tests) as an initializer with shape {0}. + } else { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; + } } } @@ -125,7 +131,7 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& graph_viewer, const logging::Logger& logger, std::string_view input_description) { - if (graph_viewer.GetConstantInitializer(node_arg.Name(), true) == nullptr) { + if (graph_viewer.GetConstantInitializer(node_arg.Name()) == nullptr) { LOGS(logger, VERBOSE) << input_description << " (NodeArg name: '" << node_arg.Name() << "') is not a constant initializer tensor"; return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 2570e6d88ae0d..83a572f4b60fa 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -83,9 +83,14 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInput0Supported(const Node& node, const OpBuilderInputParams& /*input_params*/, - const logging::Logger& logger) { - const auto& input = *node.InputDefs()[0]; +bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) { + if (idx >= node.InputDefs().size()) { + LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; + return false; + } + + const auto& input = *node.InputDefs()[idx]; int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; @@ -102,7 +107,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInpu const logging::Logger& logger) const { // We only check the type of input 0 by default // specific op builder can override this - return IsInput0Supported(node, input_params, logger); + return IsInputFloat(node, 0, input_params, logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 06c4dd94ea30d..63f0b813d654c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -28,9 +28,9 @@ class BaseOpBuilder : public IOpBuilder { void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: - // check if the first input's data type is supported. - static bool IsInput0Supported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + // currently we only support float + static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, + const logging::Logger& logger); private: virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 6074fba1433d9..fb8e07633621f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -5,6 +5,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -19,6 +20,8 @@ class BinaryOpBuilder : public BaseOpBuilder { bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; namespace { @@ -57,38 +60,72 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = model_builder.CreateNNLayer(node); - - if (op_type == "Add") { - // original mutable_add() has limited broadcasting support - // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support - if (CheckIfBothInputShapesMatch(node, logger)) { - layer->mutable_add(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary + std::string_view coreml_op_type; + if (op_type == "Add") { + coreml_op_type = "add"; + } else if (op_type == "Mul") { + coreml_op_type = "mul"; + } else if (op_type == "Sub") { + coreml_op_type = "sub"; + } else if (op_type == "Div") { + // we only support fp32 currently. when we add support for integers we need to check the type and use + // "floor_div" or "real_div" accordingly + coreml_op_type = "real_div"; + } else if (op_type == "Pow") { + coreml_op_type = "pow"; } else { - layer->mutable_addbroadcastable(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); } - } else if (op_type == "Mul") { - if (CheckIfBothInputShapesMatch(node, logger)) { - layer->mutable_multiply(); + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + AddOperationInput(*op, "y", input_defs[1]->Name()); + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + if (op_type == "Add") { + // original mutable_add() has limited broadcasting support + // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support + if (CheckIfBothInputShapesMatch(node, logger)) { + layer->mutable_add(); + } else { + layer->mutable_addbroadcastable(); + } + } else if (op_type == "Mul") { + if (CheckIfBothInputShapesMatch(node, logger)) { + layer->mutable_multiply(); + } else { + layer->mutable_multiplybroadcastable(); + } + } else if (op_type == "Sub") { + layer->mutable_subtractbroadcastable(); + } else if (op_type == "Div") { + layer->mutable_dividebroadcastable(); + } else if (op_type == "Pow") { + layer->mutable_powbroadcastable(); } else { - layer->mutable_multiplybroadcastable(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); } - } else if (op_type == "Sub") { - layer->mutable_subtractbroadcastable(); - } else if (op_type == "Div") { - layer->mutable_dividebroadcastable(); - } else if (op_type == "Pow") { - layer->mutable_powbroadcastable(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_input()->Add() = input_defs[1]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_input()->Add() = input_defs[1]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -99,25 +136,11 @@ int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - if (node.OpType() != "Pow") { - return IsInput0Supported(node, input_params, logger); - } - - const auto& input_1 = *node.InputDefs()[0]; - const auto& input_2 = *node.InputDefs()[1]; - - // Pow we only support both inputs as fp32 for now - int32_t input_type_1; - int32_t input_type_2; - if (!GetType(input_1, input_type_1, logger) || - !GetType(input_2, input_type_2, logger)) { - return false; - } - - if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { - LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type" - << ", Input type 1: " << input_type_1 - << ", Input type 2: " << input_type_2; + // Add/Sub/Mul/Div spec says inputs must be of the same type. + // Pow spec says inputs can be different types. + // We only support float for all of these inputs. + if (!IsInputFloat(node, 0, input_params, logger) || + ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 710f596b2a562..cbea969904ed5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -7,6 +7,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" @@ -132,6 +133,7 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::spansize(); + size_t num_dims = num_pads / 2; + std::vector reordered_pads(num_pads, 0); + for (size_t i = 0; i < num_pads; ++i) { + auto cur_dim = i % num_dims; + if (i < num_dims) { // start values + reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; + } else { // end values + reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; + } + } + + AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); + + break; + } + + // fall through if explicit pads were not provided as the default value for `pads` is all zeros, + // which is the same as 'valid' padding. + [[fallthrough]]; + } + case AutoPadType::VALID: + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + + break; + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: { + const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); + + // despite what the spec says, a 'pad' input seems to be required. + // https://github.com/apple/coremltools/issues/2127 + // Provide the default value as that's what coremltools does for conv/avg_pool/max_pool. + std::vector ignored_pads(num_spatial_dims * 2, 0); + AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); + + break; + } + } +} +#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 8126f0c126914..2804589065631 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -11,13 +11,15 @@ #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" - #include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { class NodeArg; namespace coreml { +class ModelBuilder; + // Try to see if we can map explicit padding to auto padding for Conv/Pool // Since usually use auto padding is more efficient Status HandleAutoPad(const std::vector input_shape, @@ -45,6 +47,7 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +#if defined(COREML_ENABLE_MLPROGRAM) // // MLProgram utils // @@ -130,5 +133,17 @@ void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, /// Operation to update. /// NodeArg with details of output to add. void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); + +/// +/// Add pad_type and pad values. +/// +/// Operator to update +/// ModelBuilder to add constants with. +/// Operator type. +/// Node attribute helper. +/// Number of spatial dims in input. Generally rank - 2 (ignore N and C dims). +void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, + const NodeAttrHelper& helper, int num_spatial_dims); +#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 9aca172abec98..41f4041ef1181 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -17,11 +18,31 @@ class ClipOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + bool skip = true; + + if (model_builder.CreateMLProgram()) { + float min, max; + ORT_IGNORE_RETURN_VALUE(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, model_builder.Logger())); + + bool has_min = min != std::numeric_limits::lowest(); + bool has_max = max != std::numeric_limits::max(); + if (has_min && has_max && min == 0.f && max == 6.f) { + // relu6 - skip both + } else if (has_min && min == 0.f && !has_max) { + // relu - skip both + } else { + // clip - we will use both + skip = false; + } + } + // Both min and max values will be injected into the layer, no need to add to the model - if (node.SinceVersion() >= 11) { + if (skip && node.SinceVersion() >= 11) { if (node.InputDefs().size() > 1) model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); @@ -35,72 +56,126 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const logging::Logger& logger) const { const auto& node_name = node.Name(); const auto& input_name = node.InputDefs()[0]->Name(); - const auto& output_name = node.OutputDefs()[0]->Name(); + const auto& output = *node.OutputDefs()[0]; + const auto& output_name = output.Name(); float min, max; ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); - if (!has_min && !has_max) { - // Clip without min/max is an identity node - // In CoreML we don't have identity, use ActivationLinear instead - std::unique_ptr layer = model_builder.CreateNNLayer(node); - layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - - model_builder.AddLayer(std::move(layer)); - } else { - // The implementation of clip(min, max) is done by - // 1. Clipping at min -> max(input, min) is handled by - // min_output = threshold(input, min) - // 2. Clipping at max -> min(min_output, max) is handled by - // output = -1 * (threshold(-min_output, -max)) - - // Now we have at least one or min or max is not default value - // Clipping at max will need take the output of clipping at min, or the node input, if min value is default - // If max value is default, the output of clipping at min will be the output of the node - std::string min_output_name = output_name; - if (has_max) { - min_output_name = has_min - ? model_builder.GetUniqueName(node_name + "min_output") - : input_name; +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::unique_ptr op; + if (!has_min && !has_max) { + // Clip without min/max is an identity node. + op = model_builder.CreateOperation(node, "identity"); + Operation& identity_op = *op; + AddOperationInput(identity_op, "x", input_name); + } else { + if (has_min && has_max && min == 0.f && max == 6.f) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu6 + op = model_builder.CreateOperation(node, "relu6"); + Operation& relu6_op = *op; + AddOperationInput(relu6_op, "x", input_name); + } else if (has_min && min == 0.f && !has_max) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu + op = model_builder.CreateOperation(node, "relu"); + Operation& relu_op = *op; + AddOperationInput(relu_op, "x", input_name); + } else { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.clip + op = model_builder.CreateOperation(node, "clip"); + + Operation& clip_op = *op; + AddOperationInput(clip_op, "x", input_name); + + // if min and max were attributes we need to add initializers. otherwise we use the existing inputs + const bool min_max_attribs = node.SinceVersion() < 11; + std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); + + AddOperationInput(clip_op, "alpha", min_name); + + if (has_max) { + std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + AddOperationInput(clip_op, "beta", max_name); + } + } } - // Handle clipping at min first - if (has_min) { - std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); - if (min == 0.0f) { // If min is 0. then this min will be handled by relu - min_layer->mutable_activation()->mutable_relu(); - } else { // otherwise, min will be handled by unary->threshold - min_layer->mutable_unary()->set_alpha(min); - min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + AddOperationOutput(*op, output); + model_builder.AddOperation(std::move(op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + // TODO: CoreML has a Clip layer for NeuralNetwork. Added in CoreML 4. We could potentially use that if available + // to simplify. + // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#cliplayerparams + + if (!has_min && !has_max) { + // Clip without min/max is an identity node + // In CoreML we don't have identity, use ActivationLinear instead + std::unique_ptr layer = model_builder.CreateNNLayer(node); + layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + + model_builder.AddLayer(std::move(layer)); + } else { + // The implementation of clip(min, max) is done by + // 1. Clipping at min -> max(input, min) is handled by + // min_output = threshold(input, min) + // 2. Clipping at max -> min(min_output, max) is handled by + // output = -1 * (threshold(-min_output, -max)) + + // Now we have at least one or min or max is not default value + // Clipping at max will need take the output of clipping at min, or the node input, if min value is default + // If max value is default, the output of clipping at min will be the output of the node + std::string min_output_name = output_name; + if (has_max) { + min_output_name = has_min + ? model_builder.GetUniqueName(node_name + "min_output") + : input_name; } - *min_layer->mutable_input()->Add() = input_name; - *min_layer->mutable_output()->Add() = min_output_name; - model_builder.AddLayer(std::move(min_layer)); - } - - // Clipping at max is handled by -1 * (threshold (-min_output, -max)) - if (has_max) { - const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); - { // Add threshold layer, which is actually max( -1 * min_output, -max) - auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); - threshold_layer->mutable_unary()->set_alpha(-max); - threshold_layer->mutable_unary()->set_scale(-1.0f); - threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); - *threshold_layer->mutable_input()->Add() = min_output_name; - *threshold_layer->mutable_output()->Add() = threshold_output_name; - model_builder.AddLayer(std::move(threshold_layer)); + // Handle clipping at min first + if (has_min) { + std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); + if (min == 0.0f) { // If min is 0. then this min will be handled by relu + min_layer->mutable_activation()->mutable_relu(); + } else { // otherwise, min will be handled by unary->threshold + min_layer->mutable_unary()->set_alpha(min); + min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + } + + *min_layer->mutable_input()->Add() = input_name; + *min_layer->mutable_output()->Add() = min_output_name; + model_builder.AddLayer(std::move(min_layer)); } - { // Add linear activation layer -1 * threshold_output - auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); - linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); - *linear_layer->mutable_input()->Add() = threshold_output_name; - *linear_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(linear_layer)); + + // Clipping at max is handled by -1 * (threshold (-min_output, -max)) + if (has_max) { + const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); + { // Add threshold layer, which is actually max( -1 * min_output, -max) + auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); + threshold_layer->mutable_unary()->set_alpha(-max); + threshold_layer->mutable_unary()->set_scale(-1.0f); + threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + *threshold_layer->mutable_input()->Add() = min_output_name; + *threshold_layer->mutable_output()->Add() = threshold_output_name; + model_builder.AddLayer(std::move(threshold_layer)); + } + { // Add linear activation layer -1 * threshold_output + auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); + linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); + *linear_layer->mutable_input()->Add() = threshold_output_name; + *linear_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(linear_layer)); + } } } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index 05e43dbbd16af..38125957bf481 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -67,99 +67,25 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*conv_op, "bias", input_defs[2]->Name()); } - // ONNX attributes. Add as inputs if specified/required - auto strides = helper.GetInt64s("strides"); - auto dilations = helper.GetInt64s("dilations"); - auto groups = helper.GetInt64("group"); - // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; const auto& op_type = conv_op->type(); - if (strides) { - AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", *strides)); - } else { - // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) - static const auto default_value = std::vector(num_spatial_dims, 1); - AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", default_value)); - } + // Spec says strides and dilations are optional, but reality is they're required for at least the iOS15 target + // (CoreML5). + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + auto dilations = helper.Get("dilations", std::vector(num_spatial_dims, 1)); + auto groups = helper.GetInt64("group"); - if (dilations) { - AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", *dilations)); - } else { - // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) - static const auto default_value = std::vector(num_spatial_dims, 1); - AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", default_value)); - } + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", strides)); + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", dilations)); if (groups) { AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); } - AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); - - // pad type (string) - // valid - no pads (ONNX auto_pad VALID) - // custom - pads input (ONNX NOTSET) - // same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])` (assuming == ONNX SAME_UPPER) - // same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER) - // - // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value - // can be used. TBD if that provides any performance benefit with ML Program though as CoreML could - // potentially do that for us. - switch (auto_pad_type) { - case AutoPadType::NOTSET: { - // use `pads` attribute. - auto onnx_pads = helper.GetInt64s("pads"); // 'pads' must be provided if auto_pad is NOTSET - if (onnx_pads) { - AddOperationInput(*conv_op, "pad_type", - model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom"))); - - // need to re-order from x1_start, x2_start..., x1_end, x2_end... to - // x1_start, x1_end, x2_start, x2_end,... - size_t num_pads = onnx_pads->size(); - size_t num_dims = num_pads / 2; - std::vector reordered_pads(num_pads, 0); - for (size_t i = 0; i < num_pads; ++i) { - auto cur_dim = i % num_dims; - if (i < num_dims) { // start values - reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; - } else { // end values - reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; - } - } - - AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); - - break; - } - - // in theory the pads may not be provided and in that case the default is no padding. - // as that is the same as 'valid', fall through - [[fallthrough]]; - } - case AutoPadType::VALID: - AddOperationInput(*conv_op, "pad_type", - model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); - - break; - case AutoPadType::SAME_UPPER: - case AutoPadType::SAME_LOWER: { - const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); - AddOperationInput(*conv_op, "pad_type", - model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); - - // despite what the spec says, a 'pad' input seems to be required. - // https://github.com/apple/coremltools/issues/2127 - // provide the default value. passing in an empty vector also works. TBD what's better. - std::vector ignored_pads(num_spatial_dims * 2, 0); - AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); - - break; - } - } + AddPadTypeAndPads(*conv_op, model_builder, op_type, helper, num_spatial_dims); - // set output AddOperationOutput(*conv_op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(conv_op)); @@ -297,7 +223,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name, true); + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name); #if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { @@ -324,7 +250,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name(), true)) { + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 48f77354d7c30..8daf64dc4a457 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -22,18 +22,51 @@ class GemmOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; - bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op = node.OpType(); const auto& input_defs(node.InputDefs()); - // We have already embedded the weights (matrix B and C(if any)) into the coreml layer - // No need to copy them later to reduce memory consumption - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - if (op == "Gemm" && input_defs.size() > 2) { - model_builder.AddInitializerToSkip(input_defs[2]->Name()); + const bool is_gemm = op == "Gemm"; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + // we have to transpose the weight input of Gemm if transB is false, and potentially override the bias shape + if (is_gemm) { + NodeAttrHelper helper(node); + const auto transB = helper.Get("transB", 0); + if (transB == 0) { + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + } + + if (input_defs.size() > 2) { + // ONNX spec requires B to be 2D and we required it to be a constant initializer so reading N this way is safe + // B is {K, N] by default. or {N, K} if transB is true + int N_dim = transB ? 0 : 1; + int64_t N = input_defs[1]->Shape()->dim().at(N_dim).dim_value(); + + const auto& bias_name = input_defs[2]->Name(); + const auto& bias = *model_builder.GetConstantInitializer(bias_name); + if (bias.dims_size() != 1 || bias.dims(0) != N) { + // we have to override the shape/duplicate data to convert {}, {1} or {1, N} to 1D {N} + // when adding the Gemm operation so skip adding the original initializer + model_builder.AddInitializerToSkip(bias_name); + } + } + } + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + // We have already embedded the weights (matrix B and C(if any)) into the coreml layer + // No need to copy them later to reduce memory consumption + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + if (is_gemm && input_defs.size() > 2) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); + } } } @@ -57,54 +90,152 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te } Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + const logging::Logger& logger) const { std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - const auto& b_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - const auto& b_shape = b_tensor.dims(); - - auto* coreml_inner_product = layer->mutable_innerproduct(); - - // The coreml innerproduct weight (matrix B) is stored transposed - // - for MatMul and Gemm (transB = 0), the coreml weight is B' - // - for Gemm (transB = 1), the coreml weight is B - if (op_type == "MatMul") { - coreml_inner_product->set_inputchannels(b_shape[0]); - coreml_inner_product->set_outputchannels(b_shape[1]); - // Add weight (b of MatMul) - std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed)); - CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); - } else { // Gemm - NodeAttrHelper helper(node); - const auto transB = helper.Get("transB", 0); - if (transB == 0) { - coreml_inner_product->set_inputchannels(b_shape[0]); - coreml_inner_product->set_outputchannels(b_shape[1]); + const auto& a = *input_defs[0]; + const auto& b = *input_defs[1]; + const auto* b_initializer = model_builder.GetConstantInitializer(b.Name()); // MLProgram MatMul may not be constant + + const bool is_matmul = op_type == "MatMul"; + const bool is_gemm = op_type == "Gemm"; + + NodeAttrHelper helper(node); + const auto transB = is_gemm ? helper.Get("transB", 0) : 0; + + std::vector b_shape; + ORT_IGNORE_RETURN_VALUE(GetShape(b, b_shape, logger)); + int64_t b0 = -1, b1 = -1; + + // ML Program MatMul supports N-D input + if (model_builder.CreateMLProgram() && is_matmul) { + if (b_shape.size() == 1) { + // B is treated as {b_shape[0], 1} according to the numpy rules. + b0 = b_shape[0]; + b1 = 1; + } else { + // last 2 dims are used + b0 = b_shape[b_shape.size() - 2]; + b1 = b_shape[b_shape.size() - 1]; + } + } else { + // we only support 2D input + b0 = b_shape[0]; + b1 = b_shape[1]; + } + + // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true + const auto K = transB ? b1 : b0; + const auto N = transB ? b0 : b1; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + if (is_gemm) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.linear + auto gemm_op = model_builder.CreateOperation(node, "linear"); + AddOperationInput(*gemm_op, "x", a.Name()); + + // CoreML takes weight input as {N, K} which is the reverse of ONNX. + // if transB is true the input weight is {N, K} so can be added directly. + if (transB) { + AddOperationInput(*gemm_op, "weight", b.Name()); + } else { + // transpose from {K, N} to {N, K} + std::vector weight_nk; + std::vector weight_nk_shape = {N, K}; + ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk)); + + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } + + if (input_defs.size() == 3) { + const auto& bias_arg = *input_defs[2]; + const auto& bias = *model_builder.GetConstantInitializer(bias_arg.Name()); + + // CoreML linear op requires bias to be 1D tensor of size N + if (bias.dims_size() == 1 && bias.dims().at(0) == N) { + // can use existing initializer + AddOperationInput(*gemm_op, "bias", bias_arg.Name()); + } else { + Initializer unpacked_tensor(bias); + auto bias_data = unpacked_tensor.DataAsSpan(); + std::string_view bias_data_name; + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } + + AddOperationInput(*gemm_op, "bias", bias_data_name); + } + } + + AddOperationOutput(*gemm_op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(gemm_op)); + } else { + // CoreML implementation is the same as ONNX MatMul. + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.matmul + auto matmul_op = model_builder.CreateOperation(node, "matmul"); + AddOperationInput(*matmul_op, "x", a.Name()); + AddOperationInput(*matmul_op, "y", b.Name()); + + // once again the spec lies and says transpose_y and transpose_x are optional... + auto false_value_name = model_builder.AddScalarConstant(matmul_op->type(), "false", false); + AddOperationInput(*matmul_op, "transpose_x", false_value_name); + AddOperationInput(*matmul_op, "transpose_y", false_value_name); + + AddOperationOutput(*matmul_op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(matmul_op)); + } + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + auto* coreml_inner_product = layer->mutable_innerproduct(); + + *layer->mutable_input()->Add() = a.Name(); + + coreml_inner_product->set_inputchannels(K); + coreml_inner_product->set_outputchannels(N); + + // CoreML takes weight input as {N, K} which is the reverse of ONNX. + // if Gemm's transB is true the input weight is {N, K} and can be added directly. + if (transB) { + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); + } else { std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed)); + ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed)); CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); - } else { - coreml_inner_product->set_inputchannels(b_shape[1]); - coreml_inner_product->set_outputchannels(b_shape[0]); - // Add weight (b of MatMul) - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_tensor)); } - // Add bias if present - if (input_defs.size() > 2) { + if (is_gemm && input_defs.size() > 2) { + // Add bias coreml_inner_product->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_tensor)); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + + // if scalar, or single value expand to 1D tensor of size N + // IsOpSupportedImpl enforces it's scalar, {1}, {N}, or {1, N}. + Initializer unpacked_tensor(bias_tensor); + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1 && N > 1) { + std::vector expanded_bias_data(N, bias_data[0]); + CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), expanded_bias_data); + } else { + CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_data); + } } - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -112,98 +243,105 @@ bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); + const bool is_matmul = op_type == "MatMul"; + const bool is_gemm = op_type == "Gemm"; + size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, input_defs[b_idx]->Name())) { - LOGS(logger, VERBOSE) << "B of Gemm/Matmul must be an initializer tensor"; + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { return false; } - std::vector a_shape; - { - if (!GetShape(*input_defs[a_idx], a_shape, logger)) - return false; - - if (a_shape.size() != 2) { - LOGS(logger, VERBOSE) << "A must be 2D"; - return false; - } + std::vector b_shape; + if (!GetShape(*input_defs[b_idx], b_shape, logger)) { + return false; + } - // TODO is it ok if the shape is dynamic and empty? - if (Product(a_shape) == 0) { - LOGS(logger, VERBOSE) << "A must be non-empty"; + if (!input_params.graph_viewer.GetConstantInitializer(input_defs[b_idx]->Name())) { + if (input_params.create_mlprogram && is_matmul) { + // ML Program MatMul allows non-constant B input + } else { + LOGS(logger, VERBOSE) << op_type << " B input must be a constant initializer"; return false; } } - std::vector b_shape; - { - if (!GetShape(*input_defs[b_idx], b_shape, logger)) - return false; - - if (b_shape.size() != 2) { - LOGS(logger, VERBOSE) << "B must be 2D"; - return false; - } + if (is_matmul) { + if (input_params.create_mlprogram) { + // ML Program matmul op has numpy semantics the same as the ONNX spec so we can use directly + } else { + // we could potentially support 1D and 3D if required. beyond 3D the dims that merge diverge. + // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/onnx/_operators.py#L1607 + // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/backend/nn/op_mapping.py#L1374 + // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#innerproductlayerparams + if (a_shape.size() != 2 || b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "a and b inputs must be 2D. "; + return false; + } - if (Product(b_shape) == 0) { - LOGS(logger, VERBOSE) << "B must be non-empty"; - return false; + if (input_defs.size() > 2) { + LOGS(logger, VERBOSE) << "MatMul with C input is not supported"; + return false; + } } } - if (op_type == "Gemm") { + if (is_gemm) { + // A and B are 2D due to the ONNX spec NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); const auto transB = helper.Get("transB", 0); const auto alpha = helper.Get("alpha", 1.0f); const auto beta = helper.Get("beta", 1.0f); + + // TODO: We can support transA, alpha and beta by using multiple layers/operations if needed. if (!(transA == 0 && alpha == 1.f && beta == 1.f)) { - LOGS(logger, VERBOSE) << "Only transA == 0, alpha == 1.0 " - << "and beta == 1.0 is supported." + LOGS(logger, VERBOSE) << "Only support for transA == 0, alpha == 1.0 " + << "and beta == 1.0 is currently implemented." << " transA " << transA << " alpha " << alpha << " beta " << beta; return false; } - // C of Gemm - // For now we only support {n} or {1,n} tensor if (input_defs.size() == 3) { - if (!Contains(initializers, input_defs[c_idx]->Name())) { - LOGS(logger, VERBOSE) << "C of Gemm must be an initializer tensor"; + if (!input_params.graph_viewer.GetConstantInitializer(input_defs[c_idx]->Name())) { + LOGS(logger, VERBOSE) << "C of Gemm must be a constant initializer"; return false; } std::vector c_shape; - if (!GetShape(*input_defs[c_idx], c_shape, logger)) + if (!GetShape(*input_defs[c_idx], c_shape, logger)) { return false; + } - size_t c_dim = c_shape.size(); + // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true + const auto N = transB ? b_shape[0] : b_shape[1]; - if (c_dim == 0) { - LOGS(logger, VERBOSE) << "C of Gemm cannot be a scalar"; - return false; - } + size_t c_rank = c_shape.size(); - if (c_dim != 1) { - // If C is a (2+)d tensor, it must have the format {1, 1, ..., 1, n} - // where every except the last dimension should be 1 - for (size_t i = 0; i < c_dim - 1; ++i) { - if (c_shape[i] != 1) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector or a tensor with only last dimension != 1"; - return false; + // allowed: scalar, or 1D where the value is 1 or N, 2D with shape {1, N} + bool c_valid = false; + switch (c_rank) { + case 0: + c_valid = true; + break; + case 1: + if (c_shape[0] == 1 || c_shape[0] == N) { + c_valid = true; } - } + break; + case 2: + if (c_shape[0] == 1 && c_shape[1] == N) { + c_valid = true; + } + break; } - auto c_size = c_shape[c_dim - 1]; - if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" - << (transB == 0 ? "1" : "0") << "]" - << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" - << " c_size: " << c_size; + if (!c_valid) { + LOGS(logger, VERBOSE) << "Shape of C Gemm input must be {}, {1}, {N}, or {1, N}. N:" << N << " C shape:" + << Shape2String(c_shape); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index 01aced739b36d..17910ba6fd486 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -19,104 +19,176 @@ class PoolOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - - auto* coreml_pool = layer->mutable_pooling(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - bool is_global_pooling = false; - if (op_type == "GlobalAveragePool") { - is_global_pooling = true; - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); - } else if (op_type == "GlobalMaxPool") { - is_global_pooling = true; - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); - } else if (op_type == "AveragePool") { - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); - } else if (op_type == "MaxPool") { - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unknown op: ", op_type); - } +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::string_view coreml_op_type; + bool is_global = false; + bool is_avg_pool = false; + if (op_type == "GlobalAveragePool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_mean + coreml_op_type = "reduce_mean"; + is_global = true; + } else if (op_type == "GlobalMaxPool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_max + coreml_op_type = "reduce_max"; + is_global = true; + } else if (op_type == "AveragePool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.avg_pool + coreml_op_type = "avg_pool"; + is_avg_pool = true; + } else if (op_type == "MaxPool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.max_pool + coreml_op_type = "max_pool"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type); + } - if (is_global_pooling) { - coreml_pool->set_globalpooling(true); - coreml_pool->mutable_valid(); - } else { // AveragePool or MaxPool - NodeAttrHelper helper(node); - const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); - const auto strides = helper.Get("strides", std::vector{1, 1}); - const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - - coreml_pool->add_kernelsize(kernel_shape[0]); - coreml_pool->add_kernelsize(kernel_shape[1]); - coreml_pool->add_stride(strides[0]); - coreml_pool->add_stride(strides[1]); - coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0); - coreml_pool->set_globalpooling(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1], - onnx_pads, strides, {1, 1} /* dilations */, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_pool->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + + AddOperationInput(*op, "x", input_defs[0]->Name()); + + if (is_global) { + // keep N and C dims, reduce the rest with keepdims=True. equivalent to the ONNX Global*Pool ops. + std::vector axes{2, 3}; // we only support 4D input currently. + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", axes)); + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", true)); + } else { + NodeAttrHelper helper(node); + constexpr int num_spatial_dims = 2; // we only support 4D. -2 for N and C dims. + + AddPadTypeAndPads(*op, model_builder, op->type(), helper, num_spatial_dims); + + const auto kernel_shape = helper.GetInt64s("kernel_shape"); // required + AddOperationInput(*op, "kernel_sizes", model_builder.AddConstant(op->type(), "kernel_sizes", *kernel_shape)); + + // in theory all these values are optional according to the CoreML spec but simpler to just provide default + // values as the actual model compilation tends to require them. + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + const bool ceil_mode = helper.Get("ceil_mode", int64_t(0)); // convert int64_t to bool + + AddOperationInput(*op, "strides", model_builder.AddConstant(op->type(), "strides", strides)); + AddOperationInput(*op, "ceil_mode", model_builder.AddScalarConstant(op->type(), "ceil_mode", ceil_mode)); + + if (is_avg_pool) { + const bool count_exclude_pad = helper.Get("count_include_pad", int64_t(0)) == 0; + AddOperationInput(*op, "exclude_padding_from_average", + model_builder.AddScalarConstant(op->type(), "count_exclude_pad", count_exclude_pad)); } + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto* coreml_pool = layer->mutable_pooling(); + + bool is_global_pooling = false; + if (op_type == "GlobalAveragePool") { + is_global_pooling = true; + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); + } else if (op_type == "GlobalMaxPool") { + is_global_pooling = true; + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); + } else if (op_type == "AveragePool") { + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); + } else if (op_type == "MaxPool") { + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); } else { - auto* padding_type = coreml_pool->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type); + } + + if (is_global_pooling) { + coreml_pool->set_globalpooling(true); + coreml_pool->mutable_valid(); + } else { // AveragePool or MaxPool + NodeAttrHelper helper(node); + const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + const auto strides = helper.Get("strides", std::vector{1, 1}); + const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + + coreml_pool->add_kernelsize(kernel_shape[0]); + coreml_pool->add_kernelsize(kernel_shape[1]); + coreml_pool->add_stride(strides[0]); + coreml_pool->add_stride(strides[1]); + coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0); + coreml_pool->set_globalpooling(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1], + onnx_pads, strides, {1, 1} /* dilations */, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_pool->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_pool->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } } } - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) + if (!GetShape(*input_defs[0], input_shape, logger)) { return false; + } + // TODO: ML Program supports 3D and 5D. Add if we have a use case for that. const auto input_size = input_shape.size(); if (input_size != 4) { - LOGS(logger, VERBOSE) - << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; + LOGS(logger, VERBOSE) << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; return false; } if (op_type == "AveragePool" || op_type == "MaxPool") { NodeAttrHelper helper(node); + const auto storage_order = helper.Get("storage_order", 0); if (storage_order == 1) { LOGS(logger, VERBOSE) << "storage_order == 1 is not supported"; @@ -128,12 +200,14 @@ bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - // TODO, add support of the ceil_mode by adjusting the padding - // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode - // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644 - if (helper.Get("ceil_mode", 0) == 1) { - LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling"; - return false; + if (!input_params.create_mlprogram) { + // TODO, add support of the ceil_mode by adjusting the padding + // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode + // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644 + if (helper.Get("ceil_mode", 0) == 1) { + LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling"; + return false; + } } if (helper.Get("dilations", std::vector{1, 1}) != diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 7ae1746be3122..27d24d9c21893 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" -#include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -26,34 +25,56 @@ class ReshapeOpBuilder : public BaseOpBuilder { // Reshape opset 4- uses attributes for new shape which we do not support for now int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } + + bool SupportsMLProgram() const override { return true; } }; void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip the second input which is the new shape as we always have to create a new version as the CoreML rules + // are different from ONNX. model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); - const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() - ? reinterpret_cast(target_shape_tensor.raw_data().data()) - : target_shape_tensor.int64_data().data(); - - const auto size = target_shape_tensor.dims()[0]; - TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; std::vector input_shape; - ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - ReshapeHelper helper(TensorShape(input_shape), target_shape); - *layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape of data"); + + const auto& data_name = input_defs[0]->Name(); + const auto& new_shape_name = input_defs[1]->Name(); + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(new_shape_name)); + TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan()); + + // ReshapeHelper applies the ONNX rules to create the concrete output shape + ReshapeHelper helper(TensorShape(input_shape), new_shape); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - model_builder.AddLayer(std::move(layer)); + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.reshape + std::unique_ptr reshape_op = model_builder.CreateOperation(node, "reshape"); + + AddOperationInput(*reshape_op, "x", data_name); + AddOperationInput(*reshape_op, "shape", + model_builder.AddConstant(reshape_op->type(), "shape", ToConstSpan(new_shape))); + + AddOperationOutput(*reshape_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(reshape_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + *layer->mutable_reshapestatic()->mutable_targetshape() = {new_shape.cbegin(), new_shape.cend()}; + *layer->mutable_input()->Add() = data_name; + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } @@ -61,14 +82,15 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& new_shape_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, new_shape_name)) { + const auto* new_shape_tensor = input_params.graph_viewer.GetConstantInitializer(new_shape_name); + if (!new_shape_tensor) { + // ONNX has different rules around how -1 and 0 values are used/combined, and + // we can't check if those can be translated to CoreML if the shape is unknown. LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; return false; } - const auto& new_shape_tensor = *initializers.at(new_shape_name); - Initializer unpacked_tensor(new_shape_tensor); + Initializer unpacked_tensor(*new_shape_tensor); auto new_shape = unpacked_tensor.DataAsSpan(); if (new_shape.empty()) { LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; @@ -84,7 +106,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP return false; } - // CoreML reshape doesn't support new shape with more than 5 dimensions + // CoreML reshape doesn't support new shape with more than 5 dimensions. if (new_shape.size() > 5) { LOGS(logger, VERBOSE) << "Reshape does not support new shape with rank greater than 5. Input shape: " << Shape2String(input_shape) << ", new shape: " << Shape2String(new_shape); @@ -93,7 +115,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP // CoreML reshape does not support 0 as dimension NodeAttrHelper helper(node); - const bool allow_zero = helper.Get("allowzero ", 0) == 1; + const bool allow_zero = helper.Get("allowzero", 0) == 1; if (allow_zero) { if (std::find(new_shape.begin(), new_shape.end(), int64_t{0}) != new_shape.end()) { LOGS(logger, VERBOSE) << "Reshape does not support new shape with 0 as dimension when allowzero is enabled. " diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 35dcde41a6bcf..6c2fcc2ace856 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -98,7 +98,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); - if (input_defs.size() == 3) { // use scales + if (input_defs.size() >= 3 && input_defs[2]->Exists()) { // use scales std::vector scales; ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); coreml_upsample->add_scalingfactor(static_cast(scales[2])); @@ -182,20 +182,24 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } + bool using_scales = input_defs.size() >= 3 && input_defs[2]->Exists(); // scales - if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { - LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + if (using_scales && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "scales input of Resize must be a constant initializer"; return false; } // sizes - if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { - LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + if (!using_scales && + (input_defs.size() < 4 || + !input_defs[3]->Exists() || + !input_params.graph_viewer.GetConstantInitializer(input_defs[3]->Name()))) { + LOGS(logger, VERBOSE) << "sizes input of Resize must be a constant initializer"; return false; } // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (input_defs.size() == 3) { // we are using scales + if (using_scales) { std::vector scales; if (!GetResizeScales(initializers, node, scales, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index b716af738e1b1..39bfbfe5bba1f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -54,7 +54,7 @@ Status PrepareSliceComputeMetadataFromConstantInitializers(const Node& slice_nod return Status::OK(); } - const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name(), true); + const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name()); ORT_RETURN_IF_NOT(tensor_proto, "Failed to get constant initializer."); Initializer unpacked_tensor(*tensor_proto, graph_viewer.ModelPath()); const auto data_type = unpacked_tensor.data_type(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index 266396a0fe90e..d6584124c6aba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -52,7 +52,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, target_shape.push_back(size_to_dimension); target_shape.push_back(size_from_dimension); - const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); + const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); { // Add reshape layer auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; @@ -60,7 +60,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, *reshape_layer->mutable_output()->Add() = reshape1_output_name; model_builder.AddLayer(std::move(reshape_layer)); } - const auto softmax_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "softmax_output")); + const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); { auto* coreml_softmaxnd = layer->mutable_softmaxnd(); coreml_softmaxnd->set_axis(-1); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index daab36f7b933d..eb4723a3b9746 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -144,14 +144,18 @@ void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_prot break; } case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - // from: int64_data/raw, to: longints - if (has_raw_data) { - CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); - - } else { - tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); - } - break; + // enable when this is proven to not be the case + ORT_THROW( + "INT64 is unexpected as CoreML uses 32-bit int for indices. " + "Most likely an initializer that should have been skipped was not."); + //// from: int64_data/raw, to: longints + // if (has_raw_data) { + // CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + + //} else { + // tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); + //} + // break; } case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { // from: int32_data/raw, to: bytes @@ -186,18 +190,22 @@ void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_prot break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { - // from: uint64_data/raw, to: longints - if (has_raw_data) { - CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); - } else { - // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this - // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each - // individual value. - tensor_value.mutable_longints()->mutable_values()->CopyFrom( - reinterpret_cast&>(tensor_proto.uint64_data())); - } - - break; + // enable when this is proven to not be the case + ORT_THROW( + "UINT64 is unexpected as CoreML uses 32-bit int for indices. " + "Most likely an initializer that should have been skipped was not."); + //// from: uint64_data/raw, to: longints + // if (has_raw_data) { + // CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + // } else { + // // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this + // // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each + // // individual value. + // tensor_value.mutable_longints()->mutable_values()->CopyFrom( + // reinterpret_cast&>(tensor_proto.uint64_data())); + // } + + // break; } case ONNX_NAMESPACE::TensorProto_DataType_BOOL: { // from: int32_data/raw, to: bools @@ -392,23 +400,28 @@ std::string GetModelOutputPath(bool create_ml_program) { } // namespace ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags) + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names) : graph_viewer_(graph_viewer), logger_(logger), coreml_version_(coreml_version), coreml_flags_(coreml_flags), create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), model_output_path_(GetModelOutputPath(create_ml_program_)), + onnx_input_names_(std::move(onnx_input_names)), + onnx_output_names_(std::move(onnx_output_names)), coreml_model_(std::make_unique()) { if (create_ml_program_) { #if defined(COREML_ENABLE_MLPROGRAM) coreml_model_->set_specificationversion(CoreMLSpecVersion()); MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); - MILSpec::Function& main = (*mlprogram.mutable_functions())["main"]; + mlprogram.set_version(1); + mlprogram_main_fn_ = &(*mlprogram.mutable_functions())["main"]; const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion()); - *main.mutable_opset() = coreml_opset; - mlprogram_main_ = &(*main.mutable_block_specializations())[coreml_opset]; + *mlprogram_main_fn_->mutable_opset() = coreml_opset; + mlprogram_main_block_ = &(*mlprogram_main_fn_->mutable_block_specializations())[coreml_opset]; // create the ModelPackage. this creates the output directory. mlpackage_ = std::make_unique(model_output_path_, /* create */ true); @@ -426,6 +439,8 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); #else // should never happen due to handling in coreml_execution_provider.cc + // throw here so all other code in this class can assume create_ml_program_ is only ever true in a build + // where ML Program support is enabled. ORT_THROW("ML Program is not enabled in this build"); #endif } else { @@ -435,6 +450,28 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge neural_network->set_arrayinputshapemapping( CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); } + + // populate names. + const auto& initializers = graph_viewer_.GetAllInitializedTensors(); + const auto& inputs = graph_viewer_.GetInputs(); + // rough guess to try and avoid reallocs. most nodes produce one output but some have more so allow for that. + // also need to convert attributes to constants so allow for that + unique_names_.reserve(initializers.size() + inputs.size() + size_t(graph_viewer_.NumberOfNodes() * 1.5)); + for (const auto& pair : initializers) { + unique_names_.insert(pair.first); + } + + for (const auto* input : inputs) { + unique_names_.insert(input->Name()); + } + + for (const auto& node : graph_viewer_.Nodes()) { + for (const auto& def : node.OutputDefs()) { + if (def->Exists()) { + unique_names_.insert(def->Name()); + } + } + } } ModelBuilder::~ModelBuilder() = default; @@ -455,11 +492,94 @@ void ModelBuilder::AddLayer(std::unique_ptr layer) { neural_network->mutable_layers()->AddAllocated(layer.release()); } -#if defined(COREML_ENABLE_MLPROGRAM) - /* * ML Program related helpers */ +#if defined(COREML_ENABLE_MLPROGRAM) +const std::string& ModelBuilder::GetSafeName(const std::string& name) { + // Check the name is valid according to the MILSpec rules + // `Identifiers, generally used for names and keys, must match the regular expression [A-Za-z\_][A-Za-z0-9\_@]*.` + // + // There is a secondary list of reserved words that the coremltools python uses, but it's not clear if those are + // required here, or if we will ever hit a model that uses one of them. Due to that, skip checking them for now as + // it adds cost and code complexity + // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L151C1-L175C10 + // static InlinedHashSet reserved_names = + // {"any", "bool", "program", "func", "tensor", "list", "dict", "tuple", "true", "false", + // "string", "bf16", "fp16", "fp32", "fp64", "int8", "int16", "int32", "int64", + // "uint8", "uint16", "uint32", "uint64"}; + + // handle empty name. shouldn't happen but code below assumes name is not empty + if (name.empty()) { + return name; + } + + // We don't need '@' or '\' even though they're allowed. Optimize for a good name that does not need to be changed. + + // has been sanitized and changed already + const auto entry = values_to_rename_.find(name); + if (entry != values_to_rename_.end()) { + return entry->second; + } + + // Replace anything but a good char with '_'. If first char is 0-9 we prefix with '_'; + bool changed = false; + std::string result = name; + + if (std::isdigit(result[0])) { + changed = true; + result = '_' + name; + } + + for (char& c : result) { + if (!std::isalnum(c) && c != '_') { + changed = true; + c = '_'; + } + } + + if (!changed) { + return name; // return original as the return value is a reference that must remain valid + } + + return (values_to_rename_[name] = GetUniqueName(result)); +} + +void ModelBuilder::SanitizeNames() { + // ML Model level inputs/outputs + auto* desc = coreml_model_->mutable_description(); + for (auto& input : *desc->mutable_input()) { + input.set_name(GetSafeName(input.name())); + } + + for (auto& output : *desc->mutable_output()) { + output.set_name(GetSafeName(output.name())); + } + + // main function inputs/outputs. + for (auto& input : *mlprogram_main_fn_->mutable_inputs()) { + input.set_name(GetSafeName(input.name())); + } + + // outputs from block with operations for current coreml version + for (auto& output : *mlprogram_main_block_->mutable_outputs()) { + output = GetSafeName(output); + } + + // iterate operations changing input/output/node names + for (auto& op : *mlprogram_main_block_->mutable_operations()) { + for (auto& input : *op.mutable_inputs()) { + for (auto& arg : *input.second.mutable_arguments()) { + arg.set_name(GetSafeName(arg.name())); + } + } + + for (auto& output : *op.mutable_outputs()) { + output.set_name(GetSafeName(output.name())); + } + } +} + std::unique_ptr ModelBuilder::CreateOperation(const Node& node, std::string_view op_type, std::string_view suffix) { @@ -472,14 +592,9 @@ std::unique_ptr ModelBuilder::CreateOperation(c return op; } -void ModelBuilder::AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer) { - MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(initializer, *weights_file_writer_); - AddConstantOperation(name, std::move(coreml_tensor)); -} - -void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { +const std::string& ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic - MILSpec::Operation& const_op = *mlprogram_main_->mutable_operations()->Add(); + MILSpec::Operation& const_op = *mlprogram_main_block_->mutable_operations()->Add(); const_op.set_type("const"); MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add(); @@ -487,58 +602,63 @@ void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& *output.mutable_type() = coreml_tensor.type(); auto& attr_map = *const_op.mutable_attributes(); - attr_map["name"] = CreateScalarTensorValue(std::string(name)); + // the operation name doesn't really matter as it isn't used elsewhere, so sanitize name now + attr_map["name"] = CreateScalarTensorValue(GetSafeName(output.name())); attr_map["val"] = std::move(coreml_tensor); + + return output.name(); } // Add operation to the Block for the main function in the ML Program void ModelBuilder::AddOperation(std::unique_ptr operation) { - mlprogram_main_->mutable_operations()->AddAllocated(operation.release()); + mlprogram_main_block_->mutable_operations()->AddAllocated(operation.release()); } -std::string ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, - MILSpec::Value&& input_value) { +const std::string& ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, + std::string_view value_type, + MILSpec::Value&& input_value) { auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type)); - AddConstantOperation(unique_value_name, std::move(input_value)); - return unique_value_name; + return AddConstantOperation(unique_value_name, std::move(input_value)); } template -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { // add specialization below static_assert(false_for_T, "Missing specialization for value type"); - return ""; // unreachable + + return "ModelBuilder::AddConstant error"; // unreachable } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); // CoreML uses int32 return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } @@ -581,11 +701,13 @@ Status ModelBuilder::RegisterInitializers() { continue; } - if (create_ml_program_) { #if defined(COREML_ENABLE_MLPROGRAM) - AddConstant(name, tensor); + if (create_ml_program_) { + MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(tensor, *weights_file_writer_); + ORT_IGNORE_RETURN_VALUE(AddConstantOperation(name, std::move(coreml_tensor))); + } else #endif - } else { + { std::unique_ptr layer = std::make_unique(); layer->set_name(GetUniqueName("initializer_" + name)); @@ -616,32 +738,33 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (is_input) { // input should not be an initializer - if (Contains(GetInitializerTensors(), name)) + if (Contains(GetInitializerTensors(), name)) { return Status::OK(); + } // This input will not be used - if (Contains(skipped_inputs_, name)) + if (Contains(skipped_inputs_, name)) { return Status::OK(); + } } auto* model_description = coreml_model_->mutable_description(); - auto& input_output = is_input - ? *model_description->mutable_input()->Add() - : *model_description->mutable_output()->Add(); + auto& input_output = is_input ? *model_description->mutable_input()->Add() + : *model_description->mutable_output()->Add(); input_output.set_name(name); + auto* multi_array = input_output.mutable_type()->mutable_multiarraytype(); std::vector shape; - ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), - "Unable to get shape for ", input_output_type, ": ", name); + ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), "Unable to get shape for ", input_output_type, ": ", name); if (shape.empty()) { - // If we have an empty shape, this is a scalar input, - // Since all the input output of CoreML EP is MultiArray, we will make the scalar input output as a {1} MultiArray + // If we have an empty shape, this is a scalar + // Since all the input/output of CoreML EP is MultiArray, we will make the scalar input/output a {1} MultiArray shape.push_back(1); - // we need to change the shapes of these scalar outputs back to {} when CoreML EP returns these values to ORT + // we need to change the shapes of scalar outputs back to {} when CoreML EP returns values to ORT if (!is_input) { AddScalarOutput(name); } @@ -713,13 +836,20 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i #if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { - MILSpec::Function& main = (*coreml_model_->mutable_mlprogram()->mutable_functions())["main"]; if (is_input) { - // the model inputs need to be wired up as args to the 'main' function - main.mutable_inputs()->Add(CreateNamedTensorValueType(node_arg)); + // the model inputs need to be wired up as args to the 'main' function. + auto tensor_value_type = CreateNamedTensorValueType(node_arg); + tensor_value_type.set_name(name); + if (node_arg.Shape()->dim_size() == 0) { + // update shape from {} to {1} (same change we made at the model input level above). + tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1); + tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1); + } + + mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type)); } else { // the model outputs need to be set as outputs of the Block for the 'main' function - *mlprogram_main_->mutable_outputs()->Add() = node_arg.Name(); + *mlprogram_main_block_->mutable_outputs()->Add() = name; } } #endif // defined(COREML_ENABLE_MLPROGRAM) @@ -744,7 +874,7 @@ Status ModelBuilder::ProcessNodes() { // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node.Name(), "], type [", node.OpType(), "] is not supported"); + "Node [", node.Name(), "], type [", node.OpType(), "] was not able to be processed"); } } @@ -767,6 +897,12 @@ Status ModelBuilder::CreateModel() { ORT_RETURN_IF_ERROR(ProcessNodes()); ORT_RETURN_IF_ERROR(RegisterModelOutputs()); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + SanitizeNames(); + } +#endif + return Status::OK(); } @@ -795,7 +931,7 @@ Status ModelBuilder::SaveModel() { #if defined(COREML_ENABLE_MLPROGRAM) // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program // related types as well. - mlprogram_main_ = nullptr; + mlprogram_main_block_ = nullptr; mlpackage_.reset(); weights_file_writer_.reset(); #endif @@ -804,11 +940,51 @@ Status ModelBuilder::SaveModel() { } Status ModelBuilder::LoadModel(std::unique_ptr& model) { - model = std::make_unique(model_output_path_, - std::move(input_output_info_), - std::move(scalar_outputs_), - std::move(int64_outputs_), - logger_, coreml_flags_); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + // we need to provide the sanitized names for model inputs/outputs so that info is captured. + // the input/output matching when we execute the model from the CoreML EP is based on order, so the change + // to the names doesn't matter for that. + auto get_sanitized_names = [this](std::vector&& names) -> std::vector { + std::vector output(std::move(names)); + + for (std::string& name : output) { + name = GetSafeName(name); + } + + return output; + }; + + // also need to update the keys in input_output_info_ + auto get_sanitized_io_info = [this](std::unordered_map&& info) { + std::unordered_map output; + output.reserve(info.size()); + + for (auto entry = info.begin(), end = info.end(); entry != end; ++entry) { + output.emplace(GetSafeName(entry->first), std::move(entry->second)); + } + + return output; + }; + + model = std::make_unique(model_output_path_, + get_sanitized_names(std::move(onnx_input_names_)), + get_sanitized_names(std::move(onnx_output_names_)), + get_sanitized_io_info(std::move(input_output_info_)), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + } else +#endif + { + model = std::make_unique(model_output_path_, + std::move(onnx_input_names_), + std::move(onnx_output_names_), + std::move(input_output_info_), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + } return model->LoadModel(); // load using CoreML API, including compilation } @@ -816,8 +992,11 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { // static Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names, std::unique_ptr& model) { - ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags); + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags, + std::move(onnx_input_names), std::move(onnx_output_names)); ORT_RETURN_IF_ERROR(builder.CreateModel()); ORT_RETURN_IF_ERROR(builder.SaveModel()); @@ -847,20 +1026,31 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) { skipped_inputs_.insert(input_name); } -std::string ModelBuilder::GetUniqueName(std::string_view base_name) { +const std::string& ModelBuilder::GetUniqueName(const std::string& base_name) { + if (unique_names_.find(base_name) == unique_names_.end()) { + return *unique_names_.insert(base_name).first; + } + std::string unique_name; - do { - std::ostringstream os; - os << base_name << "_token_" << name_token_++; - unique_name = os.str(); - } while (Contains(unique_names_, unique_name)); + std::string suffix; + + // supports up to 1000 unique names without having to grow in the loop + unique_name.reserve(base_name.size() + 5); + unique_name = base_name; + + while (Contains(unique_names_, unique_name)) { + // assign followed by += to avoid creating temporary strings. + unique_name = base_name; + unique_name += "__"; + unique_name += std::to_string(name_token_++); + } - return unique_name; + return *unique_names_.insert(unique_name).first; } -std::string ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { +const std::string& ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { if (node.Name().empty()) { - return GetUniqueName(MakeString("Node_", node.Index(), "_", node.OpType(), suffix)); + return GetUniqueName(MakeString(node.OpType(), "_", node.Index(), suffix)); } else { return GetUniqueName(node.Name() + std::string(suffix)); } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 961ba647257b5..8f85ab2c09e7c 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -25,17 +25,20 @@ namespace onnxruntime { namespace coreml { class IOpBuilder; -class Model; class ModelBuilder { private: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags); + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names); public: // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names, std::unique_ptr& model); ~ModelBuilder(); @@ -101,8 +104,8 @@ class ModelBuilder { /// /// Unique name generated for value. template - std::string AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, - std::optional> shape = std::nullopt) { + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v || @@ -113,8 +116,8 @@ class ModelBuilder { } template - std::string AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, - std::optional> shape = std::nullopt) { + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, + std::optional> shape = std::nullopt) { return AddConstant(op_type, value_type, AsSpan(value), shape); } @@ -122,17 +125,10 @@ class ModelBuilder { /// Add a scalar value as a 'const' operation. See AddConstant for details. /// template - std::string AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { + std::string_view AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { return AddConstant(op_type, value_type, AsSpan({value}), AsSpan({})); } - /// - /// Add an existing a constant ONNX initializer to the ML Program as a 'const' operation - /// - /// Initializer name - /// Initializer data - void AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer); - // add the operation to the main function void AddOperation(std::unique_ptr operation); #endif @@ -149,18 +145,26 @@ class ModelBuilder { // be added to CoreML model, since CoreML does not like input unused void AddInputToSkip(const std::string& input_name); - std::string GetUniqueName(std::string_view base_name); - std::string GetUniqueName(const Node& node, std::string_view suffix); + const std::string& GetUniqueName(const std::string& base_name); + const std::string& GetUniqueName(const Node& node, std::string_view suffix); + + const logging::Logger& Logger() const { return logger_; } private: #if defined(COREML_ENABLE_MLPROGRAM) template - std::string AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, - std::optional> shape = std::nullopt); - - void AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); - std::string AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, - COREML_SPEC::MILSpec::Value&& input_value); + std::string_view AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt); + + // apply the CoreML naming rules and fix any invalid names. + const std::string& GetSafeName(const std::string& name); + // sanitize all the names in the ML Model + void SanitizeNames(); + + // add Value as a const operation. return value name in case sanitization changed it + const std::string& AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); + const std::string& AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + COREML_SPEC::MILSpec::Value&& input_value); #endif // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk. @@ -193,6 +197,9 @@ class ModelBuilder { const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel + std::vector onnx_input_names_; + std::vector onnx_output_names_; + std::unique_ptr coreml_model_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; @@ -208,9 +215,19 @@ class ModelBuilder { // mlprogram_main_ is the main block of the CoreML ML Program. // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] // entry we create. - COREML_SPEC::MILSpec::Block* mlprogram_main_{nullptr}; + COREML_SPEC::MILSpec::Function* mlprogram_main_fn_{nullptr}; // Function that contains a Block with the operations + COREML_SPEC::MILSpec::Block* mlprogram_main_block_{nullptr}; // Block that all the operations are added to std::unique_ptr mlpackage_; std::unique_ptr weights_file_writer_; + + // Values must start with [a-zA-A_] + // Additionally they can't be in a list of reserved words. + // If we need to sanitize an initializer name we do so during PreprocessInitializers and apply the change during + // RegisterInitializers. + // We also check inputs in AddOperation and apply the change there. + // This means an op builder author doesn't need to be aware of the renaming. + // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L146-L149 + std::unordered_map values_to_rename_; #endif }; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 8e718da07703c..0ba715cc7c6d9 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -114,28 +114,27 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector& node_compute_funcs) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); std::unique_ptr coreml_model; - ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, - coreml_model)); - { - const auto& input_defs = fused_node.InputDefs(); - std::vector onnx_input_names(input_defs.size()); - for (size_t i = 0, end = input_defs.size(); i < end; ++i) { - onnx_input_names[i] = input_defs[i]->Name(); - } - coreml_model->SetOnnxInputs(std::move(onnx_input_names)); - } + auto get_names = [](const ConstPointerContainer>& args) -> std::vector { + std::vector names; + names.reserve(args.size()); - { - const auto& output_defs = fused_node.OutputDefs(); - std::vector onnx_output_names(output_defs.size()); - for (size_t i = 0, end = output_defs.size(); i < end; ++i) { - onnx_output_names[i] = output_defs[i]->Name(); - } - coreml_model->SetOnnxOutputs(std::move(onnx_output_names)); + for (const NodeArg* def : args) { + names.push_back(def->Name()); + } + + return names; + }; + + std::vector onnx_input_names = get_names(fused_node.InputDefs()); + std::vector onnx_output_names = get_names(fused_node.OutputDefs()); + + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + std::move(onnx_input_names), std::move(onnx_output_names), + coreml_model)); } coreml_models_.emplace(fused_node.Name(), std::move(coreml_model)); @@ -153,13 +152,14 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector(state); - const auto& model_inputs = model->GetOnnxInputs(); - const auto& model_outputs = model->GetOnnxOutputs(); + + // input/output names used by the CoreML model in the order that matches the fused_node InputDefs/OutputDefs + const auto& model_inputs = model->GetOrderedInputs(); + const auto& model_outputs = model->GetOrderedOutputs(); ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes"); ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes"); @@ -182,28 +182,25 @@ common::Status CoreMLExecutionProvider::Compile(const std::vectorshape; - ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape), - "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape), - ") but the runtime shape (", coreml::Shape2String(shape), - ") has zero elements. This is not supported by the CoreML EP."); - } + const auto& inferred_shape = input_info->shape; + ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape), + "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape), + ") but the runtime shape (", coreml::Shape2String(shape), + ") has zero elements. This is not supported by the CoreML EP."); // If we have an empty shape, this is a scalar input, // Since all the input output of CoreML EP is MultiArray, we will make the scalar input as a {1} MultiArray - if (shape.empty()) + if (shape.empty()) { shape.push_back(1); + } // CoreML MLMultiArray API expect input to be non-const // https://developer.apple.com/documentation/coreml/mlmultiarray/2881219-initwithdatapointer?language=objc void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); - inputs.emplace( - input_name, - coreml::OnnxTensorData{ - coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape}, - inputBuffer, - }); + inputs.emplace(input_name, coreml::OnnxTensorData{ + coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape}, + inputBuffer, + }); } // From this point we will need to take the exclusive lock on the model until the Predict is @@ -215,14 +212,13 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector static_shape) -> void* { + [&ctx, &model_outputs](const std::string& name, + int32_t requested_onnx_tensor_element_type, + gsl::span static_shape) -> void* { const auto model_output_it = std::find(model_outputs.begin(), model_outputs.end(), name); ORT_ENFORCE(model_output_it != model_outputs.end(), "Failed to find CoreML model output name: ", name); - const auto output_idx = gsl::narrow_cast(std::distance(model_outputs.begin(), model_output_it)); + const auto output_idx = gsl::narrow_cast(std::distance(model_outputs.begin(), model_output_it)); auto output_tensor = ctx.GetOutput(output_idx, static_shape.data(), static_shape.size()); const auto type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo(); @@ -243,13 +239,15 @@ common::Status CoreMLExecutionProvider::Compile(const std::vectorIsScalarOutput(output_name)) + if (model->IsScalarOutput(output_name)) { output_shape.clear(); + } // Since CoreML EP only accepts int32 output type and onnx requires int64 output, // We are going to set the model output (from int32) ->int64 - if (model->IsInt64Output(output_name)) + if (model->IsInt64Output(output_name)) { output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } outputs.emplace(output_name, coreml::OnnxTensorInfo{output_type, output_shape}); } diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py new file mode 100644 index 0000000000000..a3ceee70684dc --- /dev/null +++ b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py @@ -0,0 +1,27 @@ +import sys + +import coremltools as ct + +if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ") + print("If generated by onnxruntime this will be /Data/com.microsoft.onnxruntime/model.mlmodel") + sys.exit(-1) + +model_path = sys.argv[1] +m = ct.models.MLModel(model_path) + +spec = m.get_spec() +print(spec) + +# Example code if you want to filter output or do more advanced things +# main = spec.mlProgram.functions["main"] +# block = main.block_specializations[main.opset] +# print(f"{len(block.operations)} operators") +# for op in block.operations: +# if op.type == 'const': +# if op.attributes["name"].immediateValue.tensor.strings.values[0] == "conv_0_pad_type_0": +# print(f"Conv pad_type={op.attributes['val'].immediateValue.tensor.strings.values}") +# +# if op.type == 'conv': +# #print(op) +# pass diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index 4f9a014c4d885..a9991ccb945ce 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -67,6 +67,12 @@ int CoreMLVersion(); // Get a temporary macOS/iOS temp file path std::string GetTemporaryFilePath(); +#if !defined(NDEBUG) && defined(__APPLE__) +// Override location the model is written to so that a) it's easily found and b) it is not automatically deleted +// when the EP exits. Use to debug the model that is generated. +// See onnxruntime/core/providers/coreml/dump_mlprogram_model.py for a script to dump the ML Program. +constexpr const char* kOverrideModelOutputDirectoryEnvVar = "ORT_COREML_EP_MODEL_DIR"; +#endif } // namespace util } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 0ae0cf8f0d207..5487ea35388f5 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/platform/env.h" #include "core/providers/coreml/model/host_utils.h" #import @@ -31,6 +32,15 @@ int32_t CoreMLVersion() { std::string GetTemporaryFilePath() { // Get temporary directory for user. NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; + +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + NSString* ns_path_override = [NSString stringWithUTF8String:path_override.c_str()]; + temporary_directory_url = [NSURL fileURLWithPath:ns_path_override isDirectory:YES]; + } +#endif + // Generate a Unique file name to use. NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString]; diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index b940c4b768aec..e3cd43d786fc3 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -35,6 +35,8 @@ using GetOutputTensorMutableRawDataFn = std::function&& model_input_names, + std::vector&& model_output_names, std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, @@ -60,12 +62,11 @@ class Model { // Mutex for exclusive lock to this model object OrtMutex& GetMutex() { return mutex_; } - // Input and output names in the onnx model's order - const std::vector& GetOnnxInputs() const { return onnx_inputs_; } - void SetOnnxInputs(std::vector&& inputs) { onnx_inputs_ = std::move(inputs); } - - const std::vector& GetOnnxOutputs() const { return onnx_outputs_; } - void SetOnnxOutputs(std::vector&& outputs) { onnx_outputs_ = std::move(outputs); } + // Input and output names in the ORT fused node's order. + // Names may have been adjusted from the originals due to CoreML naming rules. + // We do inputs/outputs based on order at the ONNX level so this doesn't matter. + const std::vector& GetOrderedInputs() const { return model_input_names_; } + const std::vector& GetOrderedOutputs() const { return model_output_names_; } const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const { const auto info_it = input_output_info_.find(name); @@ -80,13 +81,13 @@ class Model { private: std::unique_ptr execution_; + std::vector model_input_names_; // input names in the order of the ORT fused node's inputs + std::vector model_output_names_; // output names in the order of the ORT fused node's outputs + std::unordered_map input_output_info_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; - std::vector onnx_inputs_; - std::vector onnx_outputs_; - OrtMutex mutex_; }; diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index d5cd70bff9479..1434043e064f4 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -19,6 +19,7 @@ #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/graph/onnx_protobuf.h" +#include "core/platform/env.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" @@ -287,6 +288,14 @@ - (void)cleanup { compiled_model_path_ = nil; } +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + // don't cleanup + coreml_model_path_ = nil; + } +#endif + if (coreml_model_path_ != nil) { error = nil; [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; @@ -487,12 +496,16 @@ Status Predict(const std::unordered_map& inputs, } Model::Model(const std::string& path, + std::vector&& model_input_names, + std::vector&& model_output_names, std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& logger, uint32_t coreml_flags) : execution_(std::make_unique(path, logger, coreml_flags)), + model_input_names_(std::move(model_input_names)), + model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), scalar_outputs_(std::move(scalar_outputs)), int64_outputs_(std::move(int64_outputs)) { diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc index 087c9f8c05d5f..c6f2e7401ea1e 100644 --- a/onnxruntime/core/providers/coreml/model/model_stub.cc +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -9,12 +9,16 @@ namespace coreml { class Execution {}; Model::Model(const std::string& /*path*/, + std::vector&& model_input_names, + std::vector&& model_output_names, std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& /*logger*/, uint32_t /*coreml_flags*/) : execution_(std::make_unique()), + model_input_names_(std::move(model_input_names)), + model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), scalar_outputs_(std::move(scalar_outputs)), int64_outputs_(std::move(int64_outputs)) { diff --git a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h index 5961686674424..d7ceda16e61ea 100644 --- a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h +++ b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h @@ -37,12 +37,14 @@ class ReshapeHelper { if (unknown_dim != -1) { // calculate unknown dimension ORT_ENFORCE(size != 0 && (input_shape_size % size) == 0, - "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape)); + "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, + ", requested shape:", TensorShape(requested_shape)); requested_shape[unknown_dim] = input_shape_size / size; } else { // check if the output shape is valid. ORT_ENFORCE(input_shape_size == size, - "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape)); + "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, + ", requested shape:", TensorShape(requested_shape)); } } }; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7d4111e3b9c39..729ad34368453 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -64,17 +64,22 @@ namespace perftest { "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t [Usage]: -e -i '| |'\n" + "\n" "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" "\t [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" + "\n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" @@ -89,9 +94,8 @@ namespace perftest { "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" - "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" - "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" + "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" + "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" "\t [TensorRT only] [trt_max_workspace_size]: Set TensorRT maximum workspace size in byte.\n" @@ -108,20 +112,23 @@ namespace perftest { "\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n" "\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n" "\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n" - "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n" + "\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n" + "\n" "\t [NNAPI only] [NNAPI_FLAG_USE_FP16]: Use fp16 relaxation in NNAPI EP..\n" "\t [NNAPI only] [NNAPI_FLAG_USE_NCHW]: Use the NCHW layout in NNAPI EP.\n" "\t [NNAPI only] [NNAPI_FLAG_CPU_DISABLED]: Prevent NNAPI from using CPU devices.\n" "\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n" - "\t [Usage]: -e -i ' '\n\n" - "\t [Example] [For NNAPI EP] -e nnapi -i \" NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED \"\n" + "\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n" + "\n" + "\t [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM]: Create an ML Program model instead of Neural Network.\n" + "\t [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n" + "\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" "\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n" "\t [SNPE only] [buffer_type]: options: 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. default: ITENSOR'. \n" "\t [SNPE only] [enable_init_cache]: enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. \n" - "\t [Usage]: -e -i '| |' \n\n" - "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n" + "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n" + "\n" "\t-T [Set intra op thread affinities]: Specify intra op thread affinity string\n" "\t [Example]: -T 1,2;3,4;5,6 or -T 1-2;3-4;5-6 \n" "\t\t Use semicolon to separate configuration between threads.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 1934314b8ce43..9679ca6159464 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -468,7 +468,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); nnapi_flags |= NNAPI_FLAG_CPU_ONLY; } else if (key.empty()) { } else { - ORT_THROW("[ERROR] [NNAPI] wrong key type entered. Choose from the following runtime key options that are available for NNAPI. ['NNAPI_FLAG_USE_FP16', 'NNAPI_FLAG_USE_NCHW', 'NNAPI_FLAG_CPU_DISABLED', 'NNAPI_FLAG_CPU_ONLY'] \n"); + ORT_THROW( + "[ERROR] [NNAPI] wrong key type entered. Choose from the following runtime key options " + "that are available for NNAPI. " + "['NNAPI_FLAG_USE_FP16', 'NNAPI_FLAG_USE_NCHW', 'NNAPI_FLAG_CPU_DISABLED', 'NNAPI_FLAG_CPU_ONLY'] \n"); } } Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, nnapi_flags)); @@ -476,10 +479,31 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_THROW("NNAPI is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { +#ifdef __APPLE__ #ifdef USE_COREML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0)); + uint32_t coreml_flags = 0; + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; + std::istringstream ss(ov_string); + + std::string key; + while (ss >> key) { + if (key == "COREML_FLAG_CREATE_MLPROGRAM") { + coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + std::cout << "Enabling ML Program.\n"; + } else if (key.empty()) { + } else { + ORT_THROW( + "[ERROR] [CoreML] wrong key type entered. Choose from the following runtime key options " + "that are available for CoreML. ['COREML_FLAG_CREATE_MLPROGRAM'] \n"); + } + } + // COREML_FLAG_CREATE_MLPROGRAM + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags)); +#else + ORT_THROW("CoreML is not supported in this build\n"); +#endif #else - ORT_THROW("COREML is not supported in this build\n"); + ORT_THROW("COREML is not supported on this platform.\n"); #endif } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 7b6f1b9244be9..94817158017bd 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -192,5 +192,25 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { #endif } +// Test that we fix invalid names in model inputs, initializers and outputs. +// Names in CoreML cannot start with [0-9] or contain anything but "[a-z][A-Z][0-9]_" +TEST(CoreMLExecutionProviderTest, TestNameSanitization) { + OpTester test("Clip", 11); + + std::vector dims{3, 3}; + test.AddInput("0", dims, + {-1.0f, 0.0f, 1.0f, + -6.0f, 0.0f, 6.0f, + -5.4f, 2.0f, 6.0f}); + test.AddInput("1.min", {}, {-5}, true); // add as initializers + test.AddInput("2/max", {}, {5}, true); + test.AddOutput("3", dims, + {-1.0f, 0.0f, 1.0f, + -5.0f, 0.0f, 5.0f, + -5.0f, 2.0f, 5.0f}); + + // TensorRT does not support Clip opset 11 yet. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index efb46e86d04e4..b5d5f84df950a 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -182,7 +182,7 @@ TEST(MathOpTest, Clip) { run_test(true); } -// Use clip between [0, 6] as Relu6 (for some EPs, such as NNAPI) +// Use clip between [0, 6] as Relu6 to test optimized path in some EPs, such as NNAPI and CoreML TEST(MathOpTest, Clip_Relu6) { // To test NNAPI EP, we need the min/max to be in initializers auto run_test = [](bool min_max_are_initializer) { @@ -208,6 +208,31 @@ TEST(MathOpTest, Clip_Relu6) { run_test(true); } +// Use clip between [0, inf] as Relu to test optimized path in some EPs, such as CoreML +TEST(MathOpTest, Clip_Relu) { + // To test NNAPI EP, we need the min/max to be in initializers + auto run_test = [](bool min_max_are_initializer) { + OpTester test("Clip", 11); + + std::vector dims{3, 3}; + test.AddInput("X", dims, + {-1.0f, 0.0f, 1.0f, + -6.0f, 3.5f, 6.0f, + -5.4f, 2.0f, 8.0f}); + test.AddInput("min", {}, {0.0f}, min_max_are_initializer); + test.AddOutput("Y", dims, + {0.0f, 0.0f, 1.0f, + 0.0f, 3.5f, 6.0f, + 0.0f, 2.0f, 8.0f}); + + // TensorRT does not support Clip opset 11 yet. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + + run_test(false); + run_test(true); +} + // Use clip between [-1, 1] as Relu1 (for some EPs, such as NNAPI) TEST(MathOpTest, Clip_Relu1) { // To test NNAPI EP, we need the min/max to be in initializers diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index bf089e083d67e..428925e154497 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -281,24 +281,31 @@ using GemmOpTypedTestsTypes = ::testing::Types; TYPED_TEST_SUITE(GemmOpTypedTests, GemmOpTypedTestsTypes); TYPED_TEST(GemmOpTypedTests, TestGemmScalarBroadcast) { - OpTester test("Gemm"); + auto run_test = [](bool b_is_initializer, bool c_is_initializer) { + OpTester test("Gemm"); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), - static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); - test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); - test.AddInput("C", {1}, std::vector{static_cast(1.0f)}); - test.AddOutput("Y", {2, 3}, - {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), - static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); - test.Config(run_with_tunable_op) - .RunWithConfig(); + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f)), b_is_initializer); + test.AddInput("C", {1}, std::vector{static_cast(1.0f)}, c_is_initializer); + test.AddOutput("Y", {2, 3}, + {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), + static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); + test.Config(run_with_tunable_op) + .RunWithConfig(); + }; + + run_test(false, false); + // CoreML EP requires weight and bias to be initializers + run_test(true, true); } + TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_2) { OpTester test("Gemm"); diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index ee18cf2cea6cb..cbb4531a50b7c 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -75,6 +75,43 @@ TEST(BatchNormTest, PositiveTestCase) { input_data_map.insert({"mean", mean}); input_data_map.insert({"var", var}); + InputShapesMap input_shapes_map; + vector input_shape{1, 1, 7, 7}; + input_shapes_map.insert({"X", input_shape}); + input_shapes_map.insert({"scale", {1}}); + input_shapes_map.insert({"B", {1}}); + input_shapes_map.insert({"mean", {1}}); + input_shapes_map.insert({"var", {1}}); + + auto expected_output = {1.01359f, 0.703983f, 0.641631f, 1.08571f, 0.939167f, 0.762469f, 0.682729f, 0.762401f, 0.787021f, + 1.06744f, 0.604378f, 0.957476f, 0.667302f, 0.901764f, 1.07566f, 1.01117f, 0.928324f, 0.897667f, + 0.705842f, 0.660885f, 0.977291f, 0.878918f, 0.818345f, 1.06608f, 0.839057f, 1.04796f, 0.621471f, + 0.781831f, 0.760527f, 0.835665f, 1.05825f, 0.611442f, 0.781873f, 1.08437f, 0.907454f, 0.926173f, + 1.03375f, 0.707961f, 0.968646f, 0.621757f, 0.973095f, 0.700301f, 0.916723f, 0.807602f, 0.692598f, + 0.621972f, 0.707334f, 0.63723f, 0.63062f}; + float epsilon = 1e-05f; + TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape); +} + +TEST(BatchNormTest, PositiveTestCase_5D) { + // This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files. + vector X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f, + -0.485668f, 0.218049f, -0.360263f, 0.107016f, 0.45358f, 0.325056f, 0.15995f, 0.098852f, -0.283453f, -0.373051f, + 0.257542f, 0.0614853f, -0.0592363f, 0.434488f, -0.0179583f, 0.398374f, -0.451602f, -0.132009f, -0.174468f, + -0.0247169f, 0.418897f, -0.47159f, -0.131925f, 0.470943f, 0.118357f, 0.155664f, 0.370062f, -0.279229f, 0.240311f, + -0.451034f, 0.249178f, -0.294496f, 0.13683f, -0.0806475f, -0.309849f, -0.450604f, -0.28048f, -0.420197f, -0.433369f}; + vector scale{0.589433f}; + vector B{-0.384622f}; + vector mean{-2.45673f}; + vector var{1.37998f}; + + InputDataMap input_data_map; + input_data_map.insert({"X", X}); + input_data_map.insert({"scale", scale}); + input_data_map.insert({"B", B}); + input_data_map.insert({"mean", mean}); + input_data_map.insert({"var", var}); + InputShapesMap input_shapes_map; vector input_shape{1, 1, 7, 7, 1}; input_shapes_map.insert({"X", input_shape}); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 1d31f3fdb4eb4..5addb5dd9ce46 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -572,8 +572,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDmlExecutionProvider}); } -TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers +TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric_scales) { + // To test CoreML/NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; From acbfc29f272b5578145e7600bc42342e116ffbc2 Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 1 Mar 2024 10:57:14 +0800 Subject: [PATCH 092/237] Follow up fix for Gelu impl (#19693) ### Follow up fix for Gelu impl There are two minor comments in https://github.com/microsoft/onnxruntime/pull/19560. Fix them in this pull request. ### Motivation and Context --- docs/ORTModule_Training_Guidelines.md | 2 +- onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc | 8 +++----- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 4 +++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 91057d3dfb120..f50b18b736936 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -293,7 +293,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` -### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT +#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT - **Feature Area**: *ORTMODULE/Optimizations* - **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index e8974a29476b6..8b8e4e267f895 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -8,8 +8,7 @@ #include "contrib_ops/cpu/bert/bias_gelu_helper.h" #ifdef USE_ROCM #include "contrib_ops/rocm/bert/elementwise.h" -#endif -#ifdef USE_CUDA +#else #include "contrib_ops/cuda/bert/transformer_common.h" #endif @@ -36,7 +35,7 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { -#ifdef USE_CUDA +#ifndef USE_ROCM const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); #endif @@ -63,8 +62,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(input->Data()), static_cast(input_length), (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), reinterpret_cast(output->MutableData())); -#endif -#ifdef USE_CUDA +#else return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index d563556593e6e..26f3bd5a03928 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; // Only applicable to CUDA kernel (not ROCM). +#ifndef USE_ROCM + bool use_half2_; +#endif }; } // namespace cuda From ed550b5fe5aa41e182db84d2b2f2fb768121fd7a Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 29 Feb 2024 20:36:29 -0800 Subject: [PATCH 093/237] Change webgpu CI pipeline to use a preinstalled chrome (#19729) ### Description Change webgpu CI pipeline to use a preinstalled chrome. Hopefully it can increase the stability. Now the chrome got from puppeteer often failed to start. --- .../github/azure-pipelines/templates/win-web-ci.yml | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 043da233cc674..b882d6fb167fd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -31,6 +31,7 @@ jobs: variables: webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de' runCodesignValidationInjection: false + CHROME_BIN: 'C:\Program Files\Google\Chrome\Application\chrome.exe' timeoutInMinutes: 60 workspace: clean: all @@ -95,18 +96,6 @@ jobs: targetFolder: $(Build.SourcesDirectory)\js\web\lib\wasm\binding flattenFolders: true displayName: 'Binplace js files' - - script: | - npm i -g puppeteer - workingDirectory: '$(Build.SourcesDirectory)' - displayName: 'Use puppeteer to prepare Chrome for tests' - - script: | - FOR /F "tokens=* USEBACKQ" %%F IN (`where /r %HOMEDRIVE%%HOMEPATH%\.cache\puppeteer chrome.exe`) DO ( - SET var=%%F - ECHO found chrome.exe: %%F - ) - ECHO ##vso[task.setvariable variable=CHROME_BIN;]%var% - workingDirectory: '$(Build.SourcesDirectory)' - displayName: 'Set CHROME_BIN' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' From 5672cdebdf5648815fcc3a001dc00e610a9f9b51 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 1 Mar 2024 11:01:58 -0800 Subject: [PATCH 094/237] Update google benchmark to 1.8.3. (#19734) Update google benchmark to 1.8.3. Update deps_update_and_upload.py script to make it easier to use. --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 2 +- cmake/deps_update_and_upload.py | 135 ++++++++++++------ .../templates/download-deps.yml | 4 +- 4 files changed, 98 insertions(+), 45 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index efd901787fdb7..cfad59be6b4c0 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -116,7 +116,7 @@ "component": { "type": "git", "git": { - "commitHash": "361e8d1cfe0c6c36d30b39f1b61302ece5507320", + "commitHash": "344117638c8ff7e239044fd0fa7085839fc03021", "repositoryUrl": "https://github.com/google/benchmark.git" }, "comments": "google_benchmark" diff --git a/cmake/deps.txt b/cmake/deps.txt index cb431f8c77397..9cba25b00157d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -26,7 +26,7 @@ eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea0 flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 -google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 +google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py index d357284d91225..63df3f6f03869 100644 --- a/cmake/deps_update_and_upload.py +++ b/cmake/deps_update_and_upload.py @@ -1,56 +1,109 @@ -# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. -# Before running the script, increase the version number found at: +# If deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# +# Before running the script, find the latest version number at: # https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Increment it to obtain a new version number to use. +# # Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. -# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload -# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +# E.g.: +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 +# # check contents of C:/temp/onnxruntime_deps +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --no-download --do-upload +# +# Next, update the version number in tools/ci_build/github/azure-pipelines/templates/download-deps.yml. + +import argparse +import contextlib +import pathlib import re import subprocess -import os -import argparse import tempfile +script_dir = pathlib.Path(__file__).parent + parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") parser.add_argument( - "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" + "--root-path", + type=pathlib.Path, + help="Target root path for downloaded files. If not provided, a temporary directory is used.", +) +parser.add_argument( + "--version", + type=str, + help="Package version to publish", +) +parser.add_argument( + "--do-upload", + action="store_true", + dest="upload", + help="Upload the package to Azure Artifacts", +) +parser.add_argument( + "--no-download", + action="store_false", + dest="download", + help="Skip downloading the dependency files. " + "Use with '--do-upload' and '--root-path' to upload the package from existing dependency files.", ) -parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") -parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") args = parser.parse_args() -with open("cmake/deps.txt") as file: +if args.upload: + assert args.version is not None, "'--version' must be specified if uploading." + +if args.upload != args.download: + assert args.root_path is not None, "'--root-path' must be specified if only downloading or uploading." + +deps_path = script_dir / "deps.txt" +with open(deps_path) as file: text = file.read() lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] -root_path = args.root_path - -for line in lines: - url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) - filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) - full_path = os.path.join(root_path, filename) - subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) # noqa: PLW1510 - -package_name = "onnxruntime_build_dependencies" -version = args.version - -# Check if the user is logged in to Azure -result = subprocess.run("az account show", shell=True, capture_output=True, text=True) # noqa: PLW1510 -if "No subscriptions found" in result.stderr: - # Prompt the user to log in to Azure - print("You are not logged in to Azure. Please log in to continue.") - subprocess.run("az login", shell=True) # noqa: PLW1510 - -# Publish the package to Azure Artifacts if --no-upload is not specified - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) +with contextlib.ExitStack() as context_stack: + if args.root_path is not None: + root_path = args.root_path.resolve() + root_path.mkdir(parents=True, exist_ok=True) + else: + temp_dir_name = context_stack.enter_context(tempfile.TemporaryDirectory()) + root_path = pathlib.Path(temp_dir_name) + + if args.download: + print(f"Downloading dependencies to directory: {root_path}") + + dep_pattern = re.compile(r"^[^;]+;https://([^;]+);.*$") + + for line in lines: + match = dep_pattern.fullmatch(line) + if match is None: + continue + + dep_path = match[1] + url = f"https://{dep_path}" + full_path = root_path / dep_path + + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", str(full_path), url], check=True) + + package_name = "onnxruntime_build_dependencies" + version = args.version if args.version is not None else "VERSION_PLACEHOLDER" + + if args.upload: + # Check if the user is logged in to Azure + result = subprocess.run("az account show", shell=True, capture_output=True, text=True, check=False) + if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True, check=True) + + # Publish the package to Azure Artifacts if --do-upload is specified + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 95e34cd863915..01be343795a56 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.133 + version: 1.0.134 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.133 + version: 1.0.134 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 22176a5fa8fe97efe05a63c1e7bb89b0e54cd201 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 1 Mar 2024 13:44:29 -0800 Subject: [PATCH 095/237] disable gemm f16 on CPU (#19744) ### Description Temporarily disable fp16 gemm on CPU because it usually needs a following Cast which offsets the gain. Need more fp16 operators implementation and performance tuning. Also fix a fusion error of LayerNormalization. ### Motivation and Context --- .vscode/settings.json | 5 ++++- .../core/optimizer/layer_norm_fusion.cc | 14 +++++++++++++ .../providers/cpu/cpu_execution_provider.cc | 21 ------------------- .../test/providers/cpu/math/gemm_test.cc | 2 +- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3e2b1f31dd6cf..98d23090fd474 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,5 +21,8 @@ "cpplint.filters": [ "-build/include_subdir", "-runtime/references" - ] + ], + "files.associations": { + "span": "cpp" + } } diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index b6ad4fde6c1f7..ce696154adb6d 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -447,6 +447,13 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* x_input = has_leading_cast ? graph.GetNode(p_reduce_mean_input_node->Index())->MutableInputDefs()[0] : reduce_mean_node.MutableInputDefs()[0]; + + // CPU doesn't support fp16 + if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider && + x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + InlinedVector layer_norm_input_defs{x_input, scale, bias}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"), "LayerNormalization", @@ -689,6 +696,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr NodeArg* x_input = has_leading_cast ? graph.GetNode(p_pow_input_node->Index())->MutableInputDefs()[0] : pow_node.MutableInputDefs()[0]; + + // CPU doesn't support fp16 + if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider && + x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + InlinedVector layer_norm_input_defs{x_input, scale}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization", diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 37e7e42150413..7e0f919deb0a7 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -143,9 +143,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); @@ -335,9 +332,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); @@ -497,9 +491,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -606,9 +597,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); @@ -2617,15 +2605,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { MLFloat16, LeakyRelu)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 428925e154497..1a542fb67418e 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -277,7 +277,7 @@ class GemmOpTypedTests : public ::testing::Test { // On CPUs without fp16 instructions the tests will output a warning: // "registered execution providers CPUExecutionProvider were unable to run the model" // , then they will still pass. -using GemmOpTypedTestsTypes = ::testing::Types; +using GemmOpTypedTestsTypes = ::testing::Types; TYPED_TEST_SUITE(GemmOpTypedTests, GemmOpTypedTestsTypes); TYPED_TEST(GemmOpTypedTests, TestGemmScalarBroadcast) { From f06164ef8b8de42dd67ca2137f6996cdc87a3f72 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:50:06 -0800 Subject: [PATCH 096/237] [js/web] transfer input buffer back to caller thread (#19677) ### Description When using proxy worker, input buffers should be transferred back to the caller thread after `run()` call is done. Fixes #19488 --- js/web/lib/wasm/proxy-worker/main.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 6cbd38c76ccc8..3ce37a2d6b652 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -103,7 +103,7 @@ self.onmessage = (ev: MessageEvent): void => { } else { postMessage( {type, out: outputs} as OrtWasmMessage, - extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[])); } }, err => { From a0521f899e9d495d57ae044bd4a1fe4d17155782 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 1 Mar 2024 16:23:20 -0800 Subject: [PATCH 097/237] Enable CPUINFO for all Windows build (#19655) ### Description It was disabled in PR #9065. And the reason was: " api-ms-win-core-kernel32-legacy-*.dll wasn't available in Windows 8 and was added in Windows 10, so cpuinfo breaks our Windows 8 support. I'm disabling it again." We no longer support Windows 8. Therefore we can add CPUINFO back. ### Motivation and Context To make the code simpler. If in any case the library doesn't work as expected, we can submit a PR to their code base and fix it. --- .../external/onnxruntime_external_deps.cmake | 9 +- cmake/onnxruntime_common.cmake | 5 -- onnxruntime/core/common/cpuid_info.cc | 82 ++++++++----------- onnxruntime/core/common/cpuid_info.h | 19 ++--- 4 files changed, 42 insertions(+), 73 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 09d57164b4ee1..cb75b0b8751bb 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -256,14 +256,7 @@ if (onnxruntime_ENABLE_CPUINFO) set(CPUINFO_SUPPORTED TRUE) endif() if (WIN32) - # Exclude Windows ARM build and Windows Store - if (${onnxruntime_target_platform} MATCHES "^(ARM.*|arm.*)$" ) - message(WARNING "Cpuinfo not included for compilation problems with Windows ARM.") - set(CPUINFO_SUPPORTED FALSE) - elseif (WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) - message(WARNING "Cpuinfo not included non-Desktop builds") - set(CPUINFO_SUPPORTED FALSE) - endif() + set(CPUINFO_SUPPORTED TRUE) elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 6b8c2560b1714..fb56e3f3445d4 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -201,10 +201,6 @@ endif() if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) - if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) - # msvc compiler report syntax error with cpuinfo arm source files - # and cpuinfo does not have code for getting arm uarch info under windows - else() # Link cpuinfo if supported # Using it mainly in ARM with Android. # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. @@ -212,7 +208,6 @@ if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) endif() - endif() endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 711fd595e90fd..be881f6bc4bc2 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -52,6 +52,13 @@ #if defined(CPUINFO_SUPPORTED) #include +#if defined(CPUIDINFO_ARCH_ARM) +namespace onnxruntime { +// The following function is declared in "core/common/cpuid_uarch.h" but we cannot include the whole header file because +// some of its symbols are conflict with +void decodeMIDR(uint32_t midr, uint32_t uarch[1]); +} // namespace onnxruntime +#endif #else #include "core/common/cpuid_uarch.h" #endif // CPUINFO_SUPPORTED @@ -142,11 +149,6 @@ void CPUIDInfo::ArmLinuxInit() { // Pytorch CPUINFO only works on ARM linux or android // Assuming no hyper-threading, no NUMA groups #ifdef CPUINFO_SUPPORTED - pytorch_cpuinfo_init_ = cpuinfo_initialize(); - if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; - return; - } is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); @@ -239,52 +241,24 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } - - switch (lastUarch) { - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_cortex_a55r0: - case cpuinfo_uarch_cortex_a76: - case cpuinfo_uarch_neoverse_n1: - case cpuinfo_uarch_cortex_a77: - case cpuinfo_uarch_exynos_m4: - case cpuinfo_uarch_exynos_m5: - has_fp16_ = true; - break; - default: - break; - } - if (!has_fp16_) { - /* - * Detecting fp16 support. Different cores should have the same instruction set. - * So we just check the first ID_AA64PFR0_EL1 - * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), - */ - uint64_t ID_AA64PFR0_EL1; - unsigned long valsize = sizeof(uint64_t); - auto retCode = ::RegGetValueA( - HKEY_LOCAL_MACHINE, - "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", - "CP 4020", RRF_RT_REG_QWORD, nullptr, - &ID_AA64PFR0_EL1, &valsize); - if (retCode == ERROR_SUCCESS) { - // AdvSIMD, bits [23:20] - auto advSimd = ID_AA64PFR0_EL1 >> 20; - if ((advSimd & 0xfULL) == 1) { - has_fp16_ = true; - } - } - } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); #else has_arm_neon_dot_ = false; #endif - has_fp16_ |= has_arm_neon_dot_; - /* TODO: implement them when hw+sw is available for testing these features */ - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; - has_arm_neon_bf16_ = false; + + if (pytorch_cpuinfo_init_) { + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + } else { + has_fp16_ = false; + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; + } } #endif /* (arm or arm64) and windows */ @@ -304,5 +278,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { return 0xFFFFFFFF; // don't know how to get core index #endif } - +CPUIDInfo::CPUIDInfo() { +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) +#if CPUINFO_SUPPORTED + pytorch_cpuinfo_init_ = cpuinfo_initialize(); + if (!pytorch_cpuinfo_init_) { + LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; + } +#endif +#ifdef __linux__ + ArmLinuxInit(); +#elif defined(_WIN32) + ArmWindowsInit(); +#endif /* (arm or arm64) and windows */ +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 2f8041e39f680..a3936b4bd11a6 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -93,17 +93,7 @@ class CPUIDInfo { } private: - CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) -#ifdef __linux__ - ArmLinuxInit(); -#elif defined(_WIN32) - ArmWindowsInit(); -#endif /* (arm or arm64) and windows */ -#endif - } + CPUIDInfo(); bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -131,11 +121,13 @@ class CPUIDInfo { #ifdef CPUIDINFO_ARCH_X86 void X86Init(); - #elif defined(CPUIDINFO_ARCH_ARM) + // Now the following var is only used in ARM build, but later one we may expand the usage. + bool pytorch_cpuinfo_init_{false}; +#endif + #ifdef __linux__ - bool pytorch_cpuinfo_init_{false}; void ArmLinuxInit(); #elif defined(_WIN32) @@ -143,7 +135,6 @@ class CPUIDInfo { void ArmWindowsInit(); #endif /* (arm or arm64) and windows */ -#endif }; } // namespace onnxruntime From de3158e78d09992e4b5085c15da44108d9c6fa83 Mon Sep 17 00:00:00 2001 From: zesongw Date: Sat, 2 Mar 2024 08:55:50 +0800 Subject: [PATCH 098/237] [WebNN EP] Add contraints for MatMul (#19713) ### Description Add constraints to MatMul: - The input must be at least 2D. - CPU backend: The input rank must be the same. - CPU backend: The input shape except for the last two axis must be the same. ### Motivation and Context Prevent regression for some models. --- .../webnn/builders/impl/gemm_op_builder.cc | 73 +++++++++++-------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index d5f84f853f7de..455e0e5f16a42 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,44 +91,33 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { (void)initializers; const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C - if (op_type == "Gemm") { - std::vector a_shape; - { - if (!GetShape(*input_defs[a_idx], a_shape, logger)) - return false; - - if (a_shape.size() != 2) { - LOGS(logger, VERBOSE) << "A must be 2D"; - return false; - } - - if (Product(a_shape) == 0) { - LOGS(logger, VERBOSE) << "A must be non-empty"; - return false; - } - } - - std::vector b_shape; - { - if (!GetShape(*input_defs[b_idx], b_shape, logger)) - return false; + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) + return false; + if (Product(a_shape) == 0) { + LOGS(logger, VERBOSE) << "A must be non-empty"; + return false; + } - if (b_shape.size() != 2) { - LOGS(logger, VERBOSE) << "B must be 2D"; - return false; - } + std::vector b_shape; + if (!GetShape(*input_defs[b_idx], b_shape, logger)) + return false; + if (Product(b_shape) == 0) { + LOGS(logger, VERBOSE) << "B must be non-empty"; + return false; + } - if (Product(b_shape) == 0) { - LOGS(logger, VERBOSE) << "B must be non-empty"; - return false; - } + if (op_type == "Gemm") { + if (a_shape.size() != 2 || b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "A and B must be 2D for Gemm"; + return false; } // C of Gemm. @@ -162,6 +151,30 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } + if (op_type == "MatMul") { + if (a_shape.size() < 2 || b_shape.size() < 2) { + LOGS(logger, VERBOSE) << "Inputs of MatMul must be at least 2D"; + return false; + } + + // WebNN CPU backend has two more constraints. + // https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc;l=1177 + // TODO: Remove this workaround when Chromium enables broadcast for MatMul on WebNN CPU backend. + if (device_type == WebnnDeviceType::CPU) { + if (a_shape.size() != b_shape.size()) { + LOGS(logger, VERBOSE) << "The rank of two inputs for WebNN CPU backend MatMul must be the same."; + return false; + } + + for (size_t i = 0; i < a_shape.size() - 2; i++) { + if (a_shape[i] != b_shape[i]) { + LOGS(logger, VERBOSE) << "WebNN CPU backend can't support broadcasting for MatMul."; + return false; + } + } + } + } + return true; } From 2d79052ec38b831f3254b20e0f6a42b3f98eabc7 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 1 Mar 2024 18:39:51 -0800 Subject: [PATCH 099/237] [QNN Quant] Add preprocessing option to transpose graph inputs/outputs to channel-last (#19731) ### Description Adds the optional parameters `inputs_to_make_channel_last` and `outputs_to_make_channel_last` to the `qnn_preprocess_model()` function. ```python """ inputs_to_make_channel_last: List of graph input names to transpose to be "channel-last". For example, if "input0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change input0's shape to (N, D1, D2, ..., Dn, C) and add a transpose node after it. Original: input0 (N, C, D1, D2, ..., Dn) --> Updated: input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> This can potentially improve inference latency for QDQ models running on QNN EP because the additional transpose node may allow other transpose nodes inserted during ORT layout transformation to cancel out. outputs_to_make_channel_last: List of graph output names to transpose to be "channel-last". For example, if "output0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change output0's shape to (N, D1, D2, ..., Dn, C) and add a transpose node before it. Original: --> output0 (N, C, D1, D2, ..., Dn) Updated: --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C) This can potentially improve inference latency for QDQ models running on QNN EP because the additional transpose node may allow other transpose nodes inserted during ORT layout transformation to cancel out. """ ``` **NOTE: If you use these options with the quantization scripts, you'll have to make sure your data_reader feeds in transposed input data. It won't happen automatically.** ### Motivation and Context Native QNN operators use the channel-last data layout, but ONNX uses channel-first. To bridge the gap, ORT's layout transformer inserts transposes around layout-sensitive nodes and updates their domain to indicate that they now operate on channel-last data. The transpose optimizer is able to remove most of these inserted transposes, but not all transposes can always be removed (i.e., some could remain at the graph's inputs and outputs). We've found that these extra transpose nodes can significantly degrade inference latency on QNN EP. One workaround (provided by this PR) is to add _additional_ transpose nodes at the graph inputs or outputs. These additional nodes can often help the ORT transpose optimizer cancel out any remaining transpose nodes, which significantly improves latency. Additionally, it may make more sense for some kinds of inputs to just be in channel-last form (e.g., images), avoiding the need to pre-transpose of the input data before inference. Example at the input: ``` Original: input0 (N, C, D1, D2, ..., Dn) --> Updated: input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> ``` Example at the output: ``` Original: --> output0 (N, C, D1, D2, ..., Dn) Updated: --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C) ``` --- .../execution_providers/qnn/preprocess.py | 198 ++++++++++++++++++ .../quantization/test_qnn_preprocess_model.py | 93 ++++++++ 2 files changed, 291 insertions(+) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index b0dab81830c8b..e584a65574520 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -24,6 +24,8 @@ def qnn_preprocess_model( external_data_location: str | None = None, external_data_size_threshold: int = 1024, external_data_convert_attribute: bool = False, + inputs_to_make_channel_last: list[str] | None = None, + outputs_to_make_channel_last: list[str] | None = None, ) -> bool: """ If necessary, this method creates a new "pre-processed" model in preparation for @@ -52,6 +54,32 @@ def qnn_preprocess_model( external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false. If true, convert all tensors to external data. If false, convert only non-attribute tensors to external data. + inputs_to_make_channel_last: List of graph input names to transpose to be "channel-last". For example, + if "input0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change input0's + shape to (N, D1, D2, ..., Dn, C) and add a transpose node after it. + + Original: + input0 (N, C, D1, D2, ..., Dn) --> + + Updated: + input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> + + This can potentially improve inference latency for QDQ models running on QNN EP because the + additional transpose node may allow other transpose nodes inserted during ORT layout transformation + to cancel out. + outputs_to_make_channel_last: List of graph output names to transpose to be "channel-last". For example, + if "output0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change output0's + shape to (N, D1, D2, ..., Dn, C) and add a transpose node before it. + + Original: + --> output0 (N, C, D1, D2, ..., Dn) + + Updated: + --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C) + + This can potentially improve inference latency for QDQ models running on QNN EP because the + additional transpose node may allow other transpose nodes inserted during ORT layout transformation + to cancel out. """ modified = False model = onnx.load_model(model_input) @@ -83,6 +111,19 @@ def qnn_preprocess_model( if fusion_layernorm.apply(): modified = True + # Optionally, transpose inputs and/or outputs to make them "channel-last". + if inputs_to_make_channel_last or outputs_to_make_channel_last: + transpose_node_prefix = "Transpose_channel_" + transpose_node_suffix: int = onnx_model.get_largest_node_name_suffix(transpose_node_prefix) + 1 + update_io_to_channel_last( + onnx_model.model, + inputs_to_make_channel_last, + outputs_to_make_channel_last, + transpose_node_name_prefix=transpose_node_prefix, + transpose_node_name_start_suffix=transpose_node_suffix, + ) + modified = True + # Make sure all nodes have a name. unnamed_node_prefix = "qnn_preproc_node_" available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1 @@ -107,3 +148,160 @@ def qnn_preprocess_model( ) return modified + + +class InputOutputNameMap: + def __init__( + self, + orig_tensor_names: set[str], + orig_graph_inputs: dict[str, onnx.ValueInfoProto], + orig_graph_outputs: dict[str, onnx.ValueInfoProto], + ): + self.orig_tensor_names = orig_tensor_names + self.orig_graph_inputs = orig_graph_inputs + self.orig_graph_outputs = orig_graph_outputs + self.updated_io_names = {} + self.new_value_infos = [] + + def get_new_name(self, orig_name: str): + if orig_name in self.updated_io_names: + return self.updated_io_names[orig_name] + + # Make a new tensor name that is unique among all tensors in the graph. + prefix: str = f"{orig_name}_channel_first_" + suffix: int = -1 + for tensor_name in self.orig_tensor_names: + if tensor_name.startswith(prefix) and tensor_name[len(prefix) :].isdigit(): + index = int(tensor_name[len(prefix) :]) + suffix = max(suffix, index) + + suffix += 1 # This is the first available suffix. + new_name = f"{prefix}{suffix!s}" + + # Add new value_info objects for these new tensors. + orig_value_info = self.orig_graph_inputs.get(orig_name) or self.orig_graph_outputs[orig_name] + value_info_proto = onnx.ValueInfoProto() + value_info_proto.CopyFrom(orig_value_info) + value_info_proto.name = new_name + self.new_value_infos.append(value_info_proto) + + self.updated_io_names[orig_name] = new_name + return self.updated_io_names[orig_name] + + +def update_io_to_channel_last( + model: onnx.ModelProto, + inputs_to_update: list[str] | None, + outputs_to_update: list[str] | None, + transpose_node_name_prefix: str = "Transpose_channel_", + transpose_node_name_start_suffix: int = 0, +): + inputs_to_update = set(inputs_to_update or []) + outputs_to_update = set(outputs_to_update or []) + + if not inputs_to_update and not outputs_to_update: + return + + graph = model.graph + orig_graph_inputs = {ginput.name: ginput for ginput in graph.input} + orig_graph_outputs = {goutput.name: goutput for goutput in graph.output} + + # Check that the user passed in actual input and output names. + for input_name in inputs_to_update: + if input_name not in orig_graph_inputs: + raise ValueError(f"{input_name} is not a graph input") + + for output_name in outputs_to_update: + if output_name not in orig_graph_outputs: + raise ValueError(f"{output_name} is not a graph output") + + orig_tensor_names = set() + orig_tensor_names.update(set(orig_graph_inputs)) + orig_tensor_names.update(set(orig_graph_outputs)) + orig_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name) + + # Maps original input (or output) name to its updated name used within the graph. + io_map = InputOutputNameMap(orig_tensor_names, orig_graph_inputs, orig_graph_outputs) + + # Update each node's inputs/outputs to use the transposed versions. + for node in graph.node: + for i in range(len(node.input)): + if node.input[i] and node.input[i] in inputs_to_update: + node.input[i] = io_map.get_new_name(node.input[i]) + elif node.input[i] and node.input[i] in outputs_to_update: + node.input[i] = io_map.get_new_name(node.input[i]) + + for i in range(len(node.output)): + if node.output[i] in outputs_to_update: + node.output[i] = io_map.get_new_name(node.output[i]) + + # Update graph inputs to channel-last and a Transpose (to channel-first) after each. + for g_input_name in inputs_to_update: + g_input = orig_graph_inputs[g_input_name] + + if not g_input.type.HasField("tensor_type") or not g_input.type.tensor_type.HasField("shape"): + raise ValueError(f"Expected input {g_input.name} to have a tensor_type with a shape") + + input_shape = g_input.type.tensor_type.shape + input_rank = len(input_shape.dim) + + if input_rank < 3: + raise ValueError(f"Expected input {g_input.name} to be of rank >= 3") + + channel_dim = onnx.TensorShapeProto.Dimension() + channel_dim.CopyFrom(input_shape.dim[1]) + for i in range(1, input_rank - 1): + input_shape.dim[i].CopyFrom(input_shape.dim[i + 1]) + input_shape.dim[input_rank - 1].CopyFrom(channel_dim) + + transpose_perm = list(range(input_rank)) + for i in range(input_rank): + transpose_perm[i] = i if i < 1 else i - 1 + transpose_perm[1] = input_rank - 1 + + transpose_node = onnx.helper.make_node( + "Transpose", + name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}", + inputs=[g_input.name], + outputs=[io_map.get_new_name(g_input.name)], + perm=transpose_perm, + ) + transpose_node_name_start_suffix += 1 + + graph.node.extend([transpose_node]) + + # Update graph outputs to channel-last and a Transpose (from channel-first) before each. + for g_output_name in outputs_to_update: + g_output = orig_graph_outputs[g_output_name] + if not g_output.type.HasField("tensor_type") or not g_output.type.tensor_type.HasField("shape"): + raise ValueError(f"Expected output {g_output.name} to have a tensor_type with a shape") + + output_shape = g_output.type.tensor_type.shape + output_rank = len(output_shape.dim) + + if output_rank < 3: + raise ValueError(f"Expected output {g_output.name} to be of rank >= 3") + + channel_dim = onnx.TensorShapeProto.Dimension() + channel_dim.CopyFrom(output_shape.dim[1]) + for i in range(1, output_rank - 1): + output_shape.dim[i].CopyFrom(output_shape.dim[i + 1]) + output_shape.dim[output_rank - 1].CopyFrom(channel_dim) + + transpose_perm = list(range(output_rank)) + for i in range(output_rank): + transpose_perm[i] = i if i == 0 else i + 1 + transpose_perm[output_rank - 1] = 1 + + transpose_node = onnx.helper.make_node( + "Transpose", + name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}", + inputs=[io_map.get_new_name(g_output.name)], + outputs=[g_output.name], + perm=transpose_perm, + ) + transpose_node_name_start_suffix += 1 + + graph.node.extend([transpose_node]) + + graph.value_info.extend(io_map.new_value_infos) diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py index 9b67fd41caac3..6503b3223b828 100644 --- a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -12,6 +12,7 @@ import numpy as np import onnx +import onnxruntime from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain @@ -165,6 +166,98 @@ def test_external_data(self): for node in fused_model.graph.node: self.assertIn(node.op_type, expected_op_types) + def build_multi_input_output_model(self, shape): + """ + Returns the following model. + +----------> [X] + | + [A] ---> Add ---> Abs -+-> Mul ---> [Y] + ^ ^ + | | + [B] ------+-----------------+ + """ + input_a = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, shape) + input_b = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape) + output_x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape) + output_y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape) + + add_node = onnx.helper.make_node("Add", ["A", "B"], ["add_out"], name="add_node") + abs_node = onnx.helper.make_node("Abs", ["add_out"], ["X"], name="abs_node") + mul_node = onnx.helper.make_node("Mul", ["X", "B"], ["Y"], name="mul_node") + + graph = onnx.helper.make_graph( + [add_node, abs_node, mul_node], + "multi_io_graph", + [input_a, input_b], + [output_x, output_y], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_make_io_channel_last(self): + """ + Test making a model's inputs and outputs channel-last. + """ + model = self.build_multi_input_output_model((1, 2, 3, 4)) + onnx.save_model(model, "model.onnx") + modified = qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + inputs_to_make_channel_last=["A", "B"], + outputs_to_make_channel_last=["X", "Y"], + ) + + self.assertTrue(modified) + + preproc_model = onnx.load_model("model.qnn_pp.onnx") + self.assertEqual(len(preproc_model.graph.node), 7) + + num_transposes = sum(1 for node in preproc_model.graph.node if node.op_type == "Transpose") + self.assertEqual(num_transposes, 4) + + # Check that the outputs of the new model are the same, but transposed. + input_a = np.arange(0.0, 24.0, 1.0, dtype=np.float32).reshape((1, 2, 3, 4)) + input_a_t = input_a.transpose(0, 2, 3, 1) + input_b = np.arange(1.0, 25.0, 1.0, dtype=np.float32).reshape((1, 2, 3, 4)) + input_b_t = input_b.transpose(0, 2, 3, 1) + + orig_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + orig_results = orig_session.run(None, {"A": input_a, "B": input_b}) + + new_session = onnxruntime.InferenceSession( + preproc_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + new_results = new_session.run(None, {"A": input_a_t, "B": input_b_t}) + + self.assertEqual(len(orig_results), len(new_results)) + for idx, orig_output in enumerate(orig_results): + transposed_output = new_results[idx] + np.testing.assert_allclose( + orig_output, + transposed_output.transpose(0, 3, 1, 2), + err_msg=f"Channel-last model output {idx} differs", + ) + + def test_make_io_channel_last_rank_error(self): + """ + Test making a model's inputs and outputs channel-last with a rank < 3 (error). + """ + model = self.build_multi_input_output_model((1, 2)) + onnx.save_model(model, "model.onnx") + + with self.assertRaises(ValueError) as context: + qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + inputs_to_make_channel_last=["A", "B"], + outputs_to_make_channel_last=["X", "Y"], + ) + + self.assertIn("to be of rank >= 3", str(context.exception)) + if __name__ == "__main__": unittest.main() From 9460597b2103d8d07e88272b9f4e19700d71d632 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sat, 2 Mar 2024 11:33:47 +0800 Subject: [PATCH 100/237] Update copying API header files (#19736) ### Description Make Linux logic consistent as Windows ### Motivation and Context onnxruntime_lite_custom_op.h in Windows zip package but not in Linux zip package https://github.com/microsoft/onnxruntime/blob/acbfc29f272b5578145e7600bc42342e116ffbc2/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml#L67 Co-authored-by: Your Name --- tools/ci_build/github/linux/copy_strip_binary.sh | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index 42973a8fcb5b8..65d6d97ebf0a8 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -44,17 +44,10 @@ elif [[ $LIB_NAME == *.so.* ]] then ln -s $LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/lib/libonnxruntime.so fi -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_*.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/framework/provider_options.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_*.h $BINARY_DIR/$ARTIFACT_NAME/include if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then # copy headers for context context used in custom ops From 9acaf534a62050705d9b892a57ef0e8409fa62ec Mon Sep 17 00:00:00 2001 From: ironman Date: Mon, 4 Mar 2024 23:29:58 +0800 Subject: [PATCH 101/237] Benchmark - Updating llama-2 requirement files (#19716) ### Description ### Motivation and Context --- .../tools/transformers/models/llama/requirements-cuda.txt | 1 + .../python/tools/transformers/models/llama/requirements.txt | 3 ++- .../python/tools/transformers/models/whisper/requirements.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index acd9c23aa42d0..307afbc122901 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -2,3 +2,4 @@ # Please manually install torch>=2.2.0 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ onnxruntime-gpu>=1.16.2 +py3nvml \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 8b57279295e35..e991c2f27a1a3 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,6 +1,7 @@ optimum>=1.14.1 -transformers>=4.33.2 +transformers>=4.33.2,<= 4.37.2 torch>=2.2.0 onnx>=1.14.0 datasets>=2.8.0 protobuf==3.20.2 +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 956922dc83d51..9bbe0d7380406 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -7,8 +7,8 @@ soundfile librosa optimum onnxruntime-extensions>=0.9.0 +onnx>=1.15.0 protobuf==3.20.2 numpy==1.23.3 -onnx>=1.15.0 psutil py3nvml From 2e13d5f0ab54c726ee2400d38983000de7f61b8e Mon Sep 17 00:00:00 2001 From: inisis <46103969+inisis@users.noreply.github.com> Date: Tue, 5 Mar 2024 01:41:36 +0800 Subject: [PATCH 102/237] fix split shape inference error for opset >= 13 (#19756) ### Description get split operator split section by opset ### Motivation and Context for opset higher than 13, split section is treated as an input. --- onnxruntime/python/tools/symbolic_shape_infer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 4b56bc1e8d828..4b029f9b172b0 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1940,8 +1940,17 @@ def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 input_sympy_shape = self._get_sympy_shape(node, 0) axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) - split = get_attribute(node, "split") - if not split: + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'split' are provided as attribute or via 2nd input + if op_set < 13: + split = get_attribute(node, "split") + assert self._try_get_value(node, 1) is None + else: + split = self._try_get_value(node, 1) + assert get_attribute(node, "split") is None + + if split is None: num_outputs = len(node.output) split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs self._update_computed_dims(split) From 27b1dc91abb71b71fe6a26e1b4ebd30e13524baf Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 4 Mar 2024 11:55:35 -0800 Subject: [PATCH 103/237] [DML] MatrixMultiplyIntegerToFloat (#19608) ### Description DML Implementation for [com.microsoft.MatMulIntegerToFloat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulIntegerToFloat) ``` .\onnxruntime_test_all.exe --gtest_filter="*MatMulIntegerToFloat.*" Note: Google Test filter = *MatMulIntegerToFloat.* [==========] Running 22 tests from 1 test suite. [----------] Global test environment set-up. [----------] 22 tests from MatMulIntegerToFloat [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 (620 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 (497 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8 (488 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8 (503 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8 (495 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8 (488 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8 (492 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8 (502 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8 (452 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8 (454 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8 (446 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8 (508 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8 (456 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8 (455 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 (447 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8 (465 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8 (111 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8 (115 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8 (114 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8 (110 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16 (112 ms) [ RUN ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint [ OK ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint (337 ms) [----------] 22 tests from MatMulIntegerToFloat (8679 ms total) [----------] Global test environment tear-down [==========] 22 tests from 1 test suite ran. (8680 ms total) [ PASSED ] 22 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context * `CalculateMatMulIntegerToFloat` to replace CPU EP run reference * Added more FP32 testcases to isolate all input datatype combinations * Added fixed input to `MatMulIntegerToFloat_FP16*` test cases as for FP16 test cases. * onnxruntime/test/testdata/matmul_integer_to_float.py` is capable of generating FP16 models, but we do not produce any for now --- docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 1 + .../graph/contrib_ops/quantization_defs.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 5 +- .../core/optimizer/matmul_integer_to_float.cc | 23 +- .../src/External/DirectMLHelpers/ApiTraits.h | 12 +- .../External/DirectMLHelpers/DirectMLSchema.h | 37 +- .../DirectMLHelpers/GeneratedSchemaHelpers.h | 36 +- .../DmlOperatorMatMulIntegerToFloat.cpp | 111 +++++ .../src/Operators/OperatorRegistration.cpp | 9 + .../dml/OperatorAuthorHelper/OperatorHelper.h | 2 +- .../OperatorAuthorHelper/OperatorVersions.h | 1 + .../matmul_integer_to_float_test.cc | 414 +++++++++++++++--- .../test/optimizer/graph_transform_test.cc | 18 + .../test/testdata/matmul_integer_to_float.py | 60 ++- .../matmul_integer_to_float_int8.onnx | 4 +- .../matmul_integer_to_float_int8_bias.onnx | 4 +- .../matmul_integer_to_float_int8_int8.onnx | 4 +- ...atmul_integer_to_float_int8_int8_bias.onnx | 4 +- .../matmul_integer_to_float_uint8.onnx | 4 +- .../matmul_integer_to_float_uint8_bias.onnx | 4 +- .../fusion/matmul_integer_to_float.onnx | Bin 1520 -> 1520 bytes .../matmul_integer_to_float16_int8.onnx | 51 +++ 23 files changed, 664 insertions(+), 144 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f523e97293427..e295dfa203ae5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Constrain input A data type to 8-bit integer tensor.
T2 : tensor(int8), tensor(uint8)
Constrain input B data type to 8-bit integer tensor.
-
T3 : tensor(float)
+
T3 : tensor(float), tensor(float16)
Constrain input a_scale, b_scale and output Y data type as float tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1eaf0fb6dad76..0e60b4622f2fb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1268,6 +1268,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 4313fae767fe5..22a79ef652515 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -434,7 +434,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Matrix multiply results from A * B", "T3") .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.") .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.") - .TypeConstraint("T3", {"tensor(float)"}, + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input a_scale, b_scale and output Y data type as float tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 2, 0); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 8376b87aee6b2..f319e7254568d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -278,7 +278,8 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - + const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_eps)); transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 56e51cb787931..4fee1a6ce224e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) { return bias_last_dim > 1; } +bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { + if (!node_arg.Exists()) { + return false; + } + + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto) { + return false; + } + + int32_t actual_data_type; + if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) { + return false; + } + + return data_type == actual_data_type; +} + /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: @@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g auto& mul_node = *node_ptr; ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger)); - + const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider; if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || - !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) { + !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) || + (!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) { continue; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index e1e7eacfbd85d..7c25755a7d09e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -879,6 +879,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1041,12 +1047,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; -}; - template <> struct OperatorDescTraits { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 5fe6603c2a0bf..da57c2aa235fd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -1885,6 +1885,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, @@ -2395,24 +2414,6 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHE DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { - "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", - DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, -}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 4be41ad3924a2..86c66d8cca26c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1139,6 +1139,19 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc) { return { @@ -1487,19 +1500,6 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_P OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } -inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1829,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA; case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA; @@ -1856,7 +1857,6 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -2360,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_CONVOLUTION_INTEGER: return AbstractOperatorDesc( &DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA, @@ -2468,10 +2472,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: - return AbstractOperatorDesc( - &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp new file mode 100644 index 0000000000000..b5a3dd0960b86 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorMatMulIntegerToFloat : public DmlOperator +{ + enum OrtInputTensors : uint32_t + { + ortA, + ortB, + ortAScale, + ortBScale, + ortAZeroPoint, + ortBZeroPoint, + ortBias, + ortInputCount + }; + + enum DmlInputIndex : uint32_t + { + dmlA, + dmlAScale, + dmlAZeroPoint, + dmlB, + dmlBScale, + dmlBZeroPoint, + dmlBias, + dmlInputCount, + }; + +public: + DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + std::vector> inputIndices = { OrtInputTensors::ortA, OrtInputTensors::ortAScale, OrtInputTensors::ortAZeroPoint, OrtInputTensors::ortB, OrtInputTensors::ortBScale, OrtInputTensors::ortBZeroPoint, OrtInputTensors::ortBias }; + DmlOperator::Initialize(kernelInfo, inputIndices); + + std::vector inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA); + std::vector inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape); + + // Initialize the input descriptions with broadcasting + m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0); + m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1); + + // Broadcast Bias tensor to the shape of the output tensor. + if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) { + m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce, + TensorAxis::W, TensorAxis::RightAligned, outputShape); + } + + uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount(); + // Resize the A Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAScale, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + + // Resize the A ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortAZeroPoint)) + { + m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAZeroPoint, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + } + + // B Zeropoint and BScale are already aligned in the W dimension so no need to align them + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {}; + matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA]; + matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale]; + matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr; + matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB]; + matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale]; + matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr; + matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr; + matMulDesc.OutputTensor = &outputDescs[0]; + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 9c136ed8c9484..f08151b61197a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,6 +503,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); +DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); @@ -622,6 +623,13 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; + +constexpr static std::array supportedTypeListMatMulIntegerToFloat = { + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Float16to32 +}; + constexpr static std::array supportedTypeListQLinearConv = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -1083,6 +1091,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 1b2521a86613f..06bacc1b28c99 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -870,7 +870,6 @@ class QLinearMatMulHelper : public MatMulHelperBase QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {} }; - class TopKHelper { void Initialize( @@ -1776,6 +1775,7 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; +using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e725ba085113d..d081aa2e29148 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -449,6 +449,7 @@ namespace OperatorHelper static const int sc_sinceVer_FusedMatMulActivation = 1; static const int sc_sinceVer_QLinearSigmoid = 1; static const int sc_sinceVer_Attention = 1; + static const int sc_sinceVer_MatMulIntegerToFloat = 1; static const int sc_sinceVer_MultiHeadAttention = 1; static const int sc_sinceVer_SkipLayerNormalization = 1; static const int sc_sinceVer_EmbedLayerNormalization = 1; diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 26ce5272d25ee..6f3ca7e239671 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -23,135 +23,407 @@ using namespace std; namespace onnxruntime { namespace test { -template -void TestMatMulIntegerToFloat(const std::vector& A_dims, - std::vector B_dims, - const std::string& reference_model, - bool is_matrix_b_constant, +template +static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, const int64_t K, + const std::vector& A_data, const std::vector& A_scale, + const std::vector& A_zero_point, const std::vector& B_data, + std::vector& B_scale, std::vector& B_zero_point, + const std::vector& Bias, std::vector& Y_data, + bool per_column, bool has_zp, bool has_bias) { + if (!per_column) { + B_zero_point.resize(N, B_zero_point[0]); + B_scale.resize(N, B_scale[0]); + } + + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + float A_dequantized = has_zp ? (static_cast(A_data[m * K + k]) - static_cast(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0]; + float B_dequantized = has_zp ? (static_cast(B_data[k * N + n]) - static_cast(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n]; + + sum += A_dequantized * B_dequantized; + } + if (has_bias) { + sum += Bias[n]; + } + Y_data[m * N + n] = static_cast(sum); + } + } +} + +template +void TestMatMulIntegerToFloat(bool is_matrix_b_constant, bool per_column = false, bool has_zp = true, bool has_bias = false) { // create rand inputs RandomValueGenerator random{}; - + int64_t M = 4; + int64_t N = 128; + int64_t K = 128; + std::vector A_dims{M, K}; + std::vector B_dims{K, N}; + std::vector Y_dims{M, K}; std::vector A_data; - std::vector tmp_A_data = random.Uniform(A_dims, - std::numeric_limits::lowest(), - std::numeric_limits::max()); - std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> WType { + std::vector tmp_A_data = random.Uniform(A_dims, + std::numeric_limits::lowest(), + std::numeric_limits::max()); + std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> IType { return static_cast(v); }); std::vector B_data; - std::vector tmp_B_data = random.Uniform(B_dims, - std::numeric_limits::lowest(), - std::numeric_limits::max()); + + std::vector tmp_B_data; + tmp_B_data = random.Uniform(B_dims, + std::is_signed::value ? std::numeric_limits::lowest() / 2 : std::numeric_limits::lowest(), + std::numeric_limits::max() / 2); std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType { return static_cast(v); }); - std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f); + std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f); std::vector A_zero_point{(std::numeric_limits::lowest() + std::numeric_limits::max() + IType(2)) / 2}; int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; - std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); + std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); std::for_each(B_zero_point.begin(), B_zero_point.end(), [&random](WType& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::lowest(), - std::numeric_limits::max())[0]); + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::lowest(), + std::numeric_limits::max())[0]); }); - std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); + std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("a_scale", {1}, A_scale); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale); if (has_zp) { test.AddInput("a_zero_point", {1}, A_zero_point); test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddOptionalInputEdge(); } if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + test.AddInput("bias", {B_dims.back()}, Bias); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); } - test.AddReferenceOutputs(reference_model); - test.SetOutputRelErr("Y", 1e-4f); - test.Run(); -} + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + per_column, has_zp, has_bias); -template -void RunMatMulIntegerToFloatTest(const string& model_path) { - std::vector A_dims{4, 128}; - std::vector B_dims{128, 128}; - std::vector Y_dims{4, 128}; + if (std::is_same_v) { + test.AddOutput("Y", {M, N}, Y_data); + test.SetOutputRelErr("Y", 0.02f); + } else { + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + test.SetOutputAbsErr("Y", 0.5f); + } - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + // Only DML EP supports these data type combinations for now + if (std::is_same_v || + (std::is_same_v && + std::is_same_v && + std::is_same_v)) { + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.Run(); + } +} + +template +void RunMatMulIntegerToFloatTest() { + TestMatMulIntegerToFloat( + false, /*is_matrix_b_constant*/ + false, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + true, /*is_matrix_b_constant*/ + false, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + false, /*is_matrix_b_constant*/ + true, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + true, /*is_matrix_b_constant*/ + true, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); } -TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8X8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8.onnx"); - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8.onnx"); +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8X8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_bias.onnx"); - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8_bias.onnx"); +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8.onnx"); +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8_bias.onnx"); +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8X8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +// DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output +#if defined(USE_DML) + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8U8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {1, 5, 2, 1, 9, + 1, 1, 3, 7, 2}; + std::vector B_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector A_scale = ToFloat16({3.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {1}; + std::vector B_zero_point = {1}; + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, {}, Y_data, + false, true, false); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8S8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector B_data = {2, -1, -9, 1, 1, + -1, 0, -3, 1, -4}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {1}; + std::vector B_zero_point = {3}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, {}, Y_data, + false, true, false); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8S8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, -2, 1, 1, + 2, -1, -9, 1, 1}; + std::vector B_data = {2, -1, -9, 1, 1, + -1, 0, -3, 1, -4}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {-1}; + std::vector B_zero_point = {3}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + false, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8U8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, -2, 1, 1, + 2, -1, -9, 1, 1}; + std::vector B_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {-1}; + std::vector B_zero_point = {1}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + false, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 2; + int64_t N = 2; + int64_t K = 3; + + std::vector A_data = {11, -2, 5, + -1, 3, 10}; + std::vector B_data = {-13, -2, + 9, 55, + -1, 23}; + std::vector A_scale = ToFloat16({0.910f}); + std::vector B_scale = ToFloat16({1.10f, 1.123f}); + + std::vector A_zero_point = {113}; + std::vector B_zero_point = {98, 71}; + + std::vector Bias = ToFloat16({0.10f, 1.123f}); + + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + + test.AddInput("a_scale", {}, {A_scale}); + test.AddInput("b_scale", {N}, B_scale); + test.AddInput("a_zero_point", {}, {A_zero_point}); + test.AddInput("b_zero_point", {N}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + true, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + test.SetOutputRelErr("Y", 2e-2f); + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) { auto test_case = [&](const std::vector& input_shape, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 16f38bac62713..1535e2b60a3bd 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5679,6 +5679,24 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) { EXPECT_EQ(op_to_count["Add"], 1); } +#ifdef USE_DML +TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float16_int8.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kDmlExecutionProvider); + } + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); +} +#endif // USE_DML + #endif #ifndef DISABLE_CONTRIB_OPS diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index b898390044cf4..e6c51009018f9 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -4,7 +4,7 @@ from onnx import TensorProto, helper -def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: N802 +def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False): # noqa: N802 nodes = [ # subgraph helper.make_node( "MatMulInteger", @@ -13,7 +13,13 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: "MatMulInteger", ), helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"), - helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1), + helper.make_node( + "Cast", + ["matmul_output_int32"], + ["matmul_output_float"], + "cast", + to=TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, + ), helper.make_node( "Mul", ["matmul_output_float", "multiplier"], @@ -25,8 +31,8 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: inputs = [ # inputs helper.make_tensor_value_info("A", TensorProto.INT8 if sign_i else TensorProto.UINT8, ["M", "K"]), helper.make_tensor_value_info("B", TensorProto.INT8 if sign_w else TensorProto.UINT8, ["K", "N"]), - helper.make_tensor_value_info("a_scale", TensorProto.FLOAT, [1]), - helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("a_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("b_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["C"]), ] if has_zp: @@ -48,14 +54,22 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: if bias: nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend([helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["N"])]) + inputs.extend( + [ + helper.make_tensor_value_info( + "bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"] + ) + ] + ) graph = helper.make_graph( nodes, "DynamicQuantizeMatMul_fusion", # name inputs, [ # outputs - helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info( + "Y", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["M", "N"] + ), ], ) @@ -64,10 +78,32 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: if __name__ == "__main__": - GenerateModel("matmul_integer_to_float_int8.onnx", False, True) - GenerateModel("matmul_integer_to_float_uint8.onnx", False, False) - GenerateModel("matmul_integer_to_float_int8_bias.onnx", False, True, False, True) - GenerateModel("matmul_integer_to_float_uint8_bias.onnx", False, False, False, True) + GenerateModel("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) + GenerateModel("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) + GenerateModel("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) + GenerateModel( + "matmul_integer_to_float_int8_bias.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + ) + GenerateModel( + "matmul_integer_to_float_uint8_bias.onnx", + sign_i=False, + sign_w=False, + output_type_fp16=False, + has_zp=False, + bias=True, + ) - GenerateModel("matmul_integer_to_float_int8_int8.onnx", True, True) - GenerateModel("matmul_integer_to_float_int8_int8_bias.onnx", True, True, False, True) + GenerateModel("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) + GenerateModel( + "matmul_integer_to_float_int8_int8_bias.onnx", + sign_i=True, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + ) diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx index 9f4465a914963..906dec542a4fa 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx index 01b7e15aa4a1f..16cdf03c7ae59 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx index 9d38828e25d6a..55102757a0b57 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx index 4d9a55af50a87..d9d7222a1acaa 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx index a4c6d20d59be8..5373ce145688e 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx index a5be0c63f4dcb..e407414b23b24 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx index 7ea69c580ee435be09f12b949f14fdb2efe3d403..aa8e67bcbc59e53d3418000c23ef35c75dfd76c6 100644 GIT binary patch delta 13 Ucmeys{ehc_gL5O(TUJJ403a9x!vFvP delta 13 Ucmeys{ehc_gMA~@TUJIM03ZVcx&QzG diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx new file mode 100644 index 0000000000000..22293b0d10756 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx @@ -0,0 +1,51 @@ + :Ì +U +A +B + a_zero_point + b_zero_pointmatmul_output_int32 MatMulInteger" MatMulInteger +. +a_scale +b_scale +multiplier mul_right"Mul +A +matmul_output_int32matmul_output_floatcast"Cast* +to +  +5 +matmul_output_float + +multiplierY +mul_bottom"MulDynamicQuantizeMatMul_fusionZ +A + + +M +KZ +B + + +K +NZ +a_scale + + + +Z +b_scale +  + +CZ + a_zero_point + + +Z + b_zero_point +  +Cb +Y + + + +M +NB \ No newline at end of file From 0cdf36faeb4eafcf543bd84dd6f543a55df738c1 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 4 Mar 2024 13:46:51 -0800 Subject: [PATCH 104/237] Expose SessionOtions.DisablePerSessionThreads (#19730) ### Description ### Motivation and Context ML.NET needs to run mltiple sessions on a single threadpool. --- .../src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs | 5 +++++ .../Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs | 9 +++++++++ .../InferenceTest.cs | 5 ++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 4128524b30483..8a8426a0b3054 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -362,6 +362,7 @@ static NativeMethods() OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern)); OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena)); OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena)); + OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads)); OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId)); OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel)); OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel)); @@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); public static DOrtDisableCpuMemArena OrtDisableCpuMemArena; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options); + public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId); public static DOrtSetSessionLogId OrtSetSessionLogId; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 7a68246c9b67a..30d005b3c4236 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -696,6 +696,15 @@ public bool EnableCpuMemArena } private bool _enableCpuMemArena = true; + /// + /// Disables the per session threads. Default is true. + /// This makes all sessions in the process use a global TP. + /// + public void DisablePerSessionThreads() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle)); + } + /// /// Log Id to be used for the session. Default is empty string. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index fd8feda359f90..d6a6b9627f418 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -55,6 +55,9 @@ public void TestSessionOptions() Assert.Equal(0, opt.InterOpNumThreads); Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel); + // No get, so no verify + opt.DisablePerSessionThreads(); + // try setting options opt.ExecutionMode = ExecutionMode.ORT_PARALLEL; Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode); @@ -98,7 +101,7 @@ public void TestSessionOptions() Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message); // SessionOptions.RegisterOrtExtensions can be manually tested by referencing the - // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. + // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. ex = Assert.Throws(() => { opt.RegisterOrtExtensions(); }); Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message); From 2a5c9b86ebbdba8fb76f79de26524a2fdd2e5c2a Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Tue, 5 Mar 2024 10:11:19 +0800 Subject: [PATCH 105/237] Zhijxu/fix conv1d replacement (#19758) remove the constraint - "group number should be less than 3"; add more condition to make sure the conv1d replacement only happens on conv1d instead of conv2d/conv3d; add more tests; --- .../core/optimizer/conv1d_replacement.cc | 63 +++++++++++------- .../test/optimizer/graph_transform_test.cc | 64 ++++++++++++++++--- 2 files changed, 96 insertions(+), 31 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index 0412000e04e1b..ff220fcb067b8 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -42,30 +42,45 @@ */ namespace onnxruntime { bool NodeCanBeReplacedByMatmul(const Node& node) { - // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, - // then it can be replaced by MatMul - // Kernel_shape is 1 means it is conv1d + /* + If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul: + - not bias as input which means only has 2 inputs: input and weight + - "dilations" should be [1] + size 1 means conv1d + - "strides" should be [1] + - "pads" should be [0,0] + - "autopad" should be "NOTSET" + - "kernel_shape" should be [1] + */ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { return false; } - const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); - const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); - const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); - const auto* group = graph_utils::GetNodeAttribute(node, "group"); - if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + + // TODO: bias input can also be supported if needed + if (node.InputDefs().size() != 2) { return false; } - if ((dilations->ints_size() && dilations->ints(0) != 1) || - (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || - (stride->ints_size() && stride->ints(0) != 1) || - group->i() >= 3) { + + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* strides = graph_utils::GetNodeAttribute(node, "strides"); + const auto* pads = graph_utils::GetNodeAttribute(node, "pads"); + const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) { return false; } - return true; + if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) && + (strides->ints_size() == 1 && strides->ints(0) == 1) && + (autopad->s() == "NOTSET") && + (pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) && + (kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) { + return true; + } + return false; } -void Conv1dToMatmul(Graph& graph, Node& conv) { +void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) { // Shape of conv1d input: [batch_size, in_channels, in_length] // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 // We need to split the input into "group", and squeeze&split the weight, and then do MatMul @@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("input_split_output"), nullptr)); } - auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input}, {conv1d_input_splitted_outputs}); input_split.SetExecutionProviderType(execution_provider_type); input_split.AddAttribute("axis", int64_t(1)); @@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { } // 2. Squeeze conv weight auto conv1d_weight = conv.MutableInputDefs()[1]; + // auto con1d_bias = xx; auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); - auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze", node_description, {conv1d_weight}, {weight_squeeze_output}); + int64_t weight_squeeze_axis = 2; if (onnx_opset_version > 12) { // After onnx version 12, squeeze node has axes as input instead of attribute ONNX_NAMESPACE::TensorProto initializer_proto; - initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer")); initializer_proto.add_dims(static_cast(1)); initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector initializer_proto_value{2}; + InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); } else { - weight_squeeze.AddAttribute("axes", std::vector{2}); + weight_squeeze.AddAttribute("axes", std::vector{weight_squeeze_axis}); } weight_squeeze.SetExecutionProviderType(execution_provider_type); // 3. Split conv weight @@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("weight_split_output"), nullptr)); } - auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); weight_split.AddAttribute("axis", int64_t(0)); weight_split.SetExecutionProviderType(execution_provider_type); @@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { for (int i = 0; i < group_num; i++) { auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); matmul_outputs.push_back(matmul_output); - auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description, {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, {matmul_output}); matmul.SetExecutionProviderType(execution_provider_type); } // 5. Concat matmul outputs - auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description, matmul_outputs, {}); concat_node.SetExecutionProviderType(execution_provider_type); concat_node.AddAttribute("axis", int64_t(1)); @@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (NodeCanBeReplacedByMatmul(node)) { LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); - Conv1dToMatmul(graph, node); + Conv1dToMatmul(graph, node, Name()); modified = true; } } diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index bab7c09839273..109937ff96d1d 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1200,7 +1200,7 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } -TEST_F(GraphTransformationTests, Conv1dReplacement) { +TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); @@ -1208,7 +1208,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { }; for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { - for (auto group : {1, 2}) { + for (auto group : {1, 2, 4}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; @@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(group)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; auto post_graph_checker = [&](Graph& graph) { @@ -1243,28 +1245,64 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { } } -TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { +// node has bias input so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); return Status::OK(); }; - // "group" is 3 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* bias_arg = builder.MakeInitializer({out_channel}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + +// "auto_pad " is not NOTSET so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); - conv_node.AddAttribute("group", static_cast(3)); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "VALID"); }; std::unique_ptr transformer = std::make_unique(); @@ -1272,8 +1310,16 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { TransformerLevel::Level1, 1, pre_graph_checker, pre_graph_checker)); } +} + +// pads is not all zero, so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; - // "kernel_shape" is not 1 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); @@ -1285,9 +1331,11 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); - conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{1, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; std::unique_ptr transformer = std::make_unique(); From 7e613ee821405b1192d0b71b9434a4f94643f1e4 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 5 Mar 2024 11:45:45 +0800 Subject: [PATCH 106/237] [quant] supports act_order inputs in Matmulnbits and new quantization algorithm "hqq" (#19106) ### Description 1. Support quantized GPTQ weight in huggingface like [TheBloke/Llama-2-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) 2. Support Act_order for GPTQ 3. Support [HQQ](https://mobiusml.github.io/hqq_blog/) algorithm to quantize matmul weight and add quant script ### Motivation and Context --- docs/ContribOperators.md | 43 +- docs/OperatorKernels.md | 4 +- .../cpu/quantization/matmul_nbits.cc | 105 ++++- .../cpu/quantization/matmul_nbits_impl.cc | 108 +++++ .../cpu/quantization/matmul_nbits_impl.h | 23 ++ .../cuda/quantization/dequantize_blockwise.cu | 159 ++++++-- .../quantization/dequantize_blockwise.cuh | 6 +- .../cuda/quantization/matmul_nbits.cc | 170 ++++---- .../cuda/quantization/matmul_nbits.h | 41 ++ .../core/graph/contrib_ops/contrib_defs.cc | 38 +- .../quantization/matmul_4bits_quantizer.py | 379 ++++++++++++++++-- .../test/contrib_ops/matmul_4bits_test.cc | 78 +++- .../test/python/quantization/op_test_utils.py | 3 +- .../quantization/test_op_matmul_4bits.py | 19 +- 14 files changed, 942 insertions(+), 234 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e295dfa203ae5..5f0100fad95a2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0e60b4622f2fb..71b0def659741 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -470,7 +470,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -855,7 +855,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 166f5c8f52f54..602dd98d8c0d6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include + +#include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" @@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } } // namespace +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel { block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; @@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } #if defined(ORT_NEURAL_SPEED) if (!all_constant_) { @@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); @@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { #endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; + const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + const bool has_single_b_matrix = + (!act_order_) && (!zero_point_is_not_quant_) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); @@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const uint8_t* b_data = b->Data(); const size_t ldb = helper.Ldb(true); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - zero_points_data, // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); - + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } + } #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); @@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX( kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc new file mode 100644 index 0000000000000..f92e59e990ba5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void Dequantize4BitsKernelReOrder( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, const int32_t* reorder_idx, int block_size, + int groups_per_threadblock, int total_groups, int out_rows, int out_cols, + int blockIdx_x, int threadIdx_x) { + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size); + if (group_id >= total_groups) { + return; + } + const int scales_shape_x = (out_cols + block_size - 1) / block_size; + const int zero_point_shape_x = (scales_shape_x + 1) / 2; + + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1)); + + const int out_x = element_offset % (scales_shape_x * block_size); + const int out_y = element_offset / (scales_shape_x * block_size); + if (out_y >= out_rows || out_x >= out_cols) { + return; + } + T* output_i = output + out_y * out_cols + out_x; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int remain_x = std::min(8, out_cols - out_x); + for (int i = 0; i < remain_x; i++) { + int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + float zp_f = 8; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = *(zero_points + n_idx * scales_shape_x + rid); + } else { + uint8_t zp = 8; + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * zp_f; + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int element_per_thread = 8; + int groups_per_threadblock = 256 * element_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; // total elemenets in quant_data + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h new file mode 100644 index 0000000000000..5061ac5c800a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e221..cd6593352008b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -56,41 +58,94 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx[kb_idx * block_size + i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +153,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd893..580b5087f3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 015df70c8ec3c..1cec6f6a12f1c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulFp32Q4 operator, it is basically -// matmul float32 with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// - -#include "core/common/safeint.h" -#include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" + +#include + +#include "core/common/status.h" +#include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -19,40 +16,19 @@ namespace contrib { namespace cuda { using namespace onnxruntime::cuda; -template -class MatMulNBits final : public CudaKernel { - public: - MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_blk_{true}; -}; - template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -67,77 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - if (!is_4bit_done) { - int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); - auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - // column-wise block + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (is_4bit_done) { + return Status::OK(); + } + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const CudaT*)zero_points_data, + reorder_idx_data, SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), SafeInt(N_), + SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 - cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); - T* b_data_cpu = new T[K_ * N_]; - cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); - delete[] b_data_cpu; +cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); +T* b_data_cpu = new T[K_ * N_]; +cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); +delete[] b_data_cpu; #endif - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); - - if (helper.OutputOffsets().size() == 1) { - CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( - GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, - SafeInt(helper.N()), - SafeInt(helper.M()), - SafeInt(helper.K()), - &alpha, - reinterpret_cast(b_data), - SafeInt(K_padded), - reinterpret_cast(a_data), - helper.Lda(transa), - &zero, - reinterpret_cast(Y->MutableData()), - helper.Ldc(), - GetDeviceProp(), - UseTF32())); - } + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..f5c2c6c4e4fdf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulNBits operator, it is basically +// matmul float with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// +#pragma once +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e33ce20737f80..f06a3785f362d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3343,22 +3343,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. -Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: -- n_blocks_per_col = (K + block_size - 1) / block_size -- blob_size = block_size / 8 * bits - - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. -Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) @@ -3377,12 +3378,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored "type T1.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1-dimensional data blob", "T2") + .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index eb7bbec997d59..a1916e806c5c0 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -65,7 +65,7 @@ def __init__( self, calibration_data_reader: CalibrationDataReader, percdamp=0.01, - blocksize=128, + block_size=128, actorder=False, mse=False, perchannel=True, @@ -79,7 +79,7 @@ def __init__( a calibration data reader. It enumerates calibration data and generates inputs for the original model. percdamp: percent of the average Hessian diagonal to use for dampening. - blocksize (int, optional): + block_size (int, optional): channel number in one block to execute a GPTQ quantization iteration. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. @@ -93,42 +93,285 @@ def __init__( ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp - self.blocksize = blocksize + self.block_size = block_size self.actorder = actorder self.mse = mse self.perchannel = perchannel -class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" +class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + block_size=128, + bits=4, + axis=1, + ): + """ + This is a class for HQQ algorithm Weight Only Quant Configuration. + HQQ algorithm quant weight without needing calibrate data. + + Args: + block_size (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + bits (int, optional): + how many bits to represent weight. + axis (int, optional): + 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + """ + super().__init__( + algorithm="HQQ", + ) + self.block_size = block_size + self.bits = bits + self.axis = axis + +class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, - model: ModelProto | str, - block_size: int, - is_symmetric: bool, + block_size: int = 128, + is_symmetric: bool = False, accuracy_level: int | None = None, - nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, ): - if nodes_to_exclude is None: - nodes_to_exclude = [] - self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) - self.model_path = model if isinstance(model, str) else None + super().__init__(algorithm="DEFAULT") self.block_size = block_size self.is_symmetric = is_symmetric + self.bits = 4 self.accuracy_level = accuracy_level - self.nodes_to_exclude = set(nodes_to_exclude) - self.algo_config = algo_config + + +def is_divisible(val1, val2): + return int(val2 * np.ceil(val1 / val2)) == val1 + + +class HQQWeightOnlyQuantizer: + def __init__( + self, + config: HQQWeightOnlyQuantConfig, + ): + self.config = config + + # Proximal solver || weight - dequantize(quantize(weight))||_p^p + @staticmethod + def optimize_weights( + tensor, + scale, + zero, + min_max: list[int], + axis: int = 0, + opt_params: dict = None, # noqa: RUF013 + verbose=False, + ): + import torch + + opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + + dtype = torch.float16 if tensor.is_cuda else torch.float32 + w_f = tensor.to(dtype) + scale = scale.to(dtype) + zero = zero.to(dtype) + + if lp_norm == 1: + + def shrink_op(x, beta): + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + + else: + + def shrink_op(x, beta, p=lp_norm): + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1) + ) + + best_error = 1e4 + for i in range(iters): + w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1]) + w_r = (w_q - zero) / scale + w_e = shrink_op(w_f - w_r, beta) + zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(w_f - w_r).mean()) + if verbose: + print(i, np.round(current_error, 6)) + if current_error < best_error: + best_error = current_error + else: + break + + del w_f, w_q, w_r, w_e + + return scale, zero @staticmethod - def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: - for gid in range(len(graph_path) - 1, -1, -1): - graph = graph_path[gid] - for tensor in graph.initializer: - if tensor.name == name: - return tensor, graph - return None, None + def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): + if pack_tensor.shape[0] == ori_int_tensor.shape[0]: + ori_int_tensor = ori_int_tensor.T + pack_tensor = pack_tensor.T + if bits in [2, 4, 8]: + compress_ratio = pack_tensor.element_size() * 8 // bits + for j in range(0, compress_ratio): + pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + # from Official implementation of Half-Quadratic Quantization (HQQ) + def quantize_internal( + self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 + ): + import torch + + weight = tensor.float() + ori_shape = weight.shape + + pad_len = (group_size - ori_shape[axis] % group_size) % group_size + if axis == 1: + weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0) + else: + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0) + shape = weight.shape + + # Reshape for grouping + if (group_size is not None) and channel_wise: + weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1]) + + # Get min/max values + if channel_wise is False: + _min, _max = weight.min(), weight.max() + optimize = False + else: + _min = weight.min(axis=axis, keepdim=True)[0] + _max = weight.max(axis=axis, keepdim=True)[0] + + max_v = 2**bits - 1 + min_v = 0 + min_max = [min_v, max_v] + + # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on. + # clamp to avoid half-precision problems + scale = (max_v / (_max - _min)).clamp(max=2e4) + #!!!!!!!!!!!!!!! + min_max_axis = _max - _min + if (min_max_axis == 0).sum().item() > 0: + min_max_axis[min_max_axis == 0] = max_v + scale = (max_v / min_max_axis).clamp(max=2e4) + zero = -_min * scale + + if round_zero: + zero = torch.round(zero) + + # Fine-tune weights + if optimize: + scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis) + + # Quantize + # Necessary for fake quantization backprop + w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1]) + w_q = w_q.reshape(shape).int() + + scale = 1.0 / scale + if axis == 1: + scale = scale.reshape(shape[0], -1) + zero = zero.reshape(shape[0], -1) + else: + scale = scale.reshape(-1, shape[-1]) + zero = zero.reshape(-1, shape[-1]) + # cleanup + del weight, _min, _max + + return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) + + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + if node.op_type != "MatMul": + return node # only care about MatMul for now + import torch + + logger.info(f"start to quantize {node.name} ...") + inputB = node.input[1] # noqa: N806 + b_pb, bs_graph = get_initializer(inputB, graph_stack) + if b_pb is None: + logger.info("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + b_array = onnx.numpy_helper.to_array(b_pb) + if len(b_array.shape) != 2: + logger.info("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + b_array_torch = torch.from_numpy(b_array) + if torch.cuda.is_available(): + b_array_torch = b_array_torch.cuda() + quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal( + b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size + ) + quant_weight_torch = quant_weight_torch.contiguous() + scales_torch = scales_torch.contiguous() + zero_points_torch = zero_points_torch.contiguous() + + packed_torch = torch.zeros( + (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2), + dtype=torch.uint8, + device=quant_weight_torch.device, + ) + self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits) + scales = scales_torch.cpu().numpy() + zero_points = zero_points_torch.cpu().numpy() + b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) + b_quant.name = b_pb.name + "_Q4" + for input in bs_graph.input: + if input.name == inputB: + bs_graph.input.remove(input) + break + + scales_tensor = onnx.numpy_helper.from_array(scales) + scales_tensor.name = b_pb.name + "_scales" + bs_graph.initializer.extend([b_quant, scales_tensor]) + + input_names = [node.input[0], b_quant.name, scales_tensor.name] + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = b_pb.name + "_zero_points" + bs_graph.initializer.extend([zp_tensor]) + input_names.append(zp_tensor.name) + + kwargs = {} + rows, cols = b_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = self.config.bits + kwargs["block_size"] = self.config.block_size + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.info(f"complete quantization of {node.name} ...") + + return matmul_q4_node + + +def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + +class DefaultWeightOnlyQuantizer: + def __init__(self, config: DefaultWeightOnlyQuantConfig): + self.config = config def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: """4b quantize fp32 weight to a blob""" @@ -137,7 +380,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = fp32weight.shape - block_size = self.block_size + block_size = self.config.block_size blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size @@ -149,23 +392,19 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) + quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) return (packed, scales, zero_point) - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": return node # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - if node.name in self.nodes_to_exclude: - logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - return node - inputB = node.input[1] # noqa: N806 - B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 if B is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return node # only care about constant weight @@ -188,7 +427,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) Bs_graph.initializer.extend([B_quant, scales_tensor]) input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.is_symmetric: + if not self.config.is_symmetric: zp_tensor = onnx.numpy_helper.from_array(zero_points) zp_tensor.name = B.name + "_zero_points" Bs_graph.initializer.extend([zp_tensor]) @@ -199,8 +438,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 - kwargs["block_size"] = self.block_size - if self.accuracy_level is not None: + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( @@ -216,6 +455,38 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) return matmul_q4_node + +class MatMul4BitsQuantizer: + """Perform 4b quantization of constant MatMul weights""" + + def __init__( + self, + model: ModelProto | str, + block_size: int = 128, + is_symmetric: bool = False, + accuracy_level: int | None = None, + nodes_to_exclude=None, + algo_config: WeightOnlyQuantConfig = None, + ): + if nodes_to_exclude is None: + nodes_to_exclude = [] + self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) + self.model_path = model if isinstance(model, str) else None + self.block_size = block_size + self.is_symmetric = is_symmetric + self.accuracy_level = accuracy_level + self.nodes_to_exclude = set(nodes_to_exclude) + self.node_quantizer = None + if algo_config is None: + algo_config = DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + ) + self.algo_config = algo_config + if algo_config.algorithm == "HQQ": + self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) + elif algo_config.algorithm == "DEFAULT": + self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] @@ -246,8 +517,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - - new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) + out_node = None + if node.name in self.nodes_to_exclude: + logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + out_node = node + elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": + out_node = self.node_quantizer.quantize(node, graph_stack) + else: + out_node = self.node_quantizer.quantize(node, graph_stack) + new_nodes.append(out_node) graph.ClearField("node") graph.node.extend(new_nodes) @@ -300,7 +578,7 @@ def inc_dataloader(): from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize kwargs["percdamp"] = self.algo_config.percdamp - kwargs["blocksize"] = self.algo_config.blocksize + kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder kwargs["mse"] = self.algo_config.mse kwargs["perchannel"] = self.algo_config.perchannel @@ -316,7 +594,7 @@ def inc_dataloader(): logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): - if self.algo_config is None: + if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] opset_import = self.model.opset_import() @@ -327,7 +605,6 @@ def process(self): has_ms_domain = True if not has_ms_domain: opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - self._process_subgraph(graph_stack) self.model.clean_initializers() else: @@ -366,6 +643,14 @@ def parse_args(): parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization") + parser.add_argument( + "--quant_method", + default="default", + type=str, + choices=["default", "hqq"], + help="the algorithm used to quantize weight", + ) + parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") parser.add_argument( "--symmetric", required=False, @@ -411,12 +696,24 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) + if args.quant_method == "hqq": + quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) + elif args.quant_method == "default": + quant_config = DefaultWeightOnlyQuantConfig( + block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + ) + elif args.quant_method == "rtn": + quant_config = RTNWeightOnlyQuantConfig() + elif args.quant_method == "gptq": + quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) + else: + raise ValueError(f"Unsupported quantization method: {args.quant_method}") + quant = MatMul4BitsQuantizer( model=model, - block_size=args.block_size, - is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, nodes_to_exclude=args.nodes_to_exclude, + algo_config=quant_config, ) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..d294fd4e2b0e0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD +#include #include "core/common/span_utils.h" #include "core/framework/tensor.h" @@ -66,7 +67,9 @@ void QuantizeDequantize(std::vector& raw_vals, } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) { + bool has_zeropoint, bool use_float16, bool has_g_idx = false, + bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { + zp_is_4bit = zp_is_4bit | has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); @@ -113,12 +116,40 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddAttribute("block_size", block_size); test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", accuracy_level); + auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; }; + if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + + test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -132,9 +163,34 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); - } + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); + } test.AddOutput("Y", {M, N}, expected_vals); if (accuracy_level == 4) { test.SetOutputAbsErr("Y", 0.1f); @@ -158,6 +214,8 @@ TEST(MatMulNBits, Float32) { for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); + RunTest(M, N, K, block_size, accuracy_level, false, false, true); + RunTest(M, N, K, block_size, accuracy_level, true, false, false, false); } #endif } @@ -172,8 +230,10 @@ TEST(MatMulNBits, Float16) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - RunTest(M, N, K, block_size, 0, false, true); - RunTest(M, N, K, block_size, 0, true, true); + for (auto has_gidx : {true, false}) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx); + RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); + } } } } @@ -183,9 +243,9 @@ TEST(MatMulNBits, Float16) { TEST(MatMulNBits, Float16Large) { for (auto block_size : {16, 32, 64, 128}) { for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f); } } } diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index c1bbb49f10c7e..b30282f2ab41f 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -358,6 +358,7 @@ def check_model_correctness( model_onnx = onnx.load(f) ops_set = set(node.op_type for node in model_onnx.graph.node) check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + check_target_evaluator = False with open(model_path_to_check, "rb") as f: model_check = onnx.load(f) @@ -413,7 +414,7 @@ def check_model_correctness( check_sign_f8_quantization(model_path_origin, model_path_to_check) # Verifies the expected outputs. - if check_reference_evaluator and onnx_recent_enough: + if check_target_evaluator and onnx_recent_enough: if op_matmul: reference_new_ops = [QLinearMatMul] else: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 73dae08af8ece..88e5052db4e2e 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -125,7 +125,10 @@ def quant_test( from onnxruntime.quantization import matmul_4bits_quantizer model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric) + quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric + ) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) @@ -165,6 +168,9 @@ def quant_test_with_algo( elif algorithm == "GPTQ": # test GPTQ algorithm algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) + elif algorithm == "HQQ": + # test HQQ algorithm + algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) @@ -227,6 +233,17 @@ def test_quantize_matmul_int4_using_gptq_algo(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_using_hqq_algo(self): + if not find_spec("torch"): + self.skipTest("skip test_hqq_quant since torch is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) + if __name__ == "__main__": unittest.main() From cd56ea4a74ee41c040899d702667d2c86bee4ef0 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 5 Mar 2024 13:15:30 +0800 Subject: [PATCH 107/237] enable embedding sparse optimization by default (#19714) --- docs/ORTModule_Training_Guidelines.md | 2 +- .../training/ortmodule/_graph_execution_manager.py | 14 +++++++++----- .../python/training/ortmodule/options.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index f50b18b736936..84631bd1f6555 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -246,7 +246,7 @@ to standard outputs. #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input data sparsity based performance optimizations. ```bash diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index fda6e345da235..e189ffff9cc7f 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -681,11 +681,15 @@ def _enable_conditional_optimizations( ) if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: - graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) - self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) - self._runtime_options.embed_sparsity_ratio = ",".join( - [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] - ) + if detected_device.type == "cuda": + # Embedding sparsity optimization is only supported on CUDA devices. + graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) + self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) + self._runtime_options.embed_sparsity_ratio = ",".join( + [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] + ) + else: + self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.") # If users don't want to print input density, disable the input density observer to avoid overhead # when looping through inputs during training. diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 539859a0d58a6..93d24a34df6bd 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -271,7 +271,7 @@ def __init__(self, logger: Logger): self.enable_sparse_optimizer = True self.label_sparsity_ratio = "" self.embed_sparsity_ratio = "" - self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done. + self.enable_embedding_sparse_optimizer = True # Configuration for memory optimization. self.memory_optimization_level = ( From bdf678df93cb257e311de3fa82fe6409be2854ff Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 5 Mar 2024 17:09:42 +0100 Subject: [PATCH 108/237] Fix CUDA BatchNorm bugs and add support for NHWC (#19742) ### Description - Fix incorrect running_mean / running_var in training mode due to incorrect momentum and missing input mean/var. runnig_var could be correct, but has a too high epsilon. - Fix incorrect checks when using NHWC - Pass NHWC flag to NormalizeDims to get correct new dimensions from x_shape - Register missing double operations to get parity between NHWC/NCHW --- .../core/providers/cpu/nn/batch_norm_helper.h | 41 +++++++++++++------ .../providers/cuda/cuda_execution_provider.cc | 18 +++++--- .../core/providers/cuda/cuda_nhwc_kernels.cc | 16 ++++++++ .../core/providers/cuda/nn/batch_norm.cc | 11 ++++- .../providers/cpu/nn/batch_norm_op_test.cc | 1 + 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index a5d46aff83b50..ccecbabfa3db3 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -25,6 +25,8 @@ class BatchNormHelper { const Tensor* var, bool is_spatial = true, bool is_nhwc = false) { + // NHWC dependent shape: X + // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. @@ -48,16 +50,22 @@ class BatchNormHelper { // validate 'scales' shape const auto& scale_dims = scale->Shape().GetDims(); if (static_cast(scale_dims.size()) != kNumInputScaleDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); } if (scale_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels); } + // N & C do not belong to features + // skip the first element for NHWC and the first two elements for NCHW. + int feature_offset = is_nhwc ? 1 : 2; + // in non-spatial cases - the other dims of 'scale' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (scale_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -65,7 +73,8 @@ class BatchNormHelper { // validate 'B' shape const auto& B_dims = B->Shape().GetDims(); if (static_cast(B_dims.size()) != kNumInputBiasDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); } if (B_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels); @@ -73,8 +82,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'B' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (B_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (B_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -82,16 +92,19 @@ class BatchNormHelper { // validate 'mean' shape const auto& mean_dims = mean->Shape().GetDims(); if (static_cast(mean_dims.size()) != kNumInputMeanDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); } if (mean_dims[0] != num_channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: 0th dimension != ", num_channels); } // in non-spatial cases - the other dims of 'mean' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (mean_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -99,7 +112,8 @@ class BatchNormHelper { // validate 'var' shape const auto& var_dims = var->Shape().GetDims(); if (static_cast(var_dims.size()) != kNumInputVarianceDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); } if (var_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels); @@ -107,8 +121,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'var' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (var_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (var_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1ce089fd93044..8ba282031a5d4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); @@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index f416caecd115f..64edc319e15ac 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -18,10 +18,14 @@ namespace onnxruntime::cuda { class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, @@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double, + BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); @@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo::ComputeInternal(OpKernelContext* p_op_kernel_context) CudnnTensor data_desc; vector new_dims; - BatchNormHelper::NormalizeDims(x_shape, new_dims); + BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC); ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type @@ -137,6 +137,12 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); + auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( GetCudnnHandle(p_op_kernel_context), cudnn_batch_norm_mode_, @@ -149,7 +155,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) bn_tensor_desc, scale_data, b_data, - momentum_, + 1.0 - momentum_, running_mean_data, running_var_data, epsilon_, @@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) #ifdef ENABLE_CUDA_NHWC_OPS SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true) SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) #endif } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index cbb4531a50b7c..54e5c71bd753a 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", + // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } From 06e684c9f2f8495de5259967cc12bab24da3d522 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:37:45 -0800 Subject: [PATCH 109/237] Adding cuda kernel (optimized for sm80) for block-wise 4b quantized float 16 GEMM. (#18619) ### Description Adding CUDA kernel for block-wise 4b quantized float 16 GEMM, this is specially optimized for Nvidia Ampere GPUs. ### Motivation and Context Trying to improve quantized LLM inference performance on Nvidia Ampere GPUs ### Note: This is implemented by extending CUTLASS, so it has a hard dependency on CUTLASS. However, in current build system, loading of CUTLASS dependency is guarded with: (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) If both of these options are turned off, then compilation will fail. Why CUTLASS dependency is guarded at all? It's a header file only library that does not introduce any binary if not instantiated. What's the downside of removing all the guards and just include CUTLASS unconditionally? --- .lintrunner.toml | 1 + cmake/CMakeLists.txt | 5 +- cmake/onnxruntime_providers_cuda.cmake | 2 +- cmake/onnxruntime_unittests.cmake | 1 + onnxruntime/core/mickey/README.md | 4 + .../core/mickey/blk_q4/f16_gemm_sm80.h | 208 +++ .../{prepack_sm80.h => f16_prepack_sm80.h} | 2 +- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 481 ++++++ .../q4gemm/kernel/default_quantb_gemm.h | 255 ++++ .../cutlass_ext/q4gemm/kernel/quantb_gemm.h | 462 ++++++ .../q4gemm/threadblock/default_quantb_mma.h | 248 ++++ .../threadblock/default_quantb_mma_core.h | 340 +++++ .../optional_predicated_tile_access_iter.h | 314 ++++ .../optional_regular_tile_access_iter.h | 224 +++ .../threadblock/quantb_mma_multistage.h | 1290 +++++++++++++++++ .../warp/default_quantb_mma_tensor_op.h | 112 ++ .../quantb_meta_mma_tensor_op_tile_iterator.h | 883 +++++++++++ .../q4gemm/warp/quantb_mma_tensor_op.h | 361 +++++ onnxruntime/core/util/matrix_layout.h | 1 - .../test/cuda_host/blkq4_fp16_quant_sm80.h | 203 +++ .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 188 +++ .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 330 +++++ .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 344 +++++ .../blkq4_fp16_sm80_prepack_test.cc | 507 ------- .../cuda_execution_provider_test.cc | 4 +- 25 files changed, 6257 insertions(+), 513 deletions(-) create mode 100644 onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h rename onnxruntime/core/mickey/blk_q4/{prepack_sm80.h => f16_prepack_sm80.h} (99%) create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h create mode 100644 onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu delete mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc diff --git a/.lintrunner.toml b/.lintrunner.toml index 4e5d077b08ff4..be95e03479cf9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -132,6 +132,7 @@ exclude_patterns = [ 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', 'onnxruntime/core/mlas/**', # Contains assembly code + 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks ] command = [ diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8453da19ce3a6..0d55d4cab9826 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -727,6 +727,9 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) @@ -747,8 +750,8 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 0f6d48bdb6ec8..7f295a59a0931 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -201,7 +201,7 @@ endif() include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 88f662075e177..b004054c616a5 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -774,6 +774,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md index 7e8d30cd1805b..735ec4b80daf3 100644 --- a/onnxruntime/core/mickey/README.md +++ b/onnxruntime/core/mickey/README.md @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that are often shared by various AI operators. The intention is to make this header files only, with no binary impact unless it is instantiated where it is needed. + +Currently cuda code are scattered in multiple locations in the repo. +Hopefully this can be the starting point of consolidating all cuda +code. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h new file mode 100644 index 0000000000000..52bff7e40dbe3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -0,0 +1,208 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blk_q4/f16_gemm_sm80.h + * + * Abstract: + * Entry point for Q4F16 GEMM kernel for SM80 devices. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" + +namespace onnxruntime { +namespace cuda { + +// +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type +// +template < + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape + bool SmallM, // <- true if M <= 16 + bool kHasQuantOffset> +struct BlkQ4F16GemmImpl { + // + // Type definitions + // + + using ElementDequant = ElementDequant_; + using QuantBlocking = QuantBlocking_; + + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); + + // Data types that are fixed for this kernel + using ElementAccumulator = float; + using ElementComputeEpilogue = ElementAccumulator; + using ElementInputA = ElementDequant; + using ElementOutput = ElementDequant; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators + // for async shared memory loading and minimize bank conflict + using ElementWPack = ElementDequant; + + using ElementQScale = ElementDequant; // <- data type of quantization scale + using ElementQOffset = uint8_t; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputWPack = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // Layout of quantization scale and offset, oriented to be loaded using less instructions + // in a warp tile + using LayoutInputQScale = + typename std::conditional::type; // <- layout of quantization scale + + using ShapeMMAThreadBlock = + typename std::conditional, + cutlass::gemm::GemmShape<128, 256, 64>>::type; + + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; + using ShapeMMAWarp = + typename std::conditional, + cutlass::gemm::GemmShape<64, 64, 64>>::type; + + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + // Number of pipelines you want to use + static constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::QuantBGemm< + ElementInputA, + LayoutInputA, + ElementWPack, + LayoutInputWPack, + ElementQScale, + typename std::conditional::type, + LayoutInputQScale, + QuantBlocking, + ElementOutput, + LayoutOutput, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages>; + + using Arguments = typename Gemm::Arguments; + + // Invoke gemm kernel (the version with quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_Qoffset_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (!kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_Qoffset_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } + + // Invoke gemm kernel (the version without quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h similarity index 99% rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index e291ab39e8aa3..a08cfb97eed4a 100644 --- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -3,7 +3,7 @@ * Licensed under the MIT License. * * Module Name: - * prepack_sm80.h + * blk_q4/f16_prepack_sm80.h * * Abstract: * Prepack weights and quantization parameters (scales and offsets) for diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h new file mode 100644 index 0000000000000..38795291b0328 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! A specialized GEMM operator for quantized B GEMM. + + It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class + are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters + and controls to it. The only difference is that this class has a few more template parameters + to support quantization. + + This implementation pretty much follows the design of cutlass. But this class seems to be + just a wrapper of the Gemm kernel class. Consider combining them in future iterations. + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class QuantBGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Quantization Parameters + static_assert(std::is_same::value, + "LayoutB, i.e. packed weights must appear ColumnMajor."); + static_assert(InstructionShape::kK == 16, + "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + static constexpr bool kHasQOffset = !(std::is_same::value); + + // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset + static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); + + /// Define the kernel + using GemmKernel = typename kernel::DefaultQuantBGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; + + typename EpilogueOutputOp::Params epilogue; + + // split-K parallelism (etc.) are not yet supported, keeping this for future extension + int split_k_slices{1}; + // For gather+scatter operations + int const *gather_A_indices{nullptr}; + int const *gather_B_indices{nullptr}; + int const *scatter_D_indices{nullptr}; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); + } + + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); + } + }; + + private: + /// Kernel parameters object + typename GemmKernel::Params params_; + + public: + /// Constructs the GEMM. + QuantBGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data()); + params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: " + << cudaGetErrorString(result) << "\n"; + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h new file mode 100644 index 0000000000000..2f4460bb59e9f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining + * threadblock-scoped matrix multiply-add with the appropriate + * threadblock-scoped epilogue. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/layout/permute.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Access granularity of quant scales in units of elements + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void +> +struct DefaultQuantBGemm; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout type for quant scales + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Access granularity of quant scales in units of elements + typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultQuantBGemm { + + static_assert((platform::is_same::value + || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::QuantBGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h new file mode 100644 index 0000000000000..6e5ad8f406147 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -0,0 +1,462 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/gemm.h. + * Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct QuantBGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static constexpr bool kHasQOffset = Mma::kHasQOffset; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorQScale::Params params_QScale; + typename Mma::IteratorQScale::TensorRef ref_QScale; + typename Mma::IteratorQOffset::Params params_QOffset; + typename Mma::IteratorQOffset::TensorRef ref_QOffset; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; // how many k vectors are processed by this threadblock + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, + int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + QuantBGemm() { } + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + // TODO check problem_size K, N must be multiple of QuantBlocking + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (problem_size.k() % Mma::Shape::kK != 0) { + // Currently we don't support this case due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. We can fix this by adding a predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return Status::kErrorInvalidProblem; + } + + int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow; + int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn; + if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + + if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + + if constexpr(kHasQOffset) { + if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k/2, params.problem_size.n()/2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + + // should have been verified by can_implement() + assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k)); + assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); + + cutlass::MatrixCoord tb_offset_QScale{ + threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) + }; + + typename Mma::IteratorQScale iterator_QScale( + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); + + typename Mma::IteratorQOffset iterator_QOffset( + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + const int warp_idx = canonical_warp_idx(); + const int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h new file mode 100644 index 0000000000000..0af604f090e1f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma.h + * @brief Modified from cutlass/gemm/threadblock/default_mma.h. + * Defining global memory data layout and iterators, combinging with mma core and + * pipelined GEMM kernel. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h" +#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout for quant scales and offsets + typename LayoutQMeta_, + /// Blocking size for quantization + typename QuantBlocking_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute + > +struct DefaultQuantBMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout for quant scales and offsets + typename LayoutQMeta, + /// Blocking size for quantization + typename QuantBlocking, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout + > +struct DefaultQuantBMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; + + // Define iterators over tiles from the quant scales + using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale; + using AccessTypeQScale = + cutlass::Array; + using IteratorQScale = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, + ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>; + + using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; + using AccessTypeQOffset = + cutlass::Array; + using IteratorQOffset = + cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale, + cutlass::arch::CacheOperation::Global, IteratorQOffset, + typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global, + ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h new file mode 100644 index 0000000000000..ad322f6505200 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_core.h + * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h. + * Defining data layout in shared memory, and its iterators. + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Element data type of quant scale + typename ElementQScale, + /// Element data type of quant offset + typename ElementQOffset, + /// Layout of quant scale + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DefaultQuantBMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Element data type of quant scale + typename ElementQScale_, + /// Element data type of quant offset + typename ElementQOffset_, + /// Layout of quant scale + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultQuantBMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK/2>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + using SmemLayoutQScale = LayoutQMeta; + using SmemLayoutQOffset = LayoutQMeta; + + /// Threadblock-level quantization meta data shape + using ThreadblockQShape = MatrixShape; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!"); + static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, + "Only support single column or row quantize blocking!"); + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + + /// Threadblock-level quantization meta data shape in pitch-linear layout + using TBQPitchLinearShape = typename std::conditional< + std::is_same::value, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// By default we would like to use 128b load. However, we can't load more than + /// a column at a time in a column major layout. + static int const kElementsPerAccessQScale = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + + /// quant scale is tiny. Not all threads are needed. + static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale; + static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale; + + using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>; + + using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + + static int const kElementsPerAccessQOffset = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset; + static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset; + + using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>; + + using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h new file mode 100644 index 0000000000000..6f27a692a3a2e --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -0,0 +1,314 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_predicated_tile_access_iter.h + * @brief Templates for loading and storing optional tiles of matrix data. + * This iterator is just a wrapper of PredicatedTileAccessIterator, with + * the option to turn it off at compile time and minimize its runtime + * footprint. Also, it utilize the higher numbered threads in the + * threadblock when the iterator can not utilize all the threads. + */ + +#pragma once + +#include + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D matrix data loader, when element is std::monostate, the +/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the +/// higher numbered threads in the threadblock when the iterator can not utilize +/// all the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + /// Element data type of the iterator, no-op when it is std::monostate + typename Element_, + /// Layout of the source matrix + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + /// Number of threads in the threadblock, when provided, the iterator + /// will utilize the higher numbered threads + int kThreadBlockSize_ = -1> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kAdvanceRank = AdvanceRank_; + static constexpr int kThreadblockSize = kThreadBlockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized version below."); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + public: + Base base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + base_.add_tile_offset(tile_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + OptionalPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + base_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + base_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + base_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + base_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return base_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for the disabled version +/// Reduce runtime overhead +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + int kThreadBlockSize_> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kThreadblockSize = kThreadBlockSize_; + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + public: + std::monostate base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_() {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) {} + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) {} + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() {} + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) {} + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) {} + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return false; } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h new file mode 100644 index 0000000000000..4b0ae5317f8bb --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_regular_tile_access_iter.h + * @brief Templates implementing the address computation of storing of tiles + * from pitch-linear rank=2 tensors. + * + * This iterator is just a wrapper of RegularTileAccessIterator, with the + * option to turn it off at compile time and minimize its runtime footprint. + * Also, it utilize the higher numbered threads in the threadblock when the + * iterator can not utilize all the threads. + * + * Must be used in conjunction with OptionalPredicatedTileAccessIterator, + * with the same template parameters. + */ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D tile iterator, when element is std::monostate, the iterator +/// becomes no-op with minimal runtime footprint. Also, it utilize the higher +/// numbered threads in the threadblock when the iterator can not utilize all +/// the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Element_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + /// Number of threads in the threadblock, when not -1, the iterator + /// will utilize the higher numbered threads + int ThreadblockSize_ = -1, + int Alignment = + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized template"); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + private: + + Base base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_(ref, flip_thread_id(thread_id)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + base_.add_tile_offset(coord); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization when Element is std::monostate, the iterator becomes no-op +/// +template < + typename Shape_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + int ThreadblockSize_, + int Alignment> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + private: + + std::monostate base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + return *this; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) {} +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h new file mode 100644 index 0000000000000..8b6bac8c5099a --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -0,0 +1,1290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_multistage.h + * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h. + * Added the quantized data memory pipeline, dequantization, and feeding + * to tensor cores. Mainloop pipeline is heavily modified. + */ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Utilities for printing layout for the prepacked weights and quantization parameters +/// +template< + /// Data type of the prepacked weights + typename ElementWeight, + /// Data type of the quant scales + typename ElementQScale, + /// Data type of the quant offsets + typename ElementQOffset> +struct QuantBLayoutDebug{ + static constexpr bool debug_smem = true; + static constexpr bool debug_fragment = true; + ElementWeight* smem_b_ptr_; + ElementQScale* smem_qscale_ptr_; + ElementQOffset* smem_qoffset_ptr_; + int warp_id_; + int lane_id_; + int block_id_; + + template + CUTLASS_DEVICE + static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + static_assert(Size % 4 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const Element* ptr = reinterpret_cast(&frag); + for (int i = 0; i < Size/4; i++, ptr+=4){ + if constexpr(std::is_integral::value){ + printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", + threadIdx.x, label, i, + ptr[0], ptr[1], ptr[2], ptr[3]); + } else { + printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n", + threadIdx.x, label, i, + float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3])); + } + } + } + } + } + + template + CUTLASS_DEVICE + static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; + static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const uint8_t* ptr = reinterpret_cast(&frag); + for (int i = 0; i < I8Size/2; i++, ptr+=2){ + printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); + } + } + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dummy type when quant offset is not used, to avoid compilation error, +/// and reduce runtime footprint +/// +struct DummyType{ + std::monostate dummy_; + public: + DummyType() = default; + + CUTLASS_HOST_DEVICE + void* data() const { + return nullptr; + } + + CUTLASS_HOST_DEVICE + std::monostate& operator[](int idx) { + return dummy_; + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static constexpr bool kHasQOffset = !std::is_same::value; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the prepacked weights + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // Tensor reference to the quantization scales + using TensorRefQScale = TensorRef; + using TensorRefQOffset = TensorRef; + + // Block size of the quantization (one set of quantization parameters per block of weights) + using QuantBlocking = typename Operator::QuantBlocking; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the prepacked weights in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the quantization parameter matrix in shared memory + /// Validation done in mma core class ThreadblockQShape + using ShapeQScale = + MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages, + Shape::kN / QuantBlocking::kColumn>; + + using BufTypeQOffset = std::conditional_t, + DummyType>; + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for prepacked weights + AlignedBuffer operand_B; + + /// Buffer for quantization scales + AlignedBuffer operand_QScale; + + /// Buffer for quantization offsets + BufTypeQOffset operand_QOffset; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQScale LayoutQMeta() { + return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQOffset LayoutQOffset() { + return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the prepacked weights + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the quantization scales + CUTLASS_HOST_DEVICE + TensorRefQScale operand_QScale_ref() { + return TensorRefQScale{operand_QScale.data(), LayoutQMeta()}; + } + + CUTLASS_HOST_DEVICE + TensorRefQOffset operand_QOffset_ref() { + if constexpr (!kHasQOffset){ + return TensorRefQOffset(); + } else { + return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; + } + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of quant scales from shared memory + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over tiles of quant scales in global memory + typename IteratorQScale_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQScale_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQScale, + /// Iterators over tiles of quant scales in global memory + typename IteratorQOffset_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQOffset_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQOffset, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaMultistage : + public QuantBMmaBase { +public: + ///< Base class + using Base = QuantBMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using IteratorQScale = IteratorQScale_; + using IteratorQOffset = IteratorQOffset_; + using SmemIteratorQScale = SmemIteratorQScale_; + using SmemIteratorQOffset = SmemIteratorQOffset_; + using QuantBlocking = typename Base::QuantBlocking; + + static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale; + static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset; + static constexpr bool kHasQOffset = Base::kHasQOffset; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of packed weights + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQScale = + IteratorQScale::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant scale + static int const kAccessesPerGroupQScale = + (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQOffset = + IteratorQOffset::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant offset + static int const kAccessesPerGroupQOffset = + (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + }; + + private: + + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using WarpLoadedFragmentQScale = typename Operator::FragmentQScale; + WarpLoadedFragmentQScale warp_loaded_frag_QScale_; + + using WarpLoadedFragmentQOffset = typename std::conditional::type; + WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of quant meta data to shared memory + SmemIteratorQScale smem_iterator_QScale_; + SmemIteratorQOffset smem_iterator_QOffset_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + /// very small meta data tensor require less threads to load + bool const should_load_qscale_; + bool const should_load_qoffset_; + + /// Shared memory pointers for debug dumping + static constexpr bool debug_layout = false; + using LayoutDebugType = typename std::conditional, + std::monostate>::type; + LayoutDebugType layout_debug_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + if constexpr(debug_layout){ + layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); + layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); + if constexpr(kHasQOffset){ + layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); + } else { + layout_debug_.smem_qoffset_ptr_ = nullptr; + } + layout_debug_.warp_id_ = warp_idx; + layout_debug_.lane_id_ = lane_idx; + layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + } + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_QScale_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + IteratorQScale &iterator_QScale, + IteratorQOffset &iterator_QOffset) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_QScale.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_QScale_.add_tile_offset({1, 0}); + + if constexpr (kHasQOffset) { + iterator_QOffset.add_tile_offset({1, 0}); + smem_iterator_QOffset_.add_tile_offset({1, 0}); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); + if constexpr (kHasQOffset) { + smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); + } + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, + // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only + // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so + // it should be loaded in less than one cp.async instruction per thread. + // Even less for quant offset matrix. + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, + "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, + "Quant scale should 1 access per vector!"); + + // Async Copy for quantization scale + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); + } + + CUTLASS_DEVICE + void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, + "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, + "Quant offset should 1 access per vector!"); + + if constexpr(kHasQOffset) { + // Async Copy for quantization offset + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start = 0) { + auto group_start_A = group_start * Detail::kAccessesPerGroupA; + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + auto group_start_B = group_start * Detail::kAccessesPerGroupB; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Async Copy for quantization scale + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); + + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_QScale.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_QScale.valid()); + + if constexpr (kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + + // Async Copy for quantization offset + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { + if (threadIdx.x == 0){ + printf("stage: %d\n", smem_write_stage_idx_); + } + cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); + if constexpr(kHasQOffset){ + cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); + } + } + } + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Loading next warp-level tiles from shared memory. This can be skipped on the very + // last iteration where: + // (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1)) + // However, evaluating this condition seems more expensive than simply loading the tiles + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k == 0) { + copy_qscale_tiles(iterator_QScale); + } + if (warp_mma_k == 1) { + copy_qoffset_tiles(iterator_QOffset); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + /// Specialized mainloop iteration of matrix multiply-accumulate, for small M + CUTLASS_DEVICE + void mac_loop_iter_small_m( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // In the case of small M, memory latency dominates. We try to move uses far + // from their definitions to hide latency. + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Loading next warp-level tiles from shared memory. + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + copy_tiles_and_advance(iterator_A, iterator_B, 0); + + if constexpr(Shape::kM > 32) { + // the case of bigger m + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } else { + // the case of small m + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + } + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + if constexpr(Shape::kM > 32) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } else { + mac_loop_iter_small_m( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over quant scales in global memory + IteratorQScale iterator_QScale, + ///< Iterator over quant offsets in global memory + IteratorQOffset iterator_QOffset, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..2c49888c94504 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h + * Default warp-level GEMM operators selected by data type, size, and layouts of operands. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Data type of quant scales + typename ElementQScale, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale, + /// Data type of quant offsets + typename ElementQOffset, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset, + /// Blocking size of quantization + typename QuantBlocking, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultQuantBMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::QuantBMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h new file mode 100644 index 0000000000000..4ba39dda3db8d --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -0,0 +1,883 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file quantb_meta_mma_tensor_op_tile_iterator.h + * @brief Templates for loading quantization meta data for operand B + * from shared memory to fragments. This is meant to be used in + * lock step with the operand B tile iterator. Containing logic + * to figure out the operand B layout in the tensor core, + * and deliver each meta data element to its corresponding + * operand B element for dequantization. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace{ + +struct b32_pair{ + uint32_t a; + uint32_t b; +}; + +struct fp16_quad{ + cutlass::half_t a; + cutlass::half_t b; + cutlass::half_t c; + cutlass::half_t d; +}; + +struct b16_quad{ + int16_t a; + int16_t b; + int16_t c; + int16_t d; +}; + +union b64 { + uint64_t single; + b32_pair pair; + b16_quad quard; + fp16_quad fp16_quad; +}; + +static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); + +/// Convert packed 4b weights into fp16(weight + 16) +/// Current bit hacking only supports fp16, need to add bf16 later. +/// +template +CUTLASS_DEVICE +void weights2Half(cutlass::Array const &weights, + cutlass::Array& dest) +{ + static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const uint32_t* w_oct = reinterpret_cast(weights.data()); + + CUTLASS_PRAGMA_UNROLL + for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + // static_cast(16 + weight) + // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights + // are in different 16b half words, making it easier to convert to fp16. + asm volatile( + "{\n\t" + " shl.b32 %0, %4, 6;\n" + " shl.b32 %1, %4, 2;\n" + " shr.u32 %2, %4, 2;\n" + " shr.u32 %3, %4, 6;\n" + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" + "}\n" + : "=r"(dest_pair[0]), "=r"(dest_pair[1]), + "=r"(dest_pair[2]), "=r"(dest_pair[3]) + : "r"(*w_oct)); +#else + assert(0); +#endif + } + +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +// Traits to describe the layout of quantization meta data layout in a MMA fragment +// Since operand B is quantized on a per block basis, it's one meta data per block. + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ArchMmaOperator = ArchMmaOperator_; + + static_assert(Threads == 32, "This iterator should work in a warp only."); + + /// Shape of the curresponding operand B tile iterator + using TileShapeB = MatrixShape; + + // Tensor core operand B layout is a column major 4x8 tile, divided + // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b, + // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8. + // T0 | T4 | T8 | T12 | T16 | T20 | T24 | T28 + // T1 | T5 | T9 | T13 | T17 | T21 | T25 | T29 + // T2 | T6 | T10 | T14 | T18 | T22 | T26 | T30 + // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 + using CoreTile = layout::PitchLinearShape<4, 8>; + + /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8 + static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; + + /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) + static int const kBTilesPerMma = + sizeof_bits::value * ArchMmaOperator::FragmentB::kElements / 32; + static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma."); + + /// Each operand B tile iterator load covers a number of mma instructions + static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN; + + /// Number of B elements a fragment of meta data should cover + static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB; + + // Now we figure out how many meta data elements to load for each TileShapeB + + /// Number of meta elements per CoreTile. + static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension + /// exceeds the tile depth, so two tiles share the same meta data + static int const kTilesPerMma = ((kBTilesPerMma == 2) && + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 : 1; + + /// stride to reach the meta data for the next CoreTile on the K dimension + static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension. + static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; + + /// On N dimension, how many tiles share the same meta data + static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; + + /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension. + /// When blocking size on this dimension exceeds the tile width, multiple iterations + /// would share the same data. + static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; + + static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations; + + CUTLASS_DEVICE + static MatrixCoord lane_position(int lane_id) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The scale and offset tensors are prepacked to reduce the number of load instructions. + return make_Coord((lane_id % CoreTile::kContiguous) * 4, + lane_id / CoreTile::kContiguous); + } else { + return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, + lane_id / CoreTile::kContiguous); + } + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is to load quantization meta data for operand B from +/// shared memory to fragments (hopefully allocated to registers by compilers). +/// Examples of meta data include scale or offsets. The operand B matrix is +/// quantized on a per block basis, meaning one element of meta data per block. +/// +/// This is meant to be used in lock step with the operand B tile iterator. +/// So all parameters are logical positions in the operand B tiles. +/// The goal here is to deliver each meta data element to its corresponding +/// operand B element for dequantization. As a result, we need to figure +/// out the operand B layout in the tensor core. +/// +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class QuantBMetaMmaTensorOpTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using Layout = cutlass::layout::ColumnMajor; + using ElementOffset = ElementOffset_; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, + "Only support row blocking for column major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + + using AccessTypeScale = Array; + using AccessTypeOffset = Array; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)){} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + // The scale and offset tensors are prepacked to reduce the number of load instructions needed + const int row = lane_position_.row(); + const int column = lane_position_.column() / BlockingShape::kColumn; + + Array *dst_ptr = reinterpret_cast*>(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + + if constexpr(kHasOffset){ + Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + *dst_ptr_offset = *src_ptr_offset; + dst_ptr_offset++; + } + } + + } else { + // Other cases, offsets and scales are not prepacked. + + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + + AccessTypeScale* dst_ptr = reinterpret_cast(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + + if constexpr(kHasOffset){ + AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); + static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + if constexpr(kBTilesPerMma == 2){ + // Optimize for a special case of: + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const b64* scales_ptr = reinterpret_cast(scales.data()); + const ElementOffset* offsets_ptr = nullptr; + if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } + + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + // dequantize: d = scale * (weight - offset) + // to use FMA, d = scale * weight + (scale * (-offset)) + + b64 offsets; + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets_ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + + offsets_ptr += 4; + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); +#endif + } + + CUTLASS_PRAGMA_UNROLL + for (int n_r = 0; n_r < kNRepeats; n_r++){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %1, %3, %1, %5;\n" + "}\n" + : "+r"(dest_pair[0]), "+r"(dest_pair[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(offsets.pair.a), "r"(offsets.pair.b)); +#else + assert(0); +#endif + dest_pair += 2; + } + scales_ptr++; + } + + } else { + // unoptiomized path for other cases, very slow + int out_idx = 0; + ElementScale offset; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ElementScale s = scales[idx]; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - int(offsets[idx])); + } else { + offset = s * static_cast(-16-8); + } + dest[out_idx] = s * dest[out_idx] + offset; + out_idx++; + } + } + } + + } + + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using ElementOffset = ElementOffset_; + using Layout = cutlass::layout::RowMajor; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, + "Only support column blocking for row major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) + {} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); + + ElementScale* src_ptr = pointer_ + layout_({row, column}); + ElementScale* dst_ptr = frag.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr[n_idx] = src_ptr[n_idx * kNStride]; + } + + if constexpr(kHasOffset){ + ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column}); + ElementOffset* dst_ptr_offset = frag_offset.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride]; + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; + int idx = col * kMmaIterations + n_idx; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + ElementScale addon[kMmaIterationsB]; + if constexpr (kMmaIterationsB % 4 == 0) { + const b64* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + scales_ptr++; + p++; + addon_ptr += 2; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + assert(0); +#endif + scales_ptr++; + addon_ptr += 2; + } + } + } else if constexpr (kMmaIterationsB % 2 == 0) { + const uint32_t* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + + if constexpr (kHasOffset){ + // possible buffer over read 2 bytes here. + const uint32_t* p = reinterpret_cast(offsets.data()); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) + "r"(p[0])); +#else + assert(0); +#endif + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); +#else + assert(0); +#endif + } + } else { + // kMmaIterationsB == 1 + if constexpr(kHasOffset){ + uint8_t zp = offsets[0]; + addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); + } else { + addon[0] = scales[0] * static_cast(-16-8); + } + } + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; + dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; + out_idx += 2; + } + } + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..f29cedf326a44 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -0,0 +1,361 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h + * Templates implementing warp-level matrix multiply-accumulate operations + * targeting tensor cores. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class QuantBMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + // warp B MatrixShape<64, 64>, + // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, + // instruction op shape cutlass::MatrixShape<16, 8>, + // kPartitionsK 1 + // FragmentB::kElements 32 + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; // cutlass::Array + + /// Storage for transformed B tile + /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded + /// we multiply the number of elements by 4. + /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB + /// and change the transform function below to perform dequantization + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + using ElementQScale = ElementQScale_; + using SmemLayoutQScale = SmemLayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementQOffset = ElementQOffset_; + using SmemLayoutQOffset = SmemLayoutQOffset_; + + /// Iterates over the quantization parameters in memory + using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!"); + + // TODO This is an expanding iterator, it needs to replicate the quantization parameters + // to all threads in the warp. + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; + + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + QuantBMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentB &dst_B, + FragmentB const &B, + FragmentQScale const &scales, + FragmentQOffset const &offsets) const { + + Array const *ptr_B = + reinterpret_cast const *>(&B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index a0405e32034ae..783a29d8a2055 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -17,7 +17,6 @@ #include #include "core/common/gsl.h" -// TODO!! Already have this in cuda, what about cpu code though? #if defined(_MSC_VER) #define ORT_FORCEINLINE __forceinline #else diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h new file mode 100644 index 0000000000000..6ea8b55505214 --- /dev/null +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -0,0 +1,203 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_quant_sm80.h + * + * Abstract: + * Oracle computation for blockwise 4b quantization for fp16 + * gemm kernel specifically for Ampere GPUs. This is used for + * testing the cuda kernel implementation in + * (test/providers/cuda/test_cases) + * and for testing the cuda op prepack code in (test/optimizer) + */ + +#pragma once + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace test { + +static inline void sm80_prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns, + "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (", + tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ")."); + ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2, + "tensor_weight_prepacked shape is not compatible with prepacked weight shape"); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +inline void sm80_prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn), + "Unexpected tensor_scale shape! Expected: (", + rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); + ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values."); + } + + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } +} + +template +inline void sm80_prepack_quant_offsets_ref( + int rows, + int columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_offsets_ref should only be called for row-wise block quantization."); + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h new file mode 100644 index 0000000000000..bbe370675fc48 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -0,0 +1,188 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80.h + * + * Abstract: + * Bridge between gtest code and gemm kernel implementation. + * Gemm kernel requires CUTLASS header files, which causes strange + * compilation errors with RE2 header files, which are required + * by gtest. + */ + +#pragma once + +#include + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" +#include "test/cuda_host/blkq4_fp16_quant_sm80.h" + +namespace onnxruntime { +namespace cuda { +namespace test { + +Status sm80_supported(); + +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t v = dis(gen); + uint32_t m = (v % 63) + 1; + uint32_t e = (v >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr (has_offsets) { + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr (has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } +} + +template < + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc new file mode 100644 index 0000000000000..e687ae73e66f2 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -0,0 +1,330 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_test.cc + * + * Abstract: + * Test code for block-wise quantized 4b GEMM kernels. + * This part requires gtest header files, which do not play + * well with CUTLASS headers. + */ + +#include + +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas_q4.h" + +#include "blkq4_fp16_gemm_sm80.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +void testPrepack(int rows, int columns) { + using ElementT = MLFloat16; + constexpr int block_size = 32; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + rows, columns, dequants, q_weights, q_scales, q_zp); + + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + MatrixRef tensor_scale( + q_scales, meta_shape); + MatrixRef tensor_offset; + if constexpr (has_offset) { + tensor_offset = MatrixRef(q_zp, zp_shape); + } + + // for quantization tool, the input is row major, test weight gen output is column major + std::vector dequants_transposed(dequants.size()); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + MatrixRef tensor_dequant_transposed(dequants_transposed, make_Position(rows, columns)); + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + tensor_dequant_transposed.at(row, col) = tensor_dequant.at(row, col); + } + } + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape( + block_size, col_blocking, rows, columns, q_rows, q_cols); + // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes + EXPECT_EQ(q_rows, q_weight_shape[0]); + EXPECT_EQ(q_cols, q_weight_shape[1]); + + // + // Quantization tool outputs: + // + std::vector o_elements(q_rows * q_cols); + MatrixRef tensor_o_elements(o_elements, q_weight_shape); + + std::vector o_scales(meta_shape.product()); + MatrixRef tensor_o_scales(o_scales, meta_shape); + + std::vector o_zp(zp_shape.product()); + MatrixRef tensor_o_zp(o_zp, zp_shape); + + MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, + dequants_transposed.data(), block_size, + col_blocking, rows, columns, columns, nullptr); + for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { + for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { + EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) + << "quantized value mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + if (has_offset) { + uint8_t pair01 = tensor_o_zp.at(row / 2, col); + uint8_t expected_pair01 = tensor_offset.at(row / 2, col); + EXPECT_EQ(expected_pair01 & 0xf, pair01 & 0xf) + << "quantized offset mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(expected_pair01 >> 4, pair01 >> 4) + << "quantized offset mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) + << "quantized scale mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) + << "quantized scale mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } + + // + // Now we just setup quantized weights tensor_q_weight, quantization scale tensor_scale + // and quantization offset tensor_offset. The above tests just make sure our setup is + // consistent with quantization tool output. + // + // Next we test the prepack code + // + + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(rows, columns / 2)); + onnxruntime::test::sm80_prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_w(q_weight_shape.product()); + MatrixRef tensor_packed_w( + packed_w, make_Position(rows, columns / 2)); + Base::prepack_weights(rows, columns, o_elements, packed_w); + + for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) + << "prepacked weights mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_scales_ref( + rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } + } + + std::vector packed_scales(meta_shape.product()); + MatrixRef tensor_packed_s( + packed_scales, meta_shape); + Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); + + for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) + << "prepacked scales mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + if (has_offset) { + std::vector packed_zp_ref(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_offsets_ref( + rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } + } + + std::vector packed_zp(meta_shape.product()); + MatrixRef tensor_packed_zp( + packed_zp, meta_shape); + Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); + + for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) + << "prepacked offsets mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } +} + +// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 +TEST(BlkQ4_GEMM, PrepackSm80Test) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); +} + +TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 32, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 32, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 192); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 192); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576); +} + +TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(64, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(64, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); +} + +TEST(BlkQ4_GEMM, Sm80SmallMTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + // // small m + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(16, 1024, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, false>(16, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, true>(16, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, true>(16, 1024, 576); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu new file mode 100644 index 0000000000000..69c929d446ce4 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -0,0 +1,344 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_testcu.cu + * + * Abstract: + * Test code for invoking block-wise quantized 4b GEMM kernels. + * This part requires CUTLASS header files, which do not play + * well with gtest headers. + */ + +#include +#include +#include + +#include "core/mickey/blk_q4/f16_gemm_sm80.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "core/common/common.h" + +#include "blkq4_fp16_gemm_sm80.h" + +namespace onnxruntime { +namespace cuda{ +namespace test{ + +Status sm80_supported(){ + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::ostringstream ss; + ss << "Unable to obtain GPU device properties: " << cudaGetErrorString(error); + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::ostringstream ss; + ss << "Device compute capability mismatch, desired 8.0, actual " << props.major << "." << props.minor; + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + return Status::OK(); +} + +/** + * @brief Reference implementation of GEMM + * Copied directly from cutlass util/reference/device/gemm.h + * for the strange reason that compiler insists on asking + * for explicit stream argument in kernel launch. +*/ +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType +> +void compute_gemm_ref( + cutlass::gemm::GemmCoord problem_size, + ScalarType alpha, + cutlass::TensorRef tensor_a, + cutlass::TensorRef tensor_b, + ScalarType beta, + cutlass::TensorRef tensor_c, + cutlass::TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = cutlass::MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + cutlass::reference::device::kernel::Gemm< + cutlass::TensorRef, + cutlass::TensorRef, + cutlass::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + cutlass::multiply_add, + cutlass::NumericConverter + ><<>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Converting cutlass tensor to MatrixRef +// + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_MatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + auto* ptr = const_cast::type *>(tensor.host_data()); + return MatrixRef(ptr, tensor.capacity(), shape); +} + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_ConstMatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + return MatrixRef(tensor.host_data(), tensor.capacity(), shape); +} + +// +// Invoking the kernel +// + +template< + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k) { + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); + + using ElementDequant = cutlass::half_t; + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + + using GemmRunner = BlkQ4F16GemmImpl; + + using ElementAccumulator = typename GemmRunner::ElementAccumulator; + using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; + using ElementInputA = typename GemmRunner::ElementInputA; + using ElementOutput = typename GemmRunner::ElementOutput; + using ElementW = typename GemmRunner::ElementW; + using ElementWPack = typename GemmRunner::ElementWPack; + using ElementQScale = typename GemmRunner::ElementQScale; + using ElementQOffset = typename GemmRunner::ElementQOffset; + + using LayoutInputA = typename GemmRunner::LayoutInputA; + using LayoutOutput = typename GemmRunner::LayoutOutput; + using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; + using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + const auto q_weight_shape = cutlass::make_Coord(problem_size.k()/2, problem_size.n()); + const auto meta_shape = cutlass::make_Coord(problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn); + + // + // Generate quantized and dequantizeed input matrix B [K, N] + // + static_assert(std::is_same::value); + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp); + + using PrepackT = onnxruntime::cuda::BlockwiseQuantization< + ElementDequant, + block_size, + 4, + column_wise_blocking>; + + std::vector packed_w(q_weight_shape.product()); + PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); + std::vector packed_scales(meta_shape.product()); + PrepackT::prepack_quant_scales(problem_size.k(), problem_size.n(), q_scales, packed_scales); + std::vector packed_zp; + if constexpr (has_offsets) { + packed_zp.resize(meta_shape.product()); + PrepackT::prepack_quant_offsets(problem_size.k(), problem_size.n(), q_zp, packed_zp); + } + + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 2); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + + // + // Copy data from host to GPU... + // + thrust::device_vector d_packed_w(packed_w); + cutlass::TensorRef ref_W( + reinterpret_cast(d_packed_w.data().get()), + LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); + + thrust::device_vector d_packed_scales(packed_scales); + cutlass::TensorRef ref_scales( + d_packed_scales.data().get(), LayoutInputQScale::packed(meta_shape)); + + thrust::device_vector d_packed_zp(packed_zp); + cutlass::TensorRef ref_zp( + d_packed_zp.data().get(), LayoutInputQScale::packed(meta_shape)); + + tensor_a.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + + // run GEMM + cutlass::Status status; + if constexpr (has_offsets){ + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + ref_scales, ref_zp, + tensor_c.device_ref(), tensor_d.device_ref()); + } else { + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + ref_scales, + tensor_c.device_ref(), tensor_d.device_ref()); + } + ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + + // Running reference kernel + using ElementInputB = ElementInputA; + using LayoutInputB = cutlass::layout::ColumnMajor; + thrust::device_vector d_dequants(dequants); + cutlass::TensorRef ref_B( + d_dequants.data().get(), LayoutInputB::packed(problem_size.kn())); + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + compute_gemm_ref( + problem_size, + alpha, + tensor_a.device_ref(), + ref_B, + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + ORT_ENFORCE(passed, "Gemm kernel result wrong!"); +} + +template void run_blkq4_gemm<16, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc deleted file mode 100644 index aba2b0b2cb4a4..0000000000000 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "core/framework/float16.h" -#include "core/mickey/blk_q4/prepack_sm80.h" -#include "core/mlas/inc/mlas_q4.h" - -#include "gtest/gtest.h" - -namespace onnxruntime { -namespace test { - -void prepack_weights_ref( - int rows, - int columns, - const MatrixRef& tensor_weight, - const MatrixRef& tensor_weight_prepacked) { - EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns); - EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2); - - auto t0_base = make_Position(0, 0); - auto t1_base = make_Position(4, 0); - auto t2_base = make_Position(0, 8); - auto t3_base = make_Position(4, 8); - for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { - for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { - // Packing from a 8x16 tile to a 16x8 tile - auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); - auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto cord = make_Position(row, col); - auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 - uint8_t buf[4]; - buf[0] = tensor_weight.at(dtile_base + t0_base + cord); - buf[1] = tensor_weight.at(dtile_base + t1_base + cord); - buf[2] = tensor_weight.at(dtile_base + t2_base + cord); - buf[3] = tensor_weight.at(dtile_base + t3_base + cord); - - // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights - // are in different b16 register at the same positions. This makes it easier to convert to - // fp16x2 format in a b32 register - - tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); - tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); - } - } - } - } -} - -template < - typename ScaleElementT, - typename Layout, - typename QuantBlocking> -void prepack_quant_scales_ref( - int rows, - int columns, - const MatrixRef& tensor_scale, - const MatrixRef& tensor_scale_prepacked) { - EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking"; - } -} - -template -void prepack_quant_offsets_ref( - size_t rows, - size_t columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking"; - } - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); - } - } - } - } -} - -template -void testPrepack(int rows, int columns, bool has_offset = true) { - using ElementT = MLFloat16; - constexpr int block_size = 32; - using Base = onnxruntime::cuda::BlockwiseQuantization< - ElementT, - block_size, - 4, - ColumnMajorQuantBlocking>; - - using QuantBlocking = typename Base::QuantBlocking; - using ElementW = typename Base::ElementW; - using LayoutWPack = typename Base::LayoutWPack; - using ElementQOffset = typename Base::ElementQOffset; - using LayoutQmeta = typename Base::LayoutQmeta; - - unsigned int seed = 28571; // Replace with desired seed value - std::seed_seq seq{seed}; - std::mt19937 gen(seq); - std::uniform_int_distribution<> dis(0, 8192); - - const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); - const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - - // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. - // - - std::vector q_weights(q_weight_shape.product()); - MatrixRef tensor_q_weight( - q_weights, make_Position(rows / 2, columns)); - int v = 7; - for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { - for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); - } - } - - std::vector q_scales(meta_shape.product()); - for (size_t i = 0; i < q_scales.size(); i++) { - q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); - } - MatrixRef tensor_scale( - q_scales, meta_shape); - - std::vector q_zp(meta_shape.product()); - for (size_t i = 0; i < q_zp.size(); i++) { - q_zp[i] = dis(gen) % 16; - } - MatrixRef tensor_offset( - q_zp, meta_shape); - -#if 0 // debug - // Fill tensor_q_weight with the patterned data, easier to debug with print - int loop_val = 0; - int offset = 3; - for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { - for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); - auto val = (loop_val + offset) % 256; - tensor_q_weight.at(weight_cord) = ElementW(val); - loop_val++; - if (loop_val == 256) { - loop_val = 0; - offset += 11; - } - } - } - } - } - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - int c = col * QuantBlocking::kColumn; - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - int r = row * QuantBlocking::kRow; - auto weight_cord = cutlass::make_Coord(r/2, c); - int w = 0; - if (r % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - tensor_scale.at({row, col}) = w; - tensor_offset.at({row, col}) = ElementQOffset(w); - } - } - - int fill_val = -512; - int factor = 1; - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - fill_val++; - if (fill_val == 512) { - fill_val = -512; - factor += 1; - } - } - } - -#endif // debug - - std::vector dequants(rows * columns); - MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); - - // Dequantize weights and save into matrix B for reference - for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { - for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { - auto weight_cord = make_Position(row / 2, col); - auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - float scale = float(tensor_scale.at(scale_cord)); - float dequant = scale * float(w - offset); - tensor_dequant.at(row, col) = ElementT(dequant); - // Prints for help debugging in case of test failure - // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); - } - } - - int q_rows, q_cols; - MlasBlockwiseQuantizedShape( - block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); - // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes - EXPECT_EQ(q_rows, q_weight_shape[0]); - EXPECT_EQ(q_cols, q_weight_shape[1]); - - // - // Quantization tool outputs: - // - std::vector o_elements(q_rows * q_cols); - MatrixRef tensor_o_elements(o_elements, q_weight_shape); - - std::vector o_scales(meta_shape.product()); - MatrixRef tensor_o_scales(o_scales, meta_shape); - - std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); - MatrixRef tensor_o_zp( - o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); - - MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, - tensor_dequant.data().data(), block_size, - ColumnMajorQuantBlocking, rows, columns, columns, nullptr); - for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { - for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { - EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) - << "quantized value mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - for (int col = 0; col < meta_shape[1]; ++col) { - for (int row = 0; row < meta_shape[0]; row += 2) { - if (has_offset) { - uint8_t pair01 = tensor_o_zp.at(row / 2, col); - EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) - << "quantized offset mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) - << "quantized offset mismatch at [" << row + 1 << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) - << "quantized scale mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) - << "quantized scale mismatch at [" << row + 1 << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - } - - // - // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, - // quantization scale tensor_scale and quantization offset tensor_offset. The above - // testing just make sure our test setup is consistent with quantization tool output. - // - // Next we test the prepack code - // - - std::vector packed_w_ref(q_weight_shape.product()); - MatrixRef tensor_packed_w_ref( - packed_w_ref, make_Position(rows, columns / 2)); - prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); - - std::vector packed_w(q_weight_shape.product()); - MatrixRef tensor_packed_w( - packed_w, make_Position(rows, columns / 2)); - Base::prepack_weights(rows, columns, o_elements, packed_w); - - for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) - << "prepacked weights mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - std::vector packed_scales_ref(meta_shape.product()); - MatrixRef tensor_packed_s_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) - : tensor_scale; - if (Base::ShouldRearrangeMeta) { - prepack_quant_scales_ref( - rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); - } - - std::vector packed_scales(meta_shape.product()); - MatrixRef tensor_packed_s( - packed_scales, meta_shape); - Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); - - for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) - << "prepacked scales mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - if (has_offset) { - std::vector packed_zp_ref(meta_shape.product()); - MatrixRef tensor_packed_zp_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) - : tensor_offset; - if (Base::ShouldRearrangeMeta) { - prepack_quant_offsets_ref( - rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); - } - - std::vector packed_zp(meta_shape.product()); - MatrixRef tensor_packed_zp( - packed_zp, meta_shape); - Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); - - for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) - << "prepacked offsets mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - } -} - -// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 -TEST(BlkQ4_GEMM, PrepackSm80Test) { - testPrepack(32, 32); - testPrepack(32, 32, false); - testPrepack(32, 32); - testPrepack(32, 32, false); - testPrepack(32, 64); - testPrepack(32, 128); - testPrepack(32, 256); - testPrepack(64, 32); - testPrepack(128, 32); - testPrepack(256, 32); - testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); - testPrepack(32, 64); - testPrepack(32, 128); - testPrepack(32, 256); - testPrepack(64, 32); - testPrepack(128, 32); - testPrepack(256, 32); - testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); -} - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 5505d689381c9..8dfaaedcbb378 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -29,7 +29,7 @@ TEST(TestDeferredRelease, WithArena) { AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; @@ -71,7 +71,7 @@ TEST(TestDeferredRelease, WithoutArena) { // For details, see CUDAPinnedAllocator in cuda_allocator.cc. // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; From 1e78bcea6011ac43093bb08a647cf3717d73047a Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 13:33:01 -0800 Subject: [PATCH 110/237] Implement CUDA IsInf-10,20 (#19772) ### Description Implment IsInf-10,20 for CUDA. Add FP16 types also on CPU. ### Motivation and Context Certain models lag in performance due to IsInf not available on CUDA. --- docs/OperatorKernels.md | 4 +- .../core/framework/data_types_internal.h | 2 +- .../core/providers/cpu/tensor/isinf.cc | 64 ++++++++++--- .../core/providers/cuda/cu_inc/common.cuh | 94 +++++++++++++++++++ onnxruntime/core/providers/cuda/cuda_common.h | 18 ++++ .../providers/cuda/cuda_execution_provider.cc | 5 + .../cuda/math/unary_elementwise_ops.cc | 38 ++++++++ .../cuda/math/unary_elementwise_ops.h | 12 +++ .../cuda/math/unary_elementwise_ops_impl.cu | 38 ++++++++ .../cuda/math/unary_elementwise_ops_impl.h | 15 +++ .../core/providers/rocm/cu_inc/common.cuh | 94 +++++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 9 ++ .../test/providers/cpu/tensor/isinf_test.cc | 42 +++++++++ 13 files changed, 420 insertions(+), 15 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 71b0def659741..4514a85531d6b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -160,7 +160,7 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| @@ -631,6 +631,8 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fbeee8a2aedc5..3a3b5cb6888f2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -305,7 +305,7 @@ class CallableDispatchableHelper { return 0; } - void CheckCalledOnce() { + void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index 1b449f46927a2..9d18d1fa62288 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( using IsInfTypesOpset20 = TypeList< float, - double + double, + MLFloat16, + BFloat16 #if !defined(DISABLE_FLOAT8_TYPES) , Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ @@ -76,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL( IsInf); IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { - Status status = info.GetAttr("detect_positive", &detect_positive_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); - status = info.GetAttr("detect_negative", &detect_negative_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + detect_positive_ = info.GetAttrOrDefault("detect_positive", 1); + detect_negative_ = info.GetAttrOrDefault("detect_negative", 1); opset_ = info.node().SinceVersion(); } @@ -87,29 +87,67 @@ namespace isinf_internal { template struct ComputeDispatchTarget { void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { - const auto total_items = X.Shape().Size(); + auto input_data = X.DataAsSpan(); auto output_data = Y.MutableData(); if (detect_positive && detect_negative) { EigenMap(Y) = EigenMap(X).array().isInf(); } else if (detect_positive) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == std::numeric_limits::infinity()); }); } else if (detect_negative) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == -std::numeric_limits::infinity()); }); } else { // all false - memset(output_data, false, onnxruntime::narrow(total_items)); + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); } } }; diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 66794f88d8670..bba9178348132 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -438,6 +438,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 41c999bacee13..61da125b40953 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -70,6 +70,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E4M3FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + template <> class ToCudaType { public: @@ -79,6 +88,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E5M2FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + #endif inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8ba282031a5d4..3c0930638a205 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); @@ -1342,6 +1343,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1739,6 +1741,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2250,6 +2254,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index fd8b69d7bd2f5..00de1b37f3302 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -71,6 +71,44 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa return Status::OK(); \ } +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsInf, + kOnnxDomain, + 10, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_OPERATOR_KERNEL_EX( + IsInf, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { + detect_positive_ = static_cast(info.GetAttrOrDefault("detect_positive", 1)); + detect_negative_ = static_cast(info.GetAttrOrDefault("detect_negative", 1)); + opset_ = info.node().SinceVersion(); +} + +Status IsInf::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_, + p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + return Status::OK(); +} + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 775b78c43a736..3b7d6df7221b7 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once + #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { @@ -119,5 +120,16 @@ class Sign final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +class IsInf final : public UnaryElementwise { + public: + explicit IsInf(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool detect_positive_{true}; + bool detect_negative_{true}; + int opset_; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 73c5ac80756be..fd8f7929d4426 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -11,6 +11,7 @@ #endif namespace onnxruntime { + namespace cuda { #define OP(name, expr) \ @@ -284,5 +285,42 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2) #endif +namespace isinf_details { +template +struct IsInf_DispFunc { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, + bool detect_positive, bool detect_negative, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + if (detect_positive && detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_positive) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } + } +}; + +} // namespace isinf_details + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count) { + if (op_set < 20) { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } else { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 608a81a24cf4f..a606d479bc79b 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -137,5 +137,20 @@ void Impl_CastSat( #endif +// IsInf + +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16 +#endif + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count); } // namespace cuda + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5f966ac746fcb..f3685606c17f5 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -335,6 +335,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef HIP_LONG diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 0265c06b9a938..4a679b790ee40 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -793,6 +793,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); @@ -1342,6 +1343,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, R class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -1738,6 +1742,8 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2294,6 +2300,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // opset 20 + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index 2e583c5d2547b..bd97306142f18 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -99,6 +99,48 @@ TEST(IsInfTest, test_isinf_negative_double20) { run_is_inf_test(20, 0, 1, input, output); } +TEST(IsInfTest, test_isinf_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + +TEST(IsInfTest, test_isinf_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + #if !defined(DISABLE_FLOAT8_TYPES) TEST(IsInfTest, test_Float8E4M3FN) { std::initializer_list input = { From d9730c7f43437070eba28d8dcdd9f94c102265ab Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:39:36 -0800 Subject: [PATCH 111/237] [TensorRT EP] Fix bug for DDS output handling for empty tensor (#19575) When the DDS output is empty tensor (i.e. any of the dimension is 0), TRT EP won't perform either cudaMemcpyAsync() nor cuda::Impl_Cast(), to prevent accidentally overwriting other location that might belong to other tensors. This PR also refactors the code to only allocate single bytes for all empty tensors. #TODO: add unit tests to cover the DDS code paths or doing more testing with concurrent,sequential, threaded faster-rcnn using onnx_test_runner and verifying outputs --------- Co-authored-by: Chi Lo --- cmake/deps.txt | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 465 ++++++------------ 2 files changed, 160 insertions(+), 309 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 9cba25b00157d..9630b6185fcf6 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -37,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 -#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 +#use the commit of Final DDS removal. DDS output is now supported by ORT TRT. +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bacfaaa951653cd4e72efe727a543567cb38f7de.zip;26434329612e804164ab7baa6ae629ada56c1b26 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 81346671f2aad..157cd0a200b35 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -717,6 +717,77 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + data = const_cast(input_tensor_ptr); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + buffers[output_name] = output_tensor_ptr; \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + +#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \ + } \ + break; \ + } + +#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } + /* * Set TensorRT execution context input. * @@ -737,6 +808,17 @@ Status BindContextInput(Ort::KernelContext& ctx, auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shapes = tensor_info.GetShape(); const auto tensor_type = tensor_info.GetElementType(); + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + const auto elem_cnt = tensor_info.GetElementCount(); if (trt_engine->isShapeInferenceIO(input_name)) { // Get the shape value of "shape tensor" @@ -765,113 +847,24 @@ Status BindContextInput(Ort::KernelContext& ctx, ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); } - // Bind "execution tensor" input buffers + + // Bind "execution tensor" input buffer + // + // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. + // Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte. + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors void* data = nullptr; switch (tensor_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - data = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(int32_t))); - data = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - data = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(float))); - data = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); - } - break; - } + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // Cast int64 input to int32 input because TensorRT doesn't support int64 + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + // Cast double input to float because TensorRT doesn't support double + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -884,7 +877,7 @@ Status BindContextInput(Ort::KernelContext& ctx, } /* - * Set TensorRT execution context output. + * Bind TensorRT execution context output. * * Please note that the "data-depedent shape" output needs corresponding allocator provided. * @@ -912,7 +905,6 @@ Status BindContextOutput(Ort::KernelContext& ctx, size_t i, std::unordered_map& output_tensors, std::unordered_map& output_dim_sizes, - std::unordered_set& dds_output_set, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, @@ -920,142 +912,47 @@ Status BindContextOutput(Ort::KernelContext& ctx, // Get output shape nvinfer1::Dims dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; - bool is_dds_output = false; + bool is_DDS = false; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { // data-dependent shape if (dims.d[j] == -1) { - is_dds_output = true; - dds_output_set.emplace(output_name); + is_DDS = true; break; } output_shapes[j] = dims.d[j]; } + auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) // // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. - if (is_dds_output) { - if (dds_output_allocator_map.find(output_name) == dds_output_allocator_map.end()) { + if (is_DDS || known_DDS) { + if (!known_DDS) { auto allocatorPtr = std::make_unique(); trt_context->setOutputAllocator(output_name, allocatorPtr.get()); dds_output_allocator_map[output_name] = std::move(allocatorPtr); - } else { - trt_context->setOutputAllocator(output_name, dds_output_allocator_map[output_name].get()); } } else { output_tensors[i] = ctx.GetOutput(output_index, output_shapes); auto& output_tensor = output_tensors[i]; + const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(1); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dims.d[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(1); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dims.d[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; - } - break; - } + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64 + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + // Allocate float CUDA memory for double output type because TensorRT doesn't support double + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -1068,10 +965,13 @@ Status BindContextOutput(Ort::KernelContext& ctx, } /* - * Set ORT kernel context Output. + * Bind ORT kernel context Output. * - * Note: In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime. + * In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime. * Once the output has been put in the allocation buffer, ORT calls this function to bind the allocation to ORT kernel context output. + * + * Note: Current approach of setting the ORT kernel context output is copying the output data from allocation buffer to ORT context output address which is not optimal, + * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, OrtMemoryInfo* mem_info, @@ -1083,93 +983,46 @@ Status BindKernelOutput(Ort::KernelContext& ctx, auto allocator = allocator_map[output_name].get(); auto& shape = allocator->getOutputShape(); auto output_tensor = ctx.GetOutput(output_index, shape); + + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(float), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint16_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(bool), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int8_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint8_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // The allocation buffer holds the INT32 output data since TRT doesn't support INT64 but INT32. - // So, we need to cast the data from INT32 to INT64 and then set INT64 output data to kernel context. - SafeInt output_dim_size(1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= shape[i]; - } - } - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), output_dim_size); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // The allocation buffer holds the FLOAT output data since TRT doesn't support DOUBLE but FLOAT. - // So, we need to cast the data from FLOAT to DOUBEL and then set DOUBLE output data to kernel context. - SafeInt output_dim_size(1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= shape[i]; - } - } - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), output_dim_size); - } - break; - } + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); return Status::OK(); } @@ -3513,7 +3366,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView output_tensors.reserve(num_outputs); std::unordered_map output_dim_sizes; output_dim_sizes.reserve(num_outputs); - std::unordered_set dds_output_set; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3531,7 +3383,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers); + dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3590,7 +3442,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView output_type = iter->second; } - if (dds_output_set.find(output_name) != dds_output_set.end()) { + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { size_t output_index = 0; const auto& index_iter = output_indexes.find(output_name); if (index_iter != output_indexes.end()) { @@ -3806,7 +3658,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_tensors.reserve(num_outputs); std::unordered_map output_dim_sizes; output_dim_sizes.reserve(num_outputs); - std::unordered_set dds_output_set; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3824,7 +3675,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con } Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers); + dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3883,7 +3734,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_type = iter->second; } - if (dds_output_set.find(output_name) != dds_output_set.end()) { + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { size_t output_index = 0; const auto& index_iter = output_indexes.find(output_name); if (index_iter != output_indexes.end()) { From d10256975527e8e041cedb19227cb5f207087c42 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Mar 2024 10:06:25 +0800 Subject: [PATCH 112/237] Fix seed for recomputed Dropout (#19715) ### Fix seed for recomputed Dropout If Dropout node is recomputed in the backward, we should make sure its execution is same as the run in the forward. If we don't set seed attribute, then this cannot be guaranteed. Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer recompute with compromised recomputable subgraphs. --- docs/Memory_Optimizer.md | 1 + docs/ORTModule_Training_Guidelines.md | 5 ++- onnxruntime/core/common/string_utils.h | 12 +++++++ .../memory_optimizer/memory_insight.cc | 6 +++- .../memory_optimizer/memory_optimizer.cc | 34 +++++++++++++++++-- .../memory_optimizer/memory_optimizer.h | 1 + .../ortmodule/_graph_execution_manager.py | 7 +++- .../training/ortmodule/_runtime_inspector.py | 34 ++++++++++++++----- .../python/training/ortmodule/options.py | 18 ++++++++-- 9 files changed, 101 insertions(+), 17 deletions(-) diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 97f7e7ff2c14b..eaa48c9da0609 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -51,6 +51,7 @@ There are two modes to enable the memory optimizations: - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` 3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. +4. By `export ORTMODULE_MEMORY_OPT_LEVEL=2`, all plans including compromised recomptable subgraphs will also be enabled. ### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 84631bd1f6555..54137937ad56d 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -287,7 +287,10 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_MEMORY_OPT_LEVEL - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. +- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. + - Setting the level to be 1 means all detected recomputable subgraphs (NOT including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - Setting the level to be 2 means all detected recomputable subgraphs (including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - When the level is 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. ```bash export ORTMODULE_MEMORY_OPT_LEVEL=0 diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index eca1221e84cb8..03e94cefd0564 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -65,5 +65,17 @@ inline std::string TrimString(std::string s) { return s; } +/** + * So use this simple hash to generate unique int by given string input. + */ +inline uint32_t GetHashFromString(const std::string& str_value) { + uint32_t hash = 0; + for (char const& c : str_value) { + hash = hash * 101 + c; + } + + return hash; +} + } // namespace utils } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 3fbdd5da7b768..08c402bf669c8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -9,6 +9,8 @@ #include #include +#include "core/common/string_utils.h" +#include "core/framework/random_seed.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" @@ -284,7 +286,9 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, memory_opt_planner.AddNodeOptimizationPlan(p_node, std::move(recompute_plan)); } - if (can_compromise_stashed_activation) { + // Only detect compromise recompute when recompute is not found, in case there are multiple recompute plans + // for the same named activations, then user might enable those conflicting recompute plans by mistakes. + if (recompute_plan == nullptr && can_compromise_stashed_activation) { MO_LOG_DEBUG_INFO(logger, "Searching Node " + p_node->Name() + "(" + p_node->OpType() + ") for compromised recompute"); // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 49e026ca86bd3..525e3b4b8de35 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -28,6 +28,29 @@ constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort, return op_order_in_topological_sort <= boundary_op_order_in_topological_sort; } +// Reset seed attribute for the dropout node if the seed is not set. +bool SetSeedForDropoutNode(Node& node) { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + // TODO(pengwa): add the opset check in GetAllowedRecomputeOps. + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {12, 13}, kOnnxDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskDropout", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasDropout", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskBiasDropout", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasSoftmaxDropout", {1}, kMSDomain)) { + auto& attrs = node.GetAttributes(); + if (attrs.count("seed")) { + return false; + } + + int64_t seed = static_cast(utils::GetHashFromString(node.OutputDefs()[0]->Name())) + + utils::GetRandomSeed(); + node.AddAttribute("seed", seed); + return true; + } + + return false; +} + } // namespace Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& memory_optimizer_config, @@ -74,7 +97,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, optimizer::memory_optimizer::NodeRecomputePlan* recompute_plan = dynamic_cast(node_plan.get()); ORT_ENFORCE(recompute_plan != nullptr); - ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), replacement_node_ptr).IsOK()); + ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), logger, replacement_node_ptr).IsOK()); } else { ORT_THROW("unsupported optimization type found."); } @@ -93,7 +116,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index()); // It is possible the consumer node is newly added as the recompute node, so we need a check here. - // For those kind of ops, we can treat them as backward ops. + // For those kinds of ops, we can treat them as backward ops. if (tid == node_index_to_its_order_in_topological_sort_map.end() || !IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first), boundary_op_order_in_topological_sort)) { @@ -223,6 +246,7 @@ void MemoryOptimizer::PrintSummary(const optimizer::memory_optimizer::MemoryOpti Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, const InlinedVector& nodes_in_topological_order, + const logging::Logger& logger, Node*& new_output_node_ptr) const { InlinedHashMap self_contained_outputs_map; for (size_t i = 0; i < nodes_in_topological_order.size(); ++i) { @@ -236,6 +260,12 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, continue; } + bool seed_reset = SetSeedForDropoutNode(*node_to_duplicate); + if (seed_reset) { + LOGS(logger, VERBOSE) << "Set seed for Node " << node_to_duplicate->Name() << "(" << node_to_duplicate->OpType() + << ")."; + } + InlinedVector new_input_args; new_input_args.reserve(node_to_duplicate->MutableInputDefs().size()); for (NodeArg* input_arg : node_to_duplicate->MutableInputDefs()) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h index b3e05fd334e48..1d837038e76c1 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h @@ -94,6 +94,7 @@ class MemoryOptimizer : public GraphTransformer { */ Status CreateRecomputeGraph(Graph& graph, const InlinedVector& nodes_in_topological_order, + const logging::Logger& logger, Node*& recompute_subgraph_output_node) const; /************************************************** diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index e189ffff9cc7f..c67b05758c5aa 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -754,6 +754,11 @@ def _add_record(tbl, columns): if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER" + elif ( + self._runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE + ): + opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER_WITH_COMPROMISE" else: opt_config_to_display = self._runtime_options.memory_optimizer_config @@ -766,7 +771,7 @@ def _add_record(tbl, columns): f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], " f"Optimization Config: [{opt_config_to_display}]" if len(self._runtime_options.memory_optimizer_config) > 0 - else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." + else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." ), ], ) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 772b9bd9e31ae..22e31466887a6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -545,7 +545,10 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r # If the memory optimization level is aggressive, we will first collect all # recompute subgraph by passing empty memory_optimizer_config to get_serialized_ortmodule_memory_stat. - if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + if runtime_options.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: memory_optimizer_config = "" ( @@ -581,16 +584,27 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r self.cluster_id_combination_to_saving_symbolics_map[cluster_id] = values # For aggressive memory optimization, we update the memory_optimizer_config using all. - if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + if runtime_options.memory_optimization_level > 0: recompute_configs = [] for cluster_id in self.cluster_id_combination_to_saving_symbolics_map: config_values = cluster_id.split(":") opt_type = int(config_values[1]) - # TODO(pengwa): use enum instead of 1 here. - if opt_type != 1: - continue - - recompute_configs.append(cluster_id) + if ( + runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE + and opt_type == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE + ): + recompute_configs.append(cluster_id) + elif ( + runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE + and opt_type + in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ] + ): + recompute_configs.append(cluster_id) runtime_options.memory_optimizer_config = ",".join(recompute_configs) @@ -699,14 +713,16 @@ def _get_user_config_without_freq(configs: str): notes = [] if details: notes.append( - "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1 to enable all recomputable subgraphs per transformer layer." + "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1/2 to enable all recomputable subgraphs per transformer layer." ) saving_recommendation = "[Memory Optimizer] Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n" saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." notes.append(saving_recommendation) - saving_recommendation = "memory saving is calculated based on the 1st batch symbolic dim values:\n" + saving_recommendation = ( + "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n" + ) for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," notes.append(saving_recommendation) diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 93d24a34df6bd..7263a5719e262 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -196,7 +196,10 @@ class _MemoryOptimizationLevel(IntFlag): """Enumeration to specify memory optimization level""" USER_SPECIFIED = 0 # Fully respect user-specified config - TRANSFORMER_LAYERWISE_RECOMPUTE = 1 # Enable all recomputable subgraphs per layer + TRANSFORMER_LAYERWISE_RECOMPUTE = ( + 1 # Enable all recomputable subgraphs (excluding compromised recomptable graphs) per layer + ) + TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE = 2 # Enable all recomputable subgraphs per layer @staticmethod def to_string(memory_optimization_level): @@ -206,6 +209,9 @@ def to_string(memory_optimization_level): if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: return "TRANSFORMER_LAYERWISE_RECOMPUTE" + if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE: + return "TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE" + return "" @@ -344,7 +350,10 @@ def _override_from_env_vars(self): self.memory_optimization_level = int(os.getenv("ORTMODULE_MEMORY_OPT_LEVEL", self.memory_optimization_level)) user_given_memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config) self.memory_optimizer_config = ",".join([c for c in user_given_memory_optimizer_config.split(",") if c]) - if self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + if self.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: # For transformer layer-wise recompute, we enable layer boundary when detecting subgraphs. # Then all detected subgraphs will not cross different layers. self.recompute_probe_config = "1:1" @@ -419,7 +428,10 @@ def memory_optimizer_is_enabled(self) -> bool: """Check whether memory optimizer is enabled.""" if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED: return len(self.memory_optimizer_config) > 0 - elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + elif self.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: return True return False From 1bfc26685b51522395e136a606005a72997e6bff Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 6 Mar 2024 10:11:46 +0800 Subject: [PATCH 113/237] ATen Op Supports Int Return Type and CPU Tensor Arguments (#19773) This PR: - add support for int as return type, will create a CPU scalar tensor for it. - add attributes to specify which arguments or returns are CPU tensors. - adjust ATen efficient attn to match latest PyTorch native function. - a Triton codegen bugfix by the way. --- .../cpu/aten_ops/aten_op_executor.h | 16 +- onnxruntime/core/framework/utils.cc | 24 ++- .../core/graph/contrib_ops/contrib_defs.cc | 2 + .../python/onnxruntime_pybind_state.cc | 10 +- .../aten_op_executor/__init__.py | 2 +- .../aten_op_executor/aten_op_executor.cc | 62 ++++--- .../ort_torch_ext/__init__.py | 4 +- .../python/training/ort_triton/_ir.py | 3 + .../ortmodule/graph_optimizers/__init__.py | 2 +- .../ortmodule/graph_optimizers/_aten_attn.py | 169 +++--------------- 10 files changed, 96 insertions(+), 198 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index d72868cd8fa9f..56c8e2911e280 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); +typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); - p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); + void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); + p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { - ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); + bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; + IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 23fe5e1cd3d96..b737d735b977b 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1015,9 +1015,19 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU, + // except those specified in attribute cpu_input_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1025,7 +1035,7 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true); } #else ORT_UNUSED_PARAMETER(node); @@ -1040,9 +1050,19 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU, + // except those specified in attribute cpu_output_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1050,7 +1070,7 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index f06a3785f362d..6709398c788f0 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3474,6 +3474,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 /*min_arity*/ 1) .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) + .Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false) + .Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9c36eb635ffcf..e5e0e81cb7da8 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1327,14 +1327,14 @@ void addGlobalMethods(py::module& m) { #ifdef ENABLE_ATEN m.def("register_aten_op_executor", - [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_cpu_argument_address_int, aten_op_executor_address_int; + [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_tensor_argument_address_int, aten_op_executor_address_int; ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int)); + ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_cpu_argument = reinterpret_cast(is_cpu_argument_address_int); + void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); }); #endif } diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 8bf7cbf80eb37..9dee6564509d5 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor _C.register_aten_op_executor( - str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 903a394a06ef3..e8be98cbfc0e4 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -34,18 +34,23 @@ struct ATenOperator { std::vector is_optional_arguments; std::vector> default_values; size_t return_size; + std::vector ret_kinds; c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const { TORCH_INTERNAL_ASSERT(index < argument_size); bool is_optional = is_optional_arguments[index]; - TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index]); + TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] || + elem_kinds[index] == c10::TypeKind::TensorType); if (!dlpack) { if (is_optional) { // Optional argument always has no default value. return c10::IValue(c10::nullopt); } - - return *default_values[index]; + if (default_values[index]) { + return *default_values[index]; + } + // Fow bw func, it's possible that input is an undefined tensor from fw outputs, dlpack is nullptr for such case. + return c10::IValue(at::Tensor()); } bool is_list = is_list_arguments[index]; @@ -142,7 +147,10 @@ class ATenOperatorCache { } aten_op.return_size = schema.returns().size(); for (const auto& ret : schema.returns()) { - TORCH_INTERNAL_ASSERT(ret.type()->kind() == c10::TypeKind::TensorType); + c10::TypeKind ret_type = ret.type()->kind(); + // Support tensor or int only for now. + TORCH_INTERNAL_ASSERT(ret_type == c10::TypeKind::TensorType || ret_type == c10::TypeKind::IntType); + aten_op.ret_kinds.emplace_back(ret_type); } ops_.emplace(key, aten_op); } @@ -154,32 +162,15 @@ class ATenOperatorCache { std::unordered_map, ATenOperator, PairHash> ops_; }; -const std::unordered_map> kCpuTensorInputsMap = { - {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}}; - -const std::unordered_map> kCpuTensorOutputsMap = { - {"_efficient_attention_forward", {2, 3}}}; - -// Backend uses this function to check if an argument is CPU input or not. -bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { +// Backend uses this function to check if an argument is tensor type or not. +bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { + const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); if (is_input) { - // If the argument is non-tensor type, it's CPU argument. - const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); - if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) { - return true; - } - } - - std::string full_name = std::string(op_name); - std::string overload_name_str = std::string(overload_name); - if (overload_name_str != "") { - full_name += ("." + overload_name_str); + return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; } - - const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap; - return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() && - cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end(); + TORCH_INTERNAL_ASSERT(index < aten_op.return_size); + return aten_op.ret_kinds[index] == c10::TypeKind::TensorType; } void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, @@ -216,16 +207,23 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t TORCH_INTERNAL_ASSERT(output_size == aten_op.return_size); size_t output_index = 0; for (const auto& ret : torch::jit::pop(stack, output_size)) { - const auto& tensor = ret.toTensor(); - dlpack_outputs[output_index++] = - tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; + if (ret.isTensor()) { + const auto& tensor = ret.toTensor(); + dlpack_outputs[output_index++] = + tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; + } else if (ret.isInt()) { + at::Tensor scalar = at::scalar_to_tensor(at::Scalar(ret.toInt())); + dlpack_outputs[output_index++] = at::toDLPack(scalar); + } else { + TORCH_INTERNAL_ASSERT(false); + } } } -size_t is_cpu_argument_address() { return reinterpret_cast(&IsCpuArgument); } +size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } size_t execute_aten_operator_address() { return reinterpret_cast(&ExecuteATenOperator); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check."); + m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check."); m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor"); } diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py index 329fba5aa670a..7d5716b85db30 100644 --- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -5,7 +5,7 @@ from onnxruntime.capi import _pybind_state as _C -from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address +from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address def run_once_aten_op_executor(f): @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): - _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address())) + _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) def init_aten_op_executor(): diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index a2b8407645c46..a963d30a9e6e7 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -392,5 +392,8 @@ def __init__( for ir_node in kernel.sub_nodes: if isinstance(ir_node, DropoutNode): ir_node.global_offset = running_offset + kernel.offset_calc.symbolic_shape_variables.update( + [symbol.name for symbol in running_offset.free_symbols] + ) running_offset = running_offset + sympy.prod(ir_node.outputs[0].shape) self.has_dropout = True diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py index 3d3538a62da61..368d1b238fd9e 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -13,7 +13,7 @@ if ( "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1 - and Version(torch.__version__) >= Version("2.1.1") + and Version(torch.__version__) >= Version("2.3.0") ): from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index b1e8809f03fc0..c1fb6e68568f5 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -5,9 +5,12 @@ """ PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation -is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to +is tested well on version 2.3.0.dev20240221+cu118, and should be run well since official version 2.3.0. If may fail to run is you are using PyTorch with older versions. +This file is more like an example of how to add a new graph optimizer. Ideally user can add graph optimizer according +to the specific model they are using on their own instead of putting every possible graph optimizer here. + PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add support if we want to try in the future. """ @@ -40,13 +43,14 @@ def _make_efficient_attention_nodes( scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) - int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) - true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) - false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) + one_node = make_constant_node("one_" + str(idx), TensorProto.INT64, [], [1]) + zero_node = make_constant_node("zero_" + str(idx), TensorProto.INT64, [], [0]) logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, []) seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, []) offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, []) - new_value_infos = [logsumexp, seed, offset] + msb_q = helper.make_tensor_value_info("msb_q_" + str(idx), TensorProto.INT64, []) + msb_k = helper.make_tensor_value_info("msb_k_" + str(idx), TensorProto.INT64, []) + new_value_infos = [logsumexp, seed, offset, msb_q, msb_k] if expand_bias: shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1) shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3) @@ -54,13 +58,13 @@ def _make_efficient_attention_nodes( shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2) concat = helper.make_node( "Concat", - ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)], + [shape_0.output[0], shape_1.output[0], shape_2.output[0], shape_3.output[0]], ["concated_shape_" + str(idx)], axis=0, ) - expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)]) + expand = helper.make_node("Expand", [bias, concat.output[0]], ["expanded_bias_" + str(idx)]) nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand]) - bias = "expanded_bias_" + str(idx) + bias = expand.output[0] fwd_node = helper.make_node( "ATen", [ @@ -71,18 +75,21 @@ def _make_efficient_attention_nodes( "", "", "", + "", dropout_ratio_node.output[0], causal_node.output[0], - true_node.output[0], + one_node.output[0], scale_node.output[0], "", "", ], - [y, logsumexp.name, seed.name, offset.name], + [y, logsumexp.name, seed.name, offset.name, msb_q.name, msb_k.name], "efficient_attention_forward_" + str(idx), None, "org.pytorch.aten", operator="_efficient_attention_forward", + cpu_input_args=[4, 5, 12, 13], + cpu_output_args=[2, 3, 4, 5], ) bwd_node = helper.make_node( "ATen", @@ -95,14 +102,14 @@ def _make_efficient_attention_nodes( y, "", "", - int_zero_node.output[0], - int_zero_node.output[0], + msb_q.name, + msb_k.name, logsumexp.name, dropout_ratio_node.output[0], seed.name, offset.name, causal_node.output[0], - false_node.output[0], + zero_node.output[0], scale_node.output[0], "", ], @@ -111,10 +118,9 @@ def _make_efficient_attention_nodes( None, "org.pytorch.aten", operator="_efficient_attention_backward", + cpu_input_args=[6, 7, 12, 13], ) - nodes_to_add.extend( - [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] - ) + nodes_to_add.extend([scale_node, dropout_ratio_node, causal_node, one_node, zero_node, fwd_node, bwd_node]) return nodes_to_add, new_value_infos @@ -240,140 +246,9 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro return nodes, nodes_to_add, new_value_infos -# No causal mask, no attention mask, without Dropout. -_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ - ("MatMul", False, []), # 0 - ("Mul", True, [(0, 0, 0)]), # 1 - ("Mul", True, [(0, 0, 1)]), # 2 - ("Transpose", True, [(1, 0, 0)]), # 3 - ("Transpose", True, [(2, 0, 0)]), # 4 - ("Softmax", False, [(0, 0, 0)]), # 5 - ("MatMul", False, [(5, 0, 0)]), # 6 - ("Transpose", True, [(6, 0, 1)]), # 7 - ("Transpose", False, [(6, 0, 0)]), # 8 - ("FusedMatMul", False, [(7, 0, 1)]), # 9 - ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 - ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11 - ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12 - ("Mul", False, [(11, 0, 0)]), # 13 - ("Mul", False, [(12, 0, 0)]), # 14 - ("Identity", False, [(13, 0, 0)]), # 15 - ("Identity", False, [(14, 0, 0)]), # 16 - ("Transpose", False, [(15, 0, 0)]), # 17 - ("Transpose", False, [(16, 0, 0)]), # 18 - ("FusedMatMul", False, [(5, 0, 0)]), # 19 - ("Transpose", True, [(19, 0, 1)]), # 20 - ("Transpose", False, [(19, 0, 0)]), # 21 -] - - -def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): - # Check forward only as the backward is expected to be consistent if it's built correctly. - scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) - scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 - scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) - scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 - if not ( - check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) - and scale_value_1 == scale_value_2 - ): - return [], [], [] - - nodes_to_add, new_value_infos = _make_efficient_attention_nodes( - idx, - nodes[3].input[0], - nodes[4].input[0], - nodes[7].input[0], - nodes[8].output[0], - nodes[20].input[0], - nodes[17].output[0], - nodes[18].output[0], - nodes[21].output[0], - "", - False, - scale_value_1, - 0.0, - False, - ) - return nodes, nodes_to_add, new_value_infos - - -# Has causal mask, no attention mask, without Dropout. -_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ - ("MatMul", False, []), # 0 - ("Mul", True, [(0, 0, 0)]), # 1 - ("Mul", True, [(0, 0, 1)]), # 2 - ("Transpose", True, [(1, 0, 0)]), # 3 - ("Transpose", True, [(2, 0, 0)]), # 4 - ("Add", False, [(0, 0, 0)]), # 5 - ("Slice", True, [(5, 0, 1)]), # 6 - ("Slice", True, [(6, 0, 0)]), # 7 - ("Unsqueeze", True, [(6, 0, 2)]), # 8 - ("Gather", True, [(8, 0, 0)]), # 9 - ("Shape", True, [(9, 0, 0)]), # 10 - ("Softmax", False, [(5, 0, 0)]), # 11 - ("MatMul", False, [(11, 0, 0)]), # 12 - ("Transpose", True, [(12, 0, 1)]), # 13 - ("Transpose", False, [(12, 0, 0)]), # 14 - ("FusedMatMul", False, [(13, 0, 1)]), # 15 - ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16 - ("Identity", False, [(16, 0, 0)]), # 17 - ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18 - ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19 - ("Mul", False, [(18, 0, 0)]), # 20 - ("Mul", False, [(19, 0, 0)]), # 21 - ("Identity", False, [(20, 0, 0)]), # 22 - ("Identity", False, [(21, 0, 0)]), # 23 - ("Transpose", False, [(22, 0, 0)]), # 24 - ("Transpose", False, [(23, 0, 0)]), # 25 - ("FusedMatMul", False, [(11, 0, 0)]), # 26 - ("Transpose", True, [(26, 0, 1)]), # 27 - ("Transpose", False, [(26, 0, 0)]), # 28 -] - - -def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): - # Check forward only as the backward is expected to be consistent if it's built correctly. - scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) - scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 - scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) - scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 - if not ( - check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3]) - and scale_value_1 == scale_value_2 - ): - return [], [], [] - - nodes_to_add, new_value_infos = _make_efficient_attention_nodes( - idx, - nodes[3].input[0], - nodes[4].input[0], - nodes[13].input[0], - nodes[14].output[0], - nodes[27].input[0], - nodes[24].output[0], - nodes[25].output[0], - nodes[28].output[0], - "", - False, - scale_value_1, - 0.0, - True, - ) - return nodes, nodes_to_add, new_value_infos - - _PATTERNS = [ (_PATTERN_0, _optimize_for_pattern_0), (_PATTERN_1, _optimize_for_pattern_1), - (_PATTERN_2, _optimize_for_pattern_2), - (_PATTERN_3, _optimize_for_pattern_3), ] From a788514027c3a6ee5f284c965ccffcb8805302a5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 5 Mar 2024 18:27:26 -0800 Subject: [PATCH 114/237] [js/web] dump debug logs for karma for diagnose purpose (#19785) ### Description dump debug logs for karma for diagnose purpose. This is for debugging the CI issue of Chrome launch failure and considered temporary. --- js/web/script/test-runner-cli.ts | 3 +++ .../github/azure-pipelines/templates/win-web-ci.yml | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 59bd0d5f6313a..ace64e9532b12 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -569,6 +569,9 @@ async function main() { if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } + if (process.argv.includes('--karma-debug')) { + karmaArgs.push('--log-level debug'); + } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index b882d6fb167fd..9553bc1bc3547 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -153,31 +153,31 @@ jobs: errorActionPreference: stop displayName: 'Pack NPM packages' - script: | - npm test -- -e=chrome -b=webgl,wasm + npm test -- -e=chrome -b=webgl,wasm --karma-debug workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl backend)' condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=chrome -b=webgl,wasm,webgpu $(webgpuCommandlineExtraFlags) + npm test -- -e=chrome -b=webgl,wasm,webgpu --karma-debug $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor --karma-debug $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location --karma-debug $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome + npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome --karma-debug workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebGL: packed mode' - script: | - npm test -- --wasm-enable-proxy -b=wasm -e=chrome + npm test -- --wasm-enable-proxy -b=wasm -e=chrome --karma-debug workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebAssembly: proxy' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) From db59cec82f226dbba3ce7c5b03db35b0fe07fb60 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 6 Mar 2024 15:03:55 +1000 Subject: [PATCH 115/237] Don't reduce warning level for CUDA build on Windows (#19663) ### Description Address warnings so all the ORT projects build with /W4 on Windows. Mainly - unused parameters - variables shadowing other ones ### Motivation and Context #19588 started on this. --- cmake/CMakeLists.txt | 6 +-- cmake/onnxruntime_providers_cuda.cmake | 13 ++++- .../core/providers/cuda/cuda_context.h | 2 +- .../cuda/bert/add_bias_transpose.cu | 10 ++-- .../contrib_ops/cuda/bert/attention_impl.cu | 20 +++---- .../cuda/bert/attention_prepare_qkv.cu | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 8 +-- .../cuda/bert/decoder_attention_impl.cu | 2 +- .../cuda/bert/group_query_attention_impl.cu | 4 +- .../cuda/bert/packed_attention_impl.cu | 2 +- .../bert/packed_multihead_attention_impl.cu | 4 +- .../contrib_ops/cuda/bert/rotary_embedding.cc | 2 - .../cuda/bert/rotary_embedding_impl.cu | 2 +- .../mha_runner.cu | 54 +++++++++---------- .../cuda/diffusion/group_norm_common_base.h | 6 +-- onnxruntime/contrib_ops/cuda/inverse.cc | 8 +-- .../contrib_ops/cuda/math/complex_mul_impl.cu | 4 +- .../contrib_ops/cuda/math/gemm_float8.cu | 2 +- .../cuda/moe/ft_moe/moe_cutlass_kernel.h | 2 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 29 ++++++---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 4 +- .../cuda/moe/ft_moe/moe_problem_visitor.h | 8 +-- .../quantization/attention_quantization.cc | 2 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_ops/qordered_attention_impl.cu | 2 +- .../qordered_ops/qordered_qdq_impl.cu | 2 +- .../cuda/transformers/generation_cuda_impl.cu | 17 ++++-- .../providers/cuda/cuda_execution_provider.h | 20 +++---- .../core/providers/cuda/cudnn_common.cc | 1 - .../cuda/math/unary_elementwise_ops_impl.cu | 7 +-- onnxruntime/core/providers/cuda/nn/conv.cc | 20 ++++--- onnxruntime/core/providers/cuda/nn/conv.h | 2 +- .../core/providers/cuda/nn/layer_norm.h | 2 - .../core/providers/cuda/nn/layer_norm_impl.cu | 2 - .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 1 - .../cuda/tensor/gelu_approximate_impl.cu | 6 +-- .../cuda/tensor/resize_antialias_impl.cu | 20 +++---- .../core/providers/cuda/tensor/resize_impl.cu | 2 +- .../providers/cuda/tensor/transpose_impl.cu | 6 +-- .../core/providers/cuda/triton_kernel.cu | 50 ++++++++++------- .../core/providers/tensorrt/nv_includes.h | 20 +++++++ .../tensorrt/onnx_ctx_model_helper.h | 2 +- .../tensorrt/tensorrt_execution_provider.cc | 48 ++++++++++------- .../tensorrt/tensorrt_execution_provider.h | 5 +- .../tensorrt_execution_provider_custom_ops.cc | 5 +- .../tensorrt_execution_provider_custom_ops.h | 23 +++++--- ...oder_masked_multihead_attention_op_test.cc | 12 ++--- .../providers/cpu/generator/random_test.cc | 4 +- onnxruntime/test/unittest_main/test_main.cc | 17 +++++- .../training_ops/cuda/cross_entropy_test.cc | 10 ++-- .../training_ops/cuda/nn/conv_shared.cc | 11 ++-- .../cuda/nn/conv_transpose_grad.cc | 2 - .../training_ops/cuda/nn/layer_norm_impl.cu | 2 - .../training_ops/cuda/optimizer/lamb_impl.cu | 2 +- .../templates/jobs/win-ci-prebuild-steps.yml | 11 +++- 55 files changed, 315 insertions(+), 219 deletions(-) create mode 100644 onnxruntime/core/providers/tensorrt/nv_includes.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0d55d4cab9826..3f919d7bf6e18 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1274,11 +1274,7 @@ endif() #Dependencies end. In the next we'll enable "treat warning as error" #Adjust warning flags -if (onnxruntime_USE_CUDA) - set_msvc_c_cpp_compiler_warning_level(3) -else() - set_msvc_c_cpp_compiler_warning_level(4) -endif() +set_msvc_c_cpp_compiler_warning_level(4) set(onnxruntime_DELAYLOAD_FLAGS "") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 7f295a59a0931..aeeac10ead27d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -141,18 +141,22 @@ if (HAS_GUARD_CF) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") endif() + if (HAS_QSPECTRE) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") endif() + foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") endforeach() + # CUDA 11.3+ supports parallel compilation # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() + if (UNIX) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" "$<$>:-Wno-reorder>") @@ -162,6 +166,13 @@ #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") + if (MSVC) + # the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions + # e.g. `if (std::is_same_v && not_a_const)` will generate the warning even though constexpr cannot + # be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer + # granularity. + target_compile_options(${target} PRIVATE "$<$:/wd4127>") + endif() endif() onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) @@ -187,7 +198,7 @@ target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) endif() endif() - + if (onnxruntime_USE_TRITON_KERNEL) # compile triton kernel, generate .a and .h files include(onnxruntime_compile_triton_kernel.cmake) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 108173474db46..7104e70c3a8a9 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -58,7 +58,7 @@ struct CudaContext : public CustomOpContext { template T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { - if (sizeof(T) > sizeof(void*)) { + if constexpr (sizeof(T) > sizeof(void*)) { ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT); } const auto& ort_api = Ort::GetApi(); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 1ea2540db486f..9e6752b451868 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -843,11 +843,11 @@ void InvokeAddBiasTransposeTrt( template <> void LaunchAddBiasTransposeTrt( - cudaStream_t stream, const int max_threads_per_block, - const int batch_size, const int sequence_length, - const int num_heads, const int head_size, - const float* biases, const float* query, const float* key, const float* value, float* output, - bool is_cross_attention, int kv_sequence_length) { + cudaStream_t /*stream*/, const int /*max_threads_per_block*/, + const int /*batch_size*/, const int /*sequence_length*/, + const int /*num_heads*/, const int /*head_size*/, + const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/, + bool /*is_cross_attention*/, int /*kv_sequence_length*/) { ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index c20f42c4d06bc..a93fdf74dc28c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { - if (this->sequence_length != sequence_length) { +void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { + if (this->sequence_length != seq_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, sequence_length, stream); - this->sequence_length = sequence_length; + this->max_batch_size, seq_length, stream); + this->sequence_length = seq_length; } } @@ -213,9 +213,9 @@ Status FusedTrtCrossAttention( template <> Status FusedTrtCrossAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused cross attention does not support float tensor"); } @@ -276,9 +276,9 @@ Status FusedTrtSelfAttention( // Template Specialization for float type template <> Status FusedTrtSelfAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused attention does not support float tensor"); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a513d9e8d2211..b843966d88e85 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index db78722cc0e4c..c12cb374d9adf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { using AlignedAK = AttentionKernel; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) -#pragma warning(disable : 6287) +#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect #endif // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { LaunchCutlassFmha(params); })); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index e24d9da94c964..c0b1996789183 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -17,7 +17,7 @@ Status DecoderQkvToContext( const cudaDeviceProp& device_prop, Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, + const size_t /*element_size*/, const int batch_size, const int sequence_length, const int kv_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index d88e9a49fb5ee..cb5631542c113 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int /*threads_per_block*/) { if (parameters.is_prompt) { return Status::OK(); } @@ -655,7 +655,7 @@ Status EfficientAttention( template Status QkvToContext( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, + cublasHandle_t& /*cublas*/, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ce7ac3796dbe1..a84a310b46ca0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding( template Status FusedScaledDotProductAttention( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 49029da12a308..982c7eaa2cb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -381,7 +381,7 @@ void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, - AttentionQkvFormat source_format, AttentionQkvFormat target_format, + [[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format, const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { if (key != nullptr && value != nullptr) { @@ -551,7 +551,7 @@ void LaunchTranspose( template Status FusedAttentionTrt( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 9de7ba3885c3c..ab7479f2938fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -82,8 +82,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { interleaved, device_prop.maxThreadsPerBlock, parameters.transposed); - - return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c6637041f05bd..3a14161f29e9f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -93,7 +93,7 @@ Status LaunchRotaryEmbeddingKernel( const int num_heads, const int head_size, const int rotary_embedding_dim, - const int max_sequence_length, + const int /*max_sequence_length*/, const int position_ids_format, const bool interleaved, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 8fb6575d27cc0..4a4e3eeecf642 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl { ~mhaImpl() {} - void setup(const int S, const int B) { + void setup(const int seq_len, const int B) { // For bert and vit, use flash attention when sequence length is larger than the threshold. - use_flash_attention = is_flash_attention(S); + use_flash_attention = is_flash_attention(seq_len); params.force_unroll = use_flash_attention; @@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl { warps_n = 1; } else { if (sm == 70) { - if (S == 64 || S == 96) { + if (seq_len == 64 || seq_len == 96) { warps_m = 2; warps_n = 2; - } else if (S == 128) { + } else if (seq_len == 128) { warps_m = 1; warps_n = 4; - } else if (S == 256 || S == 384) { + } else if (seq_len == 256 || seq_len == 384) { warps_m = 1; warps_n = 8; } else { ORT_ENFORCE(false, "Unsupported sequence length"); } } else { - if (S == 32 || S == 64 || S == 96 || S == 128) { + if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) { warps_m = 2; warps_n = 2; - } else if (S == 192 || S == 256) { + } else if (seq_len == 192 || seq_len == 256) { warps_m = 1; warps_n = 4; - } else if (S == 384) { + } else if (seq_len == 384) { warps_m = 1; warps_n = 8; } else { @@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl { // The number of threads per CTA. threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. - xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); + xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m); const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 @@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl { has_causal_mask = false; } - void setup_causal_masked_fmha(const int S, const int B) { + void setup_causal_masked_fmha(const int seq_len, const int B) { const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl { return max_seq_len; } - int S = max_seq_len; + int seq_len = max_seq_len; if (max_seq_len <= 32) { - S = (sm == 70) ? 64 : 32; + seq_len = (sm == 70) ? 64 : 32; } else if (max_seq_len <= 64) { - S = 64; + seq_len = 64; } else if (max_seq_len <= 96) { - S = 96; + seq_len = 96; } else if (max_seq_len <= 128) { - S = 128; + seq_len = 128; } else if (max_seq_len <= 192) { - S = (sm == 70) ? 256 : 192; + seq_len = (sm == 70) ? 256 : 192; } else if (max_seq_len <= 256) { - S = 256; + seq_len = 256; } else if (max_seq_len <= 384) { - S = 384; + seq_len = 384; } - return S; + return seq_len; } protected: - bool is_flash_attention(const int S) const { + bool is_flash_attention(const int seq_len) const { ORT_ENFORCE(interface->mHasCausalMask == false); - return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention; + return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention; } private: @@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, pimpl(new mhaImpl(this)) { } -void FusedMHARunnerFP16v2::setup(const int S, const int B) { - MHARunner::setup(S, B); +void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) { + MHARunner::setup(seq_len, B); if (mHasCausalMask) { - pimpl->setup_causal_masked_fmha(S, B); + pimpl->setup_causal_masked_fmha(seq_len, B); } else { - pimpl->setup(S, B); + pimpl->setup(seq_len, B); } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h index ea87d0c29111e..a80584d3293a0 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -136,10 +136,10 @@ struct GroupNormNHWCParams { bool use_silu, bool broadcast_skip, int channels_per_block) { - int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_group_in = num_channels / num_groups; // channels_per_block is computed in PrePack. // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { + if (channels_per_block < channels_per_group_in) { channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } @@ -167,7 +167,7 @@ struct GroupNormNHWCParams { this->hw_per_block = DivUp(this->hw, blocks_per_hw); this->channels_per_block = channels_per_block; - this->channels_per_group = channels_per_group; + this->channels_per_group = channels_per_group_in; this->hwc = this->hw * this->c; this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); this->groups_per_block = channels_per_block / this->channels_per_group; diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 81e161e60642c..9075dda26f86b 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -78,9 +78,9 @@ struct Inverse::ComputeImpl { cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // Make a copy of the input which will serve as a workspace as well. - if (std::is_same::value || std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(input_count, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { // Convert from MLFloat16(half) to float Impl_Cast(stream, reinterpret_cast(input.Data()), input_workspace.get(), input_count); } else { @@ -96,7 +96,7 @@ struct Inverse::ComputeImpl { // Need to compute ptrs for output buffers // Output for MLFloat IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { IAllocatorUniquePtr ml_float_output = inst->GetScratchBuffer(input_count, ort_stream); ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, ml_float_output.get(), num_batches, rows, output_ptrs)); // Do the inverse @@ -112,7 +112,7 @@ struct Inverse::ComputeImpl { ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches)); // We are done here } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(static_cast(input_count), ort_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(double) * input_count, cudaMemcpyDeviceToDevice, stream)); diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu index ca94477114ee2..47a64502b3480 100644 --- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu @@ -97,8 +97,8 @@ void ComplexMul_Impl( const TArray* rhs_padded_strides, const T* rhs_data, const TArray* fdm_output_strides, - const onnxruntime::cuda::fast_divmod& fdm_H, - const onnxruntime::cuda::fast_divmod& fdm_C, + const onnxruntime::cuda::fast_divmod& /*fdm_H*/, + const onnxruntime::cuda::fast_divmod& /*fdm_C*/, T* output_data, int64_t count, int64_t lhs_size, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 064b6dd392437..28ab27ee33d10 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -174,7 +174,7 @@ Status GemmFloat8::ComputeGemm( int32_t dtype_A, int32_t dtype_B, int32_t dtype_C, int32_t dtype_Y, const TensorShape& shape_A, const TensorShape& shape_B, - const TensorShape& shape_C, const TensorShape& shape_Y, + const TensorShape& shape_C, const TensorShape& /*shape_Y*/, bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, const void* p_input_c, const void* p_scale_a, const void* p_scale_b, const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h index bfe30b71170d8..cfe306c2482a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -202,7 +202,7 @@ struct MoeFCGemm { total_rows_before_expert(total_rows_before_expert), gemm_n(gemm_n), gemm_k(gemm_k), - host_problem_sizes(nullptr) { + host_problem_sizes(host_problem_sizes) { if (platform::is_same::value || platform::is_same::value) { assert(weight_scales); } diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 66950c9b65970..a3dcf0da16b98 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -20,6 +20,12 @@ #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" #include "cutlass/layout/matrix.h" @@ -36,6 +42,10 @@ #include "layout_traits_helper.h" #include "moe_cutlass_kernel.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif @@ -149,10 +159,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/, + T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, + int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, + cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); @@ -221,9 +231,10 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, - int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + int64_t* total_rows_before_expert, int64_t /*total_rows*/, + int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, + int /*sm_version*/, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, @@ -300,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig template ::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index f4f2b49032d23..a5b47bcddefbc 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -370,7 +370,7 @@ struct TopkConstants { template void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, - int num_rows, int num_experts, int k, cudaStream_t stream) { + int num_rows, int /*num_experts*/, int k, cudaStream_t stream) { static constexpr unsigned long MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); @@ -599,7 +599,7 @@ void CutlassMoeFCRunner::run_moe_fc( static constexpr bool scales_required = std::is_same::value || std::is_same::value; - if (scales_required) { + if constexpr (scales_required) { if (fc1_scales == nullptr) { ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); } else if (fc2_scales == nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h index 00f977c615df6..1de8f6b69642c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -276,13 +276,13 @@ struct MoeProblemVisitor::ComputeInternal(OpKernelContext* context) const { CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); CudaT weight_scale = *(reinterpret_cast(weight_scale_tensor->Data())); - if (sizeof(T) == 2) { + if constexpr (sizeof(T) == 2) { dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale)); } else { dequant_scale = input_scale * weight_scale; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3cecebedae2f0..12835978536e1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -142,7 +142,7 @@ inline void debug_print([[maybe_unused]] const T* arr, std::cout << "========" << name << std::endl; for (size_t i = 0; i < sz; i++) { if (i % w == 0) std::cout << std::endl; - if (std::is_same().value) { + if constepxr (std::is_same::value) { std::cout << (int)buf[i] << ", "; } else { std::cout << buf[i] << ", "; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu index f4d5a7b404a62..fd4b51f40fb4f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu @@ -151,7 +151,7 @@ QOrderBatchInt8MatrixTransposeKernel(const int8_t* src, const int8_t* dst, const } } -Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int batch_size, const int rows, const int cols, const int8_t* input, int8_t* output) { ORT_ENFORCE(rows % 4 == 0 && cols % 4 == 0, "Matrix rows and cols must be divisible by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu index baff8e76ec73b..e6ac0bc8a5171 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu @@ -389,7 +389,7 @@ QOrderDequantizeKernel_Strict(const int8_t* __restrict__ src, const __half* __re } } -Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int8_t* src, __half* dst, float scale, size_t N) { ORT_RETURN_IF(N & 0x3LL, "N can not divide by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index a39abefed9cd0..eb1943b59d976 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1,11 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4706) +#endif +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include -#include + #include "contrib_ops/cuda/bert/utils.cuh" #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 5f62f313b86a2..75fe1dff7c4a4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -131,41 +131,33 @@ class CUDAExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, cudaStream_t stream) { - constexpr bool is_float = std::is_same::value; - constexpr bool is_double = std::is_same::value; - constexpr bool is_half = std::is_same::value; - constexpr bool is_BFloat16 = std::is_same::value; -#if !defined(DISABLE_FLOAT8_TYPES) - constexpr bool is_Float8E4M3FN = std::is_same::value; - constexpr bool is_Float8E5M2 = std::is_same::value; -#endif - if (is_float) { + if constexpr (std::is_same::value) { if (!constant_ones_float_) { constant_ones_float_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (is_double) { + } else if constexpr (std::is_same::value) { if (!constant_ones_double_) { constant_ones_double_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (is_half) { + } else if constexpr (std::is_same::value) { if (!constant_ones_half_) { constant_ones_half_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); - } else if (is_BFloat16) { + } else if constexpr (std::is_same::value) { if (!constant_ones_bfloat16_) { constant_ones_bfloat16_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); #if !defined(DISABLE_FLOAT8_TYPES) - } else if (is_Float8E4M3FN) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e4m3fn_) { constant_ones_float8e4m3fn_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float8e4m3fn_->GetBuffer(stream, count)); - } else if (is_Float8E5M2) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e5m2_) { constant_ones_float8e5m2_ = cuda::CreateConstantOnes(); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index c850f7b583bfc..39b73163794f0 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -160,7 +160,6 @@ cudnnDataType_t CudnnTensor::GetDataType() { template <> cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN doesn't support BFloat16."); - return CUDNN_DATA_FLOAT; } template <> diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index fd8f7929d4426..554d5908cf854 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -127,9 +127,10 @@ struct OP_Cast { UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ - void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index a417be5a86c32..e05786248cbcf 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -97,11 +97,11 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if (NHWC && is_nhwc_domain_) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); @@ -123,6 +123,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; } + } else { + ORT_UNUSED_PARAMETER(tensor); + ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(alloc); } return Status::OK(); @@ -149,8 +153,11 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; - if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + if constexpr (channels_last) { + if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } } // set B @@ -403,7 +410,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) default: perf.algo = kDefaultConvAlgo; CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); - if (std::is_same::value) { + + if constexpr (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; } else if (std::is_same::value && !UseTF32()) { perf.mathType = CUDNN_FMA_MATH; diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 181fbc99fd8e9..3aec654224e39 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -195,7 +195,7 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; + bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.h b/onnxruntime/core/providers/cuda/nn/layer_norm.h index ff231f4f1ad5c..c021d3ffe63a2 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.h @@ -7,8 +7,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - // NOTE: This was originally a contrib op with 3 type constraints. The ONNX spec merges 'T' and 'V'. // the kernel is templatized on all three for backwards compatibility, but in ONNX usage T == V. template diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 679b8b6b78886..b9e8b45307079 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -29,8 +29,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - template __device__ void cuWelfordOnlineSum( const U curr, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index b61b104790fe5..6476364a211fd 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -305,7 +305,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); } diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu index 3292650584de8..7a27b7af33137 100644 --- a/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu @@ -62,7 +62,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -73,7 +73,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const double* input, const double* bias, double* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -108,7 +108,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu index 56b7c3f499303..d56e4bc53874d 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -680,10 +680,10 @@ template void ResizeTrilinearUpsample( cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, @@ -832,11 +832,11 @@ void ResizeTrilinearUpsample( template void ResizeBiLinearUpsample(cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, - int64_t batch_size, int64_t num_channels, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, + int64_t /*batch_size*/, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, std::tuple inferred_dim_rscales, @@ -959,10 +959,10 @@ void ResizeBiLinearUpsample(cudaStream_t stream, template void ResizeBicubicUpsample(cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 0cde0ed8e8681..e788f24052985 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -609,7 +609,7 @@ void ResizeNearestImpl( const size_t N, bool extrapolation_enabled, const T extrapolation_value, - float cubic_coeff_a, + float /*cubic_coeff_a*/, ResizeCoordinateTransformationMode transform_coordinate, ResizeNearestMode calc_nearest_pixel, int64_t* /* prefix_dim_sum */, diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 9f9c365d2a53d..6344845359b32 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -80,7 +80,7 @@ bool CanDoTranspose3D(const cudaDeviceProp& prop, size_t rank, const gsl::span& input_shape, - const TArray& input_strides, const void* input_data, void* output_data, int64_t N, + const TArray& input_strides, const void* input_data, void* output_data, int64_t /*N*/, const dim3& grid_size, const dim3& block_size) { switch (element_size) { HANDLE_TRANSPOSE_3D_TILE_DIM(int8_t); @@ -248,10 +248,10 @@ __global__ void Transpose4DKernelParallelizeOneElementPerThread( } bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop, - size_t element_size, + size_t /*element_size*/, int32_t rank, const gsl::span& input_dims, - const gsl::span& permutations, + const gsl::span& /*permutations*/, dim3& grid_size, dim3& block_size) { if (rank == 4) { // dims[3]: block.x diff --git a/onnxruntime/core/providers/cuda/triton_kernel.cu b/onnxruntime/core/providers/cuda/triton_kernel.cu index 6ffbf0420a15f..b42dbd0291b7a 100644 --- a/onnxruntime/core/providers/cuda/triton_kernel.cu +++ b/onnxruntime/core/providers/cuda/triton_kernel.cu @@ -130,27 +130,11 @@ void LoadOrtTritonKernel() { std::call_once(load_ort_triton_kernel_flag, TryToLoadKernel); } -Status LaunchTritonKernel(cudaStream_t stream, std::string fname, - int grid0, int grid1, int grid2, void* args, size_t args_size) { -#ifdef USE_TRITON_KERNEL - if (ort_triton_kernel_map.count(fname) == 0) { - // Return unsupported status if function name not found in registry. - // This error status will be used by TunableOp - std::ostringstream message_stream; - message_stream << "Can't find ort triton kernel name: " << fname; - std::string message = message_stream.str(); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); - } - auto idx = ort_triton_kernel_map[fname]; - return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); -#else - return Status::OK(); -#endif -} -Status LaunchTritonKernel(cudaStream_t stream, size_t idx, - int grid0, int grid1, int grid2, void* args, size_t args_size) { + #ifdef USE_TRITON_KERNEL +Status LaunchTritonKernel(cudaStream_t stream, size_t idx, int grid0, int grid1, int grid2, + void* args, size_t args_size) { if (idx >= ort_triton_kernel_metadata.size()) { // Return unsupported status when idx exceeds the size of ort_triton_kernel_metadata. // This error status will be used by TunableOp @@ -181,11 +165,37 @@ Status LaunchTritonKernel(cudaStream_t stream, size_t idx, nullptr, (void**)&config), "Launching kernel failed."); -#endif return Status::OK(); } +Status LaunchTritonKernel(cudaStream_t stream, std::string fname, int grid0, int grid1, int grid2, + void* args, size_t args_size) { + if (ort_triton_kernel_map.count(fname) == 0) { + // Return unsupported status if function name not found in registry. + // This error status will be used by TunableOp + std::ostringstream message_stream; + message_stream << "Can't find ort triton kernel name: " << fname; + std::string message = message_stream.str(); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); + } + auto idx = ort_triton_kernel_map[fname]; + return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); +} + +#else +Status LaunchTritonKernel(cudaStream_t /*stream*/, std::string /*fname*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} + +Status LaunchTritonKernel(cudaStream_t /*stream*/, size_t /*idx*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} +#endif + + const TritonKernelMetaData* GetOrtTritonKernelMetadata(size_t idx) { if (idx >= ort_triton_kernel_metadata.size()) { return nullptr; diff --git a/onnxruntime/core/providers/tensorrt/nv_includes.h b/onnxruntime/core/providers/tensorrt/nv_includes.h new file mode 100644 index 0000000000000..c3e9f7a3a2a77 --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/nv_includes.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// File to include the required TRT headers with workarounds for warnings we can't fix. + +// Ignore warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index bf3bf9e3495d7..9f1e5178428e7 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -6,7 +6,7 @@ #include #include -#include "NvInfer.h" +#include "core/providers/tensorrt/nv_includes.h" #include "core/providers/shared_library/provider_api.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 157cd0a200b35..e521640681a77 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -7,6 +7,7 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/common.h" +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" @@ -137,10 +138,10 @@ std::vector SplitToStringVec(std::string const& s, char separator) return splitted; } -nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) { +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSources disabledTactics = 0; nvinfer1::TacticSources enabledTactics = 0; - std::vector tacticList = SplitToStringVec(tactic_sting, ','); + std::vector tacticList = SplitToStringVec(tactic_string, ','); for (auto& t : tacticList) { bool enable{false}; if (t.front() == '+') { @@ -151,8 +152,8 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) { t.erase(0, 1); const auto toUpper = [](std::string& sourceName) { - std::transform( - sourceName.begin(), sourceName.end(), sourceName.begin(), [](char c) { return std::toupper(c); }); + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return onnxruntime::narrow(std::toupper(c)); }); return sourceName; }; @@ -288,7 +289,8 @@ void CudaCall(cudnnStatus_t retCode, const char* exprString return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } -void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept { +void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr // even for empty tensors, so allocate a dummy byte. size = std::max(size, static_cast(1)); @@ -304,7 +306,7 @@ void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMem return outputPtr; } -void OutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept { +void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { output_shapes.clear(); output_shapes.reserve(dims.nbDims); for (int i = 0; i < dims.nbDims; i++) { @@ -613,20 +615,22 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + auto input_shape = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), + shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = input[j]; + tensor_shape_values[input_name][j] = input_shape[j]; } break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - auto input = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + auto input_shape = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), + shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = static_cast(input[j]); + tensor_shape_values[input_name][j] = static_cast(input_shape[j]); } break; } @@ -974,7 +978,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, - OrtMemoryInfo* mem_info, + OrtMemoryInfo* /*mem_info*/, DDSOutputAllocatorMap& allocator_map, char const* output_name, size_t output_index, @@ -1143,7 +1147,8 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh // get or create a context if (context_state_.retired_context_pool.empty()) { - context = std::make_shared(info_.device_id, info_.has_user_compute_stream, stream_); + context = std::make_shared(narrow(info_.device_id), + info_.has_user_compute_stream, stream_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1163,7 +1168,11 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh } TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info), device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + narrow(info.device_id))}, + info_(info), + device_id_(info.device_id) { InitProviderOrtApi(); CUDA_CALL_THROW(cudaSetDevice(device_id_)); @@ -1655,7 +1664,8 @@ void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { std::vector TensorrtExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, device_id_); + [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, + narrow(device_id_)); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { @@ -3036,7 +3046,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView std::unordered_set input_names; std::unordered_map> tensor_shape_values; - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -3603,7 +3614,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 26f6b2dcc3020..339c45a8742d2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -5,8 +5,9 @@ #include #include #include -#include "NvInfer.h" -#include "NvOnnxParser.h" + +#include "core/providers/tensorrt/nv_includes.h" + #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_graph.h" #include "tensorrt_execution_provider_info.h" diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index eb340ba1e64b6..b4f348159440f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -1,12 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/framework/provider_options.h" #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" -#include -#include -#include namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index b19d9ab0f66d0..54212d34aa2ce 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,8 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, + const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); @@ -23,16 +24,22 @@ struct TensorRTCustomKernel { : compute_stream_(compute_stream) { } - void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here. + void Compute(OrtKernelContext* /*context*/){ + // The implementation is in TensorRT plugin. No need to implement it here. + }; private: void* compute_stream_; }; struct TensorRTCustomOp : Ort::CustomOpBase { - explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} + explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), + compute_stream_(compute_stream) { + } - void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); }; + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { + return new TensorRTCustomKernel(info, compute_stream_); + }; const char* GetName() const { return name_; }; @@ -46,7 +53,9 @@ struct TensorRTCustomOp : Ort::CustomOpBase QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_ // Softmax_QK_Transpose template -std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size); +std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int head_size); template <> -std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) { +std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int /*head_size*/) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } @@ -506,8 +506,8 @@ std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, } template <> -std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) { +std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int /*head_size*/) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index 16582696a81d4..532b98317405f 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -380,7 +380,7 @@ void RunRandomNormalGpuTest(const std::vector dims, const float mean, c test.AddOutput("Y", dims, fp16_data); } - auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + auto output_verifier = [&](const std::vector& fetches, const std::string& /*provider_type*/) { // Only one output, and mean of output values are near attribute mean. ASSERT_EQ(fetches.size(), 1u); const auto& output_tensor = fetches[0].Get(); @@ -472,7 +472,7 @@ void RunRandomUniformGpuTest(const std::vector dims, const float low, c test.AddOutput("Y", dims, fp16_data); } - auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + auto output_verifier = [&](const std::vector& fetches, const std::string& /*provider_type*/) { // Only one output. Each value in output tensoer is between low and high. // Mean of output values are near attribute mean of low and high. ASSERT_EQ(fetches.size(), 1u); diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 4c38c90c2b418..d7e8bf9063645 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -32,17 +32,30 @@ void ortenv_setup() { } #ifdef USE_TENSORRT + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced format parameter. +#endif + // TensorRT will load/unload libraries as builder objects are created and torn down. This will happen for // every single unit test, which leads to excessive test execution time due to that overhead. // Nvidia suggests to keep a placeholder builder object around to avoid this. #include "NvInfer.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + class DummyLogger : public nvinfer1::ILogger { public: - DummyLogger(Severity verbosity) {} - void log(Severity severity, const char* msg) noexcept override {} + DummyLogger(Severity /*verbosity*/) {} + void log(Severity /*severity*/, const char* /*msg*/) noexcept override {} }; DummyLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING); + auto const placeholder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + #endif #define TEST_MAIN main diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc index d9800ce0e0d3e..d36f9b307ec70 100644 --- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc @@ -311,11 +311,9 @@ template static std::vector RunSCELossWithEP(const char* op, int opset_version, const char* domain, - std::function()> - ep_creator, + std::function()> ep_creator, const std::string& reduction, const std::int64_t ignore_index, - const double error_tolerance, const std::vector* X_dims, const std::vector* index_dims, const std::vector* weight_dims, @@ -403,7 +401,7 @@ static void TestSCELoss(const char* op, int opset_version, cpu_fetches = RunSCELossWithEP( op, opset_version, domain, []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data_temp, index_data, weight_data_temp); @@ -411,7 +409,7 @@ static void TestSCELoss(const char* op, int opset_version, cpu_fetches = RunSCELossWithEP( op, opset_version, domain, []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data, index_data, weight_data); @@ -429,7 +427,7 @@ static void TestSCELoss(const char* op, int opset_version, return DefaultRocmExecutionProvider(); #endif }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data, index_data, weight_data); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index d23905496c9bb..9b30bd128b161 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -105,7 +105,8 @@ struct AlgoSearch { CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward data algorithms."); int perf_count; std::unique_ptr candidates = std::make_unique(num_algos); if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { @@ -146,7 +147,9 @@ struct AlgoSearch { // NOTE: - 1 because ALGO_WINOGRAD is not implemented. static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); int perf_count; if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { @@ -188,7 +191,9 @@ struct AlgoSearch { }; static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); int perf_count; if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index d3f5a89434a48..5d12e0ac312c0 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -53,7 +53,6 @@ Status ConvTransposeGrad::ComputeInputGradient(onnxruntime::Stream* stream, c algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data)); return Status::OK(); }); - return Status::OK(); } template @@ -71,7 +70,6 @@ Status ConvTransposeGrad::ComputeWeightGradient(onnxruntime::Stream* stream, algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data)); return Status::OK(); }); - return Status::OK(); } template diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 2d89ed05712e0..ad577afa06c18 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -30,8 +30,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - namespace { // This is the un-specialized struct. Note that we prevent instantiation of this // struct by putting an undefined symbol in the function body so it won't compile. diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu index c90809eb2fdcc..fd55f7c30ff75 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu @@ -619,7 +619,7 @@ CudaKernel::CudaAsyncBuffer compute_tensor_rang template void LambMultiTensorReductionFunctor::operator()( - cudaStream_t stream, + cudaStream_t /*stream*/, ChunkGroup<4> chunk_group, const CudaKernel& kernel, void* reduction_buffer, diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index 9516753d50113..864513bc4d671 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -93,8 +93,17 @@ steps: $ccache_parent_dir = (Split-Path -parent $ccache_path) Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe" Get-ChildItem $ccache_parent_dir - ccache --version } + + "ccache info:" + ccache --version + ccache --show-config + + "cl.exe from path: $((Get-Command cl).Path). Version:" + (cl.exe -?) -match 'Compiler Version' + "C:\ProgramData\chocolatey\bin\cl.exe version:" + (C:\ProgramData\chocolatey\bin\cl.exe -?) -match 'Compiler Version' + displayName: Install ccache and update PATH to use linked versions of gcc, cc, etc - ${{ if eq(parameters.WITHCACHE, true) }}: From e93a860819545ea64acfe36e19e2b954389d48bf Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 5 Mar 2024 21:54:48 -0800 Subject: [PATCH 116/237] Remove arm build for training (#19788) We no longer support Win arm 32 so removing the associated build and packaging job. --- .../ondevice-training-cpu-packaging-pipeline.yml | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index cf39be23cbdaf..b3faaf2a7f1a6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -61,21 +61,6 @@ stages: buildJava: false buildNodejs: false -- template: win-ci.yml - parameters: - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: Training_CPU_arm_${{ parameters.BuildVariant }} - artifact_name_suffix: -training - buildArch: x64 - msbuildPlatform: arm - packageName: arm - buildparameter: --arm ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe - runTests: false - buildJava: false - buildNodejs: false - ort_build_pool_name: onnxruntime-Win-CPU-2022 - - template: win-ci.yml parameters: DoCompliance: ${{ parameters.DoCompliance }} @@ -127,7 +112,6 @@ stages: - Linux_C_API_Packaging_Training_CPU - Windows_Packaging_Training_CPU_x86_${{ parameters.BuildVariant }} - Windows_Packaging_Training_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_Training_CPU_arm_${{ parameters.BuildVariant }} - Windows_Packaging_Training_CPU_arm64_${{ parameters.BuildVariant }} - Android_Java_API_AAR_Packaging_Training_Full condition: succeeded() From d9bf85613d7171b54a6ece45fc0f241b008a1fd8 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Mar 2024 21:54:16 +0800 Subject: [PATCH 117/237] Adapt memory optimizer to fit PHI2 (#19757) ### Adapt memory optimizer to fit PHI2 Few improvements and bug fixes: 1. Fix bug related to transformer layer detection. 2. Use default reversed typo order to create recompute node, to avoid the leaf nodes are handled too late, then having lowest priority for execution. 3. Add early stop when activation's element count is constant and total element count < 1M. This can avoid overhead to search subgraphs. Using export ORTMODULE_MEMORY_OPT_LEVEL=1 to enable layerwise recompute, on given recipe, memory consumption dropped from ~22GB to ~13GB . --- .../memory_optimizer/memory_insight.cc | 3 +- .../memory_optimizer/memory_optimizer.cc | 37 +++++++++++++++- .../memory_optimizer/recompute_analysis.cc | 18 +++++++- .../memory_optimizer/transformer_specific.cc | 42 +++++++++++++++++-- .../memory_optimizer/transformer_specific.h | 3 ++ 5 files changed, 95 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 08c402bf669c8..54c49db0597c7 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -258,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, logger)); InlinedHashSet layer_boundary_ln_nodes; - FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes); + FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, layer_boundary_ln_nodes); // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 525e3b4b8de35..40fa2fc5cc737 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -190,11 +190,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve .IsOK()); // The second pass - apply the transformation. - // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. + // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + // + // Note 2: Here we use default typo order (which tries to BFS from the outputs, + // so the nearest node to graph output will be visited last). So in reversed default typo order, + // the neareast node to graph output will be visited first. + // Imagine there is a such subgraph + // input1 input2 input3 + // \ | / + // multiple layers + // | + // node M + // labels-------|----- + // \ | | + // node1 | | + // \ | | + // node2 / | + // \ / | + // node loss / + // | / + // YieldOp node1_recompute + // | / + // \ node2 recompute + // \ / + // node loss_grad + // | + // critical grad path + // + // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added + // at last because we do this following reversed topological order. Then node1_recompute node will have lowest + // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then + // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. + // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes + // in this case. + + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 12c83591c0036..76b3325f36116 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50; static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); @@ -291,6 +291,22 @@ Status SelectRecomputeSubgraph(const Node& entry_node, const auto current_node_input_index = input_edge.GetDstArgIndex(); if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != input_arg_indices.end()) { + // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. + auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); + if (output_shape) { + bool all_constant_dim = true; + int64_t num_elem = 1; + for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) { + if (!output_shape->dim(k).has_dim_value()) { + all_constant_dim = false; + num_elem *= output_shape->dim(k).dim_value(); + } + } + if (all_constant_dim && num_elem < 1 * 1024 * 1024) { + // Skip this input index. + continue; + } + } NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " + diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 04f2679ac774f..c88a0f05d36b8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, const logging::Logger&, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, @@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes( std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); + // Ignore those nodes after YieldOp. + if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + nodes_to_check.push_back(&(*node_it)); + } } + bool unexpected_failure = false; + bool found_softmax = false; + bool found_layernorm = false; + ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes( visited_nodes.insert(next_node); if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - layer_boundary_ln_nodes.insert(&node); - break; + found_softmax = true; } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - break; + if (found_layernorm) { + // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. + unexpected_failure = true; + break; + } + found_layernorm = true; // don't trace further + next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); + continue; } else { for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + if (found_layernorm && + node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { + continue; + } nodes_to_check.push_back(&(*node_it)); } } } + + if (unexpected_failure) { + layer_boundary_ln_nodes.clear(); + break; + } + + if (found_softmax) { + layer_boundary_ln_nodes.insert(&node); + } else if (!found_layernorm) { + // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, + // we also consider it as boundary node. + layer_boundary_ln_nodes.insert(&node); + } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index f2cfd640b0840..b58d822124f43 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const logging::Logger& logger, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer From f9a92e589ad8588424725a91bbd0683a63bda950 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 6 Mar 2024 09:10:35 -0800 Subject: [PATCH 118/237] Upgrade the Windows SDK version that is used in WindowsAI Nuget Packaging pipeline (#19786) ### Description 1. Upgrade the version from 10.0.19041.0 to 10.0.22621.0. The old one misses some macros that are needed by PyTorch's CPUINFO 2. Also update cmake. ### Motivation and Context In PR #19655 I added CPUINFO to all Windows builds, but forgot to test this pipeline. --- .pipelines/windowsai-steps.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index ff5179e6135c2..855573de753b0 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -80,11 +80,11 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | - curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip - 7z x cmake-3.26.3-windows-x86_64.zip + curl -O -L https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-windows-x86_64.zip + 7z x cmake-3.28.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" --cmake_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' From db8d0c8e06fd030da6b7bf00cf3fb20661dd13b8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 6 Mar 2024 11:21:19 -0800 Subject: [PATCH 119/237] reset dcvsEnable for different HTP performance mode (#19728) reset dcvsEnable for different HTP performance mode --- .../qnn/builder/qnn_backend_manager.cc | 80 ++++++++++--------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index e354bf6562722..6bb57b6a3e56c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -678,13 +678,13 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, dcvs_v3.setSleepDisable = 0; dcvs_v3.sleepDisable = 0; dcvs_v3.setDcvsEnable = 1; - dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; // choose performance mode switch (htp_performance_mode) { case HtpPerformanceMode::kHtpBurst: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMinLatency; + dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.setBusParams = 1; dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER; dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER; @@ -698,6 +698,7 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, case HtpPerformanceMode::kHtpHighPerformance: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepLowLatency; + dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.setBusParams = 1; dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_TURBO; dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO; @@ -707,33 +708,36 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO; dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO; break; - case HtpPerformanceMode::kHtpPowerSaver: + case HtpPerformanceMode::kHtpBalanced: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; break; - case HtpPerformanceMode::kHtpLowPowerSaver: + case HtpPerformanceMode::kHtpLowBalanced: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; break; case HtpPerformanceMode::kHtpHighPowerSaver: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS_PLUS; dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS_PLUS; @@ -743,41 +747,45 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS_PLUS; dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS_PLUS; break; - case HtpPerformanceMode::kHtpExtremePowerSaver: + case HtpPerformanceMode::kHtpPowerSaver: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; break; - case HtpPerformanceMode::kHtpLowBalanced: + case HtpPerformanceMode::kHtpLowPowerSaver: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; break; - case HtpPerformanceMode::kHtpBalanced: + case HtpPerformanceMode::kHtpExtremePowerSaver: + dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_POWER_SAVER_MODE; dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; break; default: ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode)); From 8bd1335d00375179fa9cdccf1c6fbda8c04304df Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:34:33 -0800 Subject: [PATCH 120/237] Fix GQA Rotary Embedding sequence length (#19801) ### Description Previously, GQA incorrectly enforced rotary cos and sin cache to be of sequence length equal to present sequence length. Now it enforces that it be greater than or equal to present sequence length since to match Rotary Embedding Op it should be of max_sequence_length ### Motivation and Context Fixes issue with fusing Rotary Embedding and GQA for certain models which prefer this optimization. --- .../contrib_ops/cuda/bert/group_query_attention_helper.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 853e1a710cb24..6fa11200fd5be 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -214,13 +214,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] != present_sequence_length) { + if (cos_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 must be of present_sequence_length."); + "cos_cache dimension 0 should be of max_sequence_length."); } - if (sin_dims[0] != present_sequence_length) { + if (sin_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 must be of present_sequence_length."); + "sin_cache dimension 0 should be of max_sequence_length."); } if (cos_dims[1] != (head_size / 16) * 8) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, From f2dc725b3355ec25e61d6970b6c030c68f9d3ac4 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 6 Mar 2024 21:35:55 +0100 Subject: [PATCH 121/237] Add SpaceToDepth and DepthToSpace CUDA NHWC Ops (#19646) ### Description - Adding CUDA NHWC support for SpaceToDepth and DepthToSpace - Add a new test which verifies that swizzling SpaceToDepth swizzling for the H axis is correct. - If CUDA NHWC is enabled, run all tests on the CUDA EP with NHWC as well. ### Motivation and Context Adding more NHWC operations to avoid layout transformations when using the CUDA EP for more efficiency. --- include/onnxruntime/core/graph/constants.h | 1 + .../contrib_ops/internal_nhwc_onnx_schemas.cc | 1 + .../layout_transformation.cc | 3 +- .../providers/cpu/tensor/space_depth_ops.h | 16 +- .../core/providers/cuda/cuda_nhwc_kernels.cc | 16 ++ .../providers/cuda/tensor/space_depth_ops.cc | 196 +++++++++++++----- .../providers/cuda/tensor/space_depth_ops.h | 2 + .../test/contrib_ops/gridsample_test.cc | 17 +- onnxruntime/test/providers/base_tester.cc | 7 + .../providers/cpu/generator/random_test.cc | 12 +- .../providers/cpu/nn/batch_norm_op_test.cc | 6 +- .../test/providers/cpu/nn/conv_op_test.cc | 2 + .../cpu/nn/conv_transpose_op_test.cc | 15 +- .../test/providers/cpu/nn/pool_op_test.cc | 86 ++++---- .../cpu/reduction/reduction_ops_test.cc | 3 + .../test/providers/cpu/rnn/rnn_op_test.cc | 7 +- .../cpu/tensor/gather_elements_op_test.cc | 2 +- .../providers/cpu/tensor/resize_op_test.cc | 22 +- .../providers/cpu/tensor/scatter_op_test.cc | 7 +- .../cpu/tensor/space_depth_ops_test.cc | 47 +++++ .../providers/cpu/tensor/upsample_op_test.cc | 6 +- 21 files changed, 345 insertions(+), 129 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7dd..8e04050d089a0 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; +constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index c8960578f9e3d..6bf19654a3ce9 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -106,6 +106,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function& GetCUDALayoutSensitiveOps() { "GlobalAveragePool", "AveragePool", "GridSample", - }; + "DepthToSpace", + "SpaceToDepth"}; }(); return cuda_nhwc_ops; } diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h index 7d117317ba172..3218c8952d6ec 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h @@ -14,6 +14,7 @@ class SpaceDepthBase { "Attribute blocksize is not set."); } + template Status InputValidationsAndOutputDimsCalc(const Tensor& input, int64_t& batch, int64_t& input_depth, int64_t& input_height, int64_t& input_width, @@ -27,9 +28,15 @@ class SpaceDepthBase { } batch = input_shape[0]; - input_depth = input_shape[1]; - input_height = input_shape[2]; - input_width = input_shape[3]; + if constexpr (IsNHWC) { + input_depth = input_shape[3]; + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + input_depth = input_shape[1]; + input_height = input_shape[2]; + input_width = input_shape[3]; + } if (is_space_to_depth) { // SpaceToDepth op if ((input_height % this->blocksize_) != 0) { @@ -46,7 +53,8 @@ class SpaceDepthBase { } else { // DepthToSpace op if ((input_depth % (blocksize_ * blocksize_) != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DepthToSpace requires input depth to be a multiple of (block_size * blok_size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DepthToSpace requires input depth to be a multiple of (block_size * block_size)"); } output_depth = input_depth / blocksize_ / blocksize_; diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index 64edc319e15ac..da7802fe8d5dc 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -86,6 +86,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, SpaceToDepth); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, SpaceToDepth); Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn nhwc_function_table[] = { @@ -171,6 +176,17 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, ConvTranspose)>, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : nhwc_function_table) { diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc index 407a2ef3981f1..aaaf3600b676e 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc @@ -20,7 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - SpaceToDepth); + SpaceToDepth); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + SpaceToDepth, + kMSInternalNHWCDomain, + 1, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SpaceToDepth); +#endif ONNX_OPERATOR_KERNEL_EX( SpaceToDepth, @@ -32,7 +47,21 @@ ONNX_OPERATOR_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - SpaceToDepth); + SpaceToDepth); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_KERNEL_EX( + SpaceToDepth, + kMSInternalNHWCDomain, + 13, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SpaceToDepth); +#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( DepthToSpace, @@ -45,7 +74,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 1, + 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( DepthToSpace, @@ -58,7 +102,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 11, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif ONNX_OPERATOR_KERNEL_EX( DepthToSpace, @@ -70,23 +129,35 @@ ONNX_OPERATOR_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 13, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif static Status SpaceDepthOpCudaImpl(const cudaDeviceProp& prop, cudaStream_t stream, const cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, const std::vector& permutation, - const int64_t batch_size, - const int64_t in_dim1, const int64_t in_dim2, const int64_t in_dim3, - const int64_t in_dim4, const int64_t in_dim5, + const TensorShape& virtual_input_shape, const TensorShape& virtual_output_shape) { - TensorShape virtual_input_shape{batch_size, in_dim1, in_dim2, in_dim3, in_dim4, in_dim5}; return Transpose::DoTranspose(prop, stream, cublas_handle, permutation, input, output, &virtual_input_shape, &virtual_output_shape); } -Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { +template +Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& input = *tensor_pointer; @@ -101,29 +172,44 @@ Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; - ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, - batch, - input_depth, input_height, input_width, - output_depth, output_height, output_width, - true)); + ORT_RETURN_IF_ERROR( + InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + true)); // We use the "actual" output shape to construct the output tensor - Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width}); + Tensor& output = (Layout == LAYOUT_NCHW) + ? *context->Output(0, {batch, output_depth, output_height, output_width}) + : *context->Output(0, {batch, output_height, output_width, output_depth}); + + TensorShape virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, input_depth, input_height / blocksize_, + blocksize_, input_width / blocksize_, blocksize_} + : TensorShape{batch, input_height / blocksize_, blocksize_, + input_width / blocksize_, blocksize_, input_depth}; // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...) - TensorShape virtual_output_shape{batch, blocksize_, blocksize_, input_depth, - input_height / blocksize_, input_width / blocksize_}; + TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, blocksize_, blocksize_, input_depth, + input_height / blocksize_, input_width / blocksize_} + : TensorShape{batch, input_height / blocksize_, input_width / blocksize_, + blocksize_, blocksize_, input_depth}; - std::vector permutation = {0, 3, 5, 1, 2, 4}; + std::vector permutation = (Layout == LAYOUT_NCHW) + ? std::vector{0, 3, 5, 1, 2, 4} + : std::vector{0, 1, 3, 2, 4, 5}; - ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, batch, - input_depth, input_height / blocksize_, blocksize_, input_width / blocksize_, blocksize_, - virtual_output_shape)); + ORT_RETURN_IF_ERROR( + SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, + virtual_input_shape, virtual_output_shape)); return Status::OK(); } -Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { +template +Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& input = *tensor_pointer; @@ -138,46 +224,56 @@ Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; - ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, - batch, - input_depth, input_height, input_width, - output_depth, output_height, output_width, - false)); + ORT_RETURN_IF_ERROR( + InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + false)); // We use the "actual" output shape to construct the output tensor - Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width}); + Tensor& output = (Layout == LAYOUT_NCHW) + ? *context->Output(0, {batch, output_depth, output_height, output_width}) + : *context->Output(0, {batch, output_height, output_width, output_depth}); + + int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_; + TensorShape virtual_input_shape; + + // cdr only here! + if (is_dcr_) { + virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, blocksize_, blocksize_, + virtual_input_depth, input_height, input_width} + : TensorShape{batch, input_height, input_width, + blocksize_, blocksize_, virtual_input_depth}; + } else { + virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, virtual_input_depth, blocksize_, + blocksize_, input_height, input_width} + : TensorShape{batch, input_height, input_width, + virtual_input_depth, blocksize_, blocksize_}; + } // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...) - TensorShape virtual_output_shape{batch, input_depth / blocksize_ / blocksize_, - input_height, blocksize_, input_width, blocksize_}; + TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, virtual_input_depth, input_height, + blocksize_, input_width, blocksize_} + : TensorShape{batch, input_height, blocksize_, + input_width, blocksize_, virtual_input_depth}; std::vector permutation; - permutation.reserve(6); - permutation.push_back(0); if (is_dcr_) { - permutation.push_back(3); - permutation.push_back(4); - permutation.push_back(1); - permutation.push_back(5); - permutation.push_back(2); + permutation = (Layout == LAYOUT_NCHW) + ? std::vector({0, 3, 4, 1, 5, 2}) + : std::vector({0, 1, 3, 2, 4, 5}); } else { - permutation.push_back(1); - permutation.push_back(4); - permutation.push_back(2); - permutation.push_back(5); - permutation.push_back(3); + permutation = std::vector({0, 1, 4, 2, 5, 3}); } - int64_t dim1 = is_dcr_ ? blocksize_ : input_depth / blocksize_ / blocksize_; - int64_t dim3 = is_dcr_ ? input_depth / blocksize_ / blocksize_ : blocksize_; - ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, - permutation, - batch, - dim1, blocksize_, dim3, input_height, input_width, - virtual_output_shape)); + permutation, virtual_input_shape, virtual_output_shape)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h index 57b85556f1dbe..8780d9b365005 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h @@ -9,6 +9,7 @@ namespace onnxruntime { namespace cuda { +template class SpaceToDepth final : public CudaKernel, SpaceDepthBase { public: explicit SpaceToDepth(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { @@ -17,6 +18,7 @@ class SpaceToDepth final : public CudaKernel, SpaceDepthBase { Status ComputeInternal(OpKernelContext* context) const override; }; +template class DepthToSpace final : public CudaKernel, SpaceDepthBase { public: explicit DepthToSpace(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { diff --git a/onnxruntime/test/contrib_ops/gridsample_test.cc b/onnxruntime/test/contrib_ops/gridsample_test.cc index 1f31c2bd21f14..46ed04301a9e8 100644 --- a/onnxruntime/test/contrib_ops/gridsample_test.cc +++ b/onnxruntime/test/contrib_ops/gridsample_test.cc @@ -32,7 +32,7 @@ TEST(GridsampleContribOpTest, gridsample_default) { 3.8000f, 7.9000f, 8.7000f, 9.5000f, 10.3000f, 5.3000f, 5.4000f, 11.1000f, 11.9000f, 12.7000f, 13.5000f, 6.9000f, 3.0000f, 6.1500f, 6.5500f, 6.9500f, 7.3500f, 3.7500f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) { @@ -45,7 +45,7 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) { 5.0000f, 5.0000f, 10.0000f, 10.0000f}); test.AddAttribute("padding_mode", "zeros"); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_paddingmode_border) { @@ -58,7 +58,7 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_border) { 5.0000f, 5.0000f, 10.0000f, 10.0000f}); test.AddAttribute("padding_mode", "border"); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 5.0000f, 5.0000f, 1.7000f, 5.0000f, 5.0000f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) { @@ -71,7 +71,8 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) { 5.0000f, 5.0000f, 10.0000f, 10.0000f}); test.AddAttribute("padding_mode", "reflection"); test.AddOutput("Y", {1, 1, 2, 4}, {2.5000f, 0.0000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 5.0000f, 2.5000f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); // Accuracy issue for QNN + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kQnnExecutionProvider}); // Accuracy issue for QNN } TEST(GridsampleContribOpTest, gridsample_aligncorners_true) { @@ -86,7 +87,7 @@ TEST(GridsampleContribOpTest, gridsample_aligncorners_true) { test.AddAttribute("mode", "bilinear"); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 1.2500f, 2.0000f, 2.5000f, 2.5000f, 2.0000f, 3.7500f, 5.0000f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_mode_bilinear) { @@ -99,7 +100,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_bilinear) { 0.5000f, 0.5000f, 1.0000f, 1.0000f}); test.AddAttribute("mode", "bilinear"); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.5000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 4.5000f, 1.2500f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_mode_nearest) { @@ -112,7 +113,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_nearest) { 0.5000f, 0.5000f, 1.0000f, 1.0000f}); test.AddAttribute("mode", "nearest"); test.AddOutput("Y", {1, 1, 2, 4}, {0.f, 0.f, 2.f, 2.f, 2.f, 2.f, 5.f, 0.f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_mode_bicubic) { @@ -125,7 +126,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_bicubic) { 0.5000f, 0.5000f, 1.0000f, 1.0000f}); test.AddAttribute("mode", "bicubic"); test.AddOutput("Y", {1, 1, 2, 4}, {-0.1406f, 0.3828f, 1.7556f, 2.9688f, 2.9688f, 1.7556f, 5.1445f, 1.3906f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } } // namespace test diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 16cce85f7cb0a..84cb663a2984a 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -622,6 +622,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, static const std::string all_provider_types[] = { kCpuExecutionProvider, kCudaExecutionProvider, +#ifdef ENABLE_CUDA_NHWC_OPS + kCudaNHWCExecutionProvider, +#endif kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, @@ -650,6 +653,10 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultCpuExecutionProvider(); else if (provider_type == onnxruntime::kCudaExecutionProvider) execution_provider = DefaultCudaExecutionProvider(); +#ifdef ENABLE_CUDA_NHWC_OPS + else if (provider_type == onnxruntime::kCudaNHWCExecutionProvider) + execution_provider = DefaultCudaNHWCExecutionProvider(); +#endif else if (provider_type == onnxruntime::kDnnlExecutionProvider) execution_provider = DefaultDnnlExecutionProvider(); else if (provider_type == onnxruntime::kOpenVINOExecutionProvider) diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index 532b98317405f..be049d1cf0ce3 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -36,7 +36,8 @@ TEST(Random, RandomNormal2DDouble) { // The expected_output is generated using std lib, which is used by CPU kernel only. // So we need to exclude other EPs here. Ditto for other places. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } void RunRandomNormalLike3DFloat(bool infer_dtype = false) { @@ -72,7 +73,8 @@ void RunRandomNormalLike3DFloat(bool infer_dtype = false) { test.AddOutput("Y", dims, expected_output); // TensorRT does not support manual seed overrides and there will be result mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); } TEST(Random, RandomNormalLike3DDouble) { @@ -109,7 +111,8 @@ TEST(Random, RandomUniform1DFloat) { test.AddOutput("Y", dims, expected_output); // TensorRT does not support manual seed overrides and there will be result mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); } void RunRandomUniformLikeTest(bool infer_dtype = false) { @@ -142,7 +145,8 @@ void RunRandomUniformLikeTest(bool infer_dtype = false) { test.AddOutput("Y", dims, expected_output); // TensorRT does not support seed parameter and there will be result mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); } TEST(Random, RandomUniformLike2DDouble) { diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 54e5c71bd753a..3d30fc62a945d 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -917,7 +917,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 - {kCudaExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } @@ -945,7 +945,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } @@ -972,7 +972,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) { // Same exclusions as the opset 14 test test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } #endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index dede278b7274f..0efa78af2795c 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -59,6 +59,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); + // Disable CUDA NHWC execution provider as it is currently flaky + excluded_providers.insert(kCudaNHWCExecutionProvider); // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs. excluded_providers.insert(kQnnExecutionProvider); diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 472f841aa8565..ec93dc249eeb2 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -75,7 +75,8 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, const vector& expected_output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider, kQnnExecutionProvider}) { + const std::unordered_set& excluded_provider_types = + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, @@ -409,7 +410,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_2) { vector Y_shape = {1, 1, 1, 14}; auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, - OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider}); + OpTester::ExpectResult::kExpectSuccess, "", + {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { @@ -434,7 +436,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, - OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider}); + OpTester::ExpectResult::kExpectSuccess, "", + {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { @@ -871,7 +874,8 @@ TEST(ConvTransposeTest, DimWithZero) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kAclExecutionProvider, kQnnExecutionProvider}); + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, + kAclExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_3D) { @@ -1005,7 +1009,8 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kCudaExecutionProvider, kQnnExecutionProvider}); + {kTensorrtExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 4b194ec18b31b..e24cda17166ed 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -57,7 +57,8 @@ TEST(PoolTest, MaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: result differs + // TensorRT: result differs + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } // Only CUDA kernel has float 16 support @@ -115,7 +116,8 @@ TEST(PoolTest, MaxPool_F16) { test.AddInput("X", x_dims, f_X); test.AddOutput("Y", expected_dims, f_Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Assertion `!attrs.count("pads")' failed + // TensorRT: Assertion `!attrs.count("pads")' failed + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } #endif @@ -167,7 +169,9 @@ static void MaxPool_8_WithIndexTest(bool has_index, int64_t storage_order = 0) { storage_order == 0 ? test.AddOutput("Indices", expected_dims, expected_indices_row) : test.AddOutput("Indices", expected_dims, expected_indices_col); } - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kDnnlExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, + kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, MaxPool_8_With_Index) { @@ -196,7 +200,7 @@ TEST(PoolTest, MaxPool1D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { @@ -217,7 +221,8 @@ static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_8_With_Index) { @@ -243,7 +248,8 @@ static void MaxPool1D_12_WithIndexTest_int8(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { @@ -264,7 +270,8 @@ static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_12_With_Index_8bits) { @@ -302,9 +309,9 @@ TEST(PoolTest, MaxPool2D_uint8) { test.AddOutput("Output", output_shape, output); #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider}); #else - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); #endif } @@ -330,7 +337,7 @@ TEST(PoolTest, MaxPool_10_Dilation_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations) { @@ -350,7 +357,7 @@ TEST(PoolTest, MaxPool_DefaultDilations) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations_int8) { @@ -370,7 +377,7 @@ TEST(PoolTest, MaxPool_DefaultDilations_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations_uint8) { @@ -390,7 +397,7 @@ TEST(PoolTest, MaxPool_DefaultDilations_uint8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_1d) { @@ -416,7 +423,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d) { @@ -444,7 +451,7 @@ TEST(PoolTest, MaxPool_10_Dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d_int8) { @@ -472,7 +479,7 @@ TEST(PoolTest, MaxPool_10_Dilation_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_2d) { @@ -500,7 +507,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { @@ -528,7 +535,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { @@ -556,7 +564,8 @@ TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { @@ -585,7 +594,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_3d) { @@ -621,7 +631,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, GlobalMaxPool) { @@ -697,7 +707,7 @@ TEST(PoolTest, GlobalMaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, GlobalMaxPool3D) { @@ -773,7 +783,7 @@ TEST(PoolTest, GlobalMaxPool3D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool) { @@ -854,7 +864,7 @@ TEST(PoolTest, AveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_IncludePadPixel) { @@ -878,7 +888,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } // test 'strides' attribute not specified @@ -897,7 +907,7 @@ TEST(PoolTest, AveragePool_DefaultStrides) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_10_ceil1_2d) { @@ -920,7 +930,8 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, AveragePool_19_dilation_2d) { @@ -944,7 +955,7 @@ TEST(PoolTest, AveragePool_19_dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, GlobalAveragePool) { @@ -1020,7 +1031,7 @@ TEST(PoolTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, GlobalAveragePool_Large_128) { @@ -1033,7 +1044,7 @@ TEST(PoolTest, GlobalAveragePool_Large_128) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, GlobalAveragePool_Large_256) { @@ -1046,7 +1057,7 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, LpPool) { @@ -1353,7 +1364,7 @@ TEST(PoolTest, LpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } // test data generated with lp_pool_test_generator.py @@ -1385,7 +1396,7 @@ TEST(PoolTest, LpPool1d) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); y_count++; } } @@ -1417,7 +1428,7 @@ TEST(PoolTest, LpPool2d) { test.AddAttribute("kernel_shape", kernel_sizes[kernel_size_count]); test.AddOutput("Y", y_sizes[y_count], ys[y_count]); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); y_count++; } } @@ -1435,7 +1446,7 @@ TEST(PoolTest, LpPoolCeilMode) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, GlobalLpPool) { @@ -1690,7 +1701,7 @@ TEST(PoolTest, GlobalLpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, MaxPoolDimWithZeroForN) { @@ -1707,7 +1718,8 @@ TEST(PoolTest, MaxPoolDimWithZeroForN) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}); } } // namespace test diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index b0e0a0dd0d564..2902995df1e71 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3541,6 +3541,7 @@ TEST(ReductionOpTest, ReduceDimWithZero1) { { kCoreMLExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kDnnlExecutionProvider, kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, @@ -3591,6 +3592,7 @@ TEST(ReductionOpTest, ReduceDimWithZero2) { { kCoreMLExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kDnnlExecutionProvider, kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, @@ -5779,6 +5781,7 @@ void test_empty_set(const std::string& op, int opset, bool axes_as_input, float { kCoreMLExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kDmlExecutionProvider, kDnnlExecutionProvider, kMIGraphXExecutionProvider, diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index 1a31743e2f7e7..38734ab9f668f 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -744,7 +744,9 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT - test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, + kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); }; // should batch batch_size to be valid @@ -842,7 +844,8 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(RNNTest, RNN_with_invalid_activation_load_failure) { diff --git a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc index 8a8bc5560c084..b4bd3fca7b712 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc @@ -383,7 +383,7 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) { // skip openvino which will not throw error message but will ensure no out-of-bound access // skip TensorRT because it doesn't support out of bounds indices test.Run(OpTester::ExpectResult::kExpectFailure, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, kTensorrtExecutionProvider, kDmlExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 5addb5dd9ce46..062f25b989a70 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -102,7 +102,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) { @@ -132,7 +132,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) { @@ -192,7 +193,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e // DML: results mismatch test.Run( OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -267,7 +268,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { @@ -291,7 +292,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { @@ -439,7 +441,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -539,7 +542,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe // ROCm: results mismatch // DML: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -650,7 +653,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -1913,6 +1917,8 @@ void TestAntialiasing(std::map attributes, }); // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accuracy issue. excluded_eps.insert(kTensorrtExecutionProvider); + // Test is flaky on kCudaNHWCExecutionProvider + excluded_eps.insert(kCudaNHWCExecutionProvider); test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_eps); } diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 30e27bb15fa57..b1dfec7951338 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -268,7 +268,7 @@ static void scatter_invalid_index(const char* op_name, int op_version) { test.AddOutput("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]", - {kCudaExecutionProvider, kTensorrtExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(Scatter, InvalidIndex) { @@ -291,9 +291,10 @@ static void scatter_bool_with_axis_tests(const char* op_name, int op_version) { test.AddOutput("y", {1, 5}, {false, true, false, false, false}); #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kOpenVINOExecutionProvider}); // OpenVINO: Disabled due to failure for GPU + {kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Disabled due to failure for GPU #else - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider}); // OpenVINO: Disabled due to failure for GPU #endif } diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 63b92cfc187bd..5222380d9ca56 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -108,6 +108,53 @@ TEST(TensorOpTest, SpaceToDepthTest_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } +TEST(TensorOpTest, SpaceToDepthTest_3) { + // Test swizzling with H_output > 1 + OpTester test("SpaceToDepth"); + constexpr int64_t blocksize = 2; + test.AddAttribute("blocksize", blocksize); + constexpr int64_t N = 1, C = 2, H = 4, W = 8; + + const std::vector X = { + 0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, + 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, + + 2.0f, 2.1f, 2.2f, 2.3f, 2.4f, 2.5f, 2.6f, 2.7f, + 3.0f, 3.1f, 3.2f, 3.3f, 3.4f, 3.5f, 3.6f, 3.7f, + + 4.0f, 4.1f, 4.2f, 4.3f, 4.4f, 4.5f, 4.6f, 4.7f, + 5.0f, 5.1f, 5.2f, 5.3f, 5.4f, 5.5f, 5.6f, 5.7f, + 6.0f, 6.1f, 6.2f, 6.3f, 6.4f, 6.5f, 6.6f, 6.7f, + 7.0f, 7.1f, 7.2f, 7.3f, 7.4f, 7.5f, 7.6f, 7.7f}; + + test.AddInput("input", {N, C, H, W}, X); + + const std::vector result = { + 0.0f, 0.2f, 0.4f, 0.6f, + 2.0f, 2.2f, 2.4f, 2.6f, + 4.0f, 4.2f, 4.4f, 4.6f, + 6.0f, 6.2f, 6.4f, 6.6f, + + 0.1f, 0.3f, 0.5f, 0.7f, + 2.1f, 2.3f, 2.5f, 2.7f, + 4.1f, 4.3f, 4.5f, 4.7f, + 6.1f, 6.3f, 6.5f, 6.7f, + + 1.0f, 1.2f, 1.4f, 1.6f, + 3.0f, 3.2f, 3.4f, 3.6f, + 5.0f, 5.2f, 5.4f, 5.6f, + 7.0f, 7.2f, 7.4f, 7.6f, + + 1.1f, 1.3f, 1.5f, 1.7f, + 3.1f, 3.3f, 3.5f, 3.7f, + 5.1f, 5.3f, 5.5f, 5.7f, + 7.1f, 7.3f, 7.5f, 7.7f}; + + test.AddOutput("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result); + + test.Run(); +} + TEST(TensorOpTest, DepthToSpaceTest_1) { OpTester test("DepthToSpace", 7); // create an opset 7 model constexpr int64_t blocksize = 2; diff --git a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc index 72cb84d50f078..188532cfa350a 100644 --- a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc @@ -692,7 +692,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4D1CBilinearTest) { // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) { @@ -766,7 +766,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) { // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOp2DBilinearTest) { @@ -886,7 +886,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest_int32) { // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearestTest_1D) { From 1ce5bfb0ecc94a4a98eb093a53cd248ab6b7167b Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 7 Mar 2024 08:19:59 +0800 Subject: [PATCH 122/237] [WebNN EP] Make sure optional input is provided (#19686) Some optional input is presented as empty string, we should not only check if the input size is correct, but also check if the optional input is not empty. e.g. Pad node has empty optional input in sam-b-encoder.onnx model: image --- .../core/providers/webnn/builders/impl/pad_op_builder.cc | 6 +++--- .../providers/webnn/builders/impl/reduction_op_builder.cc | 2 +- .../core/providers/webnn/builders/impl/split_op_builder.cc | 2 +- .../webnn/builders/impl/squeeze_unsqueeze_op_builder.cc | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 52b5518857773..9852db0abc9d2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -88,15 +88,15 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& pads_tensor = *initializers.at(input_defs[1]->Name()); ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, logger), "Error while read pads tensor"); - // Constant value and axes are optional. - if (input_defs.size() >= 3) { + // Constant value and axes are optional. Make sure they are not empty. + if (!GetTensorName(input_defs, 2).empty()) { const auto value_tensor = *initializers.at(input_defs[2]->Name()); emscripten::val value = emscripten::val::object(); ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, logger), "Cannot read constant value"); options.set("value", value); } - if (input_defs.size() == 4) { + if (!GetTensorName(input_defs, 3).empty()) { const auto input_rank = input_shape.size(); std::vector axes; const auto& axes_tensor = *initializers.at(input_defs[3]->Name()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index f446a7b81d1c0..c0954f7cf6fb1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -65,7 +65,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (opset >= 18 || (op_type == "ReduceSum" && opset >= 13)) { // 'axes' is an optional input. const auto noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0); - if (input_defs.size() > 1) { + if (!GetTensorName(input_defs, 1).empty()) { // Optional input axes is provided, use axes initializer data. const auto& initializers(model_builder.GetInitializerTensors()); const auto& axes_tensor = *initializers.at(input_defs[1]->Name()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 91f21b196be54..9819e4ce7ac5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -57,7 +57,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, axis = SafeInt(HandleNegativeAxis(axis, rank)); options.set("axis", axis); - if (input_defs.size() == 2) { + if (!GetTensorName(input_defs, 1).empty()) { // Inputs contains optional 'split' input std::vector splits; const auto& initializers(model_builder.GetInitializerTensors()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 15149bd8fe821..8e6feb62fa8c4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -58,7 +58,7 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil std::vector axes_data; auto rank = input_rank; - if (node.SinceVersion() >= 13 && input_defs.size() > 1) { + if (node.SinceVersion() >= 13 && !GetTensorName(input_defs, 1).empty()) { // Input axes is provided, use axes initializer data. const auto& initializers = model_builder.GetInitializerTensors(); const auto& axes_tensor = *initializers.at(input_defs[1]->Name()); From 5c5d6e99ce8deac2f68167173736735a77fa53b2 Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 7 Mar 2024 09:12:12 +0800 Subject: [PATCH 123/237] Define recomputable op list with domain/opset (#19722) ### Define recomputable op list with domain/opset Originally, we just check the OpType and decide whether it is recomputable. In this PR, few improvements are made: 1. [Op type search] Domain + OpType are used to check whether the op is supported to recompute. 2. [Opset search] Then, node.SinceVersion() will be searched in the supported opsets. 3. During subgraph detection, If the node in that this opset is supported, get the ignorable input indices, which means we don't consider in the bottom-up search. This would save time for the subgraph detection. ### Motivation and Context --- onnxruntime/core/common/string_utils.h | 9 +- .../compute_optimizer/upstream_gather.cc | 25 +- .../compute_optimizer/upstream_reshape.cc | 15 +- .../upstream_transformer_base.cc | 3 +- .../upstream_transformer_base.h | 7 - .../memory_optimizer/recompute_analysis.cc | 414 +++++++++++++++--- 6 files changed, 382 insertions(+), 91 deletions(-) diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 03e94cefd0564..716eed1afec51 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -66,7 +66,14 @@ inline std::string TrimString(std::string s) { } /** - * So use this simple hash to generate unique int by given string input. + * @brief A consistent way to construct the full qualified op name. + */ +inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) { + return MakeString(domain, "::", op_type); +} + +/** + * Use this simple hash to generate unique int by given string input. */ inline uint32_t GetHashFromString(const std::string& str_value) { uint32_t hash = 0; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 9c98ed6d3e114..1516fb37a7e9f 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -26,38 +27,38 @@ UpStreamGatherGraphTransformer::UpStreamGatherGraphTransformer( // 2. Whether the outputs have the same dim changes if the Gather node moves before that operator. // 3. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction as MatMul did. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Div", kOnnxDomain), + {utils::GetFullQualifiedOpName("Div", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_13_12_10_7_6_1)}, - {GetFullQualifiedOpName("Gelu", kMSDomain), + {utils::GetFullQualifiedOpName("Gelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_9_1)}, - {GetFullQualifiedOpName("Reshape", kOnnxDomain), + {utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_19_14_13_5_1)}, - {GetFullQualifiedOpName("Softmax", kOnnxDomain), + {utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_11_1)}, - {GetFullQualifiedOpName("Transpose", kOnnxDomain), + {utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_1)}, }); @@ -69,7 +70,7 @@ bool UpStreamGatherGraphTransformer::UpStreamInternal( const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { Node& slice_node = *info.node_ptr; - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::unordered_map propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc index f7b48de2caaf5..716988e93312c 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include "core/framework/tensorprotoutils.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -21,23 +22,23 @@ UpStreamReshapeGraphTransformer::UpStreamReshapeGraphTransformer( // If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function. // 2. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig( std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_13_12_10_7_6_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_13_9_1)}, }); @@ -47,7 +48,7 @@ bool UpStreamReshapeGraphTransformer::UpStreamInternal( Graph& graph, std::deque& queue, Node& current_node, ReshapeInfo& info, const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::vector propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc index f08e37296d259..4582f26a7dc68 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc @@ -5,6 +5,7 @@ #include #include "core/common/safeint.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -130,7 +131,7 @@ template bool UpStreamGraphTransformerBase::Upstream(Graph& graph, std::deque& queue, Node& current_node, T1& info, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); if (allowed_passthrough_ops_.count(op_type)) { auto& pass_through_config = allowed_passthrough_ops_.at(op_type); LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")"); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h index 6e22fc791ade3..d848a03c555bb 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h @@ -72,13 +72,6 @@ class UpStreamGraphTransformerBase : public GraphTransformer { const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const = 0; - /** - * @brief A consistent way to construct the full qualified op name. - */ - std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) const { - return domain + "::" + op_type; - } - std::unordered_map> allowed_passthrough_ops_; private: diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 76b3325f36116..b421eb2ab32da 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -48,75 +48,352 @@ float InputOutputSizeRatio(const Node* node) { return 1.0f; } +using IgnorableInputIndices = InlinedVector; +using OpsetToIgnorableIndicesMap = InlinedHashMap; + /** - * @brief Used to define per-op recompute config. + * @brief Get the Allowed Recompute Ops object + * + * The supported op types are predefined. + * Most recent revisited for ONNX v1.15.0 release - https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/docs/Operators.md * + * We defined supported list explicitly instead of using a excluding list for the following reasons: + * 1. Some ops generate indeterministic results (for example using random number generator). We need evaluate whether + * this is a problem for recompute before adding the support, instead of fixing this after we find and try to + * fix convergence issues (which will be very hard if we have multiple indeterministic operators by default supported.) + * 2. Some ops schema will be changed in new opsets, we need also check manually whether it is applicable to recompute + * or not. + * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. */ -struct AllowedRecomputeNodeConfig { - InlinedVector input_arg_indices; // input index to iterate further (bottom up) -}; - -// The supported op types are predefined. - -const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { - static InlinedHashMap> recomputable_op_table_map; +const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { + static InlinedHashMap> recomputable_op_table_map; if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) { return recomputable_op_table_map.at(probe_op_level); } - recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); + recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level); if (probe_op_level >= static_cast(ProbeLevel::Basic)) { recomputable_op_table.insert({ - // Binary elementwise - {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Equal", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, - - // Data layout - /// The shape input is trivial whether it exists or not in backward. - {"Reshape", AllowedRecomputeNodeConfig{{0}}}, - {"Shape", AllowedRecomputeNodeConfig{{0}}}, - {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, - {"Transpose", AllowedRecomputeNodeConfig{{0}}}, - {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, - - // Unary elementwise - {"Dropout", AllowedRecomputeNodeConfig{{0}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - /// The ratio and mode input are trivial whether they exist or not in backward - {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, - /// The axis input is trivial whether it exists or not in backward - {"CumSum", AllowedRecomputeNodeConfig{{0}}}, - {"Expand", AllowedRecomputeNodeConfig{{0}}}, - {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, - {"Gelu", AllowedRecomputeNodeConfig{{0}}}, - {"QuickGelu", AllowedRecomputeNodeConfig{{0}}}, - - // Ternary elementwise - {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - - // Data copy - {"Tile", AllowedRecomputeNodeConfig{{0}}}, - {"Cast", AllowedRecomputeNodeConfig{{0}}}, - {"ConcatTraining", AllowedRecomputeNodeConfig{{0, 1}}}, // Input could be more than 2. But mostly 2. - {"Slice", AllowedRecomputeNodeConfig{{0}}}, - {"Split", AllowedRecomputeNodeConfig{{0}}}, - {"Gather", AllowedRecomputeNodeConfig{{0}}}, + { + utils::GetFullQualifiedOpName("Add", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {14, {}}, + {15, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), + { + {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Cast", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {9, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), + { + {1, {}}, + + }, + }, + { + utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), + { + {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor + {20, {0}}, + }, + }, + { + utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), + { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + {12, {1, 2}}, // ignore ratio and training_mode + {13, {1, 2}}, + }, + }, + { + utils::GetFullQualifiedOpName("Div", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), + { + {12, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Equal", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {11, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FastGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gather", kOnnxDomain), + { + {1, {1}}, // ignore the indices + {11, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), + { + {20, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Less", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Mul", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Range", kOnnxDomain), + { + {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. + }, + }, + { + utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), + { + {1, {}}, + {5, {}}, // ignore the shape. + {13, {}}, + {14, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Sin", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Slice", kOnnxDomain), + { + {1, {}}, + {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) + {11, {1, 2, 3, 4}}, + {13, {1, 2, 3, 4}}, + }, + }, + { + utils::GetFullQualifiedOpName("Split", kOnnxDomain), + { + {1, {1}}, // ignore split (optional) + {2, {}}, + {11, {}}, + {13, {1}}, // ignore the split (optional) + {18, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Sub", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Tile", kOnnxDomain), + { + {1, {1, 2}}, + {6, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), + { + {1, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), + { + {14, {1}}, // ignore k (optional) + }, + }, + { + utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Where", kOnnxDomain), + { + {9, {}}, + {16, {}}, + }, + }, + }); } if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { recomputable_op_table.insert({ - {"LayerNormalization", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Softmax", AllowedRecomputeNodeConfig{{0}}}, - {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}}, + { + utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), + { + {1, {2}}, // ignore ratio (optional) + }, + }, + { + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have LayerNormalization, + // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. + {1, {}}, + {17, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), + { + {1, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {}}, + }, + }, }); } @@ -127,8 +404,20 @@ const InlinedHashMap& GetAllowedRecompu * @brief Check whether a node is a recomputable node at given probe level. */ bool IsRecomputable(const Node& node, ProbeLevel probe_level) { - const auto& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); - return op_table.find(node.OpType()) != op_table.end(); + const InlinedHashMap& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain())); + if (it == op_table.end()) { + return false; + } + return it->second.count(node.SinceVersion()); +} + +const InlinedVector& GetIgnorableInputIndices(const Node& node, ProbeLevel probe_level) { + const InlinedHashMap& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain())); + ORT_ENFORCE(it != op_table.end(), "Cannot get ignorable indices since the node type is supported in the list."); + ORT_ENFORCE(it->second.count(node.SinceVersion()) > 0, "Cannot get ignorable indices since the opset is supported"); + return it->second.at(node.SinceVersion()); } /** @@ -163,7 +452,6 @@ Status SelectRecomputeSubgraph(const Node& entry_node, bool& can_compromise_stashed_activation, float& save_ratio) { const ProbeLevel probe_level = probe_config.probe_level; - const auto& recomputable_op_table = GetAllowedRecomputeOps(static_cast(probe_level)); can_compromise_stashed_activation = false; @@ -213,7 +501,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // If current op is NOT in allowed list: // 1). the output does not exist in backward, we cannot find a good solution for so, the search terminates. // 2). the output is used in backward, we don't need to trace back further, so continue searching. - auto op_recompute_config_it = recomputable_op_table.find(curr_node->OpType()); + bool is_recomputable = IsRecomputable(*curr_node, probe_level); auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name(); if (is_first_queue_scan) { // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of @@ -221,14 +509,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // 1. "op is not in recompute op list, but its output is used in backward" // 2. "op is in recompute op list, but its output is used in backward" // (either of the above checks is true for entry node outputs) - if (op_recompute_config_it == recomputable_op_table.end()) { + if (!is_recomputable) { early_stop = true; MO_LOG_DEBUG_INFO(logger, "Entry Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, search terminates."); break; } } else { - if (op_recompute_config_it == recomputable_op_table.end()) { + if (!is_recomputable) { if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, but its output [" + @@ -283,14 +571,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node, } // Iterate all input nodes according to allowed input arg index of the entry node. - const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices; + const auto& igorable_input_arg_indices = GetIgnorableInputIndices(*curr_node, probe_level); for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) { const Node::EdgeEnd& input_edge = *it; const auto& parent_node = input_edge.GetNode(); const auto parent_node_output_index = input_edge.GetSrcArgIndex(); const auto current_node_input_index = input_edge.GetDstArgIndex(); - if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != - input_arg_indices.end()) { + if (std::find(igorable_input_arg_indices.begin(), igorable_input_arg_indices.end(), current_node_input_index) == + igorable_input_arg_indices.end()) { // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); if (output_shape) { From bff4f8bf75562704720624fac63b149d10042ac8 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 6 Mar 2024 17:47:17 -0800 Subject: [PATCH 124/237] Update tolerance of provider tests to fix flaky tests (#19792) ### Description Check float/double/float16/bfloat16 tensors are close like [numpy.isclose](https://numpy.org/doc/stable/reference/generated/numpy.isclose.html). ``` absolute(a - b) <= (atol + rtol * absolute(b)) ``` The default tolerance thresholds: - float: atol=1e-5 and rtol=1e-4 - float16: atol=0.0025 and rtol=0.001 - bfloat16: atol=0.02 and rtol=0.01 ### Motivation and Context Current pipeline has frequent failure due to using only relative tolerance in https://github.com/microsoft/onnxruntime/pull/19608: [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 1: C:\a\_work\1\s\onnxruntime\test\providers\checkers.cc(272): error: The difference between cur_expected[i] and cur_actual[i] is 1.3113021850585938e-06, which exceeds *(params.relative_error) * std::abs(cur_expected[i]), where 1: cur_expected[i] evaluates to -1.3113021850585938e-06, 1: cur_actual[i] evaluates to 0, and 1: *(params.relative_error) * std::abs(cur_expected[i]) evaluates to 2.6226043559063328e-08. It is not reasonable to use relative tolerance for a small value very close to 0. Combining relative tolerance with a positive absolute tolerance could avoid such issue. --- .../matmul_integer_to_float_test.cc | 1 + onnxruntime/test/providers/checkers.cc | 159 +++++++++--------- 2 files changed, 83 insertions(+), 77 deletions(-) diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 6f3ca7e239671..72a5ba4dcefbf 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -127,6 +127,7 @@ void TestMatMulIntegerToFloat(bool is_matrix_b_constant, if (std::is_same_v) { test.AddOutput("Y", {M, N}, Y_data); + test.SetOutputAbsErr("Y", 0.0001f); test.SetOutputRelErr("Y", 0.02f); } else { test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 85ccb8f175f62..c97e6d9de4911 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -14,6 +14,54 @@ namespace onnxruntime { namespace test { namespace { + +template +struct DefaultTolerance; + +template <> +struct DefaultTolerance { + static constexpr float absolute = 1e-6f; + static constexpr float relative = 1e-5f; +}; + +template <> +struct DefaultTolerance { + static constexpr float absolute = 1e-5f; + static constexpr float relative = 1e-4f; +}; + +template <> +struct DefaultTolerance { + // The thresholds are estimated with PyTorch script like the following: + // x = torch.rand(1000, 1000) + // absolute = ((x + 1e-6).to(torch.float16) - x).abs().max() * 10 + // x[abs(x) < absolute] = absolute + // relative = ((x - x.to(torch.float16)) / x).abs().max() * 2 + static constexpr float absolute = 0.0025f; + static constexpr float relative = 0.001f; +}; + +template <> +struct DefaultTolerance { + static constexpr float absolute = 0.02f; + static constexpr float relative = 0.01f; +}; + +template +T get_tolerance(float absolute, float relative, T expected_value) { + static_assert(std::is_floating_point::value, "T must be a floating point type"); + + // The formula is similar to numpy.isclose: https://numpy.org/doc/stable/reference/generated/numpy.isclose.html + return static_cast(absolute) + static_cast(relative) * std::abs(expected_value); +} + +template // D is the original data type +T get_tolerance(const ValidateOutputParams& params, T expected_value) { + float absolute = (params.absolute_error.has_value() ? *(params.absolute_error) : DefaultTolerance::absolute); + float relative = (params.relative_error.has_value() ? *(params.relative_error) : DefaultTolerance::relative); + return get_tolerance(absolute, relative, expected_value); +} + template Tensor copy_sort(const Tensor& src, const AllocatorPtr& allocator) { Tensor result(src.DataType(), src.Shape(), allocator); @@ -67,7 +115,7 @@ struct TensorCheck { cur_actual = actual.Data(); } - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; } } @@ -111,7 +159,7 @@ struct TensorCheck { double threshold = has_abs_err ? *(params.absolute_error) : 0.0; - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { if (has_rel_err) { EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * cur_expected[i]) // expected[i] is unsigned, can't be negative @@ -121,7 +169,7 @@ struct TensorCheck { } } } else { - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; } } @@ -157,11 +205,11 @@ struct TensorCheck { if (has_abs_err) { double threshold = *(params.absolute_error); - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; } } else { - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { EXPECT_EQ(cur_expected[i], cur_actual[i]) << "i:" << i; } } @@ -176,8 +224,7 @@ struct TensorCheck { const std::string& /*provider_type*/) const { auto size = actual.Shape().Size(); - bool has_abs_err = params.absolute_error.has_value(); - bool has_rel_err = params.relative_error.has_value(); + const bool has_tolerance = params.absolute_error.has_value() || params.relative_error.has_value(); // deal with rare cases in which order of output data from a kernel MAY be // undefined @@ -198,7 +245,7 @@ struct TensorCheck { threshold = 0.005; #endif - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted if (std::isnan(cur_expected[i])) { @@ -206,44 +253,33 @@ struct TensorCheck { } else if (std::isinf(cur_expected[i])) { // Test infinity for equality EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; } else { - if (!has_abs_err && !has_rel_err) { - // the default for existing tests - EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; - } else { - if (has_abs_err) { - EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; - } - if (has_rel_err) { - EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) - << "i:" << i; - } - } + double tolerance = has_tolerance ? get_tolerance(params, cur_expected[i]) : threshold; + EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i; } } } }; -template +template void InternalNumericalCheck(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, const std::string& /*provider_type*/) { - const bool has_abs_err = params.absolute_error.has_value(); - const bool has_rel_err = params.relative_error.has_value(); + const bool has_tolerance = params.absolute_error.has_value() || params.relative_error.has_value(); // deal with rare cases in which order of output data from a kernel MAY be // undefined Tensor expected_sorted, actual_sorted; - const TypeToCheck* cur_expected; - const TypeToCheck* cur_actual; + const T* cur_expected; + const T* cur_actual; auto size = actual.Shape().Size(); if (params.sort_output) { - sort_expected_and_actual_buffers(expected, expected_sorted, actual, actual_sorted); - cur_expected = expected_sorted.Data(); - cur_actual = actual_sorted.Data(); + sort_expected_and_actual_buffers(expected, expected_sorted, actual, actual_sorted); + cur_expected = expected_sorted.Data(); + cur_actual = actual_sorted.Data(); } else { - cur_expected = expected.Data(); - cur_actual = actual.Data(); + cur_expected = expected.Data(); + cur_actual = actual.Data(); } #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) @@ -252,7 +288,7 @@ void InternalNumericalCheck(const Tensor& expected, constexpr float threshold = 0.0001f; #endif - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted if (std::isnan(cur_expected[i])) { @@ -260,19 +296,8 @@ void InternalNumericalCheck(const Tensor& expected, } else if (std::isinf(cur_expected[i])) { // Test infinity for equality EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; } else { - if (!has_abs_err && !has_rel_err) { - // the default for existing tests - EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; - } else { - if (has_abs_err) { - EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) - << "i:" << i; - } - if (has_rel_err) { - EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) - << "i:" << i; - } - } + T tolerance = has_tolerance ? get_tolerance(params, cur_expected[i]) : threshold; + EXPECT_NEAR(cur_expected[i], cur_actual[i], tolerance) << "i:" << i; } } } @@ -308,8 +333,7 @@ struct TensorCheck { sort_expected_and_actual_buffers(f_expected, f_actual); } - const bool has_abs_err = params.absolute_error.has_value(); - const bool has_rel_err = params.relative_error.has_value(); + const bool has_tolerance = params.absolute_error.has_value() || params.relative_error.has_value(); float threshold = 0.001f; #if defined(USE_TENSORRT) || defined(ENABLE_TRAINING_CORE) || defined(USE_CUDA) || defined(USE_ROCM) @@ -317,25 +341,14 @@ struct TensorCheck { #elif defined(USE_DML) threshold = 0.02f; #endif - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { - if (!has_abs_err && !has_rel_err) { - // the default for existing tests - EXPECT_NEAR(f_expected[i], f_actual[i], threshold) << "i:" << i; - } else { - if (has_abs_err) { - EXPECT_NEAR(f_expected[i], f_actual[i], *(params.absolute_error)) - << "i:" << i; - } - if (has_rel_err) { - EXPECT_NEAR(f_expected[i], f_actual[i], *(params.relative_error) * std::abs(static_cast(cur_expected[i]))) - << "i:" << i; - } - } + float tolerance = has_tolerance ? get_tolerance(params, f_expected[i]) : threshold; + EXPECT_NEAR(f_expected[i], f_actual[i], tolerance) << "i:" << i; } } } @@ -362,32 +375,24 @@ struct TensorCheck { sort_expected_and_actual_buffers(f_expected, f_actual); } - /// XXX: May need to adjust threshold as BFloat is coarse + const bool has_tolerance = params.absolute_error.has_value() || params.relative_error.has_value(); + float abs_threshold = 0.0001f; - float threshold = 0.001f; + float rel_threshold = 0.001f; #if defined(USE_TENSORRT) || defined(ENABLE_TRAINING_CORE) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_DNNL) - threshold = 0.05f; // expect at least 95% close + rel_threshold = 0.05f; // expect at least 95% close #endif - for (int i = 0; i < size; ++i) { + for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { - // the default for existing tests - const float max_value = fmax(fabs(f_expected[i]), fabs(f_actual[i])); - if (max_value != 0) { // max_value = 0 means output and expected are 0s. - const float abs_error = fabs(f_expected[i] - f_actual[i]); - if (abs_error <= abs_threshold) { - // if the absolute error is small enough, then no need to calculate realative error - EXPECT_NEAR(0, abs_error, abs_threshold); - } else { - // default for existing tests. - const float rel_error = abs_error / max_value; - EXPECT_NEAR(0, rel_error, threshold); - } - } + float tolerance = has_tolerance + ? get_tolerance(params, f_expected[i]) + : get_tolerance(abs_threshold, rel_threshold, f_expected[i]); + EXPECT_NEAR(f_expected[i], f_actual[i], tolerance) << "i:" << i; } } } From 72ce4de07df91b43d36d5c475a609095bde50a53 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 7 Mar 2024 18:15:18 +0000 Subject: [PATCH 125/237] cuda graph enhancement (#19636) ### Description 1. add a config key in run_options to control cuda graph in runtime. 2. enhance cuda graph class to support mutiple graph saving and retrieving in one ORT session 3. provide model modification/inference example on Phi2 4. benchmark shows an average of 13% latency reduction in token generation. limitation: TRT ep and ROCM ep hasn't applied this feature. we can revisit this in the future. ### Motivation and Context --- .../core/framework/execution_provider.h | 14 +- .../onnxruntime_run_options_config_keys.h | 7 + .../providers/cuda/cuda_execution_provider.cc | 74 ++++-- .../providers/cuda/cuda_execution_provider.h | 17 +- onnxruntime/core/providers/cuda/cuda_graph.cc | 89 +++++-- onnxruntime/core/providers/cuda/cuda_graph.h | 48 +++- .../providers/js/js_execution_provider.cc | 10 +- .../core/providers/js/js_execution_provider.h | 4 +- .../providers/rocm/rocm_execution_provider.cc | 35 +-- .../providers/rocm/rocm_execution_provider.h | 12 +- .../providers/shared_library/provider_api.h | 1 + .../shared_library/provider_interfaces.h | 3 + .../shared_library/provider_wrappedtypes.h | 8 + .../tensorrt/tensorrt_execution_provider.cc | 38 +-- .../tensorrt/tensorrt_execution_provider.h | 16 +- onnxruntime/core/session/inference_session.cc | 22 +- onnxruntime/core/session/inference_session.h | 16 +- .../core/session/provider_bridge_ort.cc | 4 + .../models/phi2/convert_to_onnx.py | 79 +++++- .../models/phi2/inference_example.py | 236 ++++++++++++++++-- .../onnxruntime_test_python_cudagraph.py | 61 ++++- onnxruntime/test/shared_lib/test_inference.cc | 149 +++++++++++ onnxruntime/test/testdata/mul_1_dynamic.onnx | Bin 0 -> 142 bytes 23 files changed, 766 insertions(+), 177 deletions(-) create mode 100644 onnxruntime/test/testdata/mul_1_dynamic.onnx diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index c1cc69edc17d8..40ca96a19aef1 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -202,21 +202,21 @@ class IExecutionProvider { /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for - the provider. Currently only CUDA execution provider supports it. + the provider. */ virtual bool IsGraphCaptureEnabled() const { return false; } /** - Indicate whether the graph has been captured and instantiated. Currently - only CUDA execution provider supports it. + Indicate whether the graph has been captured and instantiated. */ - virtual bool IsGraphCaptured() const { return false; } + virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; } /** - Run the instantiated graph. Currently only CUDA execution provider supports - it. + Run the instantiated graph. */ - virtual common::Status ReplayGraph() { return Status::OK(); } + virtual common::Status ReplayGraph(int /*graph_annotation_id*/) { + return Status::OK(); + } /** Called when session creation is complete diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index b0a17e175fef3..c80b8c0c164b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -42,3 +42,10 @@ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_ // Set RPC control latency for QNN HTP backend static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; + +// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. +// The value should be an integer. If the value is not set, the default value is 0 and +// ORT session only captures one cuda graph before another capture is requested. +// If the value is set to -1, cuda graph capture/replay is disabled in that run. +// User are not expected to set the value to 0 as it is reserved for internal use. +static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3c0930638a205..bade2faf8f2e2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include "core/common/inlined_containers.h" +#include "core/common/parse_string.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/cuda/cuda_execution_provider.h" @@ -11,6 +12,7 @@ #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/cuda_profiler.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #ifndef USE_CUDA_MINIMAL #ifndef DISABLE_CONTRIB_OPS @@ -190,27 +192,46 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() { #endif } -bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed( + CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_ && + IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); } -void CUDAExecutionProvider::PerThreadContext::CaptureBegin() { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(); +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun( + CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); } -void CUDAExecutionProvider::PerThreadContext::CaptureEnd() { - cuda_graph_.CaptureEnd(); - is_graph_captured_ = true; +CudaGraphAnnotation_t CUDAExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId( + const onnxruntime::RunOptions& run_options) const { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + // If graph annotation is not provided, fall back to the one cuda graph per session behavior + CudaGraphAnnotation_t cuda_graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(TryParseStringWithClassicLocale(*graph_annotation_str, cuda_graph_annotation_id), + "Failed to parse the cuda graph annotation id: ", + *graph_annotation_str); + } + + return cuda_graph_annotation_id; +} + +void CUDAExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureBegin(cuda_graph_annotation_id); +} + +void CUDAExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureEnd(cuda_graph_annotation_id); } -bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const { - return is_graph_captured_; +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const { + return cuda_graph_.IsGraphCaptured(graph_annotation_id); } -Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); - return cuda_graph_.Replay(); +Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) { + return cuda_graph_.Replay(graph_annotation_id); } void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { @@ -386,23 +407,26 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().CaptureBegin(); + GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id); } return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { - if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(); +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + if (GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id)); } else { GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); } @@ -433,12 +457,12 @@ bool CUDAExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_cuda_graph; } -bool CUDAExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); +bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return GetPerThreadContext().IsGraphCaptured(graph_annotation_id); } -Status CUDAExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); +Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) { + return GetPerThreadContext().ReplayGraph(graph_annotation_id); } namespace cuda { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 75fe1dff7c4a4..6c70e6abc4fdf 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -92,8 +92,8 @@ class CUDAExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override; + Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -168,11 +168,13 @@ class CUDAExecutionProvider : public IExecutionProvider { } } - bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); private: @@ -192,7 +194,6 @@ class CUDAExecutionProvider : public IExecutionProvider { // Cuda graph with multi threads will be supported in the future, so cuda_graph_ // is put under PerThreadContext. CUDAGraph cuda_graph_; - bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; // There is chance that the second regular run allocates GPU memory for causes like: diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc index 230d664391611..8353c654681fc 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.cc +++ b/onnxruntime/core/providers/cuda/cuda_graph.cc @@ -9,17 +9,44 @@ namespace onnxruntime { -CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) { +CudaGraphSet::~CudaGraphSet() { + Clear(); } -void CUDAGraph::SetStream(cudaStream_t stream) { +void CudaGraphSet::Clear() { + for (auto& it : cuda_graphs_) { + CUDA_CALL_THROW(cudaGraphExecDestroy(it.second)); + } + cuda_graphs_.clear(); +} + +bool CudaGraphSet::Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graphs_.find(cuda_graph_annotation_id) != cuda_graphs_.end(); +} + +void CudaGraphSet::Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec) { + ORT_ENFORCE(!Contains(cuda_graph_annotation_id)); + cuda_graphs_.emplace(cuda_graph_annotation_id, graph_exec); +} + +cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + ORT_ENFORCE(Contains(cuda_graph_annotation_id)); + return cuda_graphs_.at(cuda_graph_annotation_id); +} + +CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) { +} + +void CUDAGraphManager::SetStream(cudaStream_t stream) { stream_ = stream; } -void CUDAGraph::CaptureBegin() { - ORT_ENFORCE(!has_graph_exec_, - "This cuda graph has already captured a graph. " - "Create a new instance to capture a new graph."); +void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)); + + ORT_ENFORCE(!cuda_graph_set_.Contains(cuda_graph_annotation_id), + "Trying to capture a graph with annotation id ", cuda_graph_annotation_id, + " that already used. Please use a different annotation id."); CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); // For now cuda graph can only work with a single thread. In the future, we @@ -29,40 +56,48 @@ void CUDAGraph::CaptureBegin() { CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal)); } -void CUDAGraph::CaptureEnd() { - CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_)); - if (graph_ == NULL) { +void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cudaGraph_t graph = NULL; + CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph)); + if (graph == NULL) { ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL"); } - has_graph_ = true; - CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); - has_graph_exec_ = true; - CUDA_CALL_THROW(cudaGraphDestroy(graph_)); - has_graph_ = false; + cudaGraphExec_t graph_exec = NULL; + CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0)); + CUDA_CALL_THROW(cudaGraphDestroy(graph)); + + // Currently all the captured graphs will be tied to the session's lifecycle + // TODO(wy): Addd an interface to free captured graphs + cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec); } -Status CUDAGraph::Replay() { +Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) { // Although this function is not thread safe, the lock is not needed here because // CUDA EP maintains a separate cuda graph per thread - LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_; - CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_)); + LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " + << cuda_graph_annotation_id; + + cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id); + CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); return Status::OK(); } -void CUDAGraph::Reset() { - if (has_graph_) { - CUDA_CALL_THROW(cudaGraphDestroy(graph_)); - has_graph_ = false; - } - if (has_graph_exec_) { - CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_)); - has_graph_exec_ = false; - } +bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_annotation_id != kCudaGraphAnnotationSkip; +} + +bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_set_.Contains(cuda_graph_annotation_id); +} + +void CUDAGraphManager::Reset() { + cuda_graph_set_.Clear(); } -CUDAGraph::~CUDAGraph() { +CUDAGraphManager::~CUDAGraphManager() { Reset(); } diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index 9bcefcc64ea77..064994c1f14ae 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -3,33 +3,55 @@ #pragma once +#include + #include "core/common/common.h" #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_pch.h" namespace onnxruntime { -using CaptureId_t = unsigned long long; +using CudaGraphAnnotation_t = int; +using CudaGraphSet_t = std::unordered_map; + +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1; +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0; + +struct CudaGraphSet { + CudaGraphSet(){}; + ~CudaGraphSet(); -struct CUDAGraph { - CUDAGraph(){}; - CUDAGraph(cudaStream_t stream); - ~CUDAGraph(); + void Clear(); + bool Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + void Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec); + cudaGraphExec_t Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + + private: + CudaGraphSet_t cuda_graphs_; +}; + +struct CUDAGraphManager { + CUDAGraphManager(){}; + CUDAGraphManager(cudaStream_t stream); + ~CUDAGraphManager(); void SetStream(cudaStream_t stream); - void CaptureBegin(); - void CaptureEnd(); - Status Replay(); + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id); + void Reset(); - private: - cudaGraph_t graph_ = NULL; - cudaGraphExec_t graph_exec_ = NULL; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; - bool has_graph_ = false; - bool has_graph_exec_ = false; + private: + CudaGraphSet cuda_graph_set_; + CudaGraphAnnotation_t cuda_graph_annotation_id_ = kCudaGraphAnnotationDefault; cudaStream_t stream_ = nullptr; // Does not own the stream }; +using CUDAGraph = CUDAGraphManager; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 62c3981682cfc..2d2c89f36f1a7 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -757,7 +757,7 @@ JsExecutionProvider::~JsExecutionProvider() { } Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; EM_ASM({ Module.jsepCaptureBegin(); }); } @@ -765,7 +765,7 @@ Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_opti } Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); is_graph_captured_ = true; @@ -781,12 +781,12 @@ bool JsExecutionProvider::IsGraphCaptureEnabled() const { return enable_graph_capture_; } -bool JsExecutionProvider::IsGraphCaptured() const { +bool JsExecutionProvider::IsGraphCaptured(int) const { return is_graph_captured_; } -Status JsExecutionProvider::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); +Status JsExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); EM_ASM({ Module.jsepReplay(); }); return Status::OK(); } diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index b4518c67d1e60..efacf510e75df 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -63,8 +63,8 @@ class JsExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; private: bool IsGraphCaptureAllowed() const; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 4a679b790ee40..32be74550951e 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -183,23 +183,24 @@ bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; } -void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { +void ROCMExecutionProvider::PerThreadContext::CaptureBegin(int) { hip_graph_.Reset(); - hip_graph_.CaptureBegin(); + hip_graph_.CaptureBegin(0); } -void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { - hip_graph_.CaptureEnd(); +void ROCMExecutionProvider::PerThreadContext::CaptureEnd(int) { + hip_graph_.CaptureEnd(0); is_graph_captured_ = true; } -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(int) const { return is_graph_captured_; } -Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); - return hip_graph_.Replay(); +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(int graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); + + return hip_graph_.Replay(graph_annotation_id); } void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { @@ -356,20 +357,20 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; - GetPerThreadContext().CaptureBegin(); + GetPerThreadContext().CaptureBegin(0); } return Status::OK(); } Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(0)) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(); + GetPerThreadContext().CaptureEnd(0); // HIP work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(0)); } else { GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); } @@ -400,12 +401,12 @@ bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_hip_graph; } -bool ROCMExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); +bool ROCMExecutionProvider::IsGraphCaptured(int) const { + return GetPerThreadContext().IsGraphCaptured(0); } -Status ROCMExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); +Status ROCMExecutionProvider::ReplayGraph(int /*graph_annotation_id*/) { + return GetPerThreadContext().ReplayGraph(0); } namespace rocm { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index da671d9e863bb..6d6c05027e7bd 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -75,8 +75,8 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -139,10 +139,10 @@ class ROCMExecutionProvider : public IExecutionProvider { } bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); + void CaptureBegin(int graph_annotation_id); + void CaptureEnd(int graph_annotation_id); + bool IsGraphCaptured(int graph_annotation_id) const; + Status ReplayGraph(int graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); private: diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b78279040acb6..1cebe4a256fd4 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -159,6 +159,7 @@ class OpKernel; struct OpKernelContext; struct OpKernelInfo; struct PrimitiveDataTypeBase; +struct OrtRunOptions; struct Tensor; struct SparseTensor; class TensorSeq; diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f5a8327443864..0b8551e0c5a66 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -481,6 +481,9 @@ struct ProviderHost { // ConfigOptions virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; + // OrtRunOptions + virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0; + // ComputeCapability virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index dde4005c80b9d..dc2b79015d95e 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -393,6 +393,14 @@ struct ConfigOptions final { PROVIDER_DISALLOW_ALL(ConfigOptions) }; +struct OrtRunOptions final { + const ConfigOptions& GetConfigOptions() const { + return g_host->RunOptions__GetConfigOptions(this); + } + + PROVIDER_DISALLOW_ALL(OrtRunOptions) +}; + struct ComputeCapability final { static std::unique_ptr Create(std::unique_ptr t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); } static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast(p)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e521640681a77..632d521dc21a8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1633,26 +1633,26 @@ bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::CaptureBegin() { +void TensorrtExecutionProvider::CaptureBegin(int) { cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(); + cuda_graph_.CaptureBegin(0); } -void TensorrtExecutionProvider::CaptureEnd() { - cuda_graph_.CaptureEnd(); +void TensorrtExecutionProvider::CaptureEnd(int) { + cuda_graph_.CaptureEnd(0); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured(int) const { return is_graph_captured_; } -Status TensorrtExecutionProvider::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); +Status TensorrtExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); // Please note that CUDAGraph::Replay() is not thread safe. - // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), + // ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_.Replay(); + return cuda_graph_.Replay(0); } void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { @@ -3412,10 +3412,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; cuda_graph_.SetStream(stream); - CaptureBegin(); + CaptureBegin(0); } // Run TRT inference @@ -3483,12 +3483,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured()) { + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { - CaptureEnd(); + CaptureEnd(0); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph()); + ORT_RETURN_IF_ERROR(ReplayGraph(0)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -3705,10 +3705,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; cuda_graph_.SetStream(stream); - CaptureBegin(); + CaptureBegin(0); } // Run TRT inference @@ -3776,12 +3776,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured()) { + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { - CaptureEnd(); + CaptureEnd(0); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph()); + ORT_RETURN_IF_ERROR(ReplayGraph(0)); } else { IncrementRegularRunCountBeforeGraphCapture(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 339c45a8742d2..f73031eaefceb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -250,8 +250,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; private: mutable TensorrtExecutionProviderInfo info_; @@ -373,10 +373,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { void InitCUDAGraph(); void SetGraphStream(cudaStream_t stream); bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); + void CaptureBegin(int graph_annotation_id); + void CaptureEnd(int graph_annotation_id); + bool IsGraphCaptured(int graph_annotation_id) const; + Status ReplayGraph(int graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); private: @@ -540,8 +540,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs); bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); + void CaptureBegin(int graph_annotation_id); + void CaptureEnd(int graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); /** diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5fd66c459d382..684f390857d0b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2383,21 +2383,32 @@ Status InferenceSession::Run(const RunOptions& run_options, Status retval = Status::OK(); const Env& env = Env::Default(); + int graph_annotation_id = 0; + const std::string& graph_annotation_str = + run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigCudaGraphAnnotation, ""); + if (!graph_annotation_str.empty()) { + if (!TryParseStringWithClassicLocale(graph_annotation_str, graph_annotation_id)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to parse the cuda graph annotation id: ", + graph_annotation_str); + } + } + // Increment/decrement concurrent_num_runs_ and control // session threads spinning as configured. Do nothing for graph replay except the counter. const bool control_spinning = use_per_session_threads_ && force_spinning_stop_between_runs_ && - !cached_execution_provider_for_graph_replay_.IsGraphCaptured(); + !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id); auto* intra_tp = (control_spinning) ? thread_pool_.get() : nullptr; auto* inter_tp = (control_spinning) ? inter_op_thread_pool_.get() : nullptr; ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_); // Check if this Run() is simply going to be a CUDA Graph replay. - if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { + if (cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) { LOGS(*session_logger_, INFO) << "Replaying the captured " << cached_execution_provider_for_graph_replay_.Type() - << " CUDA Graph for this model with tag: " << run_options.run_tag; - ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph()); + << " CUDA Graph for this model with tag: " << run_options.run_tag + << " with graph annotation id: " << graph_annotation_id; + ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; exec_providers_to_stop.reserve(execution_providers_.NumProviders()); @@ -2559,7 +2570,8 @@ Status InferenceSession::Run(const RunOptions& run_options, // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && - !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { + cached_execution_provider_for_graph_replay_.AllowGraphCaptureOnRun(graph_annotation_id) && + !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f8211bfd2dd4e..3038c8d22ec80 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -675,7 +675,6 @@ class InferenceSession { * If we encounter an invalid request, we return an error * back to the user. */ - [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list, /*out*/ InlinedVector& arenas_to_shrink) const; @@ -867,14 +866,17 @@ class InferenceSession { return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptureEnabled(); } - bool IsGraphCaptured() const { - return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured(); + bool IsGraphCaptured(int graph_annotation_id) const { + return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured(graph_annotation_id); + } + + bool AllowGraphCaptureOnRun(int graph_annotation_id) const { + return cached_execution_provider_for_graph_replay_ != nullptr && graph_annotation_id != kGraphAnnotationSkip; } - Status ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); + Status ReplayGraph(int graph_annotation_id) { if (cached_execution_provider_for_graph_replay_) { - return cached_execution_provider_for_graph_replay_->ReplayGraph(); + return cached_execution_provider_for_graph_replay_->ReplayGraph(graph_annotation_id); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cached EP instance for graph replay is not set yet before calling ReplayGraph()"); } @@ -884,6 +886,8 @@ class InferenceSession { } IExecutionProvider* cached_execution_provider_for_graph_replay_ = nullptr; + // TODO(wy): Same as kCudaGraphAnnotationSkip in cuda_graph.h. Move to a common place. + constexpr static int kGraphAnnotationSkip = -1; }; CachedExecutionProviderForGraphReplay cached_execution_provider_for_graph_replay_; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3bec9aa146f76..d6797512d9e47 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -14,6 +14,7 @@ #include "core/framework/execution_provider.h" #include "core/framework/kernel_registry.h" #include "core/framework/provider_shutdown.h" +#include "core/framework/run_options.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/TensorSeq.h" #include "core/framework/provider_options.h" @@ -676,6 +677,9 @@ struct ProviderHostImpl : ProviderHost { return p->GetConfigEntry(config_key); } + // OrtRunOptions (wrapped) + const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) override { return p->config_options; } + // ComputeCapability (wrapped) std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index 796d6ec55ef80..8083778423241 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -13,6 +13,7 @@ import torch from benchmark_helper import Precision from fusion_options import AttentionOpType +from onnx_model import OnnxModel from transformers import AutoConfig, AutoModelForCausalLM from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer @@ -168,6 +169,58 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): quant.process() quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True) + # This function currently only works for phi2 model + def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str): + onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True)) + + from onnx import TensorProto, helper + + graph = onnx_model.graph() + new_inputs = [] + for vi in graph.input: + if "attention_mask" in vi.name: + vi_seqlen_k = helper.make_tensor_value_info( + "seqlens_k", + elem_type=TensorProto.INT32, + shape=["batch_size"], + ) + vi_total_seq_len = helper.make_tensor_value_info( + "total_sequence_length", + elem_type=TensorProto.INT32, + shape=[1], + ) + new_inputs.extend([vi_seqlen_k, vi_total_seq_len]) + else: + new_inputs.append(vi) + + graph.ClearField("input") + graph.input.extend(new_inputs) + + gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention") + gqa = gqas[0] + seqlens_path = onnx_model.match_parent_path( + gqa, + ["Cast", "Sub", "ReduceSum", "Cast"], + [5, 0, 0, 0], + ) + if seqlens_path is None: + raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.") + total_seq_len_path = onnx_model.match_parent_path( + gqa, + ["Cast", "Gather", "Shape"], + [6, 0, 0], + ) + if total_seq_len_path is None: + raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.") + onnx_model.remove_nodes(seqlens_path) + onnx_model.remove_nodes(total_seq_len_path) + + for gqa in gqas: + gqa.input[5] = "seqlens_k" + gqa.input[6] = "total_sequence_length" + + onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True) + def parse_arguments(): parser = argparse.ArgumentParser() @@ -235,6 +288,13 @@ def parse_arguments(): help="Generate int4 ONNX model for ORT VLLM", ) + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use CUDA Graph in decoding process", + ) + parser.add_argument( "--overwrite", required=False, @@ -265,6 +325,13 @@ def parse_arguments(): help="Run ORT inference example", ) + parser.add_argument( + "--run_benchmark", + required=False, + action="store_true", + help="Run ORT benchmark", + ) + parser.add_argument( "--skip_export", required=False, @@ -375,6 +442,9 @@ def run_optimize_phi2_onnx( ): converter.init_attn_type_and_precision(attention_type, precision) converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path) + if args.use_cuda_graph: + assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x + converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path) processes = [] if args.fp32_cpu: @@ -447,7 +517,7 @@ def run_optimize_phi2_onnx( [p.start() for p in processes] [p.join() for p in processes] - if args.run_example: + if args.run_example or args.run_benchmark: from inference_example import run_phi2 if args.fp16_gpu_sm8x: @@ -457,6 +527,8 @@ def run_optimize_phi2_onnx( use_buffer_share=True, device_id=args.device_id, use_step=True, + use_cuda_graph=args.use_cuda_graph, + run_benchmark=args.run_benchmark, ) if args.int4_gpu_sm8x: logging.info("Running int4_gpu_sm8x example...") @@ -465,6 +537,8 @@ def run_optimize_phi2_onnx( use_buffer_share=True, device_id=args.device_id, use_step=True, + use_cuda_graph=args.use_cuda_graph, + run_benchmark=args.run_benchmark, ) if args.fp32_gpu: logging.info("Running fp32_gpu example...") @@ -474,6 +548,7 @@ def run_optimize_phi2_onnx( device_id=args.device_id, packed_kv=True, use_fp16=False, + run_benchmark=args.run_benchmark, ) if args.fp16_gpu: logging.info("Running fp16_gpu example...") @@ -482,6 +557,7 @@ def run_optimize_phi2_onnx( use_buffer_share=False, device_id=args.device_id, packed_kv=True, + run_benchmark=args.run_benchmark, ) if args.int4_gpu: logging.info("Running int4_gpu example...") @@ -490,6 +566,7 @@ def run_optimize_phi2_onnx( use_buffer_share=False, device_id=args.device_id, packed_kv=True, + run_benchmark=args.run_benchmark, ) if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm: raise NotImplementedError("CPU/vllm inference example is not implemented yet.") diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py index 28828ffb853cb..829334b46b469 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py +++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py @@ -17,6 +17,17 @@ } +def cuda_memcpy(dst, src): + from cuda import cudart + + cudart.cudaMemcpy( + dst.data_ptr(), + src.data_ptr(), + src.element_size() * src.nelement(), + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + ) + + class ORTGenerator: def __init__(self, decoder_path): self.onnx_decoder_path = decoder_path @@ -24,13 +35,68 @@ def __init__(self, decoder_path): self.head_size = 80 self.num_layers = 32 self.max_sequence_length = 2048 + self.device_id = 0 + self.use_cuda_graph = False + self.use_traced_inputs = False + self.static_inputs_map = {} + + def append_static_inputs(self, batch_size): + # Only use this function with GQA and with use_cuda_graph=True + if batch_size in self.static_inputs_map: + return + + cpu_device = torch.device("cpu") + cuda_device = torch.device("cuda", self.device_id) + + static_io = {} + static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device) + static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device) + static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device) + static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device) + + cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size) + for i in range(self.num_layers): + cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16) + static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()}) + + static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device) + + self.static_inputs_map[batch_size] = static_io def get_initial_inputs_and_outputs(self, encodings_dict): self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) - step = torch.tensor([0], device=self.device, dtype=torch.int64) + + batch_size, sequence_length = input_ids.shape + + self.use_traced_inputs = ( + self.use_cuda_graph + and (batch_size in self.static_inputs_map) + and self.use_buffer_share + and not self.packed_kv + ) + + step = ( + torch.tensor([0], device=self.device, dtype=torch.int64) + if not self.use_traced_inputs + else self.static_inputs_map[batch_size]["step"] + ) + + seqlens_k = ( + torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32) + if not self.use_traced_inputs + else self.static_inputs_map[batch_size]["seqlens_k"] + ) + cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32)) + + total_seq_length = ( + torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32) + if not self.use_traced_inputs + else self.static_inputs_map[batch_size]["total_sequence_length"] + ) + total_seq_length[0] = sequence_length inputs = { "input_ids": input_ids.contiguous(), @@ -40,7 +106,10 @@ def get_initial_inputs_and_outputs(self, encodings_dict): if self.use_step: inputs["step"] = step.contiguous() - batch_size, sequence_length = input_ids.shape + if self.use_cuda_graph: + inputs["seqlens_k"] = seqlens_k.contiguous() + inputs["total_sequence_length"] = total_seq_length.contiguous() + del inputs["attention_mask"] past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 past_shape = ( @@ -48,11 +117,21 @@ def get_initial_inputs_and_outputs(self, encodings_dict): if self.packed_kv else (batch_size, self.num_heads, past_seq_length, self.head_size) ) - for i in range(self.num_layers): - past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) - inputs.update( - {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} - ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + + if not self.use_traced_inputs: + for i in range(self.num_layers): + past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) + inputs.update( + {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} + ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + else: + for i in range(self.num_layers): + inputs.update( + { + f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(), + f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(), + } + ) logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) outputs = {"logits": logits.contiguous()} @@ -111,12 +190,23 @@ def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: d return io_binding - def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False): + def create_session( + self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False + ): + self.device_id = device_id sess_options = ort.SessionOptions() - ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider" + sess_options.log_verbosity_level = 4 + sess_options.log_severity_level = 4 + self.use_cuda_graph = use_cuda_graph + ep = ( + ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph}) + if self.device_id >= 0 + else "CPUExecutionProvider" + ) self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) + self.ro = ort.RunOptions() - self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu") + self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu") self.use_fp16 = use_fp16 self.use_buffer_share = use_buffer_share self.packed_kv = packed_kv @@ -125,9 +215,7 @@ def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) self.tokenizer.pad_token = "[PAD]" - def generate(self, prompt, max_length): - encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) - + def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False): inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) all_token_ids = inputs["input_ids"].clone() @@ -136,13 +224,38 @@ def generate(self, prompt, max_length): current_length = sequence_length has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + if benchmark: + import time + + latency = [] + + prompt_run = True while current_length < max_length: io_binding = self.apply_io_binding(self.sess, inputs, outputs) + if benchmark: + start = time.time() + io_binding.synchronize_inputs() - self.sess.run_with_iobinding(io_binding) + if prompt_run: + if self.use_cuda_graph: + # Disable CUDA graph for the prompt run + self.ro.add_run_config_entry("gpu_graph_id", "-1") + self.sess.run_with_iobinding(io_binding, self.ro) + if self.use_cuda_graph: + # Enable CUDA graph for the decoding run + self.ro.add_run_config_entry( + "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1" + ) + prompt_run = False + else: + self.sess.run_with_iobinding(io_binding, self.ro) io_binding.synchronize_outputs() + if benchmark: + end = time.time() + latency.append(end - start) + # Sample with argmax (greedy search) next_token_logits = outputs["logits"][:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) @@ -161,16 +274,37 @@ def generate(self, prompt, max_length): # Update inputs for next inference run current_length += 1 + inputs["input_ids"] = tokens_to_add.to(torch.int32) + if self.use_traced_inputs: + cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"]) + inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"] + if self.use_step: inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) - inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to( - torch.int32 - ) + if self.use_traced_inputs: + cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"]) + inputs["step"] = self.static_inputs_map[batch_size]["step"] + + if self.use_cuda_graph: + previous_seqlens_k = inputs["seqlens_k"] + inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32) + inputs["total_sequence_length"][0] = current_length + if self.use_traced_inputs: + cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"]) + inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"] + self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0] + inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"] + else: + inputs["attention_mask"] = torch.cat( + [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1 + ).to(torch.int32) # Set logits to zeros for next inference run and re-use memory buffer if outputs["logits"].shape[1] != 1: outputs["logits"] = outputs["logits"][:, :1, :].contiguous() + if self.use_traced_inputs: + outputs["logits"] = self.static_inputs_map[batch_size]["logits"] outputs["logits"].zero_() if not self.use_buffer_share: @@ -193,11 +327,59 @@ def generate(self, prompt, max_length): {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + if benchmark: + print( + f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}" + ) + print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms") + return + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) return texts + def generate(self, prompt, max_length, cuda_graph_annotation): + encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) + + return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation) + + def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation): + batch_size, sequence_length = prompt_shape + max_length = sequence_length + token_num + + encodings_dict = {} + encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist() + encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist() + + # Warm up run + self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False) + + # Benchmark run + self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True) + + +def run_phi2( + onnx_model_path, + use_buffer_share, + device_id, + packed_kv=False, + use_fp16=True, + use_step=False, + use_cuda_graph=False, + run_benchmark=False, +): + generator = ORTGenerator(onnx_model_path) + generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph) + + def simple_run(prompt): + example_batch_size = len(prompt) + if use_cuda_graph: + generator.append_static_inputs(batch_size=example_batch_size) + texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size) + + for i in range(len(texts)): + print("Prompt: ", prompt[i]) + print("Texts: ", texts[i]) -def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False): prompt = [ '''```python def print_prime(n): @@ -206,10 +388,14 @@ def print_prime(n): """''' ] - generator = ORTGenerator(onnx_model_path) - generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step) - texts = generator.generate(prompt, max_length=200) - - for i in range(len(texts)): - print("Prompt: ", prompt[i]) - print("Texts: ", texts[i]) + if not run_benchmark: + simple_run(prompt) + + # Run simple benchmark. Time the decoder only. + if run_benchmark: + token_num = 32 + for batch_size in [1, 2, 4, 8]: + generator.append_static_inputs(batch_size) + for sequence_length in [16, 512]: + prompt_shape = (batch_size, sequence_length) + generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index c4e13e773535d..ce04dff2aecb0 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -84,6 +84,7 @@ def test_select_ep_to_run_cuda_graph(self): elif "CUDAExecutionProvider" in onnxrt.get_available_providers(): providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] self.run_model_with_cuda_graph(providers) + self.run_model_with_cuda_graph_annotation(providers) def run_model_with_cuda_graph(self, providers): INPUT_SIZE = 1280 # noqa: N806 @@ -100,13 +101,15 @@ def run_model_with_cuda_graph(self, providers): io_binding.bind_ortvalue_input("X", x_ortvalue) io_binding.bind_ortvalue_output("Y", y_ortvalue) + ro = onnxrt.RunOptions() + # One regular run for the necessary memory allocation and cuda graph capturing - session.run_with_iobinding(io_binding) + session.run_with_iobinding(io_binding, ro) expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) # After capturing, CUDA graph replay happens from this Run onwards - session.run_with_iobinding(io_binding) + session.run_with_iobinding(io_binding, ro) np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) # Update input and then replay CUDA graph @@ -116,7 +119,7 @@ def run_model_with_cuda_graph(self, providers): dtype=np.float32, ) ) - session.run_with_iobinding(io_binding) + session.run_with_iobinding(io_binding, ro) np.testing.assert_allclose( np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), y_ortvalue.numpy(), @@ -124,6 +127,58 @@ def run_model_with_cuda_graph(self, providers): atol=1e-05, ) + def run_model_with_cuda_graph_annotation(self, providers): + INPUT_SIZE = 1280 # noqa: N806 + + x_base = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] + y_base = [[0.0], [0.0], [0.0], [0.0]] + expected_y_base = [[5.0], [11.0], [17.0], [23.0]] + + x_base_mul_10 = [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0], [70.0, 80.0]] + expected_y_base_mul_10 = [[50.0], [110.0], [170.0], [230.0]] + + test_num = 4 + + x_ortvalues = [] + y_ortvalues = [] + for i in range(test_num): + x = np.array(x_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + y = np.array(y_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + x_ortvalues.append(onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0)) + y_ortvalues.append(onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0)) + + onnxrt.set_default_logger_severity(0) + session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) + io_bindings = [session.io_binding()] * test_num + ro = onnxrt.RunOptions() + + # Regular run to capture CUDA graph + for i in range(test_num): + io_bindings[i].bind_ortvalue_input("X", x_ortvalues[i]) + io_bindings[i].bind_ortvalue_output("Y", y_ortvalues[i]) + # TODO: Temporarily remove the default cuda graph capture test for the first regular run + # because it fails on a training CI. Need to investigate the root cause. + ro.add_run_config_entry("gpu_graph_id", str(i + 1)) + io_bindings[i].synchronize_inputs() + session.run_with_iobinding(io_bindings[i], ro) + io_bindings[i].synchronize_outputs() + expected_y = np.array(expected_y_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalues[i].numpy(), rtol=1e-05, atol=1e-05) + + del ro + ro = onnxrt.RunOptions() + + # After capturing, CUDA graph replay happens from this Run onwards + for i in range(test_num): + # Update input and then replay CUDA graph + x_ortvalues[i].update_inplace(np.array(x_base_mul_10[: i + 1][:] * INPUT_SIZE, dtype=np.float32)) + ro.add_run_config_entry("gpu_graph_id", str(i + 1)) + io_bindings[i].synchronize_inputs() + session.run_with_iobinding(io_bindings[i], ro) + io_bindings[i].synchronize_outputs() + expected_y = np.array(expected_y_base_mul_10[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalues[i].numpy(), rtol=1e-05, atol=1e-05) + def test_arena_with_cuda_graph(self): if "CUDAExecutionProvider" in onnxrt.get_available_providers(): # To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 8dad2c8e2d10d..453b5fdd360bf 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -180,6 +180,9 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); +#if defined(USE_CUDA) +static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx"); +#endif static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx"); #ifndef ORT_NO_RTTI static constexpr PATH_TYPE SEQUENCE_MODEL_URI = TSTR("testdata/sequence_length.onnx"); @@ -2082,6 +2085,152 @@ TEST(CApiTest, basic_cuda_graph) { #endif } +#if defined(USE_CUDA) +struct CudaGraphInputOutputData_0 { + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + std::array y_values; + std::array new_x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; + std::array new_expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; +} cg_data_0; + +struct CudaGraphInputOutputData_1 { + const std::array x_shape = {3, 1}; + std::array x_values = {1.0f, 3.0f, 5.0f}; + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 2.0f, 9.0f, 12.0f, 25.0f, 30.0f}; + + std::array y_values; + std::array new_x_values = {10.0f, 30.0f, 50.0f}; + std::array new_expected_y = {10.0f, 20.0f, 90.0f, 120.0f, 250.0f, 300.0f}; +} cg_data_1; + +struct CudaGraphInputOutputData_2 { + const std::array x_shape = {1, 2}; + std::array x_values = {1.0f, 2.0f}; + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 4.0f, 3.0f, 8.0f, 5.0f, 12.0f}; + + std::array y_values; + std::array new_x_values = {10.0f, 20.0f}; + std::array new_expected_y = {10.0f, 40.0f, 30.0f, 80.0f, 50.0f, 120.0f}; +} cg_data_2; + +template +static void RunWithCudaGraphAnnotation(T& cg_data, + Ort::Session& session, + Ort::MemoryInfo& info_mem, + Ort::MemoryAllocation& input_data, + Ort::MemoryAllocation& output_data, + const char* cuda_graph_annotation) { + (void)cudaMemcpy(input_data.get(), + cg_data.x_values.data(), + sizeof(float) * cg_data.x_values.size(), + cudaMemcpyHostToDevice); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, + reinterpret_cast(input_data.get()), + cg_data.x_values.size(), + cg_data.x_shape.data(), + cg_data.x_shape.size()); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, + reinterpret_cast(output_data.get()), + cg_data.expected_y.size(), + cg_data.expected_y_shape.data(), + cg_data.expected_y_shape.size()); + + // Create IoBinding for inputs and outputs. + Ort::IoBinding binding(session); + binding.BindInput("X", bound_x); + binding.BindOutput("Y", bound_y); + + Ort::RunOptions run_option; + if (cuda_graph_annotation != nullptr) { + run_option.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, cuda_graph_annotation); + } + + // One regular run for necessary memory allocation and graph capturing + session.Run(run_option, binding); + + // Check the values against the bound raw memory (needs copying from device to host first) + (void)cudaMemcpy(cg_data.y_values.data(), + output_data.get(), + sizeof(float) * cg_data.y_values.size(), + cudaMemcpyDeviceToHost); + ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.expected_y)); + + // Replay the captured CUDA graph + session.Run(run_option, binding); + (void)cudaMemcpy(cg_data.y_values.data(), + output_data.get(), + sizeof(float) * cg_data.y_values.size(), + cudaMemcpyDeviceToHost); + ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.expected_y)); + + // Change the input and replay the CUDA graph again. + (void)cudaMemcpy(input_data.get(), + cg_data.new_x_values.data(), + sizeof(float) * cg_data.new_x_values.size(), + cudaMemcpyHostToDevice); + binding.SynchronizeInputs(); + + session.Run(run_option, binding); + (void)cudaMemcpy(cg_data.y_values.data(), + output_data.get(), + sizeof(float) * cg_data.y_values.size(), + cudaMemcpyDeviceToHost); + ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.new_expected_y)); + + // Clean up + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); +} + +TEST(CApiTest, basic_cuda_graph_with_annotation) { + const auto& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + // Enable cuda graph in cuda provider option. + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); + std::unique_ptr + rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); + std::vector keys{"enable_cuda_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( + static_cast(session_options), + rel_cuda_options.get()) == nullptr); + + Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options); + Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + + Ort::Allocator allocator(session, info_mem); + auto allocator_info = allocator.GetInfo(); + ASSERT_TRUE(info_mem == allocator_info); + + size_t max_input_size = 6; + size_t max_output_size = 6; + + auto input_data = allocator.GetAllocation(max_input_size * sizeof(float)); + auto output_data = allocator.GetAllocation(max_output_size * sizeof(float)); + + ASSERT_NE(input_data.get(), nullptr); + ASSERT_NE(output_data.get(), nullptr); + + RunWithCudaGraphAnnotation(cg_data_0, session, info_mem, input_data, output_data, nullptr); + RunWithCudaGraphAnnotation(cg_data_1, session, info_mem, input_data, output_data, "1"); + RunWithCudaGraphAnnotation(cg_data_2, session, info_mem, input_data, output_data, "2"); +} +#endif + // The following test uses some ops not supported in the reduced ops build #ifndef REDUCED_OPS_BUILD #if defined(USE_CUDA) || defined(USE_TENSORRT) diff --git a/onnxruntime/test/testdata/mul_1_dynamic.onnx b/onnxruntime/test/testdata/mul_1_dynamic.onnx new file mode 100644 index 0000000000000000000000000000000000000000..fb7822498b0048716e701f4c23846d30ae36a6dc GIT binary patch literal 142 zcmd;J7Gg`zNX;urw5s6}b^xTQ7&8C> literal 0 HcmV?d00001 From 3dfce2f1cd9776f312f68f1cfc0d826875adcb67 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Thu, 7 Mar 2024 11:31:34 -0800 Subject: [PATCH 126/237] Fix argparser in `matmul_bnb4_quantizer` (#19812) ### Description The argparser had incorrectly used `description` and `options` instead of `help` and `choices`. ### Motivation and Context Fixes: #19751 --- .../python/tools/quantization/matmul_bnb4_quantizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py index 951746a089305..2bf47fe1680e9 100644 --- a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -199,14 +199,14 @@ def parse_args(): "--quant_type", required=False, default=1, - options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], + choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], help="Quantization data type. 0: FP4, 1: NF4", ) parser.add_argument( "--block_size", required=False, default=64, - description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", + help="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", ) parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) From 33578cc76efc19b50c9fc011215b2777de193cd1 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Thu, 7 Mar 2024 13:54:16 -0800 Subject: [PATCH 127/237] Remove memset for the case no any mask (#19823) Improved OCR model speed by 1.034 end-to-end, by eliminating unnecessary memset when no mask is present. --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index b761b1afd8529..c617533319a18 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -140,17 +140,6 @@ class AttentionCPUBase : public AttentionBase { if (mask_data != nullptr) { PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); - } else { // no any mask - const int memset_loop_len = batch_size * num_heads_; - const double memset_cost = static_cast(sequence_length) * total_sequence_length; - - ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int output_offset = static_cast(i) * sequence_length * total_sequence_length; - T* output = attention_probs + output_offset; - memset(output, 0, static_cast(sequence_length) * total_sequence_length * sizeof(T)); - } - }); } const int loop_len = batch_size * num_heads_; @@ -188,7 +177,7 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, 1.0, + Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, nullptr); if (relative_position_bias_data != nullptr) { From 296435264182e09cc37cfd981b012854226ddd2c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 7 Mar 2024 15:46:11 -0800 Subject: [PATCH 128/237] Implement IsNaN-9,13,20 for CUDA along with tests (#19807) ### Description ### Motivation and Context Some models require IsNan CUDA along with training --- docs/OperatorKernels.md | 5 +- .../providers/cpu/cpu_execution_provider.cc | 4 ++ .../core/providers/cpu/tensor/isnan.cc | 19 +++++- .../core/providers/cuda/cu_inc/common.cuh | 59 ++++++++++++++++++- .../providers/cuda/cuda_execution_provider.cc | 7 ++- .../cuda/math/unary_elementwise_ops.cc | 44 ++++++++++++++ .../cuda/math/unary_elementwise_ops.h | 6 ++ .../cuda/math/unary_elementwise_ops_impl.cu | 24 +++++++- .../cuda/math/unary_elementwise_ops_impl.h | 14 +++++ .../core/providers/rocm/cu_inc/common.cuh | 57 ++++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 6 ++ .../test/providers/cpu/tensor/isnan_test.cc | 16 ++++- 12 files changed, 252 insertions(+), 9 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 4514a85531d6b..9f5cd4cc842dc 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -162,7 +162,7 @@ Do not modify directly.* |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| |IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| @@ -633,6 +633,9 @@ Do not modify directly.* |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 7e0f919deb0a7..c3d5a51b636ef 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -714,6 +714,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, BFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -1023,6 +1024,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, BFloat16, IsNaN); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); @@ -2553,6 +2555,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo::Compute(OpKernelContext* context) const { template <> Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); - if (!X_ptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Null input ptr"); - } + auto X_data = X_ptr->Data(); auto& dims = X_ptr->Shape(); auto shape_size = dims.Size(); @@ -91,6 +91,19 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X_ptr = context->Input(0); + + auto X_data = X_ptr->DataAsSpan(); + auto& Y = *context->Output(0, X_ptr->Shape()); + + std::transform(X_data.begin(), X_data.end(), Y.MutableData(), + [](BFloat16 x) { return x.IsNaN(); }); + + return Status::OK(); +} + #if !defined(DISABLE_FLOAT8_TYPES) template <> Status IsNaN::Compute(OpKernelContext* context) const { diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index bba9178348132..bed2f677166d6 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -485,7 +485,7 @@ struct IsInfTyped { #if !defined(DISABLE_FLOAT8_TYPES) -template +template struct ReturnFalse { constexpr static bool __device__ __inline__ IsInf(T) { return false; } constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } @@ -532,6 +532,63 @@ struct _IsInf { } }; +// float and double +template +struct _IsNan { + __device__ __inline__ bool operator()(T a) const { + return isnan(a); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(half a) const { + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(BFloat16 a) const { + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FN a) const { + return (*reinterpret_cast(&a) & 0x7f) == 0x7f; + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2 a) const { + uint8_t c = *reinterpret_cast(&a); + return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); + } +}; + +template<> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +#endif + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index bade2faf8f2e2..18c7334af6611 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -746,6 +746,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad); @@ -938,7 +939,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Clip); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); @@ -1087,6 +1087,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, U class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); @@ -1368,6 +1369,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1553,6 +1555,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1979,6 +1982,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2279,6 +2283,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 00de1b37f3302..24593b255371c 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -109,6 +109,50 @@ Status IsInf::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } +// IsNan +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsNaN, + kOnnxDomain, + 9, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsNaN, + kOnnxDomain, + 13, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +ONNX_OPERATOR_KERNEL_EX( + IsNaN, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsNaN); + +Status IsNaN::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsNan(Stream(context), p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + + return Status::OK(); +} + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 3b7d6df7221b7..95d68b5e1d534 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -131,5 +131,11 @@ class IsInf final : public UnaryElementwise { int opset_; }; +class IsNaN : public UnaryElementwise { + public: + explicit IsNaN(const OpKernelInfo& info) : UnaryElementwise(info) {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 554d5908cf854..2cdfcda5be26a 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -315,13 +315,33 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, if (op_set < 20) { utils::MLTypeCallDispatcher dispatcher{input_data_type}; dispatcher.Invoke(stream, input_raw, output_data, - detect_positive, detect_negative, count); + detect_positive, detect_negative, count); } else { utils::MLTypeCallDispatcher dispatcher{input_data_type}; dispatcher.Invoke(stream, input_raw, output_data, - detect_positive, detect_negative, count); + detect_positive, detect_negative, count); } } +// IsNan + +namespace isnan_details { +template +struct IsNan_Disp { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + UnaryElementWiseImpl(stream, input_data, output_data, _IsNan{}, count); + } +}; +} // namespace isnan_details + +void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type, + const void* input_raw, bool* output_data, size_t count) { + // KernelDef constraints would ensure only subset of datatypes is used. + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, count); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index a606d479bc79b..2588f56e32c12 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -151,6 +151,20 @@ void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, int32_t input_data_type, const void* input_raw, bool* output_data, size_t count); + +// IsNan +#define ISNAN_OPSET9_FLOATS float, double, MLFloat16 +#define ISNAN_OPSET13_FLOATS float, double, MLFloat16, BFloat16 +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISNAN_OPSET20_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISNAN_OPSET20_FLOATS ISNAN_OPSET13_FLOATS +#endif + +void Explicit_Impl_IsNan(cudaStream_t stream, int32_t input_data_type, + const void* input_raw, bool* output_data, size_t count); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index f3685606c17f5..1698e5ca8478c 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -429,6 +429,63 @@ struct _IsInf { } }; +// float and double +template +struct _IsNan { + __device__ __inline__ bool operator()(T a) const { + return isnan(a); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(half a) const { + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(BFloat16 a) const { + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FN a) const { + return (*reinterpret_cast(&a) & 0x7f) == 0x7f; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2 a) const { + uint8_t c = *reinterpret_cast(&a); + return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); + } +}; + +template <> +struct _IsNan { + __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { + return *reinterpret_cast(&a) == 0x80; + } +}; + +#endif + // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef HIP_LONG diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 32be74550951e..87daaeea969ac 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -734,6 +734,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less); @@ -1067,6 +1068,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); @@ -1346,6 +1348,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, S // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1531,6 +1534,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1941,6 +1945,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2304,6 +2309,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc index 0f1e5c07cdd9b..3cf99fde2cce7 100644 --- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc @@ -38,9 +38,23 @@ TEST(IsNaNOpTest, IsNaNFloat16_9) { run_is_nan_test(9, dims, input, output); } +TEST(IsNaNOpTest, IsNaNFloat16_13) { + std::vector dims{2, 2}; + std::initializer_list input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(13, dims, input, output); +} + TEST(IsNaNOpTest, IsNaNFloat16_20) { std::vector dims{2, 2}; - std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list input = {MLFloat16::One, MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNBFloat16_20) { + std::vector dims{2, 2}; + std::initializer_list input = {BFloat16::One, BFloat16::NaN, BFloat16(2.0f), BFloat16::NaN}; std::initializer_list output = {false, true, false, true}; run_is_nan_test(20, dims, input, output); } From 6c3bed674008694847374a59c9057a640cdd40e2 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 8 Mar 2024 12:50:13 +1000 Subject: [PATCH 129/237] Run CoreML EP with NeuralNetwork and ML Program in CI unit tests (#19796) ### Description Add synthetic CoreML EP name to the list of providers so we test with NeuralNetwork and MLProgram model types. ### Motivation and Context Automatically test new MLProgram support in CI --- onnxruntime/test/providers/base_tester.cc | 11 +++++++++++ .../test/providers/coreml/coreml_basic_test.cc | 6 +++++- onnxruntime/test/util/default_providers.cc | 6 +++--- onnxruntime/test/util/include/default_providers.h | 2 +- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 84cb663a2984a..e94f8c2673be3 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -613,6 +613,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, number_of_pre_packed_weights_counter, number_of_shared_pre_packed_weights_counter); } else { + // synthetic EP name for testing CoreML EP with ML Program + constexpr const char* kCoreMLExecutionProviderMLProgram = "CoreMLExecutionProvider_MLProgram"; + #ifdef USE_TENSORRT // only run trt ep to reduce test time static const std::string all_provider_types[] = { @@ -634,10 +637,16 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, kNnapiExecutionProvider, kRocmExecutionProvider, kCoreMLExecutionProvider, + kCoreMLExecutionProviderMLProgram, kQnnExecutionProvider, kSnpeExecutionProvider, kXnnpackExecutionProvider, }; + + // need to special case any synthetic EP names in the exclude list + if (ctx_.excluded_provider_types.count(kCoreMLExecutionProvider) > 0) { + ctx_.excluded_provider_types.insert(kCoreMLExecutionProviderMLProgram); + } #endif bool has_run = false; @@ -675,6 +684,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultRocmExecutionProvider(); else if (provider_type == onnxruntime::kCoreMLExecutionProvider) execution_provider = DefaultCoreMLExecutionProvider(); + else if (provider_type == kCoreMLExecutionProviderMLProgram) + execution_provider = DefaultCoreMLExecutionProvider(/*use_mlprogram*/ true); else if (provider_type == onnxruntime::kSnpeExecutionProvider) execution_provider = DefaultSnpeExecutionProvider(); else if (provider_type == onnxruntime::kQnnExecutionProvider) diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 94817158017bd..0f068ba48d3d8 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -192,8 +192,10 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { #endif } -// Test that we fix invalid names in model inputs, initializers and outputs. +#if defined(COREML_ENABLE_MLPROGRAM) // Names in CoreML cannot start with [0-9] or contain anything but "[a-z][A-Z][0-9]_" +// Test that we fix invalid names in model inputs, initializers and outputs. +// This is only enforced for ML Program, so we only do name sanitization when creating an ML Program format model. TEST(CoreMLExecutionProviderTest, TestNameSanitization) { OpTester test("Clip", 11); @@ -212,5 +214,7 @@ TEST(CoreMLExecutionProviderTest, TestNameSanitization) { // TensorRT does not support Clip opset 11 yet. test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index b404c12db3582..c12a52c4356aa 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -223,21 +223,21 @@ std::unique_ptr DefaultRocmExecutionProvider(bool test_tunab return nullptr; } -std::unique_ptr DefaultCoreMLExecutionProvider() { +std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram) { // To manually test CoreML model generation on a non-macOS platform, comment out the `&& defined(__APPLE__)` below. // The test will create a model but execution of it will obviously fail. - // To test creating an ML Program, set the environment variable COREML_EP_TEST_MLPROGRAM to any value. #if defined(USE_COREML) && defined(__APPLE__) // We want to run UT on CPU only to get output value without losing precision uint32_t coreml_flags = 0; coreml_flags |= COREML_FLAG_USE_CPU_ONLY; - if (!Env::Default().GetEnvironmentVar("COREML_EP_TEST_MLPROGRAM").empty()) { + if (use_mlprogram) { coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; } return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); #else + ORT_UNUSED_PARAMETER(use_mlprogram); return nullptr; #endif } diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 738fc66d775c6..ae8e89c386994 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -54,7 +54,7 @@ std::unique_ptr DefaultRknpuExecutionProvider(); std::unique_ptr DefaultAclExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultRocmExecutionProvider(bool test_tunable_op = false); -std::unique_ptr DefaultCoreMLExecutionProvider(); +std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram = false); std::unique_ptr DefaultSnpeExecutionProvider(); std::unique_ptr DefaultQnnExecutionProvider(); std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options, From 24b72d26134a5b8d841588efc8dff7579241b0ce Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 7 Mar 2024 19:07:49 -0800 Subject: [PATCH 130/237] [JS/WebGPU] Preserve zero size input tensor dims. (#19737) ### Description For Concat operation, the zero-size input tensor shape need to be preserved and, unlike non-zero tensors, the dims are not constrained to match other input tensors' dims. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 146 +++++++++---------- js/web/test/data/ops/concat_zero-sized.jsonc | 80 ++++++++++ 2 files changed, 149 insertions(+), 77 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b142a82e551a7..010ee589c44fa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; } -const validateInputs = (inputs: readonly TensorView[]): void => { +const validateInputs = (inputs: readonly TensorView[], axis: number): void => { if (!inputs || inputs.length < 1) { throw new Error('too few inputs'); } - - const inputType = inputs[0].dataType; - const inputDimensionality = inputs[0].dims.length; - - for (const input of inputs) { + const referenceIndex = 0; + const referenceInput = inputs[referenceIndex]; + const inputType = referenceInput.dataType; + const inputRank = referenceInput.dims.length; + inputs.forEach((input, i) => { + if (i === referenceIndex) { + return; + } // make sure types of all inputs match if (input.dataType !== inputType) { throw new Error('input tensors should be one type'); } - // make sure the dimensionality of all inputs are the same - if (input.dims.length !== inputDimensionality) { + if (input.dims.length !== inputRank) { throw new Error('input tensors should have the same shape'); } - } + input.dims.forEach((dim, i) => { + if (i !== axis && dim !== referenceInput.dims[i]) { + throw new Error('non concat dimensions must match'); + } + }); + }); }; const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` @@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => { - const inputShape = inputs[0].dims.slice(); - if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { - throw new Error('axis specified for concat doesn\'t match input dimensionality'); - } - const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; - // ensure all of the non-concatenated axes match each other - // calculate the shape of the output tensor while we do that - const outputShape = inputShape.slice(0); - for (let i = 1; i < inputs.length; i++) { - const dataNShape = inputs[i].dims.slice(); - for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) { - // add to the placeholder for computing output shape - if (axisIndex === adjustedAxis) { - outputShape[adjustedAxis] += dataNShape[axisIndex]; +const createConcatProgramInfo = + (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { + const outputSize = ShapeUtil.size(outputShape); + + const sizeInConcatAxis = new Array(inputs.length); + const inputVars = new Array(inputs.length); + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputRanks = []; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; + inputRanks.push(inputs[i].dims.length); + inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); + inputDependencies.push('rank'); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } - // ensure all non-cancatenated axes match each other - else if (inputShape[axisIndex] !== dataNShape[axisIndex]) { - throw new Error('non concat dimensions must match'); + for (let i = 0; i < inputs.length; ++i) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); } - } - } - - const outputSize = ShapeUtil.size(outputShape); - - const sizeInConcatAxis = new Array(inputs.length); - const inputVars = new Array(inputs.length); - const dataType = inputs[0].dataType; - - let previousSum = 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = []; - const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; - for (let i = 0; i < inputs.length; ++i) { - previousSum += inputs[i].dims[adjustedAxis]; - sizeInConcatAxis[i] = previousSum; - inputRanks.push(inputs[i].dims.length); - inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); - inputDependencies.push('rank'); - programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); - } - for (let i = 0; i < inputs.length; ++i) { - programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); - } - programUniforms.push(...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(outputShape)); - const output = outputVariable('output', dataType, outputShape.length); - const indicesAxis = output.indicesGet('indices', adjustedAxis); - const sizeInConcatAxisStr = - Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const output = outputVariable('output', dataType, outputShape.length); + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${(() => { - shaderHelper.registerUniform('outputSize', 'u32'); - for (let i = 0; i < inputs.length; i++) { - shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); - } - return shaderHelper.declareVariables(...inputVars, output); - })()} + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} @@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P ${assignOutputData(inputVars, output)} }`; - return { - name: 'Concat', - shaderCache: {hint: `${axis}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms, - }), - getShaderSource, - }; -}; + return { + name: 'Concat', + shaderCache: {hint: `${adjustedAxis}`, inputDependencies}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, + }), + getShaderSource, + }; + }; export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { - validateInputs(context.inputs); + const inputs = context.inputs; + const inputShape = inputs[0].dims; + const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); + validateInputs(inputs, adjustedAxis); + const outputShape = inputShape.slice(); + outputShape[adjustedAxis] = + inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0); // 0 length tensors are valid for concat, remove them - const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); - context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); + const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute( + createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc index 7be8e8c1cc602..be9625145d157 100644 --- a/js/web/test/data/ops/concat_zero-sized.jsonc +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -557,5 +557,85 @@ ] } ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 0, + "type": "int" + } + ], + "cases": [ + { + "name": "Some but not all input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1], + "dims": [1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1; Preserve dims", + "operator": "Concat", + "attributes": [ + { + "name": "axis", + "data": 1, + "type": "int" + } + ], + "cases": [ + { + "name": "All input tensors are zero-sized", + "inputs": [ + { + "data": [], + "dims": [0, 0], + "type": "float32" + }, + { + "data": [], + "dims": [0, 1], + "type": "float32" + }, + { + "data": [], + "dims": [0, 2], + "type": "float32" + }, + { + "data": [], + "dims": [0, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [0, 6], + "type": "float32" + } + ] + } + ] } ] From 01c376a0b9ebd251d5712fa14a448335a2bde780 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 8 Mar 2024 17:52:47 +1000 Subject: [PATCH 131/237] Update script to run CIs for a branch. (#19797) ### Description - Support multiple include/exclude values. - e.g. can now run with `-i MacOS -i iOS` to run CIs for both Apple platforms. - Default to current branch if run from directory in repo. - make lazier usage possible ### Motivation and Context Improve tools. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- tools/python/run_CIs_for_branch.py | 55 +++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/tools/python/run_CIs_for_branch.py b/tools/python/run_CIs_for_branch.py index c507cae0d9f43..975ea2b988d75 100644 --- a/tools/python/run_CIs_for_branch.py +++ b/tools/python/run_CIs_for_branch.py @@ -13,13 +13,20 @@ from util.platform_helpers import is_windows +class DefaultArgsRawHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): + pass + + def _parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), - formatter_class=argparse.RawDescriptionHelpFormatter, + formatter_class=DefaultArgsRawHelpFormatter, description="""Run the CIs used to validate PRs for the specified branch. + If not specified, the branch will be inferred (if possible) by running `git branch --show-current`. + If specified, the `--include` filter is applied first, followed by any `--exclude` filter. + `--include` and `--exclude` can be specified multiple times to accumulate values to include/exclude. Requires the Azure CLI with DevOps extension to be installed. Azure CLI: https://learn.microsoft.com/en-us/cli/azure/install-azure-cli @@ -44,12 +51,30 @@ def _parse_args(): """, ) - parser.add_argument("-i", "--include", type=str, help="Include CIs that match this string. Case insensitive.") - parser.add_argument("-e", "--exclude", type=str, help="Exclude CIs that match this string. Case insensitive.") + current_branch = None + get_branch_result = subprocess.run(["git", "branch", "--show-current"], capture_output=True, text=True, check=False) + if get_branch_result.returncode == 0: + current_branch = get_branch_result.stdout.strip() + + parser.add_argument( + "-i", "--include", action="append", type=str, help="Include CIs that match this string. Case insensitive." + ) + parser.add_argument( + "-e", "--exclude", action="append", type=str, help="Exclude CIs that match this string. Case insensitive." + ) parser.add_argument("--dry-run", action="store_true", help="Print selected CIs but do not run them.") - parser.add_argument("branch", type=str, help="Specify the branch to run.") + parser.add_argument( + "branch", + type=str, + nargs="?", + default=current_branch, + help="Specify the branch to run. Default is current branch if available.", + ) args = parser.parse_args() + if not args.branch: + raise ValueError("Branch was unable to be inferred and must be specified") + return args @@ -77,25 +102,37 @@ def main(): pipelines = get_pipeline_names() pipelines_to_run = [] if args.include: - value = args.include.lower().strip() + values = [i.lower().strip() for i in args.include] for p in pipelines: - if value in p.lower(): + include = False + for value in values: + if value in p.lower(): + include = True + break + + if include: print(f"Including {p}") pipelines_to_run.append(p) else: pipelines_to_run = pipelines if args.exclude: - value = args.exclude.lower().strip() + values = [e.lower().strip() for e in args.exclude] cur_pipelines = pipelines_to_run pipelines_to_run = [] for p in cur_pipelines: - if value in p.lower(): + exclude = False + for value in values: + if value in p.lower(): + exclude = True + break + + if exclude: print(f"Excluding {p}") else: pipelines_to_run.append(p) - print("Pipelines to run:") + print(f"Pipelines to run for {args.branch}:") for p in pipelines_to_run: print(f"\t{p}") From 3170a48e60979ce1fb0d391cab7b0572bab90fff Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Fri, 8 Mar 2024 10:24:36 -0800 Subject: [PATCH 132/237] [EP Perf] Add tag to indicate which TRT parser is using (#19784) ### Description * Add tag to distinguish if TRT `builtin` or `oss` parser is being used * `oss` tag will be inserted with onnx-tensorrt commit id, to indicate which version oss parser is ### Validate DB entry before/after this PR (during test, `builtin` or `oss_{commit_id}` tag was inserted in the database entries): ### Motivation and Context To distinguish perf results using builtin/oss parser in the database, this parser tag is needed. In future, results using different parsers will be listed in different Perf Dashboard pages. --- .../python/tools/tensorrt/perf/post.py | 25 ++++++++++++++++--- ...linux-gpu-tensorrt-daily-perf-pipeline.yml | 6 ++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/tensorrt/perf/post.py b/onnxruntime/python/tools/tensorrt/perf/post.py index 0f5614bd5160f..363fa3a96d283 100644 --- a/onnxruntime/python/tools/tensorrt/perf/post.py +++ b/onnxruntime/python/tools/tensorrt/perf/post.py @@ -56,6 +56,7 @@ def parse_arguments(): parser.add_argument("-b", "--branch", help="Branch", required=True) parser.add_argument("--kusto_conn", help="Kusto connection URL", required=True) parser.add_argument("--database", help="Database name", required=True) + parser.add_argument("--use_tensorrt_oss_parser", help="Use TensorRT OSS parser", required=False) parser.add_argument( "-d", "--commit_datetime", @@ -370,7 +371,7 @@ def write_table( ingest_client.ingest_from_dataframe(table, ingestion_properties=ingestion_props) -def get_identifier(commit_datetime, commit_hash, trt_version, branch): +def get_identifier(commit_datetime, commit_hash, trt_version, branch, use_tensorrt_oss_parser): """ Returns an identifier that associates uploaded data with an ORT commit/date/branch and a TensorRT version. @@ -383,7 +384,23 @@ def get_identifier(commit_datetime, commit_hash, trt_version, branch): """ date = str(commit_datetime.date()) # extract date only - return date + "_" + commit_hash + "_" + trt_version + "_" + branch + if use_tensorrt_oss_parser: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, "../../../../..")) + deps_txt_path = os.path.join(root_dir, "cmake", "deps.txt") + commit_head = "" + with open(deps_txt_path) as file: + for line in file: + parts = line.split(";") + if parts[0] == "onnx_tensorrt": + url = parts[1] + commit = url.split("/")[-1] + commit_head = commit[:6] + break + parser = f"oss_{commit_head}" + else: + parser = "builtin" + return "_".join([date, commit_hash, trt_version, parser, branch]) def main(): @@ -396,7 +413,9 @@ def main(): # connect to database kcsb_ingest = KustoConnectionStringBuilder.with_az_cli_authentication(args.kusto_conn) ingest_client = QueuedIngestClient(kcsb_ingest) - identifier = get_identifier(args.commit_datetime, args.commit_hash, args.trt_version, args.branch) + identifier = get_identifier( + args.commit_datetime, args.commit_hash, args.trt_version, args.branch, args.use_tensorrt_oss_parser + ) upload_time = datetime.datetime.now(tz=datetime.timezone.utc).replace(microsecond=0) try: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index eaadc6ad728c0..9f3a127262bb1 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -76,6 +76,10 @@ jobs: - name: image value: ort-image-$(Build.BuildId) + + - name: parser + ${{ if eq(parameters.UseTensorrtOssParser, true) }}: + value: --use_tensorrt_oss_parser $(parameters.UseTensorrtOssParser) }} steps: - ${{ if and(eq(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: @@ -155,7 +159,7 @@ jobs: inlineScript: | short_hash=$(git rev-parse --short HEAD) && commit_date=$(git log -1 --date=iso-strict --pretty=format:%cd) && - python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/post.py -r $(Build.SourcesDirectory)/Artifact/result -c $short_hash -d $commit_date -u "$(reportUrl)?buildId=$(Build.BuildId)" -t $(trtVersion) -b $(branchName) --kusto_conn $(kustoConn) --database $(database) + python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/post.py -r $(Build.SourcesDirectory)/Artifact/result -c $short_hash -d $commit_date -u "$(reportUrl)?buildId=$(Build.BuildId)" -t $(trtVersion) -b $(branchName) --kusto_conn $(kustoConn) --database $(database) $(parser) - template: templates/component-governance-component-detection-steps.yml parameters : From 069d2d6f54f5cfa49e2ddfea4542150b88f47a55 Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Fri, 8 Mar 2024 13:58:22 -0800 Subject: [PATCH 133/237] [EP Perf] Update EP Perf dockerfiles with cuda12/cudnn9 (#19781) ### Description * Update name of existing dockerfiles and add support to test latest TensorRT EA binary located in the image * Add cuda 12.3/cuDNN 9/TensorRT 8.6 dockerfile * Add detail to CI prompts and configs Instruction to test latest TRT via BIN: 1. Select `BIN` in TensorRT Version 2. In Variables, update related tarCudaVersion, **clear** tarCudnnVersion (not required in latest TRT tar binary) , and path to binary. --- .../tools/tensorrt/perf/build/build_image.py | 37 +++---- ...linux-gpu-tensorrt-daily-perf-pipeline.yml | 17 ++-- .../Dockerfile.ubuntu_cuda12_3_tensorrt8_6 | 96 +++++++++++++++++++ .../docker/Dockerfile.ubuntu_tensorrt_bin | 93 +++++++++++++----- 4 files changed, 183 insertions(+), 60 deletions(-) create mode 100644 tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index 2ae64a72d08fe..b95ad3c0a55ef 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -14,9 +14,10 @@ from typing import List, Optional TRT_DOCKER_FILES = { - "8.4": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4", - "8.5": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5", - "8.6": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", + "8.4.cuda_11_6_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4", + "8.5.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5", + "8.6.cuda_11_8_cudnn_8": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6", + "8.6.cuda_12_3_cudnn_9": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6", "BIN": "tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin", } @@ -99,18 +100,11 @@ def docker_build_trt(args: argparse.Namespace): :param args: The arguments to this script. """ - if not is_valid_ver_str(args.trt_version, min_comps=2, max_comps=4): - print(f"[ERROR]: Invalid TensorRT version '{args.trt_version}'", file=sys.stderr) - sys.exit(1) - - vers_comps = args.trt_version.split(".") - trt_ver_key = f"{vers_comps[0]}.{vers_comps[1]}" - - if trt_ver_key not in TRT_DOCKER_FILES: + if args.trt_version not in TRT_DOCKER_FILES: print(f"[ERROR]: TensorRT version '{args.trt_version}' is currently unsupported", file=sys.stderr) sys.exit(1) - docker_file = TRT_DOCKER_FILES[trt_ver_key] + docker_file = TRT_DOCKER_FILES[args.trt_version] docker_file_path = os.path.normpath(os.path.join(args.repo_path, docker_file)) if not os.path.isfile(docker_file_path): @@ -144,11 +138,7 @@ def docker_build_trt_bin(args: argparse.Namespace): sys.exit(1) if not is_valid_ver_str(args.tar_cuda_version, 2, 2): - print("[ERROR]: Must specify a valid CUDA version for binary TensorRT installs (e.g., 11.x)", file=sys.stderr) - sys.exit(1) - - if not is_valid_ver_str(args.tar_cudnn_version, 2, 2): - print("[ERROR]: Must specify a valid cuDNN version for binary TensorRT installs (e.g., 8.x)", file=sys.stderr) + print("[ERROR]: Must specify a valid CUDA version for binary TensorRT installs (e.g., 12.4)", file=sys.stderr) sys.exit(1) if not os.path.isfile(docker_file_path): @@ -170,8 +160,6 @@ def docker_build_trt_bin(args: argparse.Namespace): "--build-arg", f"TAR_CUDA_VERSION={args.tar_cuda_version}", "--build-arg", - f"TAR_CUDNN_VERSION={args.tar_cudnn_version}", - "--build-arg", f"TRT_BINS_DIR={args.trt_bins_dir}", "-f", f"{docker_file_path}", @@ -195,7 +183,9 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("-r", "--repo_path", required=True, help="Path to the onnxruntime repository") parser.add_argument("-i", "--image_name", required=True, help="The resulting Docker image name") parser.add_argument("-b", "--branch", default="main", help="Name of the onnxruntime git branch to checkout") - parser.add_argument("-t", "--trt_version", default="8.6.1.6", help="TensorRT version (e.g., 8.6.1.6)") + parser.add_argument( + "-t", "--trt_version", default="8.6.cuda_11_8_cudnn_8", help="TensorRT version (e.g., 8.6.cuda_11_8_cudnn_8)" + ) parser.add_argument("-a", "--cuda_arch", default="75", help="CUDA architecture (e.g., 75)") # Command-line options for installing TensorRT from binaries. @@ -208,12 +198,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( "--tar_cuda_version", default="", - help="CUDA version (e.g., 11.8) used to find TensorRT EA binary tar.gz package", - ) - parser.add_argument( - "--tar_cudnn_version", - default="", - help="CUDA version (e.g., 8.6) used to find TensorRT EA binary tar.gz package", + help="CUDA version (e.g., 12.4) used to find TensorRT EA binary tar.gz package", ) parser.add_argument("--trt_bins_dir", default="", help="Directory containing TensorRT tar.gz package") parser.add_argument( diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index 9f3a127262bb1..15f558e6f9ef0 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,15 +8,16 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 8.6.1.6 + default: 8.6.cuda_11_8_cudnn_8 values: - - 8.4.1.5 - - 8.5.1.1 - - 8.6.1.6 + - 8.4.cuda_11_6_cudnn_8 + - 8.5.cuda_11_8_cudnn_8 + - 8.6.cuda_11_8_cudnn_8 + - 8.6.cuda_12_3_cudnn_9 - BIN - name: UseTensorrtOssParser - displayName: Use TensorRT-OSS Parser + displayName: Use TensorRT-OSS Parser (not compatible with BIN) type: boolean default: false @@ -86,11 +87,11 @@ jobs: - script: 'ls -al $(trtBinsDir)' displayName: 'Show available TensorRT .tar.gz packages' - - script: 'cp $(trtBinsDir)/TensorRT-$(trtVersion).Linux.x86_64-gnu.cuda-$(tarCudaVersion).cudnn$(tarCudnnVersion).tar.gz $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/' + - script: 'cp $(trtBinsDir)/TensorRT-$(trtVersion).Linux.x86_64-gnu.cuda-$(tarCudaVersion).tar.gz $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/' displayName: 'Copy TensorRT .tar.gz package into Docker build directory' - - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --tar_cudnn_version=$(tarCudnnVersion) --trt_bins_dir=.' - displayName: 'Install TensorRT from binaries and build latest ORT Image' + - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --trt_bins_dir=.' + displayName: 'Install TensorRT $(tarTrtVersion) from binaries and build latest ORT Image' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' # Build ORT with TensorRT built-in parser diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 new file mode 100644 index 0000000000000..9493480784e81 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_3_tensorrt8_6 @@ -0,0 +1,96 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with TensorRT integration + +# Build base image with required system packages +FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base + +# The local directory into which to build and install CMAKE +ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update &&\ + apt-get install -y sudo git bash unattended-upgrades wget +RUN unattended-upgrade + +# Install python3 +RUN apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + python3-wheel &&\ + cd /usr/local/bin &&\ + ln -s /usr/bin/python3 python &&\ + ln -s /usr/bin/pip3 pip; + +RUN pip install --upgrade pip +RUN pip install setuptools>=68.2.2 + +# Install cuDNN v9 +RUN apt-get -y install cudnn9-cuda-12 + +# Install TensorRT +RUN v="8.6.1.6-1+cuda12.0" &&\ + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ + apt-get update &&\ + sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ + libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ + python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} + +# Compile trtexec +RUN cd /usr/src/tensorrt/samples/trtexec && make + +# Install Valgrind +RUN apt-get install -y valgrind + +# Build final image from base. Builds ORT. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID +USER $BUILD_USER + +# ONNX Runtime arguments + +# URL to the github repo from which to clone ORT. +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime + +# The local directory into which to clone ORT. +ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code + +# The git branch of ORT to checkout and build. +ARG ONNXRUNTIME_BRANCH=main + +# Optional. The specific commit to pull and build from. If not set, the latest commit is used. +ARG ONNXRUNTIME_COMMIT_ID + +# The supported CUDA architecture +ARG CMAKE_CUDA_ARCHITECTURES=75 + +WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR} + +# Clone ORT repository with branch +RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ + /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh + +WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime + +# Reset to a specific commit if specified by build args. +RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIME_BRANCH}" ;\ + else echo "Building branch ${ONNXRUNTIME_BRANCH} @ commit ${ONNXRUNTIME_COMMIT_ID}" &&\ + git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi + +# Build ORT +ENV CUDA_MODULE_LOADING "LAZY" +ARG PARSER_CONFIG="" +RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' + +# Switch to root to continue following steps of CI +USER root + +# Intall ORT wheel +RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index 21b09b2d8978e..a26bf88fbbdf6 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -4,29 +4,15 @@ # -------------------------------------------------------------- # Dockerfile to run ONNXRuntime with TensorRT installed from provided binaries -FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 +# Build base image with required system packages +FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 AS base +# The local directory into which to build and install CMAKE +ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code -# ONNX Runtime Variables -ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime -ARG ONNXRUNTIME_BRANCH=main -ARG CMAKE_CUDA_ARCHITECTURES=37;50;52;60;61;70;75;80 - -# Must provide version numbers used to build the name of the tar file containing TensorRT binaries. -# See: https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-tar -ARG TAR_TRT_VERSION -ARG TAR_CUDA_VERSION -ARG TAR_CUDNN_VERSION - -# Directory containing TensorRT tar.gz installation package -ARG TRT_BINS_DIR=. - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} - +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive -COPY ${TRT_BINS_DIR}/TensorRT-${TAR_TRT_VERSION}.Linux.x86_64-gnu.cuda-${TAR_CUDA_VERSION}.cudnn${TAR_CUDNN_VERSION}.tar.gz /TensorRT-${TAR_TRT_VERSION}.tar.gz - RUN apt-get update &&\ apt-get install -y sudo git bash unattended-upgrades wget RUN unattended-upgrade @@ -44,22 +30,77 @@ RUN apt-get install -y --no-install-recommends \ RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 +# Install cuDNN v9 +RUN apt-get -y install cudnn9-cuda-12 + +# Install TensorRT +# Must provide version numbers used to build the name of the tar file containing TensorRT binaries. +# See: https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing-tar +ARG TAR_TRT_VERSION +ARG TAR_CUDA_VERSION + +# Directory containing TensorRT tar.gz installation package +ARG TRT_BINS_DIR=. +COPY ${TRT_BINS_DIR}/TensorRT-${TAR_TRT_VERSION}.Linux.x86_64-gnu.cuda-${TAR_CUDA_VERSION}.tar.gz /TensorRT-${TAR_TRT_VERSION}.tar.gz + # Install TensorRT from tar.gz RUN tar -xzvf /TensorRT-${TAR_TRT_VERSION}.tar.gz RUN cd /TensorRT-${TAR_TRT_VERSION}/python &&\ - python3 -m pip install tensorrt-${TAR_TRT_VERSION}-cp38-none-linux_x86_64.whl + python3 -m pip install tensorrt*cp38*.whl RUN cp -r /TensorRT-${TAR_TRT_VERSION}/lib/* /usr/lib/x86_64-linux-gnu/ RUN cp /TensorRT-${TAR_TRT_VERSION}/include/* /usr/local/include/ RUN cp /TensorRT-${TAR_TRT_VERSION}/bin/* /usr/local/bin/ -WORKDIR /code +# Install Valgrind +RUN apt-get install -y valgrind + +# Build final image from base. Builds ORT. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --gecos 'onnxruntime Build User' --disabled-password $BUILD_USER --uid $BUILD_UID +USER $BUILD_USER + +# ONNX Runtime arguments + +# URL to the github repo from which to clone ORT. +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime + +# The local directory into which to clone ORT. +ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code + +# The git branch of ORT to checkout and build. +ARG ONNXRUNTIME_BRANCH=main + +# Optional. The specific commit to pull and build from. If not set, the latest commit is used. +ARG ONNXRUNTIME_COMMIT_ID + +# The supported CUDA architecture +ARG CMAKE_CUDA_ARCHITECTURES=75 # Prepare onnxruntime repository & build onnxruntime with TensorRT +WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR} + +# Clone ORT repository with branch RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ - /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ - /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' &&\ - pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ - cd .. + /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh + +WORKDIR ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime + +# Reset to a specific commit if specified by build args. +RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIME_BRANCH}" ;\ + else echo "Building branch ${ONNXRUNTIME_BRANCH} @ commit ${ONNXRUNTIME_COMMIT_ID}" &&\ + git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi + +# Build ORT +ENV CUDA_MODULE_LOADING "LAZY" +ARG PARSER_CONFIG="" +RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' + +# Switch to root to continue following steps of CI +USER root + +# Intall ORT wheel +RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl \ No newline at end of file From 7deee944c0daa9950167f6ac399c52c00c907924 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 8 Mar 2024 15:02:58 -0800 Subject: [PATCH 134/237] Implement STFT Decomposition transformer (#19725) Implement STFT Decomposition transformer. Certain hardware does not support DXIL, and therefore existing operator should be mapped to hardware supported functions. Optimized convolution can be used to implement STFT. --------- Co-authored-by: Sheil Kumar --- .../core/optimizer/stft_decomposition.cc | 381 ++++++++++++ .../core/optimizer/stft_decomposition.h | 30 + onnxruntime/core/providers/cpu/signal/dft.cc | 2 +- .../src/ExecutionProvider.cpp | 10 +- .../src/Operators/GeneratedShaders/stockham.h | 588 +++++++++--------- .../GeneratedShaders/stockham_fp16.h | 257 ++++---- .../src/Operators/Shaders/stockham.hlsl | 21 +- onnxruntime/core/session/inference_session.cc | 9 + 8 files changed, 864 insertions(+), 434 deletions(-) create mode 100644 onnxruntime/core/optimizer/stft_decomposition.cc create mode 100644 onnxruntime/core/optimizer/stft_decomposition.h diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc new file mode 100644 index 0000000000000..a54904ff15e1e --- /dev/null +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/optimizer/stft_decomposition.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/optimizer_execution_frame.h" +#include "core/optimizer/utils.h" +#include "core/framework/op_kernel.h" +#include "core/framework/tensorprotoutils.h" + +using namespace onnxruntime::common; + +namespace onnxruntime { + +STFTDecomposition::STFTDecomposition(const InlinedHashSet& compatible_execution_providers) noexcept + : GraphTransformer("STFTDecomposition", compatible_execution_providers) { +} + +template +constexpr static ONNX_NAMESPACE::TensorProto_DataType GetDataType() { + if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + } else if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + } else if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; + } else if constexpr (std::is_same::value) { + return ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else { + throw std::logic_error("Invalid data type requested for STFT decomposition"); + } +} + +template +NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[TDims], const TDataType* begin) { + ONNX_NAMESPACE::TensorProto proto; + proto.set_name(graph.GenerateNodeArgName(name)); + proto.set_data_type(GetDataType()); + int64_t element_count = 1; + for (size_t i = 0; i < TDims; i++) { + element_count *= shape[i]; + proto.add_dims(shape[i]); + } + proto.set_raw_data(begin, element_count * sizeof(TDataType)); + return &graph_utils::AddInitializer(graph, proto); +} + +template +NodeArg* AddShapeInitializer(Graph& graph, const char* name, const int64_t (&shape)[TDims]) { + int64_t shape_shape[] = {TDims}; + return AddInitializer(graph, name, shape_shape, shape); +} + +std::pair AddNode(Graph& graph, + const char* op_type, + ProviderType execution_provider_type, + gsl::span inputs) { + auto def_name = graph.GenerateNodeArgName(op_type); + auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr); + Node& node = graph.AddNode(graph.GenerateNodeName(op_type), + op_type, + "", + inputs, + {node_arg}); + node.SetExecutionProviderType(execution_provider_type); + return std::make_pair(&node, node_arg); +} + +std::pair AddNodeCast(Graph& graph, NodeArg* in, + ONNX_NAMESPACE::TensorProto_DataType data_type) { + auto def_name = graph.GenerateNodeArgName("Cast"); + auto node_arg = &graph.GetOrCreateNodeArg(def_name, nullptr); + Node& node = graph.AddNode(graph.GenerateNodeName("Cast"), + "Cast", + "", + {in}, + {node_arg}); + node.AddAttribute("to", static_cast(data_type)); + node.SetExecutionProviderType(kCpuExecutionProvider); + return std::make_pair(&node, node_arg); +} + +#define CONTINUE_IF_NO_DIM_VALUE(dim) \ + if (!dim.has_dim_value()) { \ + continue; \ + } +#define CONTINUE_IF_NULL(x) \ + if (x == nullptr) { \ + continue; \ + } + +/* + This function decomposes a STFT node into a subgraph. + The decomposition requires that: + 1) The signal input is real valued and not complex valued! + 2) Both (frame_step) *and* either (window or frame_length) inputs must be constant. + Otherwise the transform will not be applied. + + Subgraph pattern 1: STFT with optional Window parameter set + [root]--(signal)--------------------+ + [root]--(frame_step)---------------+| + [root]--(window)------------------+|| + [root]--(frame_length) ----------+||| + |||| + vvvv + [STFT]--(output)--> + After Fusion: + [root]--(signal)-------------------------+ + [root] | + [root]--(window)--+ | + [root] | | + v v + (only for non-fp32) [Cast] +--[Reshape] + | | | + v | v + [Reshape]-->[Mul]---|-->[Conv]-------+ + | | | + | +-----| | + | v v + +------>[Mul]------>[Conv]-->[Concat]-->[Reshape]-->[Transpose]--(output)--> + + + Subgraph pattern 2: STFT without optional Window parameter set + [root]--(signal)-------------------+ + [root]--(frame_step)--------------+| + [root] | + [root]--(frame_length) ----------+|| + ||| + vvv + [STFT]--(output)--> + After Fusion: + [root]--(signal)-->[Reshape]-->[Conv] + [root] | | + [root] | v + [root] +------>[Conv]-->[Concat]-->[Reshape]-->[Transpose]--(output)--> +*/ +Status STFTDecomposition::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + auto& order = graph_viewer.GetNodesInTopologicalOrder(); + + for (NodeIndex i : order) { + auto node = graph.GetNode(i); + CONTINUE_IF_NULL(node); + ORT_RETURN_IF_ERROR(Recurse(*node, modified, graph_level, logger)); + + if (node->OpType() != "STFT") { + continue; + } + + Node& stft = *node; + auto signal = stft.MutableInputDefs()[0]; + auto frame_step = stft.MutableInputDefs()[1]; + auto window = stft.MutableInputDefs()[2]; + auto frame_length = stft.MutableInputDefs()[3]; + + // If the signal has free dimensions, do not transform... + auto batch_size_dim = signal->Shape()->dim(0); + auto signal_length_dim = signal->Shape()->dim(1); + auto signal_components_dim = signal->Shape()->dim(2); + CONTINUE_IF_NO_DIM_VALUE(signal_length_dim); + CONTINUE_IF_NO_DIM_VALUE(signal_components_dim); + + auto batch_size = batch_size_dim.has_dim_value() ? batch_size_dim.dim_value() : static_cast(-1); + auto signal_length = signal_length_dim.dim_value(); + auto is_real = signal_components_dim.dim_value() == 1; + auto data_type = static_cast(signal->TypeAsProto()->tensor_type().elem_type()); + + auto frame_step_initializer = graph_utils::GetConstantInitializer(graph, frame_step->Name()); + auto window_initializer = graph_utils::GetConstantInitializer(graph, window->Name()); + auto frame_length_initializer = graph_utils::GetConstantInitializer(graph, frame_length->Name()); + CONTINUE_IF_NULL(frame_step_initializer); + if (!frame_length_initializer && !window_initializer) { + continue; + } + + auto read_int64_initializer = [](Graph& graph, const ONNX_NAMESPACE::TensorProto* initializer) { + return *Initializer(*initializer, graph.ModelPath()).data(); + }; + auto frame_step_value = read_int64_initializer(graph, frame_step_initializer); + + // Get DFT Size + int64_t dft_size = 0; + if (frame_length_initializer) { + dft_size = read_int64_initializer(graph, frame_length_initializer); + } + if (dft_size == 0 && window_initializer) { + auto window_length_dim = window->Shape()->dim(0); + CONTINUE_IF_NO_DIM_VALUE(window_length_dim); + dft_size = window_length_dim.dim_value(); + } + + bool is_onesided = true; + auto& attrs = stft.GetAttributes(); + if (attrs.find("onesided") != attrs.end()) { + auto& onesided_attr = attrs.at("onesided"); + if (utils::HasInt(onesided_attr)) { + is_onesided = static_cast(onesided_attr.i()); + } + } + + auto dft_unique_bins = is_onesided ? ((dft_size >> 1) + 1) : dft_size; + + Node* signal_recipient = nullptr; + Node* window_recipient = nullptr; + Node* stft_producer = nullptr; + if (is_real) { + auto output_num_frames = stft.MutableOutputDefs()[0]->Shape()->dim(1).dim_value(); + auto output_frame_length = stft.MutableOutputDefs()[0]->Shape()->dim(2).dim_value(); + auto weight_size = static_cast(dft_unique_bins * dft_size); + auto real_weights_data = std::vector(weight_size); + auto imag_weights_data = std::vector(weight_size); + + // Populate weights + for (size_t k = 0; k < static_cast(dft_unique_bins); k++) { + for (size_t n = 0; n < static_cast(dft_size); n++) { + auto index = static_cast(k * dft_size + n); + auto theta = -2 * M_PI * k * n / static_cast(dft_size); + real_weights_data[index] = static_cast(cos(theta)); + imag_weights_data[index] = static_cast(sin(theta)); + } + } + + const int64_t weight_shape[] = {dft_unique_bins, 1, 1, dft_size}; + auto real_weights = AddInitializer(graph, "stft_real_conv_weights", weight_shape, real_weights_data.data()); + auto imaginary_weights = AddInitializer(graph, "stft_imaginary_conv_weights", weight_shape, imag_weights_data.data()); + + const int64_t signal_reshaped[] = {batch_size, 1, 1, signal_length}; + auto signal_shape = AddShapeInitializer(graph, "stft_signal_shape", signal_reshaped); + + const int64_t unsqueezed_output_shape[] = {2, batch_size, output_frame_length, output_num_frames}; + auto unsqueezed_shape = AddShapeInitializer(graph, "stft_output_reshaped", unsqueezed_output_shape); + + NodeArg* signal_reshaped_inputs[] = {signal, signal_shape}; + Node* reshape_signal_node = nullptr; + NodeArg* reshape_output = nullptr; + std::tie(reshape_signal_node, reshape_output) = + AddNode(graph, "Reshape", stft.GetExecutionProviderType(), signal_reshaped_inputs); + + NodeArg* real_weights_final = real_weights; + NodeArg* imag_weights_final = imaginary_weights; + if (!window->Exists()) { + // When we are missing a window function + if (real_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) { + std::tie(std::ignore, real_weights_final) = + AddNodeCast(graph, real_weights_final, data_type); + } + if (imag_weights_final->TypeAsProto()->tensor_type().elem_type() != data_type) { + std::tie(std::ignore, imag_weights_final) = + AddNodeCast(graph, imag_weights_final, data_type); + } + } else { + // When we have a window function + const int64_t window_reshaped_shape[] = {1, 1, 1, dft_size}; + auto window_shape = AddShapeInitializer(graph, "stft_window_shape", window_reshaped_shape); + + auto window_final = window; + if (window->TypeAsProto()->tensor_type().elem_type() != GetDataType()) { + Node* window_cast_node = nullptr; + std::tie(window_cast_node, window_final) = + AddNodeCast(graph, window, GetDataType()); + window_recipient = window_cast_node; + } + + NodeArg* window_reshaped_inputs[] = {window_final, window_shape}; + Node* window_reshape_node; + NodeArg* window_reshaped = nullptr; + std::tie(window_reshape_node, window_reshaped) = + AddNode(graph, "Reshape", kCpuExecutionProvider, window_reshaped_inputs); + if (!window_recipient) { + window_recipient = window_reshape_node; + } + + NodeArg* scale_real_weights_inputs[] = {real_weights, window_reshaped}; + NodeArg* windowed_real_weights_output = nullptr; + std::tie(std::ignore, windowed_real_weights_output) = + AddNode(graph, "Mul", kCpuExecutionProvider, scale_real_weights_inputs); + + NodeArg* scale_imag_weights_inputs[] = {imaginary_weights, window_reshaped}; + NodeArg* windowed_imag_weights_output = nullptr; + std::tie(std::ignore, windowed_imag_weights_output) = + AddNode(graph, "Mul", kCpuExecutionProvider, scale_imag_weights_inputs); + + std::tie(std::ignore, real_weights_final) = + AddNodeCast(graph, windowed_real_weights_output, data_type); + std::tie(std::ignore, imag_weights_final) = + AddNodeCast(graph, windowed_imag_weights_output, data_type); + } + + // Add Convolution (reals) + NodeArg* conv_real_inputs[] = {reshape_output, real_weights_final}; + Node* real_conv_node = nullptr; + NodeArg* real_conv_output = nullptr; + std::tie(real_conv_node, real_conv_output) = + AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_real_inputs); + real_conv_node->AddAttribute("strides", std::vector{1, frame_step_value}); + + // Add Convolution (imaginary) + NodeArg* conv_imag_inputs[] = {reshape_output, imag_weights_final}; + Node* imag_conv_node = nullptr; + NodeArg* imag_conv_output = nullptr; + std::tie(imag_conv_node, imag_conv_output) = + AddNode(graph, "Conv", stft.GetExecutionProviderType(), conv_imag_inputs); + imag_conv_node->AddAttribute("strides", std::vector{1, frame_step_value}); + + // Concatenate + NodeArg* concatenate_inputs[] = {real_conv_output, imag_conv_output}; + Node* concat_node = nullptr; + NodeArg* concatenated_conv_output = nullptr; + std::tie(concat_node, concatenated_conv_output) = + AddNode(graph, "Concat", stft.GetExecutionProviderType(), concatenate_inputs); + concat_node->AddAttribute("axis", static_cast(0)); + + // Unsqueeze Reshape + NodeArg* unsqueeze_reshape_inputs[] = {concatenated_conv_output, unsqueezed_shape}; + NodeArg* unsqueezed_output = nullptr; + std::tie(std::ignore, unsqueezed_output) = + AddNode(graph, "Reshape", stft.GetExecutionProviderType(), unsqueeze_reshape_inputs); + + // Transpose + NodeArg* transpose_inputs[] = {unsqueezed_output}; + Node* transpose_node = nullptr; + NodeArg* transpose_output = nullptr; + std::tie(transpose_node, transpose_output) = + AddNode(graph, "Transpose", stft.GetExecutionProviderType(), transpose_inputs); + transpose_node->AddAttribute("perm", std::vector{1, 3, 2, 0}); + + signal_recipient = reshape_signal_node; + stft_producer = transpose_node; + } else { + continue; + } + + auto input_edges = graph_utils::GraphEdge::GetNodeInputEdges(stft); + auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(stft); + + // Copy inputs + auto signal_target_idx = signal_recipient->Index(); + auto window_target_idx = window_recipient->Index(); + for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) { + const graph_utils::GraphEdge& edge = *cur; + NodeIndex target_idx = 0; + Node* recipient = nullptr; + switch (cur->dst_arg_index) { + case 0: + target_idx = signal_target_idx; + recipient = signal_recipient; + break; + case 2: + target_idx = window_target_idx; + recipient = window_recipient; + break; + } + + if (!recipient) { + continue; + } + + auto arg_index = graph_utils::GetNodeInputIndexFromInputName(*recipient, edge.arg_name); + graph.AddEdge(edge.src_node, target_idx, edge.src_arg_index, arg_index); + } + + // Copy STFT outputs to stft_producer + stft_producer->MutableOutputDefs() = stft.MutableOutputDefs(); + auto stft_producer_target_idx = stft_producer->Index(); + for (auto cur = output_edges.cbegin(), end = output_edges.cend(); cur != end; ++cur) { + graph.AddEdge(stft_producer_target_idx, cur->dst_node, cur->src_arg_index, cur->dst_arg_index); + } + + graph_utils::GraphEdge::RemoveGraphEdges(graph, input_edges); + graph_utils::GraphEdge::RemoveGraphEdges(graph, output_edges); + graph.RemoveNode(stft.Index()); + + modified = true; + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/stft_decomposition.h b/onnxruntime/core/optimizer/stft_decomposition.h new file mode 100644 index 0000000000000..cac058474375e --- /dev/null +++ b/onnxruntime/core/optimizer/stft_decomposition.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" +#include "core/framework/ort_value.h" +#include +#include "core/framework/execution_provider.h" + +namespace onnxruntime { + +/** +@class STFTDecomposition + +Transformer that traverses the graph top-down and decomposes +STFT into convolution. +*/ +class STFTDecomposition : public GraphTransformer { + public: + /*! STFT decomposition . + \param execution_provider Execution provider instance to execute constant folding. + */ + STFTDecomposition(const InlinedHashSet& compatible_execution_providers = {}) noexcept; + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 15bf633579e5f..50fe7d1344eaf 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -506,7 +506,7 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside // Calculate the window size with preference to the window input. const auto window_size = window ? window->Shape()[0] : frame_length; - ORT_ENFORCE(window_size < signal_size, "Ensure that the dft size is smaller than the signal."); + ORT_ENFORCE(window_size <= signal_size, "Ensure that the dft size is smaller than the signal."); // Calculate the number of dfts to run const auto n_dfts = diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 8a32d06534dda..6c347ebdca7c1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -771,8 +771,14 @@ namespace Dml !native16BitShaderOpsSupported && IsCustomOpShader(node)) { - nodeContainsSupportedDataTypes = false; - return; + // STFT is a special case since it has a dml ep registered + // graph transformation that will decompose fp16 STFT into convolution + // and so it is OK to register for fp16. + if (strcmp("STFT", node.OpType().c_str()) != 0) + { + nodeContainsSupportedDataTypes = false; + return; + } } // Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h index 9c03b7f6de639..1bfd6e6c6068d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham.h @@ -21,7 +21,7 @@ dcl_uav_structured u0, 4 dcl_uav_structured u1, 4 dcl_uav_structured u2, 4 dcl_input vThreadID.x -dcl_temps 6 +dcl_temps 5 dcl_thread_group 64, 1, 1 iadd r0.x, vThreadID.x, cb0[0].x ult r0.y, r0.x, cb0[0].y @@ -40,66 +40,57 @@ if_nz r0.y ieq r1.y, cb0[7].x, l(1) ult r1.z, r0.w, cb0[5].z and r1.z, r1.z, r1.y - if_nz r1.z - imul null, r1.z, r0.w, cb0[6].z - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.z, l(0), u2.xxxx - imad r1.z, r0.w, cb0[6].z, cb0[6].w - ieq r1.w, cb0[5].w, l(2) - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.z, r1.z, l(0), u2.xxxx - and r4.y, r1.z, r1.w + imul null, r1.w, r0.w, cb0[6].z + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.w, l(0), u2.xxxx + ieq r1.w, cb0[5].w, l(2) + if_nz r1.w + imad r2.y, r0.w, cb0[6].z, cb0[6].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r2.y, l(0), u2.xxxx else - mov r4.xy, l(1.000000,0,0,0) + mov r4.y, l(0) endif + movc r2.yz, r1.zzzz, r4.yyxy, l(0,0,1.000000,0) ult r1.z, r0.w, cb0[1].y - if_nz r1.z - imul null, r0.w, r0.w, cb0[2].y - imad r0.w, r1.x, cb0[2].x, r0.w - imad r0.w, r3.x, cb0[2].z, r0.w - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r5.x, r0.w, l(0), u0.xxxx - ieq r1.z, cb0[1].w, l(2) - if_nz r1.z - iadd r0.w, r0.w, cb0[2].w - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r5.y, r0.w, l(0), u0.xxxx - else - mov r5.y, l(0) - endif + imul null, r1.x, r1.x, cb0[2].x + imad r0.w, r0.w, cb0[2].y, r1.x + imad r0.w, r3.x, cb0[2].z, r0.w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r0.w, l(0), u0.xxxx + ieq r2.w, cb0[1].w, l(2) + if_nz r2.w + iadd r0.w, r0.w, cb0[2].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r0.w, l(0), u0.xxxx else - mov r5.xy, l(0,0,0,0) + mov r4.y, l(0) endif - mul r0.w, r4.y, r5.y - mad r0.w, r5.x, r4.x, -r0.w - dp2 r1.z, r5.yxyy, r4.xyxx - ult r1.w, r0.y, cb0[5].z - and r1.y, r1.w, r1.y - if_nz r1.y - imul null, r1.y, r0.y, cb0[6].z - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r1.y, l(0), u2.xxxx - imad r1.y, r0.y, cb0[6].z, cb0[6].w - ieq r1.w, cb0[5].w, l(2) - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.y, r1.y, l(0), u2.xxxx - and r4.y, r1.y, r1.w + and r3.yz, r1.zzzz, r4.xxyx + mul r0.w, r2.y, r3.z + mad r0.w, r3.y, r2.z, -r0.w + dp2 r1.z, r3.yzyy, r2.yzyy + ult r2.y, r0.y, cb0[5].z + and r1.y, r1.y, r2.y + imul null, r2.y, r0.y, cb0[6].z + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.x, r2.y, l(0), u2.xxxx + if_nz r1.w + imad r1.w, r0.y, cb0[6].z, cb0[6].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r4.y, r1.w, l(0), u2.xxxx else - mov r4.xy, l(1.000000,0,0,0) + mov r4.y, l(0) endif - ult r1.y, r0.y, cb0[1].y - if_nz r1.y - imul null, r0.y, r0.y, cb0[2].y - imad r0.y, r1.x, cb0[2].x, r0.y - imad r0.y, r3.x, cb0[2].z, r0.y - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.x, r0.y, l(0), u0.xxxx - ieq r1.w, cb0[1].w, l(2) - if_nz r1.w - iadd r0.y, r0.y, cb0[2].w - ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r1.y, r0.y, l(0), u0.xxxx - else - mov r1.y, l(0) - endif + movc r1.yw, r1.yyyy, r4.yyyx, l(0,0,0,1.000000) + ult r2.y, r0.y, cb0[1].y + imad r0.y, r0.y, cb0[2].y, r1.x + imad r0.y, r3.x, cb0[2].z, r0.y + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r3.x, r0.y, l(0), u0.xxxx + if_nz r2.w + iadd r0.y, r0.y, cb0[2].w + ld_structured_indexable(structured_buffer, stride=4)(mixed,mixed,mixed,mixed) r3.y, r0.y, l(0), u0.xxxx else - mov r1.xy, l(0,0,0,0) + mov r3.y, l(0) endif - mul r0.y, r4.y, r1.y - mad r0.y, r1.x, r4.x, -r0.y - dp2 r1.x, r1.yxyy, r4.xyxx + and r2.yz, r2.yyyy, r3.xxyx + mul r0.y, r1.y, r2.z + mad r0.y, r2.y, r1.w, -r0.y + dp2 r1.x, r2.yzyy, r1.ywyy udiv null, r1.y, r2.x, r0.z ieq r1.w, cb0[0].w, l(1) movc r1.w, r1.w, l(6.283185), l(-6.283185) @@ -117,17 +108,22 @@ if_nz r0.y mad r0.y, r3.x, r1.x, r0.y add r0.y, r0.y, r1.z mul r0.yw, r0.yyyw, cb0[7].zzzz - ne r1.x, cb0[7].y, l(0.000000) - mul r1.y, r1.y, r1.y - mul r1.y, r1.y, l(3.141593) - div r1.y, r1.y, cb0[7].y - sincos r2.x, r3.x, r1.y - mov r2.y, r3.x - movc r1.xy, r1.xxxx, r2.xyxx, l(0,1.000000,0,0) - mul r1.zw, r0.yyyy, r1.xxxy - mad r0.y, r0.w, r1.y, -r1.z - store_structured u1.x, r0.z, l(0), r0.y - mad r0.y, r0.w, r1.x, r1.w + eq r1.x, cb0[7].y, l(0.000000) + if_nz r1.x + mov r1.x, r0.w + else + ne r1.z, cb0[7].y, l(0.000000) + mul r1.y, r1.y, r1.y + mul r1.y, r1.y, l(3.141593) + div r1.y, r1.y, cb0[7].y + sincos r2.x, r3.x, r1.y + mov r2.y, r3.x + movc r1.yz, r1.zzzz, r2.xxyx, l(0,0,1.000000,0) + mul r2.xy, r0.yyyy, r1.yzyy + mad r1.x, r0.w, r1.z, -r2.x + mad r0.y, r0.w, r1.y, r2.y + endif + store_structured u1.x, r0.z, l(0), r1.x store_structured u1.x, r0.x, l(0), r0.y endif ret @@ -136,11 +132,11 @@ ret const BYTE g_DFT[] = { - 68, 88, 66, 67, 222, 156, - 188, 133, 179, 57, 118, 25, - 122, 216, 102, 13, 91, 242, - 99, 27, 1, 0, 0, 0, - 172, 12, 0, 0, 3, 0, + 68, 88, 66, 67, 63, 188, + 200, 227, 206, 73, 64, 21, + 140, 126, 47, 226, 169, 81, + 175, 134, 1, 0, 0, 0, + 112, 12, 0, 0, 3, 0, 0, 0, 44, 0, 0, 0, 60, 0, 0, 0, 76, 0, 0, 0, 73, 83, 71, 78, @@ -149,8 +145,8 @@ const BYTE g_DFT[] = 79, 83, 71, 78, 8, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 83, 72, - 69, 88, 88, 12, 0, 0, - 80, 0, 5, 0, 22, 3, + 69, 88, 28, 12, 0, 0, + 80, 0, 5, 0, 7, 3, 0, 0, 106, 8, 0, 1, 89, 0, 0, 4, 70, 142, 32, 0, 0, 0, 0, 0, @@ -164,7 +160,7 @@ const BYTE g_DFT[] = 17, 0, 2, 0, 0, 0, 4, 0, 0, 0, 95, 0, 0, 2, 18, 0, 2, 0, - 104, 0, 0, 2, 6, 0, + 104, 0, 0, 2, 5, 0, 0, 0, 155, 0, 0, 4, 64, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, @@ -256,11 +252,9 @@ const BYTE g_DFT[] = 16, 0, 1, 0, 0, 0, 42, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 42, 0, 16, 0, 1, 0, 0, 0, 38, 0, 0, 9, 0, 208, 0, 0, - 66, 0, 16, 0, 1, 0, + 130, 0, 16, 0, 1, 0, 0, 0, 58, 0, 16, 0, 0, 0, 0, 0, 42, 128, 32, 0, 0, 0, 0, 0, @@ -268,221 +262,203 @@ const BYTE g_DFT[] = 0, 139, 2, 35, 0, 128, 131, 153, 25, 0, 18, 0, 16, 0, 4, 0, 0, 0, - 42, 0, 16, 0, 1, 0, + 58, 0, 16, 0, 1, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 2, 0, 0, 0, - 35, 0, 0, 11, 66, 0, + 32, 0, 0, 8, 130, 0, 16, 0, 1, 0, 0, 0, - 58, 0, 16, 0, 0, 0, - 0, 0, 42, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 32, 0, 0, 8, - 130, 0, 16, 0, 1, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 5, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 167, 0, + 58, 128, 32, 0, 0, 0, + 0, 0, 5, 0, 0, 0, + 1, 64, 0, 0, 2, 0, + 0, 0, 31, 0, 4, 3, + 58, 0, 16, 0, 1, 0, + 0, 0, 35, 0, 0, 11, + 34, 0, 16, 0, 2, 0, + 0, 0, 58, 0, 16, 0, + 0, 0, 0, 0, 42, 128, + 32, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 58, 128, + 32, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 167, 0, 0, 139, 2, 35, 0, 128, - 131, 153, 25, 0, 66, 0, - 16, 0, 1, 0, 0, 0, - 42, 0, 16, 0, 1, 0, + 131, 153, 25, 0, 34, 0, + 16, 0, 4, 0, 0, 0, + 26, 0, 16, 0, 2, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 2, 0, 0, 0, - 1, 0, 0, 7, 34, 0, - 16, 0, 4, 0, 0, 0, - 42, 0, 16, 0, 1, 0, - 0, 0, 58, 0, 16, 0, - 1, 0, 0, 0, 18, 0, - 0, 1, 54, 0, 0, 8, - 50, 0, 16, 0, 4, 0, + 18, 0, 0, 1, 54, 0, + 0, 5, 34, 0, 16, 0, + 4, 0, 0, 0, 1, 64, + 0, 0, 0, 0, 0, 0, + 21, 0, 0, 1, 55, 0, + 0, 12, 98, 0, 16, 0, + 2, 0, 0, 0, 166, 10, + 16, 0, 1, 0, 0, 0, + 86, 4, 16, 0, 4, 0, 0, 0, 2, 64, 0, 0, - 0, 0, 128, 63, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 21, 0, - 0, 1, 79, 0, 0, 8, - 66, 0, 16, 0, 1, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 42, 0, 16, 0, - 1, 0, 0, 0, 38, 0, - 0, 9, 0, 208, 0, 0, - 130, 0, 16, 0, 0, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 35, 0, - 0, 10, 130, 0, 16, 0, - 0, 0, 0, 0, 10, 0, + 0, 0, 0, 0, 128, 63, + 0, 0, 0, 0, 79, 0, + 0, 8, 66, 0, 16, 0, + 1, 0, 0, 0, 58, 0, + 16, 0, 0, 0, 0, 0, + 26, 128, 32, 0, 0, 0, + 0, 0, 1, 0, 0, 0, + 38, 0, 0, 9, 0, 208, + 0, 0, 18, 0, 16, 0, + 1, 0, 0, 0, 10, 0, 16, 0, 1, 0, 0, 0, 10, 128, 32, 0, 0, 0, 0, 0, 2, 0, 0, 0, + 35, 0, 0, 10, 130, 0, + 16, 0, 0, 0, 0, 0, 58, 0, 16, 0, 0, 0, - 0, 0, 35, 0, 0, 10, - 130, 0, 16, 0, 0, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 2, 0, 0, 0, 10, 0, 16, 0, - 3, 0, 0, 0, 42, 128, + 1, 0, 0, 0, 35, 0, + 0, 10, 130, 0, 16, 0, + 0, 0, 0, 0, 10, 0, + 16, 0, 3, 0, 0, 0, + 42, 128, 32, 0, 0, 0, + 0, 0, 2, 0, 0, 0, + 58, 0, 16, 0, 0, 0, + 0, 0, 167, 0, 0, 139, + 2, 35, 0, 128, 131, 153, + 25, 0, 18, 0, 16, 0, + 4, 0, 0, 0, 58, 0, + 16, 0, 0, 0, 0, 0, + 1, 64, 0, 0, 0, 0, + 0, 0, 6, 224, 17, 0, + 0, 0, 0, 0, 32, 0, + 0, 8, 130, 0, 16, 0, + 2, 0, 0, 0, 58, 128, 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 58, 0, + 1, 0, 0, 0, 1, 64, + 0, 0, 2, 0, 0, 0, + 31, 0, 4, 3, 58, 0, + 16, 0, 2, 0, 0, 0, + 30, 0, 0, 8, 130, 0, 16, 0, 0, 0, 0, 0, - 167, 0, 0, 139, 2, 35, - 0, 128, 131, 153, 25, 0, - 18, 0, 16, 0, 5, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 1, 64, - 0, 0, 0, 0, 0, 0, - 6, 224, 17, 0, 0, 0, - 0, 0, 32, 0, 0, 8, - 66, 0, 16, 0, 1, 0, + 58, 0, 16, 0, 0, 0, 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 31, 0, - 4, 3, 42, 0, 16, 0, - 1, 0, 0, 0, 30, 0, - 0, 8, 130, 0, 16, 0, - 0, 0, 0, 0, 58, 0, + 0, 0, 0, 0, 2, 0, + 0, 0, 167, 0, 0, 139, + 2, 35, 0, 128, 131, 153, + 25, 0, 34, 0, 16, 0, + 4, 0, 0, 0, 58, 0, 16, 0, 0, 0, 0, 0, - 58, 128, 32, 0, 0, 0, - 0, 0, 2, 0, 0, 0, - 167, 0, 0, 139, 2, 35, - 0, 128, 131, 153, 25, 0, - 34, 0, 16, 0, 5, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 1, 64, - 0, 0, 0, 0, 0, 0, - 6, 224, 17, 0, 0, 0, - 0, 0, 18, 0, 0, 1, - 54, 0, 0, 5, 34, 0, - 16, 0, 5, 0, 0, 0, 1, 64, 0, 0, 0, 0, - 0, 0, 21, 0, 0, 1, - 18, 0, 0, 1, 54, 0, - 0, 8, 50, 0, 16, 0, - 5, 0, 0, 0, 2, 64, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 21, 0, 0, 1, 56, 0, - 0, 7, 130, 0, 16, 0, - 0, 0, 0, 0, 26, 0, - 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 5, 0, - 0, 0, 50, 0, 0, 10, - 130, 0, 16, 0, 0, 0, - 0, 0, 10, 0, 16, 0, - 5, 0, 0, 0, 10, 0, + 0, 0, 6, 224, 17, 0, + 0, 0, 0, 0, 18, 0, + 0, 1, 54, 0, 0, 5, + 34, 0, 16, 0, 4, 0, + 0, 0, 1, 64, 0, 0, + 0, 0, 0, 0, 21, 0, + 0, 1, 1, 0, 0, 7, + 98, 0, 16, 0, 3, 0, + 0, 0, 166, 10, 16, 0, + 1, 0, 0, 0, 6, 1, 16, 0, 4, 0, 0, 0, - 58, 0, 16, 128, 65, 0, - 0, 0, 0, 0, 0, 0, - 15, 0, 0, 7, 66, 0, - 16, 0, 1, 0, 0, 0, - 22, 5, 16, 0, 5, 0, - 0, 0, 70, 0, 16, 0, - 4, 0, 0, 0, 79, 0, - 0, 8, 130, 0, 16, 0, + 56, 0, 0, 7, 130, 0, + 16, 0, 0, 0, 0, 0, + 26, 0, 16, 0, 2, 0, + 0, 0, 42, 0, 16, 0, + 3, 0, 0, 0, 50, 0, + 0, 10, 130, 0, 16, 0, + 0, 0, 0, 0, 26, 0, + 16, 0, 3, 0, 0, 0, + 42, 0, 16, 0, 2, 0, + 0, 0, 58, 0, 16, 128, + 65, 0, 0, 0, 0, 0, + 0, 0, 15, 0, 0, 7, + 66, 0, 16, 0, 1, 0, + 0, 0, 150, 5, 16, 0, + 3, 0, 0, 0, 150, 5, + 16, 0, 2, 0, 0, 0, + 79, 0, 0, 8, 34, 0, + 16, 0, 2, 0, 0, 0, + 26, 0, 16, 0, 0, 0, + 0, 0, 42, 128, 32, 0, + 0, 0, 0, 0, 5, 0, + 0, 0, 1, 0, 0, 7, + 34, 0, 16, 0, 1, 0, + 0, 0, 26, 0, 16, 0, 1, 0, 0, 0, 26, 0, + 16, 0, 2, 0, 0, 0, + 38, 0, 0, 9, 0, 208, + 0, 0, 34, 0, 16, 0, + 2, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, 42, 128, 32, 0, 0, 0, - 0, 0, 5, 0, 0, 0, - 1, 0, 0, 7, 34, 0, - 16, 0, 1, 0, 0, 0, - 58, 0, 16, 0, 1, 0, + 0, 0, 6, 0, 0, 0, + 167, 0, 0, 139, 2, 35, + 0, 128, 131, 153, 25, 0, + 18, 0, 16, 0, 4, 0, 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 26, 0, 16, 0, - 1, 0, 0, 0, 38, 0, - 0, 9, 0, 208, 0, 0, - 34, 0, 16, 0, 1, 0, + 2, 0, 0, 0, 1, 64, + 0, 0, 0, 0, 0, 0, + 6, 224, 17, 0, 2, 0, + 0, 0, 31, 0, 4, 3, + 58, 0, 16, 0, 1, 0, + 0, 0, 35, 0, 0, 11, + 130, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, 42, 128, + 32, 0, 0, 0, 0, 0, + 6, 0, 0, 0, 58, 128, 32, 0, 0, 0, 0, 0, 6, 0, 0, 0, 167, 0, 0, 139, 2, 35, 0, 128, - 131, 153, 25, 0, 18, 0, + 131, 153, 25, 0, 34, 0, 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 1, 0, + 58, 0, 16, 0, 1, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 2, 0, 0, 0, - 35, 0, 0, 11, 34, 0, - 16, 0, 1, 0, 0, 0, - 26, 0, 16, 0, 0, 0, - 0, 0, 42, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 6, 0, - 0, 0, 32, 0, 0, 8, - 130, 0, 16, 0, 1, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 5, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 167, 0, - 0, 139, 2, 35, 0, 128, - 131, 153, 25, 0, 34, 0, + 18, 0, 0, 1, 54, 0, + 0, 5, 34, 0, 16, 0, + 4, 0, 0, 0, 1, 64, + 0, 0, 0, 0, 0, 0, + 21, 0, 0, 1, 55, 0, + 0, 12, 162, 0, 16, 0, + 1, 0, 0, 0, 86, 5, 16, 0, 1, 0, 0, 0, - 26, 0, 16, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 0, 0, 0, 0, 6, 224, - 17, 0, 2, 0, 0, 0, - 1, 0, 0, 7, 34, 0, - 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 1, 0, - 0, 0, 58, 0, 16, 0, - 1, 0, 0, 0, 18, 0, - 0, 1, 54, 0, 0, 8, - 50, 0, 16, 0, 4, 0, + 86, 1, 16, 0, 4, 0, 0, 0, 2, 64, 0, 0, - 0, 0, 128, 63, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 21, 0, - 0, 1, 79, 0, 0, 8, - 34, 0, 16, 0, 1, 0, - 0, 0, 26, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 1, 0, 0, 0, 31, 0, - 4, 3, 26, 0, 16, 0, - 1, 0, 0, 0, 38, 0, - 0, 9, 0, 208, 0, 0, - 34, 0, 16, 0, 0, 0, - 0, 0, 26, 0, 16, 0, - 0, 0, 0, 0, 26, 128, - 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 35, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 128, 63, 79, 0, + 0, 8, 34, 0, 16, 0, + 2, 0, 0, 0, 26, 0, + 16, 0, 0, 0, 0, 0, + 26, 128, 32, 0, 0, 0, + 0, 0, 1, 0, 0, 0, + 35, 0, 0, 10, 34, 0, + 16, 0, 0, 0, 0, 0, + 26, 0, 16, 0, 0, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 2, 0, + 0, 0, 10, 0, 16, 0, + 1, 0, 0, 0, 35, 0, 0, 10, 34, 0, 16, 0, 0, 0, 0, 0, 10, 0, - 16, 0, 1, 0, 0, 0, - 10, 128, 32, 0, 0, 0, + 16, 0, 3, 0, 0, 0, + 42, 128, 32, 0, 0, 0, 0, 0, 2, 0, 0, 0, 26, 0, 16, 0, 0, 0, - 0, 0, 35, 0, 0, 10, - 34, 0, 16, 0, 0, 0, - 0, 0, 10, 0, 16, 0, - 3, 0, 0, 0, 42, 128, - 32, 0, 0, 0, 0, 0, - 2, 0, 0, 0, 26, 0, + 0, 0, 167, 0, 0, 139, + 2, 35, 0, 128, 131, 153, + 25, 0, 18, 0, 16, 0, + 3, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, - 167, 0, 0, 139, 2, 35, - 0, 128, 131, 153, 25, 0, - 18, 0, 16, 0, 1, 0, - 0, 0, 26, 0, 16, 0, - 0, 0, 0, 0, 1, 64, - 0, 0, 0, 0, 0, 0, - 6, 224, 17, 0, 0, 0, - 0, 0, 32, 0, 0, 8, - 130, 0, 16, 0, 1, 0, - 0, 0, 58, 128, 32, 0, - 0, 0, 0, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 2, 0, 0, 0, 31, 0, + 1, 64, 0, 0, 0, 0, + 0, 0, 6, 224, 17, 0, + 0, 0, 0, 0, 31, 0, 4, 3, 58, 0, 16, 0, - 1, 0, 0, 0, 30, 0, + 2, 0, 0, 0, 30, 0, 0, 8, 34, 0, 16, 0, 0, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, @@ -490,39 +466,37 @@ const BYTE g_DFT[] = 0, 0, 2, 0, 0, 0, 167, 0, 0, 139, 2, 35, 0, 128, 131, 153, 25, 0, - 34, 0, 16, 0, 1, 0, + 34, 0, 16, 0, 3, 0, 0, 0, 26, 0, 16, 0, 0, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 6, 224, 17, 0, 0, 0, 0, 0, 18, 0, 0, 1, 54, 0, 0, 5, 34, 0, - 16, 0, 1, 0, 0, 0, + 16, 0, 3, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, 21, 0, 0, 1, - 18, 0, 0, 1, 54, 0, - 0, 8, 50, 0, 16, 0, - 1, 0, 0, 0, 2, 64, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 21, 0, 0, 1, 56, 0, + 1, 0, 0, 7, 98, 0, + 16, 0, 2, 0, 0, 0, + 86, 5, 16, 0, 2, 0, + 0, 0, 6, 1, 16, 0, + 3, 0, 0, 0, 56, 0, 0, 7, 34, 0, 16, 0, 0, 0, 0, 0, 26, 0, - 16, 0, 4, 0, 0, 0, - 26, 0, 16, 0, 1, 0, + 16, 0, 1, 0, 0, 0, + 42, 0, 16, 0, 2, 0, 0, 0, 50, 0, 0, 10, 34, 0, 16, 0, 0, 0, - 0, 0, 10, 0, 16, 0, - 1, 0, 0, 0, 10, 0, - 16, 0, 4, 0, 0, 0, + 0, 0, 26, 0, 16, 0, + 2, 0, 0, 0, 58, 0, + 16, 0, 1, 0, 0, 0, 26, 0, 16, 128, 65, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 7, 18, 0, 16, 0, 1, 0, 0, 0, - 22, 5, 16, 0, 1, 0, - 0, 0, 70, 0, 16, 0, - 4, 0, 0, 0, 78, 0, + 150, 5, 16, 0, 2, 0, + 0, 0, 214, 5, 16, 0, + 1, 0, 0, 0, 78, 0, 0, 8, 0, 208, 0, 0, 34, 0, 16, 0, 1, 0, 0, 0, 10, 0, 16, 0, @@ -610,65 +584,77 @@ const BYTE g_DFT[] = 16, 0, 0, 0, 0, 0, 166, 138, 32, 0, 0, 0, 0, 0, 7, 0, 0, 0, - 57, 0, 0, 8, 18, 0, + 24, 0, 0, 8, 18, 0, 16, 0, 1, 0, 0, 0, 26, 128, 32, 0, 0, 0, 0, 0, 7, 0, 0, 0, 1, 64, 0, 0, 0, 0, + 0, 0, 31, 0, 4, 3, + 10, 0, 16, 0, 1, 0, + 0, 0, 54, 0, 0, 5, + 18, 0, 16, 0, 1, 0, + 0, 0, 58, 0, 16, 0, + 0, 0, 0, 0, 18, 0, + 0, 1, 57, 0, 0, 8, + 66, 0, 16, 0, 1, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 7, 0, + 0, 0, 1, 64, 0, 0, + 0, 0, 0, 0, 56, 0, + 0, 7, 34, 0, 16, 0, + 1, 0, 0, 0, 26, 0, + 16, 0, 1, 0, 0, 0, + 26, 0, 16, 0, 1, 0, 0, 0, 56, 0, 0, 7, 34, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 26, 0, - 16, 0, 1, 0, 0, 0, - 56, 0, 0, 7, 34, 0, + 1, 0, 0, 0, 1, 64, + 0, 0, 219, 15, 73, 64, + 14, 0, 0, 8, 34, 0, 16, 0, 1, 0, 0, 0, 26, 0, 16, 0, 1, 0, - 0, 0, 1, 64, 0, 0, - 219, 15, 73, 64, 14, 0, - 0, 8, 34, 0, 16, 0, - 1, 0, 0, 0, 26, 0, + 0, 0, 26, 128, 32, 0, + 0, 0, 0, 0, 7, 0, + 0, 0, 77, 0, 0, 7, + 18, 0, 16, 0, 2, 0, + 0, 0, 18, 0, 16, 0, + 3, 0, 0, 0, 26, 0, 16, 0, 1, 0, 0, 0, - 26, 128, 32, 0, 0, 0, - 0, 0, 7, 0, 0, 0, - 77, 0, 0, 7, 18, 0, + 54, 0, 0, 5, 34, 0, 16, 0, 2, 0, 0, 0, - 18, 0, 16, 0, 3, 0, - 0, 0, 26, 0, 16, 0, - 1, 0, 0, 0, 54, 0, - 0, 5, 34, 0, 16, 0, - 2, 0, 0, 0, 10, 0, - 16, 0, 3, 0, 0, 0, - 55, 0, 0, 12, 50, 0, - 16, 0, 1, 0, 0, 0, - 6, 0, 16, 0, 1, 0, - 0, 0, 70, 0, 16, 0, - 2, 0, 0, 0, 2, 64, + 10, 0, 16, 0, 3, 0, + 0, 0, 55, 0, 0, 12, + 98, 0, 16, 0, 1, 0, + 0, 0, 166, 10, 16, 0, + 1, 0, 0, 0, 6, 1, + 16, 0, 2, 0, 0, 0, + 2, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, - 0, 0, 0, 0, 0, 0, - 56, 0, 0, 7, 194, 0, + 0, 0, 56, 0, 0, 7, + 50, 0, 16, 0, 2, 0, + 0, 0, 86, 5, 16, 0, + 0, 0, 0, 0, 150, 5, 16, 0, 1, 0, 0, 0, - 86, 5, 16, 0, 0, 0, - 0, 0, 6, 4, 16, 0, - 1, 0, 0, 0, 50, 0, - 0, 10, 34, 0, 16, 0, + 50, 0, 0, 10, 18, 0, + 16, 0, 1, 0, 0, 0, + 58, 0, 16, 0, 0, 0, + 0, 0, 42, 0, 16, 0, + 1, 0, 0, 0, 10, 0, + 16, 128, 65, 0, 0, 0, + 2, 0, 0, 0, 50, 0, + 0, 9, 34, 0, 16, 0, 0, 0, 0, 0, 58, 0, 16, 0, 0, 0, 0, 0, 26, 0, 16, 0, 1, 0, - 0, 0, 42, 0, 16, 128, - 65, 0, 0, 0, 1, 0, - 0, 0, 168, 0, 0, 9, + 0, 0, 26, 0, 16, 0, + 2, 0, 0, 0, 21, 0, + 0, 1, 168, 0, 0, 9, 18, 224, 17, 0, 1, 0, 0, 0, 42, 0, 16, 0, 0, 0, 0, 0, 1, 64, 0, 0, 0, 0, 0, 0, - 26, 0, 16, 0, 0, 0, - 0, 0, 50, 0, 0, 9, - 34, 0, 16, 0, 0, 0, - 0, 0, 58, 0, 16, 0, - 0, 0, 0, 0, 10, 0, - 16, 0, 1, 0, 0, 0, - 58, 0, 16, 0, 1, 0, + 10, 0, 16, 0, 1, 0, 0, 0, 168, 0, 0, 9, 18, 224, 17, 0, 1, 0, 0, 0, 10, 0, 16, 0, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h index 988c0aa66ade2..56ce759875687 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/stockham_fp16.h @@ -15,7 +15,7 @@ ; Name Index Mask Register SysValue Format Used ; -------------------- ----- ------ -------- -------- ------- ------ ; no parameters -; shader hash: e08f21199c48b0db30bf21bd8c5b80dc +; shader hash: 6a1d88feb14177832f5ee49ca330c549 ; ; Pipeline Runtime Information: ; @@ -125,7 +125,7 @@ define void @DFT() { %47 = fpext half %46 to float %48 = extractvalue %dx.types.CBufRet.i32 %37, 3 %49 = icmp eq i32 %48, 2 - br i1 %49, label %50, label %56 + br i1 %49, label %50, label %56, !dx.controlflow.hints !15 ;