From 26a52971ecc171bebdfc1d773d25506b6c2abd27 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:13:16 +0100 Subject: [PATCH] Compile slice and join (#36) --- src/metatensor/models/utils/data/dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/metatensor/models/utils/data/dataset.py b/src/metatensor/models/utils/data/dataset.py index 0fbd40f4a..b0144a5bd 100644 --- a/src/metatensor/models/utils/data/dataset.py +++ b/src/metatensor/models/utils/data/dataset.py @@ -6,6 +6,10 @@ from metatensor.torch.atomistic import ModelCapabilities, System +compiled_slice = torch.jit.script(metatensor.torch.slice) +compiled_join = torch.jit.script(metatensor.torch.join) + + class Dataset(torch.utils.data.Dataset): def __init__(self, structures: List[System], targets: Dict[str, TensorMap]): """ @@ -52,7 +56,7 @@ def __getitem__(self, index): targets = {} for name, tensor_map in self.targets.items(): - targets[name] = metatensor.torch.slice( + targets[name] = compiled_slice( tensor_map, "samples", structure_index_samples ) @@ -130,9 +134,7 @@ def collate_fn(batch): structures = [sample[0] for sample in batch] targets = {} for name in batch[0][1].keys(): - targets[name] = metatensor.torch.join( - [sample[1][name] for sample in batch], "samples" - ) + targets[name] = compiled_join([sample[1][name] for sample in batch], "samples") return structures, targets