diff --git a/pyblock2/driver/core.py b/pyblock2/driver/core.py index 969646ab..ed98e1fa 100644 --- a/pyblock2/driver/core.py +++ b/pyblock2/driver/core.py @@ -2836,6 +2836,8 @@ def dmrg( store_wfn_spectra=True, spectra_with_multiplicity=False, lowmem_noise=False, + sweep_start=0, + forward=None, ): bw = self.bw if bond_dims is None: @@ -2893,15 +2895,18 @@ def dmrg( if n_sweeps == -1: return None me.init_environments(iprint >= 2) + if forward is None: + forward = ket.center == 0 if twosite_to_onesite is None: - ener = dmrg.solve(n_sweeps, ket.center == 0, tol) + ener = dmrg.solve(n_sweeps, forward, tol, sweep_start) else: assert twosite_to_onesite < n_sweeps - ener = dmrg.solve(twosite_to_onesite, ket.center == 0, 0) - dmrg.me.dot = 1 - for ext_me in dmrg.ext_mes: - ext_me.dot = 1 - ener = dmrg.solve(n_sweeps, ket.center == 0, tol, twosite_to_onesite) + if sweep_start < twosite_to_onesite: + ener = dmrg.solve(twosite_to_onesite, forward, 0, sweep_start) + dmrg.me.dot = 1 + for ext_me in dmrg.ext_mes: + ext_me.dot = 1 + ener = dmrg.solve(n_sweeps, forward, tol, twosite_to_onesite) ket.dot = 1 if self.mpi is not None: self.mpi.barrier() @@ -3788,6 +3793,7 @@ def align_mps_center(self, ket, ref): if self.mpi is not None: self.mpi.barrier() + # if restarting from the middle, this method should not be used def adjust_mps(self, ket, dot=None): if dot is None: dot = ket.dot diff --git a/src/dmrg/sweep_algorithm.hpp b/src/dmrg/sweep_algorithm.hpp index 33587c99..8594fa15 100644 --- a/src/dmrg/sweep_algorithm.hpp +++ b/src/dmrg/sweep_algorithm.hpp @@ -113,6 +113,7 @@ template struct DMRG { int davidson_def_max_size = 50; double tprt = 0, teig = 0, teff = 0, tmve = 0, tblk = 0, tdm = 0, tsplt = 0, tsvd = 0, torth = 0; + double accumulated_elapsed_time = 0; bool print_connection_time = false; // store all wfn singular values (for analysis) at each site bool store_wfn_spectra = false; @@ -2332,7 +2333,7 @@ template struct DMRG { Timer start, current; start.get_time(); current.get_time(); - energies.resize(sweep_start); + energies.resize(sweep_start, vector(1, (FPLS)(FPS)0.0)); discarded_weights.resize(sweep_start); mps_quanta.resize(sweep_start); bool converged; @@ -2370,7 +2371,9 @@ template struct DMRG { double tswp = current.get_time(); if (iprint >= 1) { cout << "Time elapsed = " << fixed << setw(10) - << setprecision(3) << current.current - start.current; + << setprecision(3) + << current.current - start.current + + accumulated_elapsed_time; cout << fixed << setprecision(10); if (get<0>(sweep_results).size() == 1) cout << " | E = " << setw(18) << get<0>(sweep_results)[0]; @@ -2488,6 +2491,7 @@ template struct DMRG { break; } this->forward = forward; + accumulated_elapsed_time += current.current - start.current; if (!converged && iprint > 0 && tol != 0) cout << "ATTENTION: DMRG is not converged to desired tolerance of " << scientific << tol << endl;