Skip to content

Function of branch cuda11

Asuna981002 edited this page Nov 1, 2021 · 2 revisions

hamilt.cu

Call CG on device

(device: GPU)

(host: CPU)

(cudaMalloc->cudaMemcpyH2D->CG_on_GPU->cudaMemcpyD2H)

{
    // to gpu
    // CHECK_CUDA(cudaMalloc((void**)&d_wf_evc, GlobalV::NBANDS * DIM_CG_CUDA2 * sizeof(double2)));
    // CHECK_CUDA(cudaMalloc((void**)&d_wf_ekb, GlobalV::NBANDS * sizeof(double)));
    CHECK_CUDA(cudaMalloc((void**)&d_precondition, DIM_CG_CUDA2 * sizeof(double)));

    // CHECK_CUDA(cudaMemcpy(d_wf_evc, GlobalC::wf.evc[ik0].c, GlobalV::NBANDS * DIM_CG_CUDA2 * sizeof(double2), cudaMemcpyHostToDevice));
    // CHECK_CUDA(cudaMemcpy(d_wf_ekb, GlobalC::wf.ekb[ik], GlobalV::NBANDS * sizeof(double), cudaMemcpyHostToDevice));
    CHECK_CUDA(cudaMemcpy(d_precondition, precondition, DIM_CG_CUDA2 * sizeof(double), cudaMemcpyHostToDevice));
    // do things
    CHECK_CUFFT(cufftPlan3d(&GlobalC::UFFT.fft_handle, GlobalC::pw.nx, GlobalC::pw.ny, GlobalC::pw.nz, CUFFT_Z2Z));

    d_cg_cuda.diag(d_wf_evc, d_wf_ekb, d_vkb_c, DIM_CG_CUDA2, DIM_CG_CUDA2,
        GlobalV::NBANDS, d_precondition, GlobalV::ETHR,
        GlobalV::DIAGO_CG_MAXITER, reorder, notconv, avg);
    CHECK_CUFFT(cufftDestroy(GlobalC::UFFT.fft_handle));

    // to cpu
    CHECK_CUDA(cudaMemcpy(GlobalC::wf.evc[ik0].c, d_wf_evc, GlobalV::NBANDS * DIM_CG_CUDA2 * sizeof(double2), cudaMemcpyDeviceToHost));
    CHECK_CUDA(cudaMemcpy(GlobalC::wf.ekb[ik], d_wf_ekb, GlobalV::NBANDS * sizeof(double), cudaMemcpyDeviceToHost));

    // CHECK_CUDA(cudaFree(d_wf_evc));
    // CHECK_CUDA(cudaFree(d_wf_ekb));
    CHECK_CUDA(cudaFree(d_precondition));
}

Hybrid precision conversion interface

double(host) -> double(device) -> float(device) -> Call CG(float) -> float result(device) -> double result(device) -> double result(host)

