Skip to content

Commit

Permalink
sany det coeffs
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Jun 22, 2024
1 parent 4c195e6 commit f987afc
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 27 deletions.
21 changes: 18 additions & 3 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5546,6 +5546,18 @@ def get_csf_coefficients(
if iprint:
print("mps center changed (temporarily)")

if iprint and SymmetryTypes.SAny in bw.symm_type:
print("basis mapping:")
kk = 0
for j in range(ket.info.basis[0].n):
for jm in range(ket.info.basis[0].quanta[j].multiplicity):
for jj in range(ket.info.basis[0].n_states[j]):
print(
" [%2d] %24s : m=%2d state=%2d"
% (kk, ket.info.basis[0].quanta[j], jm, jj)
)
kk += 1

tx = time.perf_counter()
dtrie = bw.bs.DeterminantTRIE(ket.n_sites, True)
ddstr = "0+-2" if SymmetryTypes.SU2 in bw.symm_type else "0ab2"
Expand Down Expand Up @@ -5575,11 +5587,14 @@ def get_csf_coefficients(
"Sum of weights of included %s = %20.15f\n" % (dname, (dvals**2).sum())
)
for ii, idx in enumerate(gidx):
arr = np.array(dtrie[idx])
if self.reorder_idx is not None:
rev_idx = np.argsort(self.reorder_idx)
det = "".join([ddstr[x] for x in np.array(dtrie[idx])[rev_idx]])
arr = arr[rev_idx]
if SymmetryTypes.SAny in bw.symm_type:
det = "".join(["%s" % x for x in arr])
else:
det = "".join([ddstr[x] for x in np.array(dtrie[idx])])
det = "".join([ddstr[x] for x in arr])
val = dvals[idx]
print(dname, "%10d" % ii, det, " = %20.15f" % val)
if len(dvals) > max_print:
Expand Down Expand Up @@ -6935,7 +6950,7 @@ def get_mps_from_csf_coefficients(
dtrie = bw.bs.DeterminantTRIE(self.n_sites, True)
ddstr = "0+-2" if SymmetryTypes.SU2 in bw.symm_type else "0ab2"

if iprint:
if iprint and SymmetryTypes.SAny in bw.symm_type:
print("basis mapping:")
kk = 0
for j in range(self.ghamil.basis[0].n):
Expand Down
24 changes: 13 additions & 11 deletions src/core/complex_matrix_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ struct GMatrixFunctions<FL, typename enable_if<is_complex<FL>::value>::type> {
// conj can be 0 (no conj no trans), 1 (trans), 3 (conj trans)
static void multiply(const GMatrix<FL> &a, uint8_t conja,
const GMatrix<FL> &b, uint8_t conjb,
const GMatrix<FL> &c, FL scale, FL cfactor) {
const GMatrix<FL> &c, FL scale, FL cfactor,
MKL_INT ldb = 0) {
ldb = ldb ? ldb : b.n;
static const char ntxc[5] = "ntxc";
// if assertion failes here, check whether it is the case
// where different bra and ket are used with the transpose rule
Expand All @@ -951,48 +953,48 @@ struct GMatrixFunctions<FL, typename enable_if<is_complex<FL>::value>::type> {
if (!(conja & 1) && !(conjb & 1)) {
assert(a.n >= b.m && c.m == a.m && c.n >= b.n);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.n, &c.m, &b.m, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if (!(conja & 1) && (conjb & 1)) {
assert(a.n >= b.n && c.m == a.m && c.n >= b.m);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.m, &c.m, &b.n, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if ((conja & 1) && !(conjb & 1)) {
assert(a.m == b.m && c.m <= a.n && c.n >= b.n);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.n, &c.m, &b.m, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if ((conja & 1) && (conjb & 1)) {
assert(a.m == b.n && c.m <= a.n && c.n >= b.m);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.m, &c.m, &b.n, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
}
#else
if (!conja && !conjb) {
assert(a.n >= b.m && c.m == a.m && c.n >= b.n);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.n, &c.m, &b.m, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if (!conja && conjb != 2) {
assert(a.n >= b.n && c.m == a.m && c.n >= b.m);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.m, &c.m, &b.n, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if (conja != 2 && !conjb) {
assert(a.m == b.m && c.m <= a.n && c.n >= b.n);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.n, &c.m, &b.m, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if (conja != 2 && conjb != 2) {
assert(a.m == b.n && c.m <= a.n && c.n >= b.m);
xgemm<FL>(ntxc + conjb, ntxc + conja, &b.m, &c.m, &b.n, &scale,
b.data, &b.n, a.data, &a.n, &cfactor, c.data, &c.n);
b.data, &ldb, a.data, &a.n, &cfactor, c.data, &c.n);
} else if (conja == 2 && conjb != 2) {
const MKL_INT one = 1;
for (MKL_INT k = 0; k < c.m; k++)
xgemm<FL>(ntxc + conjb, "c", (conjb & 1) ? &b.m : &b.n, &one,
(conjb & 1) ? &b.n : &b.m, &scale, b.data, &b.n,
(conjb & 1) ? &b.n : &b.m, &scale, b.data, &ldb,
&a(k, 0), &one, &cfactor, &c(k, 0), &c.n);
} else if (conja != 3 && conjb == 2) {
const MKL_INT one = 1;
for (MKL_INT k = 0; k < c.m; k++)
xgemm<FL>(ntxc + (conja ^ 1), "c", &one, &b.n, &b.m, &scale,
(conja & 1) ? &a(0, k) : &a(k, 0), &a.n, b.data, &b.n,
(conja & 1) ? &a(0, k) : &a(k, 0), &a.n, b.data, &ldb,
&cfactor, &c(k, 0), &one);
} else
assert(false);
Expand Down
15 changes: 12 additions & 3 deletions src/core/iterative_matrix_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,16 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
for (int i = 0; i < k; i++) {
for (int j = 0; j < i; j++)
iadd(bs[i], bs[j], -complex_dot(bs[j], bs[i]));
iscale(bs[i], (FP)1.0 / norm(bs[i]));
FL normx = norm(bs[i]);
if (abs(normx * normx) < 1E-14) {
stringstream ss;
ss << "Cannot generate initial guess " << i
<< " for Davidson unitary to all given states (you are "
"possibly targeting a global symmetry sector with no "
"states or MPS has zero norm)!";
throw runtime_error(ss.str());
}
iscale(bs[i], (FP)1.0 / normx);
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < nor; j++)
Expand All @@ -585,7 +594,7 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
ss << "Cannot generate initial guess " << i
<< " for Davidson unitary to all given states (you are "
"possibly targeting a global symmetry sector with no "
"states)!";
"states or MPS has zero norm)!";
throw runtime_error(ss.str());
}
iscale(bs[i], (FP)1.0 / normx);
Expand Down Expand Up @@ -925,7 +934,7 @@ template <typename FL> struct IterativeMatrixFunctions : GMatrixFunctions<FL> {
ss << "Cannot generate initial guess " << i
<< " for Davidson unitary to all given states (you are "
"possibly targeting a global symmetry sector with no "
"states)!";
"states or MPS has zero norm)!";
throw runtime_error(ss.str());
}
iscale(bs[i], (FP)1.0 / normx);
Expand Down
12 changes: 7 additions & 5 deletions src/core/matrix_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,25 +939,27 @@ struct GMatrixFunctions<
// c.n is used for ldc; a.n is used for lda
static void multiply(const GMatrix<FL> &a, uint8_t conja,
const GMatrix<FL> &b, uint8_t conjb,
const GMatrix<FL> &c, FL scale, FL cfactor) {
const GMatrix<FL> &c, FL scale, FL cfactor,
MKL_INT ldb = 0) {
ldb = ldb ? ldb : b.n;
// if assertion fails here, check whether it is the case
// where different bra and ket are used with the transpose rule
// use no-transpose-rule to fix it
if (!(conja & 1) && !(conjb & 1)) {
assert(a.n >= b.m && c.m == a.m && c.n >= b.n);
xgemm<FL>("n", "n", &b.n, &c.m, &b.m, &scale, b.data, &b.n, a.data,
xgemm<FL>("n", "n", &b.n, &c.m, &b.m, &scale, b.data, &ldb, a.data,
&a.n, &cfactor, c.data, &c.n);
} else if (!(conja & 1) && (conjb & 1)) {
assert(a.n >= b.n && c.m == a.m && c.n >= b.m);
xgemm<FL>("t", "n", &b.m, &c.m, &b.n, &scale, b.data, &b.n, a.data,
xgemm<FL>("t", "n", &b.m, &c.m, &b.n, &scale, b.data, &ldb, a.data,
&a.n, &cfactor, c.data, &c.n);
} else if ((conja & 1) && !(conjb & 1)) {
assert(a.m == b.m && c.m <= a.n && c.n >= b.n);
xgemm<FL>("n", "t", &b.n, &c.m, &b.m, &scale, b.data, &b.n, a.data,
xgemm<FL>("n", "t", &b.n, &c.m, &b.m, &scale, b.data, &ldb, a.data,
&a.n, &cfactor, c.data, &c.n);
} else {
assert(a.m == b.n && c.m <= a.n && c.n >= b.m);
xgemm<FL>("t", "t", &b.m, &c.m, &b.n, &scale, b.data, &b.n, a.data,
xgemm<FL>("t", "t", &b.m, &c.m, &b.n, &scale, b.data, &ldb, a.data,
&a.n, &cfactor, c.data, &c.n);
}
}
Expand Down
Loading

0 comments on commit f987afc

Please sign in to comment.