Skip to content

Commit

Permalink
Merge branch 'develop' into warning
Browse files Browse the repository at this point in the history
  • Loading branch information
hongriTianqi authored Jan 12, 2024
2 parents eec7d4a + 8a969c7 commit 55a819f
Show file tree
Hide file tree
Showing 36 changed files with 1,042 additions and 390 deletions.
2 changes: 2 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ list(APPEND device_srcs
module_hamilt_pw/hamilt_pwdft/kernels/wf_op.cpp
module_hamilt_pw/hamilt_pwdft/kernels/vnl_op.cpp
module_base/kernels/math_op.cpp
module_hamilt_general/module_xc/kernels/xc_functional_op.cpp
)

if(USE_CUDA)
Expand All @@ -57,6 +58,7 @@ if(USE_CUDA)
module_hamilt_pw/hamilt_pwdft/kernels/cuda/wf_op.cu
module_hamilt_pw/hamilt_pwdft/kernels/cuda/vnl_op.cu
module_base/kernels/cuda/math_op.cu
module_hamilt_general/module_xc/kernels/cuda/xc_functional_op.cu
)
endif()

Expand Down
2 changes: 2 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ VPATH=./src_global:\
./module_hamilt_general/module_ewald:\
./module_hamilt_general/module_surchem:\
./module_hamilt_general/module_xc:\
./module_hamilt_general/module_xc/kernels:\
./module_hamilt_pw/hamilt_pwdft:\
./module_hamilt_pw/hamilt_ofdft:\
./module_hamilt_pw/hamilt_stodft:\
Expand Down Expand Up @@ -392,6 +393,7 @@ OBJS_SYMMETRY=symm_other.o\
symmetry.o\

OBJS_XC=xc_functional.o\
xc_functional_op.o\
xc_functional_vxc.o\
xc_functional_gradcorr.o\
xc_functional_wrapper_xc.o\
Expand Down
34 changes: 34 additions & 0 deletions source/module_base/module_container/ATen/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,40 @@ class Tensor {
return output;
}

/**
* @brief Copies data from a given device to the current tensor object.
*
* This function is designed to copy a given number of elements from a device-specific memory location
* to the memory associated with this object. It ensures that the size of the data being copied does not exceed
* the size of the destination tensor.
*
* @tparam DEVICE The device type from which the data will be copied.
* @tparam T The data type of the elements being copied.
*
* @param data Pointer to the data array in the device memory that needs to be copied.
* @param num_elements The number of elements to copy.
*
* @pre The number of elements to copy (`num_elements`) must be less than or equal to the number of elements
* in the destination tensor (`this->shape_.num_elements()`). If this condition is not met, the function
* will trigger an error through `REQUIRES_OK`.
*
* @note The function uses a template specialization `TEMPLATE_CZ_2` to handle the copying of memory
* based on the data type `T` and the device type `DEVICE`. It utilizes the `kernels::cast_memory`
* method to perform the actual memory copy operation.
*/
template <typename DEVICE, typename T>
void copy_from_device(const T* data, int64_t num_elements = -1) {
if (num_elements == -1) {
num_elements = this->NumElements();
}
REQUIRES_OK(this->shape_.NumElements() >= num_elements,
"The number of elements of the input data must match the number of elements of the tensor.")

TEMPLATE_CZ_2(this->data_type_, this->device_,
kernels::cast_memory<T_, T, DEVICE_, DEVICE>()(
this->data<T_>(), data, num_elements))
}

/**
* @brief Method to transform data from a given tensor object to the output tensor with a given data type
*
Expand Down
14 changes: 14 additions & 0 deletions source/module_base/module_container/ATen/core/tensor_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ struct PsiToContainer<psi::DEVICE_GPU> {
using type = container::DEVICE_GPU; /**< The return type specialization for std::complex<double>. */
};

template <typename T>
struct ContainerToPsi {
using type = T; /**< The return type based on the input type. */
};

template <>
struct ContainerToPsi<container::DEVICE_CPU> {
using type = psi::DEVICE_CPU; /**< The return type specialization for std::complex<float>. */
};

template <>
struct ContainerToPsi<container::DEVICE_GPU> {
using type = psi::DEVICE_GPU; /**< The return type specialization for std::complex<double>. */
};


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <ATen/core/tensor.h>
#include <ATen/kernels/blas.h>
#include <test/test_utils.h>
#include <base/utils/gtest.h>

