From e5bf7e206533443f95fd3cade67eec400b5629dc Mon Sep 17 00:00:00 2001 From: Martin Ganahl Date: Tue, 5 Oct 2021 08:31:19 +0200 Subject: [PATCH] Fix symmetric svd decomp (#944) * modify decompositions.svd to allow truncating all singular values * modify decompositions.svd to allow truncating all singular values * updated test_max_truncation_error * remove unusedvariable --- tensornetwork/backends/symmetric/decompositions.py | 8 +------- tensornetwork/backends/symmetric/decompositions_test.py | 5 +++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tensornetwork/backends/symmetric/decompositions.py b/tensornetwork/backends/symmetric/decompositions.py index 8b49e6341..3bd437aef 100644 --- a/tensornetwork/backends/symmetric/decompositions.py +++ b/tensornetwork/backends/symmetric/decompositions.py @@ -84,10 +84,6 @@ def svd( #sort singular values inds = np.argsort(extended_flat_singvals, kind='stable') discarded_inds = np.zeros(0, dtype=SIZE_T) - if inds.shape[0] > 0: - maxind = inds[-1] - else: - maxind = 0 if max_truncation_error is not None: if relative and (len(singvals) > 0): max_truncation_error = max_truncation_error * np.max( @@ -114,9 +110,7 @@ def svd( if len(inds) == 0: #special case of truncation to 0 dimension; - warnings.warn("svd_decomposition truncated to 0 dimensions. " - "Adjusting to `max_singular_values = 1`") - inds = np.asarray([maxind]) + warnings.warn("svd_decomposition truncated to 0 dimensions.") if extended_singvals.shape[1] > 0: #pylint: disable=no-member diff --git a/tensornetwork/backends/symmetric/decompositions_test.py b/tensornetwork/backends/symmetric/decompositions_test.py index 181481dad..d3e243f5a 100644 --- a/tensornetwork/backends/symmetric/decompositions_test.py +++ b/tensornetwork/backends/symmetric/decompositions_test.py @@ -87,8 +87,9 @@ def test_max_singular_values(dtype, R, R1, num_charges): @pytest.mark.parametrize("dtype", np_dtypes) @pytest.mark.parametrize("num_charges", [1, 2, 3]) -def test_max_truncation_error(dtype, num_charges): - np.random.seed(10) +@pytest.mark.parametrize("seed", np.arange(20)) +def test_max_truncation_error(dtype, num_charges, seed): + np.random.seed(seed) R = 2 D = 30 charges = [