Skip to content

Commit

Permalink
use numpy as default to fit references
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque committed Jul 23, 2024
1 parent 71d19da commit 6875b93
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
27 changes: 21 additions & 6 deletions src/fairchem/core/modules/normalization/element_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def fit_linear_references(
num_workers: int = 0,
max_num_elements: int = 118,
log_metrics: bool = True,
use_numpy: bool = True,
driver: str | None = None,
shuffle: bool = True,
seed: int = 0,
Expand All @@ -166,7 +167,8 @@ def fit_linear_references(
see function below...
max_num_elements: max number of elements in dataset. If not given will use an ambitious value of 118
log_metrics: if true will compute MAE, RMSE and R2 score of fit and log.
driver: backend used to solve linear system. See torch.linalg.lstsq docs.
use_numpy: use numpy.linalg.lstsq instead of torch. This tends to give better solutions.
driver: backend used to solve linear system. See torch.linalg.lstsq docs. Ignored if use_numpy=True
shuffle: whether to shuffle when loading the dataset
seed: random seed used to shuffle the sampler if shuffle=True
Expand Down Expand Up @@ -232,17 +234,30 @@ def fit_linear_references(
mask = composition_matrix.sum(axis=0) != 0.0
reduced_composition_matrix = composition_matrix[:, mask]
elementrefs = {}

for target in targets:
coeffs = torch.zeros(max_num_elements)
lstsq = torch.linalg.lstsq(
reduced_composition_matrix, target_vectors[target], driver=driver
)
coeffs[mask] = lstsq.solution

if use_numpy:
solution = torch.tensor(
np.linalg.lstsq(
reduced_composition_matrix.numpy(),
target_vectors[target].numpy(),
rcond=None,
)[0]
)
else:
lstsq = torch.linalg.lstsq(
reduced_composition_matrix, target_vectors[target], driver=driver
)
solution = lstsq.solution

coeffs[mask] = solution
elementrefs[target] = LinearReferences(coeffs)

if log_metrics is True:
y = target_vectors[target]
y_pred = torch.matmul(reduced_composition_matrix, lstsq.solution)
y_pred = torch.matmul(reduced_composition_matrix, solution)
y_mean = target_vectors[target].mean()
N = len(target_vectors[target])
ss_res = ((y - y_pred) ** 2).sum()
Expand Down
5 changes: 3 additions & 2 deletions tests/core/modules/test_element_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
)


@pytest.fixture(scope="session")
def element_refs(dummy_binary_dataset, max_num_elements):
@pytest.fixture(scope="session", params=(True, False))
def element_refs(dummy_binary_dataset, max_num_elements, request):
return fit_linear_references(
["energy"],
dataset=dummy_binary_dataset,
batch_size=16,
shuffle=False,
max_num_elements=max_num_elements,
seed=0,
use_numpy=request.param,
)


Expand Down

0 comments on commit 6875b93

Please sign in to comment.