{
    CHECK_CUDA(cudaMalloc((void**)&f_wf_evc, GlobalV::NBANDS * GlobalC::wf.npwx * sizeof(float2)));
    CHECK_CUDA(cudaMalloc((void**)&f_wf_ekb, GlobalV::NBANDS * sizeof(float)));
    CHECK_CUDA(cudaMalloc((void**)&f_precondition, DIM_CG_CUDA * sizeof(float)));

    // add vkb_c parameter
    CHECK_CUDA(cudaMalloc((void**)&f_vkb_c, GlobalC::wf.npwx*nkb*sizeof(float2)));

    int thread = 512;
    int block = GlobalV::NBANDS * GlobalC::wf.npwx / thread + 1;
    int block2 = GlobalV::NBANDS / thread + 1;
    int block3 = DIM_CG_CUDA / thread + 1;
    int block4 = GlobalC::wf.npwx*nkb / thread + 1;

    hamilt_cast_d2f<<<block, thread>>>(f_wf_evc, d_wf_evc, GlobalV::NBANDS * GlobalC::wf.npwx);
    hamilt_cast_d2f<<<block3, thread>>>(f_precondition, d_precondition, DIM_CG_CUDA);
    // add vkb_c parameter
    hamilt_cast_d2f<<<block4, thread>>>(f_vkb_c, d_vkb_c, GlobalC::wf.npwx*nkb);
    // CHECK_CUFFT(cufftPlan3d(&GlobalC::UFFT.fft_handle, GlobalC::pw.nx, GlobalC::pw.ny, GlobalC::pw.nz, CUFFT_C2C));
    // cout<<"Do float CG ..."<<endl;
    f_cg_cuda.diag(f_wf_evc, f_wf_ekb, f_vkb_c, DIM_CG_CUDA, GlobalC::wf.npwx,
        GlobalV::NBANDS, f_precondition, GlobalV::ETHR,
        GlobalV::DIAGO_CG_MAXITER, reorder, notconv, avg);
    hamilt_cast_f2d<<<block, thread>>>(d_wf_evc, f_wf_evc, GlobalV::NBANDS * GlobalC::wf.npwx);
    hamilt_cast_f2d<<<block2, thread>>>(d_wf_ekb, f_wf_ekb, GlobalV::NBANDS);

    CHECK_CUDA(cudaFree(f_vkb_c));
}

Initialize CUFFT handle

How to call a CUFFT operation:

  1. cufft Handle_t handle;
  2. cufftCreate(&handle);
  3. cufftPlan1d(handle, ...);
  4. cufftExec(handle, ...);
  5. cufftDestroy(handle);

The 3rd and 5th steps take up more time, so we set cufft_handle as a member of class Use_FFT and reuse it in several files.

// in hamilt.cu
{
    CHECK_CUFFT(cufftPlan3d(&GlobalC::UFFT.fft_handle, GlobalC::pw.nx, GlobalC::pw.ny, GlobalC::pw.nz, CUFFT_Z2Z));
    d_cg_cuda.diag(d_wf_evc, d_wf_ekb, d_vkb_c, DIM_CG_CUDA2, DIM_CG_CUDA2,
        GlobalV::NBANDS, d_precondition, GlobalV::ETHR,
        GlobalV::DIAGO_CG_MAXITER, reorder, notconv, avg);
    CHECK_CUFFT(cufftDestroy(GlobalC::UFFT.fft_handle));
}
// in use_fft.h
#ifdef __CUDA
	cufftHandle fft_handle;
	void RoundTrip(const float2 *psi, const float *vr, const int *fft_index, float2 *psic)
	{
		RoundTrip_kernel(psi, vr, fft_index, psic);
	}
	void RoundTrip(const double2 *psi, const double *vr, const int *fft_index, double2 *psic)
	{
		RoundTrip_kernel(psi, vr, fft_index, psic);
	}
#endif

diagH_subspace_cuda

  1. hm.diagH_subspace_cuda
  2. hpw.diagH_subspace_cuda
  3. hm.diagH_LAPACK / hm.diagH_CUSOLVER (in hamilt.cu)
// in hamilt_pw.cu
// Method1 : Do with diagH_LAPACK
ModuleBase::ComplexMatrix h_hc(nstart, nstart);
ModuleBase::ComplexMatrix h_sc(nstart, nstart);
ModuleBase::ComplexMatrix h_hvec(nstart,n_band);

double *h_en = new double[n_band];

CHECK_CUDA(cudaMemcpy(h_hc.c, hc, nstart*nstart*sizeof(double2), cudaMemcpyDeviceToHost));
CHECK_CUDA(cudaMemcpy(h_sc.c, sc, nstart*nstart*sizeof(double2), cudaMemcpyDeviceToHost));

