Skip to content

Commit

Permalink
TL: improved Z2N and N2N switch, added SDelta feature
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Jul 14, 2024
1 parent dc6743b commit e48429e
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 27 deletions.
17 changes: 10 additions & 7 deletions qmat/qcoeff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,16 @@ def hCoeffs(self):
approx = LagrangeApproximation(self.nodes)
return approx.getInterpolationMatrix([1]).ravel()

def genCoeffs(self, withS=False, hCoeffs=False, embedded=False):
out = [self.nodes, self.weights, self.Q]

def genCoeffs(self, form="Z2N", hCoeffs=False, embedded=False):
if form == "Z2N":
mat = self.Q
elif form == "N2N":
mat = self.S
else:
raise ValueError(f"form must be Z2N or N2N, not {form}")
out = [self.nodes, self.weights, mat]
if embedded:
out[1] = np.vstack([out[1], self.weightsEmbedded])
if withS:
out.append(self.S)
if hCoeffs:
out.append(self.hCoeffs)
return out
Expand Down Expand Up @@ -141,13 +144,13 @@ def register(cls:QGenerator)->QGenerator:
storeClass(cls, Q_GENERATORS)
return cls

def genQCoeffs(qType, withS=False, hCoeffs=False, embedded=False, **params):
def genQCoeffs(qType, form="Z2N", hCoeffs=False, embedded=False, **params):
try:
Generator = Q_GENERATORS[qType]
except KeyError:
raise ValueError(f"{qType=!r} is not available")
gen = Generator(**params)
return gen.genCoeffs(withS, hCoeffs, embedded)
return gen.genCoeffs(form, hCoeffs, embedded)


# Import all local submodules
Expand Down
31 changes: 23 additions & 8 deletions qmat/qdelta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def size(self):
def zeros(self):
M = self.size
return np.zeros((M, M), dtype=float)

def computeQDelta(self, k=None) -> np.ndarray:
"""Compute and returns the QDelta matrix"""
raise NotImplementedError("mouahahah")
Expand All @@ -41,15 +41,28 @@ def getQDelta(self, k=None, copy=True):
raise Exception("some very weird bug happened ... did you do fishy stuff ?")
return QDelta.copy() if copy else QDelta

def getSDelta(self, k=None):
QDelta = self.getQDelta(k)
M = QDelta.shape[0]
T = np.eye(M)
T[1:,:-1][np.diag_indices(M-1)] = -1
return T @ QDelta

@property
def dTau(self):
return np.zeros(self.size, dtype=float)

def genCoeffs(self, k=None, dTau=False):
def genCoeffs(self, k=None, form="Z2N", dTau=False):
if form == "Z2N":
gen = lambda k, copy=False: self.getQDelta(k, copy)
elif form == "N2N":
gen = lambda k, copy=None: self.getSDelta(k)
else:
raise ValueError(f"form must be Z2N or N2N, not {form}")
if isinstance(k, list):
out = [np.array([self.getQDelta(_k, copy=False) for _k in k])]
out = [np.array([gen(_k, copy=False) for _k in k])]
else:
out = [self.getQDelta(k)]
out = [gen(k)]
if dTau:
out += [self.dTau]
return out if len(out) > 1 else out[0]
Expand All @@ -71,7 +84,8 @@ def register(cls:QDeltaGenerator)->QDeltaGenerator:
storeClass(cls, QDELTA_GENERATORS)
return cls

def genQDeltaCoeffs(qDeltaType, nSweeps=None, dTau=False, **params):

def genQDeltaCoeffs(qDeltaType, nSweeps=None, form="Z2N", dTau=False, **params):

# Check arguments
if isinstance(qDeltaType, str):
Expand Down Expand Up @@ -103,7 +117,7 @@ def genQDeltaCoeffs(qDeltaType, nSweeps=None, dTau=False, **params):
raise ValueError(f"qDeltaType={qDeltaType} is not available")

gen = Generator(**params)
return gen.genCoeffs(dTau=dTau)
return gen.genCoeffs(form=form, dTau=dTau)

else: # Multiple matrices return
try:
Expand All @@ -113,12 +127,13 @@ def genQDeltaCoeffs(qDeltaType, nSweeps=None, dTau=False, **params):

if len(qDeltaType) == 1: # Single QDelta generator
gen = Generators[0](**params)
return gen.genCoeffs(k=[k+1 for k in range(nSweeps)], dTau=dTau)
return gen.genCoeffs(
k=[k+1 for k in range(nSweeps)], form=form, dTau=dTau)

else: # Multiple QDelta generators
gens = [Gen(**params) for Gen in Generators]
out = [np.array(
[gen.getQDelta(k+1) for k, gen in enumerate(gens)]
[gen.genCoeffs(k+1, form) for k, gen in enumerate(gens)]
)]
if dTau:
out += [gens[0].dTau]
Expand Down
11 changes: 5 additions & 6 deletions tests/test_qcoeff/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,21 @@ def testAdditionalCoeffs(name):
f"hCoeffs for {name} has inconsistent size : {h1.size}"

