diff --git a/src/fairchem/core/modules/normalization/element_references.py b/src/fairchem/core/modules/normalization/element_references.py index 5a0823bd4..b37b3460c 100644 --- a/src/fairchem/core/modules/normalization/element_references.py +++ b/src/fairchem/core/modules/normalization/element_references.py @@ -149,6 +149,7 @@ def fit_linear_references( num_workers: int = 0, max_num_elements: int = 118, log_metrics: bool = True, + use_numpy: bool = True, driver: str | None = None, shuffle: bool = True, seed: int = 0, @@ -166,7 +167,8 @@ def fit_linear_references( see function below... max_num_elements: max number of elements in dataset. If not given will use an ambitious value of 118 log_metrics: if true will compute MAE, RMSE and R2 score of fit and log. - driver: backend used to solve linear system. See torch.linalg.lstsq docs. + use_numpy: use numpy.linalg.lstsq instead of torch. This tends to give better solutions. + driver: backend used to solve linear system. See torch.linalg.lstsq docs. Ignored if use_numpy=True shuffle: whether to shuffle when loading the dataset seed: random seed used to shuffle the sampler if shuffle=True @@ -232,17 +234,30 @@ def fit_linear_references( mask = composition_matrix.sum(axis=0) != 0.0 reduced_composition_matrix = composition_matrix[:, mask] elementrefs = {} + for target in targets: coeffs = torch.zeros(max_num_elements) - lstsq = torch.linalg.lstsq( - reduced_composition_matrix, target_vectors[target], driver=driver - ) - coeffs[mask] = lstsq.solution + + if use_numpy: + solution = torch.tensor( + np.linalg.lstsq( + reduced_composition_matrix.numpy(), + target_vectors[target].numpy(), + rcond=None, + )[0] + ) + else: + lstsq = torch.linalg.lstsq( + reduced_composition_matrix, target_vectors[target], driver=driver + ) + solution = lstsq.solution + + coeffs[mask] = solution elementrefs[target] = LinearReferences(coeffs) if log_metrics is True: y = target_vectors[target] - y_pred = torch.matmul(reduced_composition_matrix, lstsq.solution) + y_pred = torch.matmul(reduced_composition_matrix, solution) y_mean = target_vectors[target].mean() N = len(target_vectors[target]) ss_res = ((y - y_pred) ** 2).sum() diff --git a/tests/core/modules/test_element_references.py b/tests/core/modules/test_element_references.py index fe7fd4eeb..62928b623 100644 --- a/tests/core/modules/test_element_references.py +++ b/tests/core/modules/test_element_references.py @@ -13,8 +13,8 @@ ) -@pytest.fixture(scope="session") -def element_refs(dummy_binary_dataset, max_num_elements): +@pytest.fixture(scope="session", params=(True, False)) +def element_refs(dummy_binary_dataset, max_num_elements, request): return fit_linear_references( ["energy"], dataset=dummy_binary_dataset, @@ -22,6 +22,7 @@ def element_refs(dummy_binary_dataset, max_num_elements): shuffle=False, max_num_elements=max_num_elements, seed=0, + use_numpy=request.param, )