namespace container {
namespace kernels {
Expand All @@ -11,14 +11,14 @@ template <typename T>
class BlasTest : public testing::Test {
public:
BlasTest() {
test_utils::init_blas_handle();
base::utils::init_blas_handle();
}
~BlasTest() override {
test_utils::delete_blas_handle();
base::utils::delete_blas_handle();
}
};

TYPED_TEST_SUITE(BlasTest, test_utils::Types);
TYPED_TEST_SUITE(BlasTest, base::utils::Types);

TYPED_TEST(BlasTest, Dot) {
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <ATen/core/tensor.h>
#include <ATen/kernels/lapack.h>
#include <test/test_utils.h>
#include <base/utils/gtest.h>

namespace container {
namespace kernels {
Expand All @@ -11,16 +11,16 @@ template <typename T>
class LapackTest : public testing::Test {
public:
LapackTest() {
test_utils::init_blas_handle();
test_utils::init_cusolver_handle();
base::utils::init_blas_handle();
base::utils::init_cusolver_handle();
}
~LapackTest() override {
test_utils::delete_blas_handle();
test_utils::delete_cusolver_handle();
base::utils::delete_blas_handle();
base::utils::delete_cusolver_handle();
}
};

TYPED_TEST_SUITE(LapackTest, test_utils::Types);
TYPED_TEST_SUITE(LapackTest, base::utils::Types);

TYPED_TEST(LapackTest, Trtri) {
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <ATen/core/tensor.h>
#include <ATen/kernels/linalg.h>
#include <test/test_utils.h>
#include <base/utils/gtest.h>

namespace container {
namespace kernels {
Expand All @@ -15,7 +15,7 @@ class LinalgTest : public testing::Test {
~LinalgTest() override = default;
};

TYPED_TEST_SUITE(LinalgTest, test_utils::Types);
TYPED_TEST_SUITE(LinalgTest, base::utils::Types);

TYPED_TEST(LinalgTest, Add) {
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <ATen/core/tensor.h>
#include <ATen/core/tensor_map.h>
#include <ATen/kernels/memory.h>
#include <test/test_utils.h>
#include <base/utils/gtest.h>

namespace container {
namespace kernels {
Expand All @@ -15,7 +15,7 @@ class MemoryTest : public testing::Test {
~MemoryTest() override = default;
};

TYPED_TEST_SUITE(MemoryTest, test_utils::Types);
TYPED_TEST_SUITE(MemoryTest, base::utils::Types);

TYPED_TEST(MemoryTest, ResizeAndSynchronizeMemory) {
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

#include <ATen/core/tensor.h>
#include <ATen/ops/einsum_op.h>
#include <test/test_utils.h>
#include <base/utils/gtest.h>


namespace container {
namespace op {
Expand All @@ -11,16 +12,16 @@ template <typename T>
class EinsumOpTest : public testing::Test {
public:
EinsumOpTest() {
test_utils::init_blas_handle();
test_utils::init_cusolver_handle();
base::utils::init_blas_handle();
base::utils::init_cusolver_handle();
}
~EinsumOpTest() override {
test_utils::delete_blas_handle();
test_utils::delete_cusolver_handle();
base::utils::delete_blas_handle();
base::utils::delete_cusolver_handle();
}
};

TYPED_TEST_SUITE(EinsumOpTest, test_utils::Types);
TYPED_TEST_SUITE(EinsumOpTest, base::utils::Types);

TYPED_TEST(EinsumOpTest, Transform) {
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include <ATen/core/tensor.h>
#include <ATen/ops/linalg_op.h>
#include <test/test_utils.h>
#include <base/utils/gtest.h>

namespace container {
namespace op {
Expand All @@ -14,7 +14,7 @@ class LinalgOpTest : public testing::Test {
~LinalgOpTest() override = default;
};

TYPED_TEST_SUITE(LinalgOpTest, test_utils::Types);
TYPED_TEST_SUITE(LinalgOpTest, base::utils::Types);

TYPED_TEST(LinalgOpTest, Add) {
using Type = typename std::tuple_element<0, decltype(TypeParam())>::type;
Expand Down
57 changes: 57 additions & 0 deletions source/module_base/module_container/base/utils/gtest.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef BASE_UTILS_GTEST_H_
#define BASE_UTILS_GTEST_H_
#include <gtest/gtest.h>
#include <ATen/kernels/blas.h>
#include <ATen/kernels/lapack.h>

namespace base {
namespace utils {

#if __CUDA || __ROCM
using ComplexTypes = ::testing::Types<
std::tuple<std::complex<float>, ct::DEVICE_CPU>, std::tuple<std::complex<float>, ct::DEVICE_GPU>,
std::tuple<std::complex<double>, ct::DEVICE_CPU>, std::tuple<std::complex<double>, ct::DEVICE_GPU>>;
using Types = ::testing::Types<
std::tuple<float, ct::DEVICE_CPU>, std::tuple<float, ct::DEVICE_GPU>,
std::tuple<double, ct::DEVICE_CPU>, std::tuple<double, ct::DEVICE_GPU>,
std::tuple<std::complex<float>, ct::DEVICE_CPU>, std::tuple<std::complex<float>, ct::DEVICE_GPU>,
std::tuple<std::complex<double>, ct::DEVICE_CPU>, std::tuple<std::complex<double>, ct::DEVICE_GPU>>;
#else
using ComplexTypes = ::testing::Types<
std::tuple<std::complex<float>, ct::DEVICE_CPU>,
std::tuple<std::complex<double>, ct::DEVICE_CPU>>;
using Types = ::testing::Types<
std::tuple<float, ct::DEVICE_CPU>,
std::tuple<double, ct::DEVICE_CPU>,
std::tuple<std::complex<float>, ct::DEVICE_CPU>,
std::tuple<std::complex<double>, ct::DEVICE_CPU>>;
#endif

static inline void init_blas_handle() {
#if __CUDA || __ROCM
ct::kernels::createGpuBlasHandle();
#endif
}

static inline void delete_blas_handle() {
#if __CUDA || __ROCM
ct::kernels::destroyGpuBlasHandle();
#endif
}

static inline void init_cusolver_handle() {
#if __CUDA || __ROCM
ct::kernels::createGpuSolverHandle();
#endif
}

static inline void delete_cusolver_handle() {
#if __CUDA || __ROCM
ct::kernels::destroyGpuSolverHandle();
#endif
}

} // namespace utils
} // namespace base

#endif // BASE_UTILS_GTEST_H_
57 changes: 0 additions & 57 deletions source/module_base/module_container/test/test_utils.h

This file was deleted.

Loading

0 comments on commit 55a819f

Please sign in to comment.