From c1f54632b0c03aacaac210b1e9a8a19f20edef4c Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Tue, 2 Jul 2024 09:18:22 -0400 Subject: [PATCH] [REF] Improve numerical stability of Frobenius estimator --- curvlinops/norm/hutchinson.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/curvlinops/norm/hutchinson.py b/curvlinops/norm/hutchinson.py index 0d5000c..4d0818b 100644 --- a/curvlinops/norm/hutchinson.py +++ b/curvlinops/norm/hutchinson.py @@ -1,8 +1,9 @@ """Hutchinson-style matrix norm estimation.""" +from numpy import dot from scipy.sparse.linalg import LinearOperator -from curvlinops.trace.hutchinson import HutchinsonTraceEstimator +from curvlinops.sampling import random_vector class HutchinsonSquaredFrobeniusNormEstimator: @@ -43,7 +44,7 @@ def __init__(self, A: LinearOperator): Args: A: Linear operator whose squared Frobenius norm will be estimated. """ - self._trace_estimator = HutchinsonTraceEstimator(A.T @ A) + self._A = A def sample(self, distribution: str = "rademacher") -> float: """Draw a sample from the squared Frobenius norm estimator. @@ -59,4 +60,7 @@ def sample(self, distribution: str = "rademacher") -> float: Returns: Sample from the squared Frobenius norm estimator. """ - return self._trace_estimator.sample(distribution=distribution) + dim = self._A.shape[1] + v = random_vector(dim, distribution) + Av = self._A @ v + return dot(Av, Av)