Skip to content

Commit

Permalink
refactor TNCudaGateCache & add appendGateTensorOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD committed May 8, 2024
1 parent 4b4af4f commit c85c2ff
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 714 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ option(PL_DISABLE_CUDA_SAFETY "Build without CUDA call safety checks" OFF)

target_link_libraries(${PL_TENSOR} PUBLIC lightning_compile_options
lightning_external_libs
${PL_BACKEND}_utils
${PL_BACKEND}_gates
${PL_BACKEND}_tensor
${PL_BACKEND}_tensornetBase
${PL_BACKEND}_utils
)

target_include_directories(${PL_TENSOR} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cuda.h>
#include <cutensornet.h>

#include "TNCudaGateCache.hpp"
#include "TensorBase.hpp"
#include "TensorCuda.hpp"
#include "TensornetBase.hpp"
Expand All @@ -47,6 +48,7 @@ namespace Pennylane::LightningTensor::TNCuda {
template <class Precision, class Derived>
class TNCudaBase : public TensornetBase<Precision, Derived> {
private:
using CFP_t = decltype(cuUtil::getCudaType(Precision{}));
using BaseType = TensornetBase<Precision, Derived>;
SharedTNCudaHandle handle_;
cudaDataType_t typeData_;
Expand All @@ -57,13 +59,16 @@ class TNCudaBase : public TensornetBase<Precision, Derived> {
CUTENSORNET_STATE_PURITY_PURE; // Only supports pure tensor network
// states as v24.03

std::shared_ptr<TNCudaGateCache<Precision>> gate_cache_;

public:
TNCudaBase() = delete;

explicit TNCudaBase(const std::size_t numQubits, int device_id = 0,
cudaStream_t stream_id = 0)
: BaseType(numQubits), handle_(make_shared_tncuda_handle()),
dev_tag_({device_id, stream_id}) {
dev_tag_({device_id, stream_id}),
gate_cache_(std::make_shared<TNCudaGateCache<Precision>>(dev_tag_)) {
// TODO this code block could be moved to base class and need to revisit
// when working on copy ctor
if constexpr (std::is_same_v<Precision, double>) {
Expand All @@ -87,7 +92,8 @@ class TNCudaBase : public TensornetBase<Precision, Derived> {

explicit TNCudaBase(const std::size_t numQubits, DevTag<int> dev_tag)
: BaseType(numQubits), handle_(make_shared_tncuda_handle()),
dev_tag_(dev_tag) {
dev_tag_(dev_tag),
gate_cache_(std::make_shared<TNCudaGateCache<Precision>>(dev_tag_)) {
// TODO this code block could be moved to base class and need to revisit
// when working on copy ctor
if constexpr (std::is_same_v<Precision, double>) {
Expand Down Expand Up @@ -141,6 +147,51 @@ class TNCudaBase : public TensornetBase<Precision, Derived> {
return dev_tag_;
}

void appendGateTensorOperator(
const std::string &opName, const std::vector<size_t> &wires,
bool adjoint = false, const std::vector<Precision> &params = {0.0},
[[maybe_unused]] const std::vector<CFP_t> &gate_matrix = {}) {
auto &&par = (params.empty()) ? std::vector<Precision>{0.0} : params;

DataBuffer<Precision, int> dummy_device_data(
Pennylane::Util::exp2(wires.size()), getDevTag());

int64_t id;
std::vector<int32_t> stateModes(wires.size());
std::transform(
wires.begin(), wires.end(), stateModes.begin(), [&](size_t x) {
return static_cast<int32_t>(BaseType::getNumQubits() - 1 - x);
});

// Note adjoint indicates whether or not all tensor elements of the
// tensor operator will be complex conjugated adjoint in the following
// API is not equivalent to inverse in the lightning context
// NOTE: cutensornetStateApplyTensorOperator doesn't update the quantum
// state but only appends gate tensor operator to the graph.
PL_CUTENSORNET_IS_SUCCESS(cutensornetStateApplyTensorOperator(
/* const cutensornetHandle_t */ getTNCudaHandle(),
/* cutensornetState_t */ getQuantumState(),
/* int32_t numStateModes */ stateModes.size(),
/* const int32_t * stateModes */ stateModes.data(),
/* void * */ static_cast<void *>(dummy_device_data.getData()),
/* const int64_t *tensorModeStrides */ nullptr,
/* const int32_t immutable */ 1,
/* const int32_t adjoint */ adjoint,
/* const int32_t unitary */ 1,
/* int64_t * */ &id));

gate_cache_->add_gate(static_cast<size_t>(id), opName, par[0]);

PL_CUTENSORNET_IS_SUCCESS(cutensornetStateUpdateTensorOperator(
/* const cutensornetHandle_t */ getTNCudaHandle(),
/* cutensornetState_t */ getQuantumState(),
/* int64_t tensorId*/ id,
/* void* */
static_cast<void *>(
gate_cache_->get_gate_device_ptr(static_cast<size_t>(id))),
/* int32_t unitary*/ 1));
}

protected:
/**
* @brief Returns the workspace size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ add_library(${PL_BACKEND}_gates INTERFACE)

target_include_directories(${PL_BACKEND}_gates INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})

target_link_libraries(${PL_BACKEND}_gates INTERFACE ${PL_BACKEND}_utils )
target_link_libraries(${PL_BACKEND}_gates INTERFACE ${PL_BACKEND}_utils ${PL_BACKEND}_tensor)

set_property(TARGET ${PL_BACKEND}_gates PROPERTY POSITION_INDEPENDENT_CODE ON)

Expand Down
Loading

0 comments on commit c85c2ff

Please sign in to comment.