Skip to content

Commit

Permalink
[stream_executor:gpu] Force serial execution order between all record…
Browse files Browse the repository at this point in the history
…ed commands tensorflow#6542

This is required for correctness until we have a proper way to model dependencies with barriers and scopes

PiperOrigin-RevId: 576741499
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Oct 26, 2023
1 parent 3b26963 commit fd98106
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ cc_library(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status",
Expand Down
19 changes: 14 additions & 5 deletions third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
Expand Down Expand Up @@ -140,6 +141,11 @@ tsl::Status GpuCommandBuffer::Trace(
return tsl::OkStatus();
}

absl::Span<GpuGraphNodeHandle> GpuCommandBuffer::GetDependencies() {
return nodes_.empty() ? absl::Span<GpuGraphNodeHandle>()
: absl::Span<GpuGraphNodeHandle>(&nodes_.back(), 1);
}

tsl::Status GpuCommandBuffer::CheckNotFinalized() {
if (state_ == State::kFinalized)
return absl::InternalError(
Expand Down Expand Up @@ -167,11 +173,12 @@ tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads,

// Adds a new kernel node to the graph under construction.
if (state_ == State::kCreate) {
absl::Span<GpuGraphNodeHandle> deps = GetDependencies();
GpuGraphNodeHandle* node = &nodes_.emplace_back();
return GpuDriver::GraphAddKernelNode(
node, graph_, {}, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z,
threads.x, threads.y, threads.z, args.number_of_shared_bytes(),
kernel_params, /*extra=*/nullptr);
node, graph_, deps, kernel.name(), gpu_func, blocks.x, blocks.y,
blocks.z, threads.x, threads.y, threads.z,
args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr);
}

// Updates kernel node in the executable graph.
Expand All @@ -193,9 +200,10 @@ tsl::Status GpuCommandBuffer::AddNestedCommandBuffer(

// Adds a child graph node to the graph under construction.
if (state_ == State::kCreate) {
absl::Span<GpuGraphNodeHandle> deps = GetDependencies();
GpuGraphNodeHandle* node = &nodes_.emplace_back();
return GpuDriver::GraphAddChildNode(
node, graph_, {}, GpuCommandBuffer::Cast(&nested)->graph());
node, graph_, deps, GpuCommandBuffer::Cast(&nested)->graph());
}

return UnsupportedStateError(state_);
Expand All @@ -208,9 +216,10 @@ tsl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst,

// Adds a new memcpy node to the graph under construction.
if (state_ == State::kCreate) {
absl::Span<GpuGraphNodeHandle> deps = GetDependencies();
GpuGraphNodeHandle* node = &nodes_.emplace_back();
return GpuDriver::GraphAddMemcpyD2DNode(parent_->gpu_context(), node,
graph_, {}, AsDevicePtr(*dst),
graph_, deps, AsDevicePtr(*dst),
AsDevicePtr(src), size);
}

Expand Down
6 changes: 6 additions & 0 deletions third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <vector>

#include "absl/functional/any_invocable.h"
#include "absl/types/span.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/gpu_types.h"
Expand Down Expand Up @@ -80,6 +81,11 @@ class GpuCommandBuffer : public internal::CommandBufferInterface {
}

private:
// TODO(ezhulenev): Currently we serialize all Gpu nodes by adding a
// dependency between all nodes added to a command buffer. We need a concept
// of a barrier at a command buffer level.
absl::Span<GpuGraphNodeHandle> GetDependencies();

// Returns OK status if command buffer is not finalized and it is still
// possible to add new commands to it, otherwise returns internal error.
tsl::Status CheckNotFinalized();
Expand Down

0 comments on commit fd98106

Please sign in to comment.