Skip to content

Commit

Permalink
Compile slice and join
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 24, 2024
1 parent 2689dc1 commit f04a91e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/metatensor/models/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from metatensor.torch.atomistic import ModelCapabilities, System


compiled_slice = torch.jit.script(metatensor.torch.slice)
compiled_join = torch.jit.script(metatensor.torch.join)


class Dataset(torch.utils.data.Dataset):
def __init__(self, structures: List[System], targets: Dict[str, TensorMap]):
"""
Expand Down Expand Up @@ -52,7 +56,7 @@ def __getitem__(self, index):

targets = {}
for name, tensor_map in self.targets.items():
targets[name] = metatensor.torch.slice(
targets[name] = compiled_slice(
tensor_map, "samples", structure_index_samples
)

Expand Down Expand Up @@ -130,9 +134,7 @@ def collate_fn(batch):
structures = [sample[0] for sample in batch]
targets = {}
for name in batch[0][1].keys():
targets[name] = metatensor.torch.join(
[sample[1][name] for sample in batch], "samples"
)
targets[name] = compiled_join([sample[1][name] for sample in batch], "samples")

return structures, targets

Expand Down

0 comments on commit f04a91e

Please sign in to comment.