Skip to content

Commit

Permalink
[REF] Move example, polish rst
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 16, 2023
1 parent b1a3888 commit 42ea553
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@ class HutchinsonTraceEstimator:
matrix for laplacian smoothing splines. Communication in Statistics---Simulation
and Computation.
Example:
>>> from numpy import trace, mean
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> tr_A = trace(A) # exact trace as reference
>>> estimator = HutchinsonTraceEstimator(A)
>>> # one- and multi-sample approximations
>>> tr_A_low_precision = estimator.sample()
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)])
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> tr_A, tr_A_low_precision, tr_A_high_precision
(4.457529730942303, 6.679568384120655, 4.388630875995861)
Attributes:
SUPPORTED_SAMPLINGS: Dictionary mapping supported distributions to their
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_SAMPLINGS: Dict[str, Callable[[int], ndarray]] = {
SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}
Expand All @@ -48,38 +62,24 @@ def sample(self, distribution: str = "rademacher") -> float:
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either `'rademacher'` or `'normal'`.
Default is `'rademacher'`.
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.
Returns:
Sample from the trace estimator.
Raises:
ValueError: If the distribution is not supported.
Example:
>>> from numpy import trace, mean
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> tr_A = trace(A) # exact trace as reference
>>> estimator = HutchinsonTraceEstimator(A)
>>> # one- and multi-sample approximations
>>> tr_A_low_precision = estimator.sample()
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)])
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> tr_A, tr_A_low_precision, tr_A_high_precision
(4.457529730942303, 6.679568384120655, 4.388630875995861)
"""
dim = self._A.shape[1]

if distribution not in self.SUPPORTED_SAMPLINGS:
if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution '{distribution}'. "
f"Supported distributions are {list(self.SUPPORTED_SAMPLINGS)}."
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

v = self.SUPPORTED_SAMPLINGS[distribution](dim)
v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
Av = self._A @ v

return dot(v, Av)
File renamed without changes.

0 comments on commit 42ea553

Please sign in to comment.