From c17e6e60f982f19eaddea44f5bb661ac601a79e7 Mon Sep 17 00:00:00 2001 From: Penporn Koanantakool Date: Sat, 29 Jun 2024 01:40:50 -0700 Subject: [PATCH] [xla:cpu] Add FFI custom call thunk runtime support to PJRT CPU client. Also add a benchmark that uses PJRT CPU Client. PiperOrigin-RevId: 647916282 --- third_party/xla/xla/pjrt/cpu/BUILD | 1 - third_party/xla/xla/pjrt/cpu/cpu_client.cc | 23 +++++-- .../xla/xla/service/cpu/benchmarks/BUILD | 17 ++++++ .../benchmarks/custom_call_benchmark_test.cc | 60 +++++++++++++++++++ .../service/cpu/runtime/custom_call_thunk.cc | 3 + 5 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 6d77c95109eb75..79a244cd51e8bb 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -163,7 +163,6 @@ cc_library( "//xla/pjrt:semaphore", "//xla/pjrt:transpose", "//xla/pjrt:utils", - "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:buffer_assignment", "//xla/service:compiler", "//xla/service:computation_placer_hdr", diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 3e3feb7ffa7a82..4ee4a2a91a083c 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -1588,10 +1588,18 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( cpu::Thunk::CollectiveExecuteParams collective_params, cpu::Thunk::CollectiveExecuteParams::Create(&run_options)); + // TODO(penporn): Consolidate with other thunk parameter set up calls. + TF_ASSIGN_OR_RETURN( + cpu::Thunk::CustomCallExecuteParams custom_call_execute_params, + cpu::Thunk::CustomCallExecuteParams::Create(&run_options)); + cpu::Thunk::ExecuteParams execute_params = { - &cpu_executable->host_kernels(), &allocations, + &cpu_executable->host_kernels(), + &allocations, cpu::runtime::GetXfeedManager(run_options.device_ordinal()), - run_options.intra_op_thread_pool(), &collective_params}; + run_options.intra_op_thread_pool(), + &collective_params, + &custom_call_execute_params}; auto execute_event = cpu_executable->thunks().Execute( execute_params, [&](cpu::ThunkExecutor::Task task) { @@ -1714,11 +1722,18 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( collective_params = cpu::Thunk::CollectiveExecuteParams::Create(&run_options); + absl::StatusOr + custom_call_params = + cpu::Thunk::CustomCallExecuteParams::Create(&run_options); + if (collective_params.ok()) { cpu::Thunk::ExecuteParams execute_params = { - &cpu_executable->host_kernels(), &allocations, + &cpu_executable->host_kernels(), + &allocations, cpu::runtime::GetXfeedManager(run_options.device_ordinal()), - run_options.intra_op_thread_pool(), &*collective_params}; + run_options.intra_op_thread_pool(), + &*collective_params, + &*custom_call_params}; auto execute_event = cpu_executable->thunks().Execute( execute_params, [&](cpu::ThunkExecutor::Task task) { diff --git a/third_party/xla/xla/service/cpu/benchmarks/BUILD b/third_party/xla/xla/service/cpu/benchmarks/BUILD index 5b2594f5cc32d0..cec059c47b3855 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/service/cpu/benchmarks/BUILD @@ -148,6 +148,23 @@ xla_cc_test( ], ) +xla_cc_test( + name = "custom_call_benchmark_test", + srcs = ["custom_call_benchmark_test.cc"], + deps = [ + ":hlo_benchmark_runner", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "gather_benchmark_test", srcs = ["gather_benchmark_test.cc"], diff --git a/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc new file mode 100644 index 00000000000000..29d32f9792e501 --- /dev/null +++ b/third_party/xla/xla/service/cpu/benchmarks/custom_call_benchmark_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xla/ffi/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/service/cpu/benchmarks/hlo_benchmark_runner.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test_benchmark.h" + +namespace xla::cpu { +namespace { + +static absl::Status Minimal( + ffi::Result> unused) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER( + kMinimal, Minimal, + ffi::Ffi::Bind() + .Ret>()); // Unused out buffer + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_bm$$minimal", "Host", + kMinimal); + +static void BM_CustomCall_Minimal(benchmark::State& state) { + const char* kModuleStr = R"( + HloModule module + + ENTRY custom_call { + ROOT custom-call = f32[] custom-call(), + custom_call_target="__xla_bm$$minimal", + api_version=API_VERSION_TYPED_FFI + } + )"; + CHECK_OK(RunHloBenchmark(state, kModuleStr, /*args=*/{}, + /*replacements=*/{})); + state.SetItemsProcessed(state.iterations()); +} + +BENCHMARK(BM_CustomCall_Minimal)->MeasureProcessCPUTime(); + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc index 3f290e01491c2e..5914d75df67d6e 100644 --- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc @@ -90,6 +90,9 @@ tsl::AsyncValueRef CustomCallThunk::CallTypedFFI( "No registered implementation for FFI custom call to %s for Host", target_name_); } + if (params.custom_call_params == nullptr) { + return Internal("CustomCallExecuteParams cannot be nullptr."); + } // Build the FFI call frame. ffi::CallFrameBuilder builder(