Skip to content

Commit

Permalink
add splu cache to depletion solver
Browse files Browse the repository at this point in the history
  • Loading branch information
eepeterson committed Nov 28, 2023
1 parent 7dbd5e5 commit e0063e9
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 23 deletions.
20 changes: 14 additions & 6 deletions openmc/deplete/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,8 @@ def solver(self, func):
return

# Inspect arguments
if len(sig.parameters) != 3:
raise ValueError("Function {} does not support three arguments: "
if len(sig.parameters) < 3:
raise ValueError("Function {} does not support less than three arguments: "
"{!s}".format(func, sig))

for ix, param in enumerate(sig.parameters.values()):
Expand All @@ -683,15 +683,16 @@ def solver(self, func):

self._solver = func

def _timed_deplete(self, n, rates, dt, matrix_func=None):
def _timed_deplete(self, n, rates, dt, matrix_func=None,
use_cache=False):
start = time.time()
results = deplete(
self._solver, self.chain, n, rates, dt, matrix_func,
self.transfer_rates)
self.transfer_rates, use_cache=use_cache)
return time.time() - start, results

@abstractmethod
def __call__(self, n, rates, dt, source_rate, i):
def __call__(self, n, rates, dt, source_rate, i, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -781,7 +782,10 @@ def integrate(self, final_step=True, output=True):
n = self.operator.initial_condition()
t, self._i_res = self._get_start_data()

prev_dt = None
prev_source_rate = None
for i, (dt, source_rate) in enumerate(self):
use_cache = (prev_dt == dt) and (prev_source_rate == source_rate)
if output and comm.rank == 0:
print(f"[openmc.deplete] t={t} s, dt={dt} s, source={source_rate}")

Expand All @@ -792,7 +796,9 @@ def integrate(self, final_step=True, output=True):
n, res = self._get_bos_data_from_restart(i, source_rate, n)

# Solve Bateman equations over time interval
proc_time, n_list, res_list = self(n, res.rates, dt, source_rate, i)
proc_time, n_list, res_list = self(n, res.rates, dt,
source_rate, i,
use_cache=use_cache)

# Insert BOS concentration, transport results
n_list.insert(0, n)
Expand All @@ -804,6 +810,8 @@ def integrate(self, final_step=True, output=True):
StepResult.save(self.operator, n_list, res_list, [t, t + dt],
source_rate, self._i_res + i, proc_time)

prev_dt = dt
prev_source_rate = source_rate
t += dt

# Final simulation -- in the case that final_step is False, a zero
Expand Down
18 changes: 13 additions & 5 deletions openmc/deplete/cram.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def __init__(self, alpha, theta, alpha0):
self.alpha = alpha
self.theta = theta
self.alpha0 = alpha0
self._splu_cache = []

def __call__(self, A, n0, dt):
def __call__(self, A, n0, dt, use_cache=False):
"""Solve depletion equations using IPF CRAM
Parameters
Expand All @@ -75,11 +76,18 @@ def __call__(self, A, n0, dt):
Final compositions after ``dt``
"""
A = dt * sp.csc_matrix(A, dtype=np.float64)
y = n0.copy()
ident = sp.eye(A.shape[0], format='csc')
for alpha, theta in zip(self.alpha, self.theta):
y += 2*np.real(alpha*sla.splu(A - theta*ident).solve(y))
if use_cache:
for alpha, splu in zip(self.alpha, self._splu_cache):
y += 2*np.real(alpha*splu.solve(y))
else:
A = dt * sp.csc_matrix(A, dtype=np.float64)
ident = sp.eye(A.shape[0], format='csc')
self._splu_cache = []
for alpha, theta in zip(self.alpha, self.theta):
splu = sla.splu(A - theta*ident)
self._splu_cache.append(splu)
y += 2*np.real(alpha*splu.solve(y))
return y * self.alpha0


Expand Down
19 changes: 10 additions & 9 deletions openmc/deplete/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PredictorIntegrator(Integrator):
"""
_num_stages = 1

def __call__(self, n, rates, dt, source_rate, _i=None):
def __call__(self, n, rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -54,7 +54,8 @@ def __call__(self, n, rates, dt, source_rate, _i=None):
with predictor
"""
proc_time, n_end = self._timed_deplete(n, rates, dt)
proc_time, n_end = self._timed_deplete(n, rates, dt,
use_cache=use_cache)
return proc_time, [n_end], []


Expand All @@ -78,7 +79,7 @@ class CECMIntegrator(Integrator):
"""
_num_stages = 2

def __call__(self, n, rates, dt, source_rate, _i=None):
def __call__(self, n, rates, dt, source_rate, _i=None, use_cache=False):
"""Integrate using CE/CM
Parameters
Expand Down Expand Up @@ -142,7 +143,7 @@ class CF4Integrator(Integrator):
"""
_num_stages = 4

def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None):
def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -220,7 +221,7 @@ class CELIIntegrator(Integrator):
"""
_num_stages = 2

def __call__(self, n_bos, rates, dt, source_rate, _i=None):
def __call__(self, n_bos, rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -286,7 +287,7 @@ class EPCRK4Integrator(Integrator):
"""
_num_stages = 4

def __call__(self, n, rates, dt, source_rate, _i=None):
def __call__(self, n, rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -368,7 +369,7 @@ class LEQIIntegrator(Integrator):
"""
_num_stages = 2

def __call__(self, n_bos, bos_rates, dt, source_rate, i):
def __call__(self, n_bos, bos_rates, dt, source_rate, i, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -450,7 +451,7 @@ class SICELIIntegrator(SIIntegrator):
"""
_num_stages = 2

def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None):
def __call__(self, n_bos, bos_rates, dt, source_rate, _i=None, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down Expand Up @@ -516,7 +517,7 @@ class SILEQIIntegrator(SIIntegrator):
"""
_num_stages = 2

def __call__(self, n_bos, bos_rates, dt, source_rate, i):
def __call__(self, n_bos, bos_rates, dt, source_rate, i, use_cache=False):
"""Perform the integration across one time step
Parameters
Expand Down
6 changes: 3 additions & 3 deletions openmc/deplete/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _distribute(items):
j += chunk_size

def deplete(func, chain, n, rates, dt, matrix_func=None, transfer_rates=None,
*matrix_args):
*matrix_args, use_cache=False):
"""Deplete materials using given reaction rates for a specified time
Parameters
Expand Down Expand Up @@ -140,7 +140,7 @@ def deplete(func, chain, n, rates, dt, matrix_func=None, transfer_rates=None,

# Concatenate vectors of nuclides in one
n_multi = np.concatenate(n)
n_result = func(matrix, n_multi, dt)
n_result = func(matrix, n_multi, dt, use_cache=use_cache)

# Split back the nuclide vector result into the original form
n_result = np.split(n_result, np.cumsum([len(i) for i in n])[:-1])
Expand All @@ -155,7 +155,7 @@ def deplete(func, chain, n, rates, dt, matrix_func=None, transfer_rates=None,

return n_result

inputs = zip(matrices, n, repeat(dt))
inputs = zip(matrices, n, repeat(dt), repeat(use_cache))

if USE_MULTIPROCESSING:
with Pool(NUM_PROCESSES) as pool:
Expand Down

0 comments on commit e0063e9

Please sign in to comment.