Skip to content

Commit

Permalink
restarting
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Dec 18, 2023
1 parent ff8bff1 commit a823c32
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
18 changes: 12 additions & 6 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2836,6 +2836,8 @@ def dmrg(
store_wfn_spectra=True,
spectra_with_multiplicity=False,
lowmem_noise=False,
sweep_start=0,
forward=None,
):
bw = self.bw
if bond_dims is None:
Expand Down Expand Up @@ -2893,15 +2895,18 @@ def dmrg(
if n_sweeps == -1:
return None
me.init_environments(iprint >= 2)
if forward is None:
forward = ket.center == 0
if twosite_to_onesite is None:
ener = dmrg.solve(n_sweeps, ket.center == 0, tol)
ener = dmrg.solve(n_sweeps, forward, tol, sweep_start)
else:
assert twosite_to_onesite < n_sweeps
ener = dmrg.solve(twosite_to_onesite, ket.center == 0, 0)
dmrg.me.dot = 1
for ext_me in dmrg.ext_mes:
ext_me.dot = 1
ener = dmrg.solve(n_sweeps, ket.center == 0, tol, twosite_to_onesite)
if sweep_start < twosite_to_onesite:
ener = dmrg.solve(twosite_to_onesite, forward, 0, sweep_start)
dmrg.me.dot = 1
for ext_me in dmrg.ext_mes:
ext_me.dot = 1
ener = dmrg.solve(n_sweeps, forward, tol, twosite_to_onesite)
ket.dot = 1
if self.mpi is not None:
self.mpi.barrier()
Expand Down Expand Up @@ -3788,6 +3793,7 @@ def align_mps_center(self, ket, ref):
if self.mpi is not None:
self.mpi.barrier()

# if restarting from the middle, this method should not be used
def adjust_mps(self, ket, dot=None):
if dot is None:
dot = ket.dot
Expand Down
8 changes: 6 additions & 2 deletions src/dmrg/sweep_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ template <typename S, typename FL, typename FLS> struct DMRG {
int davidson_def_max_size = 50;
double tprt = 0, teig = 0, teff = 0, tmve = 0, tblk = 0, tdm = 0, tsplt = 0,
tsvd = 0, torth = 0;
double accumulated_elapsed_time = 0;
bool print_connection_time = false;
// store all wfn singular values (for analysis) at each site
bool store_wfn_spectra = false;
Expand Down Expand Up @@ -2332,7 +2333,7 @@ template <typename S, typename FL, typename FLS> struct DMRG {
Timer start, current;
start.get_time();
current.get_time();
energies.resize(sweep_start);
energies.resize(sweep_start, vector<FPLS>(1, (FPLS)(FPS)0.0));
discarded_weights.resize(sweep_start);
mps_quanta.resize(sweep_start);
bool converged;
Expand Down Expand Up @@ -2370,7 +2371,9 @@ template <typename S, typename FL, typename FLS> struct DMRG {
double tswp = current.get_time();
if (iprint >= 1) {
cout << "Time elapsed = " << fixed << setw(10)
<< setprecision(3) << current.current - start.current;
<< setprecision(3)
<< current.current - start.current +
accumulated_elapsed_time;
cout << fixed << setprecision(10);
if (get<0>(sweep_results).size() == 1)
cout << " | E = " << setw(18) << get<0>(sweep_results)[0];
Expand Down Expand Up @@ -2488,6 +2491,7 @@ template <typename S, typename FL, typename FLS> struct DMRG {
break;
}
this->forward = forward;
accumulated_elapsed_time += current.current - start.current;
if (!converged && iprint > 0 && tol != 0)
cout << "ATTENTION: DMRG is not converged to desired tolerance of "
<< scientific << tol << endl;
Expand Down

0 comments on commit a823c32

Please sign in to comment.