Skip to content

Commit

Permalink
context multi two dot
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jun 26, 2024
1 parent d31c2bd commit d081c81
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 67 deletions.
11 changes: 9 additions & 2 deletions src/core/iterative_matrix_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
shared_ptr<VectorAllocator<FP>> x_alloc =
make_shared<VectorAllocator<FP>>();
int k = (int)vs.size(), nor = (int)ors.size(), nwg = 0;
int orig_k = k;
assert(!(davidson_type & DavidsonTypes::Exact));
assert(!(davidson_type & DavidsonTypes::NonHermitian));
// if proj_weights is empty or ElementProj, then projection is done by
Expand Down Expand Up @@ -579,7 +580,12 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
<< " for Davidson unitary to all given states (you are "
"possibly targeting a global symmetry sector with no "
"states or MPS has zero norm)!";
throw runtime_error(ss.str());
ss << "Space size = " << bs[i].size();
if (i > 0) {
k = i;
break;
} else
throw runtime_error(ss.str());
}
iscale(bs[i], (FP)1.0 / normx);
}
Expand All @@ -595,6 +601,7 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
<< " for Davidson unitary to all given states (you are "
"possibly targeting a global symmetry sector with no "
"states or MPS has zero norm)!";
ss << "Space size = " << bs[i].size();
throw runtime_error(ss.str());
}
iscale(bs[i], (FP)1.0 / normx);
Expand Down Expand Up @@ -830,7 +837,7 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
break;
}
if (xiter == soft_max_iter)
eigvals.resize(k, 0);
eigvals.resize(orig_k, 0);
if (xiter == max_iter) {
cout << "Error : only " << ck << " converged!" << endl;
assert(false);
Expand Down
22 changes: 16 additions & 6 deletions src/dmrg/effective_hamiltonian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ struct EffectiveHamiltonian<S, FL, MultiMPS<S, FL>> {
shared_ptr<SparseMatrixGroup<S, FL>> diag;
vector<shared_ptr<SparseMatrixGroup<S, FL>>> bra, ket;
shared_ptr<SparseMatrixGroup<S, FL>> cmat, vmat;
vector<shared_ptr<SparseMatrixGroup<S, FL>>> context_mask;
shared_ptr<TensorFunctions<S, FL>> tf;
shared_ptr<SymbolicColumnVector<S>> hop_mat;
// Delta quantum of effective H
Expand Down Expand Up @@ -1584,6 +1585,9 @@ struct EffectiveHamiltonian<S, FL, MultiMPS<S, FL>> {
int ndav = 0;
assert(compute_diag);
GDiagonalMatrix<FL> aa(diag->data, (MKL_INT)diag->total_memory);
GMatrix<FL> cmask(nullptr, (MKL_INT)ket[0]->total_memory, 1);
if (this->context_mask.size() != 0)
cmask.data = this->context_mask[0]->data;
vector<GMatrix<FL>> bs;
for (int i = 0; i < (int)min((MKL_INT)ket.size(), (MKL_INT)aa.n); i++)
bs.push_back(
Expand All @@ -1604,12 +1608,15 @@ struct EffectiveHamiltonian<S, FL, MultiMPS<S, FL>> {
tf->opf->seq->cumulative_nflop = 0;
precompute();
const function<void(const GMatrix<FL> &, const GMatrix<FL> &)> &f =
[this](const GMatrix<FL> &a, const GMatrix<FL> &b) {
[this, &cmask](const GMatrix<FL> &a, const GMatrix<FL> &b) {
if (this->tf->opf->seq->mode == SeqTypes::Auto ||
(this->tf->opf->seq->mode & SeqTypes::Tasked))
return this->tf->operator()(a, b, (FL)1.0);
this->tf->operator()(a, b, (FL)1.0);
else
return (*this)(a, b, 0, (FL)1.0);
(*this)(a, b, 0, (FL)1.0);
if (cmask.data != nullptr)
GMatrixFunctions<FL>::elementwise("*", (FL)1.0, cmask,
(FL)1.0, b, b, (FL)0.0);
};
vector<FP> xeners;
if (metric == nullptr)
Expand All @@ -1621,12 +1628,15 @@ struct EffectiveHamiltonian<S, FL, MultiMPS<S, FL>> {
else {
metric->precompute();
const function<void(const GMatrix<FL> &, const GMatrix<FL> &)> &mf =
[metric](const GMatrix<FL> &a, const GMatrix<FL> &b) {
[metric, &cmask](const GMatrix<FL> &a, const GMatrix<FL> &b) {
if (metric->tf->opf->seq->mode == SeqTypes::Auto ||
(metric->tf->opf->seq->mode & SeqTypes::Tasked))
return metric->tf->operator()(a, b, (FL)1.0);
metric->tf->operator()(a, b, (FL)1.0);
else
return (*metric)(a, b, 0, (FL)1.0);
(*metric)(a, b, 0, (FL)1.0);
if (cmask.data != nullptr)
GMatrixFunctions<FL>::elementwise(
"*", (FL)1.0, cmask, (FL)1.0, b, b, (FL)0.0);
};
xeners = IterativeMatrixFunctions<FL>::davidson_generalized(
f, mf, aa, bs, shift, davidson_type, ndav, iprint,
Expand Down
147 changes: 102 additions & 45 deletions src/dmrg/moving_environment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2038,6 +2038,12 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
bool infer_info,
shared_ptr<SparseMatrix<S, FLS>> ket = nullptr,
shared_ptr<SparseMatrix<S, FLS>> cket = nullptr) {
if (mps->get_type() & MPSTypes::MultiWfn)
mps->info->target =
dynamic_pointer_cast<MultiMPSInfo<S>>(mps->info)->targets[0];
if (cmps->get_type() & MPSTypes::MultiWfn)
cmps->info->target =
dynamic_pointer_cast<MultiMPSInfo<S>>(cmps->info)->targets[0];
return symm_context_convert_impl(
i, mps->info, cmps->info, dot, fuse_left, mask, forward,
is_wfn, infer_info, false,
Expand All @@ -2049,16 +2055,40 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
nullptr, nullptr)
.first;
}
static vector<shared_ptr<SparseMatrixGroup<S, FLS>>>
symm_context_convert_group(int i, const shared_ptr<MultiMPS<S, FLS>> &mps,
const shared_ptr<MultiMPS<S, FLS>> &cmps,
int dot, bool fuse_left, bool mask, bool forward,
bool is_wfn, bool infer_info) {
const size_t nw =
mask ? 1 : (forward ? mps->wfns.size() : cmps->wfns.size());
vector<shared_ptr<SparseMatrixGroup<S, FLS>>> rwfns(nw);
for (int iw = 0; iw < (int)nw; iw++) {
rwfns[iw] =
symm_context_convert_impl(
i, mps->info, cmps->info, dot, fuse_left, mask, forward,
is_wfn, infer_info, false, nullptr, nullptr,
!(!forward && infer_info) ? mps->wfns[iw] : nullptr,
!(forward && infer_info) ? cmps->wfns[iw] : nullptr)
.second;
}
return rwfns;
}
static shared_ptr<SparseMatrixGroup<S, FLS>>
symm_context_convert_perturbative(
int i, const shared_ptr<MPS<S, FLS>> &mps,
const shared_ptr<MPS<S, FLS>> &cmps, int dot, bool fuse_left, bool mask,
bool forward, bool is_wfn, bool infer_info,
const shared_ptr<SparseMatrixGroup<S, FLS>> &pket) {
return symm_context_convert_impl(i, mps->info, cmps->info, dot,
fuse_left, mask, forward, is_wfn,
infer_info, true, mps->tensors[i],
cmps->tensors[i], pket, nullptr)
if (mps->get_type() & MPSTypes::MultiWfn)
mps->info->target =
dynamic_pointer_cast<MultiMPSInfo<S>>(mps->info)->targets[0];
if (cmps->get_type() & MPSTypes::MultiWfn)
cmps->info->target =
dynamic_pointer_cast<MultiMPSInfo<S>>(cmps->info)->targets[0];
return symm_context_convert_impl(
i, mps->info, cmps->info, dot, fuse_left, mask, forward,
is_wfn, infer_info, true, nullptr, nullptr, pket, nullptr)
.second;
}
// forward = proj to high symmetry
Expand Down Expand Up @@ -2164,13 +2194,9 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
}
gr_wfn->allocate(infos);
gr_wfn->clear();
} else if (!is_group && infer_info) {
shared_ptr<SparseMatrixInfo<S>> xinfo =
make_shared<SparseMatrixInfo<S>>(i_alloc);
} else if (infer_info) {
shared_ptr<StateInfo<S>> xll = forward ? llu : ll;
shared_ptr<StateInfo<S>> xrr = forward ? rru : rr;
S xdq = is_wfn ? (forward ? cinfo->target : info->target)
: (forward ? cinfo->vacuum : info->vacuum);
if (fuse_left) {
xrr = forward ? TransStateInfo<S, S>::forward(rr, refu)
: TransStateInfo<S, S>::forward(rru, ref);
Expand All @@ -2186,49 +2212,40 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
else
ll = xll, l = *xll;
}
xinfo->initialize(*xll, *xrr, xdq, false, is_wfn);
r_wfn->allocate(xinfo);
r_wfn->clear();
if (is_group) {
int nxw = forward ? pket->n : cpket->n;
vector<shared_ptr<SparseMatrixInfo<S>>> infos(nxw);
for (int iw = 0; iw < nxw; iw++) {
S xdq =
is_wfn
? (forward
? dynamic_pointer_cast<MultiMPSInfo<S>>(
cinfo)
->targets[iw]
: dynamic_pointer_cast<MultiMPSInfo<S>>(info)
->targets[iw])
: (forward ? cinfo->vacuum : info->vacuum);
infos[iw] = make_shared<SparseMatrixInfo<S>>(i_alloc);
infos[iw]->initialize(*xll, *xrr, xdq, false, is_wfn);
}
gr_wfn->allocate(infos);
gr_wfn->clear();
} else {
S xdq = is_wfn ? (forward ? cinfo->target : info->target)
: (forward ? cinfo->vacuum : info->vacuum);
shared_ptr<SparseMatrixInfo<S>> xinfo =
make_shared<SparseMatrixInfo<S>>(i_alloc);
xinfo->initialize(*xll, *xrr, xdq, false, is_wfn);
r_wfn->allocate(xinfo);
r_wfn->clear();
}
} else if (is_group) {
gr_wfn->allocate(forward ? cpket->infos : pket->infos);
gr_wfn->clear();
} else {
r_wfn->allocate(forward ? cket->info : ket->info);
r_wfn->clear();
}
S cptu = cinfo->target, cpt = info->target;
shared_ptr<StateInfo<S>> cplu =
is_wfn || fuse_left ? make_shared<StateInfo<S>>(lu)
: make_shared<StateInfo<S>>(
StateInfo<S>::complementary(lu, cptu));
shared_ptr<StateInfo<S>> cpl =
is_wfn || fuse_left ? make_shared<StateInfo<S>>(l)
: make_shared<StateInfo<S>>(
StateInfo<S>::complementary(l, cpt));
shared_ptr<StateInfo<S>> cpru =
is_wfn || !fuse_left ? make_shared<StateInfo<S>>(
StateInfo<S>::complementary(ru, cptu))
: make_shared<StateInfo<S>>(ru);
shared_ptr<StateInfo<S>> cpr =
is_wfn || !fuse_left
? make_shared<StateInfo<S>>(StateInfo<S>::complementary(r, cpt))
: make_shared<StateInfo<S>>(r);
shared_ptr<StateInfo<S>> conn_l =
TransStateInfo<S, S>::backward_connection(cplu, cpl);
shared_ptr<StateInfo<S>> conn_lm =
dot == 2 || (dot != 0 && fuse_left)
? TransStateInfo<S, S>::backward_connection(
make_shared<StateInfo<S>>(mlu),
make_shared<StateInfo<S>>(ml))
: nullptr;
shared_ptr<StateInfo<S>> conn_mr =
dot == 2 || (dot != 0 && !fuse_left)
? TransStateInfo<S, S>::backward_connection(
make_shared<StateInfo<S>>(mru),
make_shared<StateInfo<S>>(mr))
: nullptr;
shared_ptr<StateInfo<S>> conn_r =
TransStateInfo<S, S>::backward_connection(cpru, cpr);
map<array<S, 2>, pair<FLS *, size_t>> mp0;
map<array<S, 3>, pair<FLS *, size_t>> mp;
map<array<S, 4>, pair<FLS *, size_t>> mp2;
Expand Down Expand Up @@ -2315,6 +2332,46 @@ template <typename S, typename FL, typename FLS> struct MovingEnvironment {
}
nxw = is_group ? (forward ? gr_wfn->n : cpket->n) : 1;
for (int iw = 0; iw < nxw; iw++) {
S cptu = cinfo->target, cpt = info->target;
if (is_group && !is_pert) {
cptu =
dynamic_pointer_cast<MultiMPSInfo<S>>(cinfo)->targets[iw];
cpt = dynamic_pointer_cast<MultiMPSInfo<S>>(info)->targets[iw];
}
shared_ptr<StateInfo<S>> cplu =
is_wfn || fuse_left
? make_shared<StateInfo<S>>(lu)
: make_shared<StateInfo<S>>(
StateInfo<S>::complementary(lu, cptu));
shared_ptr<StateInfo<S>> cpl =
is_wfn || fuse_left ? make_shared<StateInfo<S>>(l)
: make_shared<StateInfo<S>>(
StateInfo<S>::complementary(l, cpt));
shared_ptr<StateInfo<S>> cpru =
is_wfn || !fuse_left
? make_shared<StateInfo<S>>(
StateInfo<S>::complementary(ru, cptu))
: make_shared<StateInfo<S>>(ru);
shared_ptr<StateInfo<S>> cpr =
is_wfn || !fuse_left ? make_shared<StateInfo<S>>(
StateInfo<S>::complementary(r, cpt))
: make_shared<StateInfo<S>>(r);
shared_ptr<StateInfo<S>> conn_l =
TransStateInfo<S, S>::backward_connection(cplu, cpl);
shared_ptr<StateInfo<S>> conn_lm =
dot == 2 || (dot != 0 && fuse_left)
? TransStateInfo<S, S>::backward_connection(
make_shared<StateInfo<S>>(mlu),
make_shared<StateInfo<S>>(ml))
: nullptr;
shared_ptr<StateInfo<S>> conn_mr =
dot == 2 || (dot != 0 && !fuse_left)
? TransStateInfo<S, S>::backward_connection(
make_shared<StateInfo<S>>(mru),
make_shared<StateInfo<S>>(mr))
: nullptr;
shared_ptr<StateInfo<S>> conn_r =
TransStateInfo<S, S>::backward_connection(cpru, cpr);
shared_ptr<SparseMatrix<S, FLS>> cwfn =
forward ? (is_group ? (*gr_wfn)[iw] : r_wfn)
: (is_group ? (*cpket)[iw] : cket);
Expand Down
Loading

0 comments on commit d081c81

Please sign in to comment.