Skip to content

Commit

Permalink
Torch compile + export escn (#826)
Browse files Browse the repository at this point in the history
* update

* update

* export so2

* move mappingReduced to member

* compile works, guard failures still

* escn so2 exports

* add gpu test

* pass cuda test

* layer block

* switch to separate export file

* message block fails export due to SO3Rotation input

* message block compiles and exports

* layer block compiles and exports

* remove most of lmax_list and mmax_list

* remove eqv2 stuff from this branch

* compile works

* update

* remove some files from main

* lint

* ruff

* lint

* revert base trainer changes

* cleanup a2g

* address comments

* cleanup
  • Loading branch information
rayg1234 committed Sep 10, 2024
1 parent 100e9aa commit eddb484
Show file tree
Hide file tree
Showing 6 changed files with 1,580 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/fairchem/core/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,13 @@ def spawn_multi_process(
)

return [mp_output_dict[i] for i in range(config.world_size)]

def init_local_distributed_process_group(backend="nccl"):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
dist.init_process_group(
rank=0,
world_size=1,
backend=backend,
timeout=timedelta(seconds=10), # setting up timeout for distributed collectives
)
5 changes: 4 additions & 1 deletion src/fairchem/core/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def sample_property_metadata(self, num_samples: int = 100):
}


def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData:
def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]:
batch = Batch.from_data_list(data_list)

if not otf_graph:
Expand All @@ -226,4 +226,7 @@ def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> Ba
"LMDB does not contain edge index information, set otf_graph=True"
)

if to_dict:
batch = dict(batch.items())

return batch
Loading

0 comments on commit eddb484

Please sign in to comment.