Skip to content

Commit

Permalink
Implement DML copy for Lora Adapters (#22396)
Browse files Browse the repository at this point in the history
### Description
Request and create DML EP and its data transfer.
Use to copy on device.

The PR includes changes to fix issues in DML provider.

### Motivation and Context
This enables Lora users to run it with DML which is important for GenAI.

Co-authored-by: @PatriceVignola

---------

Co-authored-by: Patrice Vignola <[email protected]>
  • Loading branch information
yuslepukhin and PatriceVignola authored Oct 14, 2024
1 parent 35adba2 commit 87e8a5d
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ namespace Dml
}
else
{
if (!m_context->IsClosed())
if (!m_closed)
{
// Free the underlying allocation once queued work has completed.
#ifdef _GAMING_XBOX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ namespace Dml

void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);

void Close()
{
m_closed = true;
}

public: // onnxruntime::IAllocator
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
void* Alloc(size_t size) final;
Expand Down Expand Up @@ -83,6 +88,7 @@ namespace Dml
std::vector<Bucket> m_pool;
size_t m_currentAllocationId = 0;
uint64_t m_currentResourceId = 0;
bool m_closed = false;

// Unless specifically requested, allocation sizes are not rounded to enable pooling
// until SetDefaultRoundingMode is called. This should be done at completion of session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace Dml
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
if (!m_closing)
if (!m_clearingQueue)
{
QueuedReference queuedReference = {GetLastFenceValue(), object};

Expand All @@ -70,15 +70,15 @@ namespace Dml
}
}

void CommandQueue::Close()
void CommandQueue::WaitForSignalAndClearQueue()
{
// Wait for flushed work:
assert(!m_closing);
m_closing = true;
assert(!m_clearingQueue);
m_clearingQueue = true;
GpuEvent event = GetCurrentCompletionEvent();
event.WaitForSignal(m_cpuSyncSpinningEnabled);
m_queuedReferences.clear();
m_closing = false;
m_clearingQueue = false;
}

void CommandQueue::ReleaseCompletedReferences()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace Dml
}
#endif

void Close();
void WaitForSignalAndClearQueue();
void ReleaseCompletedReferences();

private:
Expand All @@ -61,7 +61,7 @@ namespace Dml

ComPtr<ID3D12Fence> m_fence;
uint64_t m_lastFenceValue = 0;
bool m_closing = false;
bool m_clearingQueue = false;
bool m_cpuSyncSpinningEnabled = false;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled,
bool keepOpen
)
bool cpuSyncSpinningEnabled)
: m_queue(std::make_shared<CommandQueue>(queue, cpuSyncSpinningEnabled))
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
, m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled)
, m_keepOpen(keepOpen)
{
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
}
Expand All @@ -36,8 +33,6 @@ namespace Dml
D3D12_RESOURCE_STATES srcState,
uint64_t byteCount)
{
assert(!m_closed);

SetCommandRecorder(&m_dmlRecorder);

std::vector<D3D12_RESOURCE_BARRIER> barriers;
Expand Down Expand Up @@ -84,8 +79,6 @@ namespace Dml
_Out_ uint64_t* completionValue
)
{
assert(!m_closed);

SetCommandRecorder(&m_dmlRecorder);
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
}
Expand All @@ -95,7 +88,6 @@ namespace Dml
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
Expand All @@ -107,31 +99,27 @@ namespace Dml
gsl::span<const DML_BINDING_DESC> inputBindings,
gsl::span<const DML_BINDING_DESC> outputBindings)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
}

void ExecutionContext::AddUAVBarrier()
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.AddUAVBarrier();
}

void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.ResourceBarrier(barriers);
}

void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

// Ensure the descriptor heap is reset to D3D as something external may change it before recording
Expand All @@ -142,8 +130,6 @@ namespace Dml

void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
{
assert(!m_closed);

// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
// ordering of operations on the command queue.
if (m_currentRecorder != newRecorder)
Expand All @@ -160,8 +146,6 @@ namespace Dml

void ExecutionContext::Flush()
{
assert(!m_closed);

if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork())
{
// Nothing to flush
Expand All @@ -180,34 +164,21 @@ namespace Dml

void ExecutionContext::QueueReference(IUnknown* object)
{
assert(!m_closed);
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
// value is the one to signal completion.
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
m_queue->QueueReference(object, waitForUnsubmittedWork);
}

void ExecutionContext::Close()
void ExecutionContext::WaitForSignalAndClearQueue()
{
assert(!m_closed);

// Discard unflushed work and clear queued references. This prevents the circular reference:
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
m_queue->Close();

// Keep the execution context open when requested, e.g. when used through the python API where there's a single context
// and single command queue
if (!m_keepOpen)
{
m_currentRecorder = nullptr;
m_closed = true;
}
m_queue->WaitForSignalAndClearQueue();
}

GpuEvent ExecutionContext::GetCurrentCompletionEvent()
{
assert(!m_closed);

GpuEvent event = m_queue->GetCurrentCompletionEvent();

// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
Expand All @@ -223,13 +194,11 @@ namespace Dml

void ExecutionContext::ReleaseCompletedReferences()
{
assert(!m_closed);
m_queue->ReleaseCompletedReferences();
}

D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
{
assert(!m_closed);
return m_queue->GetType();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ namespace Dml
ID3D12Device* d3d12Device,
IDMLDevice* dmlDevice,
ID3D12CommandQueue* queue,
bool cpuSyncSpinningEnabled,
bool keepOpen);
bool cpuSyncSpinningEnabled);

void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);

// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close();
// prevent circular references.
void WaitForSignalAndClearQueue();

// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
Expand Down Expand Up @@ -87,7 +86,6 @@ namespace Dml

D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
bool IsClosed() const { return m_closed; }

private:
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
Expand All @@ -103,10 +101,6 @@ namespace Dml

bool m_closed = false;
bool m_cpuSyncSpinningEnabled = false;

// The python API has a global state used for I/O binding where the execution context is shared between session,
// so we don't want to close the context when one of the sessions is destroyed
bool m_keepOpen = false;
};

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,26 @@ namespace Dml
// Release the cached command list references before closing the context
m_capturedGraphs.clear();

m_context->Close();
// Close the allocator before clearing the command queue to stop it from
// appending resources to it in an attempt to keep them alive.
if (m_allocator)
{
m_allocator->Close();
}

// Destroy the allocators. We are closing the execution provider, so from now on the
// only thing it will be used for is doing copies via the DataTransfer, which doesn't
// require allocating any memory.
// TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly
// destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive.
m_allocator = nullptr;
m_cpuInputAllocator = nullptr;

// Wait for all pending commands to be done executing and empty the command queue. This will
// Force all kernels and resources in flight to get destroyed and, from this point forward,
// ExecutionProviderImpl will only be used to execute transfer between resources that are
// already existing via the DataTransfer;
m_context->WaitForSignalAndClearQueue();
}

void ExecutionProviderImpl::WaitForOutstandingWork()
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {

// First, check if an I/O binding API that was used before this session or another session has already created a queue
if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true);
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get()));
}
} else {
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_);
}

auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);
Expand Down
Loading

0 comments on commit 87e8a5d

Please sign in to comment.