GlobalC::hm.diagH_LAPACK(nstart, n_band, h_hc, h_sc, nstart, h_en, h_hvec);
CHECK_CUDA(cudaMemcpy(hvec, h_hvec.c, nstart*n_band*sizeof(double2), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(en, h_en, n_band*sizeof(double), cudaMemcpyHostToDevice));
delete [] h_en;

// Method2 : Do with diagH_CUSOLVER
// GlobalC::hm.diagH_CUSOLVER(nstart, n_band, hc, sc, nstart, en, hvec);

diagH_CUSOLVER (diagH_LAPACK's alternative)

// How to call a cusolver(dense) API
cusolverDnHandle_t handle;
cusolverDnCreate(&handle);
int lwork;
cusolverDnZhegvd_buffersize(..., &lwork);
malloc(buffer, sizeof(lwork));
cusolverDnZhegvd(handle, ..., buffer);
cusolverDnDestroy(handle);
// in diagH_CUSOLVER
int cusolver_lwork = 0;
CHECK_CUSOLVER(cusolverDnZhegvd_bufferSize(
    cusolver_handle,
    // TODO : handle
    CUSOLVER_EIG_TYPE_1,
    CUSOLVER_EIG_MODE_VECTOR,
    CUBLAS_FILL_MODE_UPPER,
    nstart,
    hvec,
    ldh,
    sdum,
    ldh,
    e,
    &cusolver_lwork));
// cout<<"work_space: "<<cusolver_lwork<<endl;
double2 *cusolver_work;
CHECK_CUDA(cudaMalloc((void**)&cusolver_work, cusolver_lwork*sizeof(double2)));
CHECK_CUSOLVER(cusolverDnZhegvd(
    cusolver_handle,
    CUSOLVER_EIG_TYPE_1,
    CUSOLVER_EIG_MODE_VECTOR,
    CUBLAS_FILL_MODE_UPPER,
    nstart,
    hvec,
    ldh,
    sdum,
    ldh,
    e,
    cusolver_work,
    cusolver_lwork,
    device_info));
CHECK_CUDA(cudaFree(cusolver_work));

hamilt_pw.cu

GR_index & d_GR_index

  1. Get GR_index on CPU;
  2. MemcpyH2D GR_index to d_GR_index;
  3. Use d_GR_index todo FFT on GPU.
// (6) The index of plane waves.
// int *GR_index_tmp = new int[GlobalC::pw.nrxx];
for (int ig = 0;ig < GlobalC::wf.npw;ig++)
{
    GR_index[ig] = GlobalC::pw.ig2fftw[ GlobalC::wf.igk(ik, ig) ];
}
#ifdef __CUDA
    CHECK_CUDA(cudaMemcpy(this->GR_index_d, GR_index, GlobalC::pw.nrxx*sizeof(int), cudaMemcpyHostToDevice));
#endif

hpw_handle

CUBLAS API are used in hamilt_pw to do some operations to matrix, which have to use a cublasHandle.

{
    cublasOperation_t trans1 = CUBLAS_OP_C;
	cublasOperation_t trans2 = CUBLAS_OP_N;
	CHECK_CUBLAS(cublasZgemm(hpw_handle, trans1, trans2, nstart, nstart, dmin, &ONE, psi_c, dmax, aux, dmax, &ZERO, tmp_hc, nstart));
	// hc=transpose(hc,false); // TODO: transpose
	
	// use 'geam' API todo transpose.
	double2 t_alpha, t_beta;
	t_alpha.y = t_beta.x = t_beta.y = 0.0;
	t_alpha.x = 1.0;
	CHECK_CUBLAS(cublasZgeam(hpw_handle, CUBLAS_OP_T, CUBLAS_OP_T, nstart, nstart, &t_alpha, tmp_hc, nstart, &t_beta, tmp_hc, nstart, hc, nstart));

	CHECK_CUBLAS(cublasZgemm(hpw_handle, trans1, trans2, nstart, nstart, dmin, &ONE, psi_c, dmax, psi_c, dmax, &ZERO, tmp_hc, nstart));
	// sc=transpose(sc,false); // TODO: transpose
	CHECK_CUBLAS(cublasZgeam(hpw_handle, CUBLAS_OP_T, CUBLAS_OP_T, nstart, nstart, &t_alpha, tmp_hc, nstart, &t_beta, tmp_hc, nstart, sc, nstart));
}

diagH_subspace_cuda

  • Method1: Do with diagH_LAPACK
  • Method2: Do with diagH_CUSOLVER

s_psi_cuda

A simple memory copy. Could be used in atomic orbital algorithms in the future.

void Hamilt_PW::s_1psi_cuda(const int dim, const float2 *psi, float2 *spsi)
{
    CHECK_CUDA(cudaMemcpy(spsi, psi, dim*sizeof(float2), cudaMemcpyDeviceToDevice));
    return;
}

hpsi_cuda (calculate H*psi)

Part-I Kinetic

  1. Memory copy *g2kin H2D;
  2. Calculate with cuda kernel.
double* d_g2kin;
CHECK_CUDA(cudaMalloc((void**)&d_g2kin, GlobalC::wf.npwx*sizeof(double)));
CHECK_CUDA(cudaMemcpy(d_g2kin, GlobalC::wf.g2kin, GlobalC::wf.npw*sizeof(double), cudaMemcpyHostToDevice));
for(int ib = 0 ; ib < m; ++ib)
{
    // cout<<"in hpsi-Kinetic, iband = "<<ib<<endl;

    int thread = 512;
    int block = (GlobalC::wf.npw + thread - 1) / thread;
    kernel_get_tmhpsi<double, double2><<<block, thread>>>(GlobalC::wf.npw, tmhpsi, tmpsi_in, d_g2kin);
    tmhpsi += GlobalC::wf.npwx;
    tmpsi_in += GlobalC::wf.npwx;
}
CHECK_CUDA(cudaFree(d_g2kin));
template<class T, class T2>
__global__ void kernel_get_tmhpsi(int size, T2 *dst, const T2 *src, T *g2kin)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if(idx < size)
    {
        dst[idx].x = src[idx].x * g2kin[idx];
        dst[idx].y = src[idx].y * g2kin[idx];
    }
}

Part-II Vloc

  1. Memory copy vr_eff1 H2D;
  2. UFFT.RoundTrip(psi, d_GR_index, d_vr_eff);
    • FFT_forward(psi);
    • psi *= d_vr_eff;
    • FFT_backward(psi);
    • Normalization(psi);
int thread = 512;
int block = (GlobalC::wf.npw + thread - 1) / thread;
int block2 = (GlobalC::pw.nrxx + thread - 1) / thread;
kernel_set<double2><<<block, thread>>>(GlobalC::wf.npw, psic, psi, fft_index);

CHECK_CUFFT(cufftExecZ2Z(GlobalC::UFFT.fft_handle, psic, psic, CUFFT_INVERSE));
kernel_roundtrip<double, double2><<<block2, thread>>>(GlobalC::pw.nrxx, psic, vr);

CHECK_CUFFT(cufftExecZ2Z(GlobalC::UFFT.fft_handle, psic, psic, CUFFT_FORWARD));
cudaDeviceSynchronize();

int block3 = (GlobalC::pw.nrxx + thread - 1) / thread;
kernel_normalization<double, double2><<<block3, thread>>>(GlobalC::pw.nrxx, psic, (double)(GlobalC::pw.nrxx));

Part-III Vnl

  • Gemm + Gemv (with CUBLAS)
  • Add nonlocal pp (Add vkb_device as a parameter to give better performance.)
// in class Hamilt_PW
void h_psi_cuda(
    const double2 *psi,
    double2 *hpsi,
    double2 *vkb_c,
    const int m = 1);

void add_nonlocal_pp_cuda(
    double2 *hpsi_in,
    const double2 *becp,
    const double2 *d_vkb_c,
    const int m);
Clone this wiki locally