diff --git a/pymc_marketing/clv/utils.py b/pymc_marketing/clv/utils.py index 99f91eb6..4bb83e98 100644 --- a/pymc_marketing/clv/utils.py +++ b/pymc_marketing/clv/utils.py @@ -79,7 +79,12 @@ def _squeeze_dims(x: xarray.DataArray): x = x.squeeze(dims_to_squeeze) return x - steps = np.arange(1, time + 1) + if discount_rate == 0.0: + # no discount rate: just compute a single time step from 0 to `time` + steps = np.arange(time, time + 1) + else: + steps = np.arange(1, time + 1) + factor = {"W": 4.345, "M": 1.0, "D": 30, "H": 30 * 24}[freq] # Monetary value can be passed as a DataArray, with entries per chain and draw or as a simple vector