Skip to content

Commit

Permalink
reduce peak device memory usage of psi's constructor (#4154)
Browse files Browse the repository at this point in the history
  • Loading branch information
denghuilu authored May 13, 2024
1 parent b7e91aa commit 9abeccc
Showing 1 changed file with 41 additions and 5 deletions.
46 changes: 41 additions & 5 deletions source/module_psi/psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "module_base/global_variable.h"
#include "module_base/tool_quit.h"
#include "module_psi/kernels/device.h"
#include <type_traits>

#include <cassert>
#include <complex>
Expand Down Expand Up @@ -163,11 +164,46 @@ Psi<T, Device>::Psi(const Psi<T_in, Device_in>& psi_in)
// this function will copy psi_in.psi to this->psi no matter the device types of each other.
this->device = device::get_device_type<Device>(this->ctx);
this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis());
memory::cast_memory_op<T, T_in, Device, Device_in>()(this->ctx,
psi_in.get_device(),
this->psi,
psi_in.get_pointer() - psi_in.get_psi_bias(),
psi_in.size());
// No need to cast the memory if the data types are the same.
if (std::is_same<T, T_in>::value)
{
memory::synchronize_memory_op<T, Device, Device_in>()(this->ctx,
psi_in.get_device(),
this->psi,
reinterpret_cast<T*>(psi_in.get_pointer()) - psi_in.get_psi_bias(),
psi_in.size());
}
// Specifically, if the Device_in type is CPU and the Device type is GPU:
// Which means we need to initialize a GPU psi from a given CPU psi.
// We first malloc a memory in CPU, then cast the memory from T_in to T in CPU.
// Finally, synchronize the memory from CPU to GPU.
// This could help to reduce the peak memory usage of device.
else if (std::is_same<Device, DEVICE_GPU>::value &&
std::is_same<Device_in, DEVICE_CPU>::value)
{
auto * arr = (T*) malloc(sizeof(T) * psi_in.size());
// cast the memory from T_in to T in CPU
memory::cast_memory_op<T, T_in, Device_in, Device_in>()(psi_in.get_device(),
psi_in.get_device(),
arr,
psi_in.get_pointer() - psi_in.get_psi_bias(),
psi_in.size());
// synchronize the memory from CPU to GPU
memory::synchronize_memory_op<T, Device, Device_in>()(this->ctx,
psi_in.get_device(),
this->psi,
arr,
psi_in.size());
free(arr);
}
else
{
memory::cast_memory_op<T, T_in, Device, Device_in>()(this->ctx,
psi_in.get_device(),
this->psi,
psi_in.get_pointer() - psi_in.get_psi_bias(),
psi_in.size());
}
this->psi_bias = psi_in.get_psi_bias();
this->current_nbasis = psi_in.get_current_nbas();
this->psi_current = this->psi + psi_in.get_psi_bias();
Expand Down

0 comments on commit 9abeccc

Please sign in to comment.