try:
_, _, _, S2, h2 = genQCoeffs(name, withS=True, hCoeffs=True)
_, _, S2, h2 = genQCoeffs(name, form="N2N", hCoeffs=True)
except TypeError:
_, _, _, S2, h2 = genQCoeffs(name, withS=True, hCoeffs=True,
**GENERATORS[name].DEFAULT_PARAMS)
_, _, S2, h2 = genQCoeffs(
name, form="N2N", hCoeffs=True, **GENERATORS[name].DEFAULT_PARAMS)
assert np.allclose(S1, S2), \
f"OOP S matrix {S1} and PP S matrix {S2} are not equals for {name}"
assert np.allclose(h1, h2), \
f"OOP hCoeffs {h1} and PP hCoeffs {h2} are not equals for {name}"


try:
try:
_, b, _ = genQCoeffs(name, embedded=True)
except TypeError:
_, b, _ = genQCoeffs(name, embedded=True, **GENERATORS[name].DEFAULT_PARAMS)

_, b, _ = genQCoeffs(
name, embedded=True, **GENERATORS[name].DEFAULT_PARAMS)
assert type(b) == np.ndarray
assert b.ndim == 2
except NotImplementedError:
Expand Down
14 changes: 11 additions & 3 deletions tests/test_qdelta/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,27 @@ def testGeneration(name, nNodes):
assert np.allclose(QD1, QD2), \
f"OOP QDelta and PP QDelta are not equals for {name}"

_, dTau1 = gen.genCoeffs(dTau=True)
SD1, dTau1 = gen.genCoeffs(form="N2N", dTau=True)
assert type(dTau1) == np.ndarray, \
f"dTau for {name} is not np.ndarray but {type(dTau1)}"
assert dTau1.ndim == 1, \
f"dTau for {name} is not 1D : {dTau1}"
assert dTau1.size == nNodes, \
f"dTau for {name} has not the correct size : {dTau1}"

_, dTau2 = genQDeltaCoeffs(name, Q=Q, dTau=True)
assert SD1.ndim == 2, \
f"SDelta for {name} is not 2D : {SD1}"
assert SD1.shape == QD2.shape, \
f"SDelta for {name} has not the correct shape : {SD1}"

SD2, dTau2 = genQDeltaCoeffs(name, Q=Q, form="N2N", dTau=True)
assert np.allclose(SD1, SD2), \
f"OOP SDelta and PP SDelta are not equals for {name}"
assert np.allclose(dTau1, dTau2), \
f"OOP dTau and PP dTau are not equals for {name}"




nNodes = 4
@pytest.mark.parametrize("nSweeps", [1, 2, 3])
@pytest.mark.parametrize("name", GENERATORS.keys())
Expand Down
14 changes: 11 additions & 3 deletions tests/test_qdelta/test_timestepping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@
def testBE(nNodes, nodeType, quadType):
coll = Collocation(nNodes, nodeType, quadType)
nodes = coll.nodes
QDelta = module.BE(nodes).getQDelta()
gen = module.BE(nodes)
QDelta = gen.getQDelta()

assert np.allclose(np.tril(QDelta), QDelta), \
"QDelta is not lower triangular"
assert np.allclose(QDelta.sum(axis=1), nodes), \
"sum over the columns is not equal to nodes"

SDelta = gen.getSDelta()
assert np.allclose(np.diag(np.diag(SDelta)), SDelta), \
"SDelta is not diagonal"


@pytest.mark.parametrize("quadType", QUAD_TYPES)
@pytest.mark.parametrize("nodeType", NODE_TYPES)
@pytest.mark.parametrize("nNodes", [2, 3, 4, 5, 6])
def testFE(nNodes, nodeType, quadType):
coll = Collocation(nNodes, nodeType, quadType)
nodes = coll.nodes
QDelta = module.FE(nodes).getQDelta()
gen = module.FE(nodes)
QDelta = gen.getQDelta()

assert np.allclose(np.tril(QDelta), QDelta), \
"QDelta is not lower triangular"
Expand All @@ -40,7 +46,7 @@ def testFE(nNodes, nodeType, quadType):
assert np.allclose(QDelta.sum(axis=1)[1:], np.cumsum(np.diff(coll.nodes))), \
"sum over the columns is not equal to cumsum of node differences"

_, dTau = module.FE(nodes).genCoeffs(dTau=True)
SDelta, dTau = module.FE(nodes).genCoeffs(form="N2N", dTau=True)
assert type(dTau) == np.ndarray, \
f"dTau is not np.ndarray but {type(dTau)}"
assert dTau.ndim == 1, \
Expand All @@ -49,6 +55,8 @@ def testFE(nNodes, nodeType, quadType):
f"dTau has not the correct size : {dTau}"
assert np.allclose(dTau, coll.nodes[0]), \
"dTau is not equal to nodes[0]"
assert np.allclose(np.diag(np.diag(SDelta, k=-1), k=-1), SDelta), \
"SDelta is not strictly lower diagonal"


@pytest.mark.parametrize("quadType", QUAD_TYPES)
Expand Down

0 comments on commit e48429e

Please sign in to comment.