From 6696a7ec80f8445ff923d45c90f4896544e8f4d7 Mon Sep 17 00:00:00 2001 From: Misko Date: Mon, 16 Sep 2024 14:00:11 -0700 Subject: [PATCH] Bump torch2.4.1 and pyg (#845) * bump torch and pyg * fix exported_prog call according to torch error message * add .module() * missed one last one * try with torch 2.4.1 * update yml configs --- packages/env.cpu.yml | 15 +++++++++------ packages/env.gpu.yml | 17 ++++++++++------- packages/requirements-optional.txt | 4 ++-- packages/requirements.txt | 2 +- tests/core/models/test_escn_compiles.py | 6 +++--- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/packages/env.cpu.yml b/packages/env.cpu.yml index 467268f70..970b20fff 100644 --- a/packages/env.cpu.yml +++ b/packages/env.cpu.yml @@ -1,15 +1,10 @@ channels: - pytorch -- pyg - conda-forge - defaults dependencies: - cpuonly -- pytorch>=2 -- pyg -- pytorch-scatter -- pytorch-sparse -- pytorch-cluster +- pytorch>=2.4 - ase - e3nn>=0.5 - numpy >=1.25.0,<2.0.0 @@ -17,6 +12,14 @@ dependencies: - numba - orjson - pip +- pip: + - --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html + - torch_cluster==1.6.3+pt24cpu + - torch_geometric==2.5.3 + - pyg-lib==0.4.0+pt24cpu + - torch_scatter==2.1.2+pt24cpu + - torch_sparse==0.6.18+pt24cpu + - torch_spline_conv==1.2.2+pt24cpu - pyyaml - tqdm - python-lmdb diff --git a/packages/env.gpu.yml b/packages/env.gpu.yml index ff3d8ecaf..2b86c2218 100644 --- a/packages/env.gpu.yml +++ b/packages/env.gpu.yml @@ -2,15 +2,10 @@ channels: - pytorch - nvidia - conda-forge -- pyg - defaults dependencies: -- pytorch-cuda=11.8 -- pytorch>=2 -- pytorch-scatter -- pytorch-sparse -- pytorch-cluster -- pyg +- pytorch-cuda=12.1 +- pytorch>=2.4 - ase - e3nn>=0.5 - numpy >=1.25.0,<2.0.0 @@ -18,6 +13,14 @@ dependencies: - numba - orjson - pip +- pip: + - --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html + - torch_cluster==1.6.3+pt24cu121 + - torch_geometric==2.5.3 + - pyg-lib==0.4.0+pt24cu121 + - torch_scatter==2.1.2+pt24cu121 + - torch_sparse==0.6.18+pt24cu121 + - torch_spline_conv==1.2.2+pt24cu121 - pyyaml - tqdm - python-lmdb diff --git a/packages/requirements-optional.txt b/packages/requirements-optional.txt index 3e9a634b4..f5c645ba4 100644 --- a/packages/requirements-optional.txt +++ b/packages/requirements-optional.txt @@ -1,5 +1,5 @@ -torch_geometric==2.3.0 --f https://data.pyg.org/whl/torch-2.2.0+cpu.html +torch_geometric==2.5.3 +-f https://data.pyg.org/whl/torch-2.4.0+cpu.html torch_scatter==2.1.2 torch_sparse==0.6.18 torch_cluster==1.6.3 diff --git a/packages/requirements.txt b/packages/requirements.txt index e62ab1130..7fe73256d 100644 --- a/packages/requirements.txt +++ b/packages/requirements.txt @@ -1,3 +1,3 @@ -torch==2.2.0 +torch==2.4.1 numpy==1.23.5 ase==3.23.0 diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 269433d4d..447522b8b 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -241,7 +241,7 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: compiled_output = compiled_model(*args[0]) exported_prog = export(message_block, args=args[0]) - exported_output = exported_prog(*args[0]) + exported_output = exported_prog.module()(*args[0]) regular_out = message_block(*args[0]) assert torch.allclose(compiled_output, regular_out, atol=tol) @@ -302,7 +302,7 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: } exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1) for run_arg in run_args: - exported_output = exported_prog(*run_arg) + exported_output = exported_prog.module()(*run_arg) compiled_model = torch.compile(layer_block, dynamic=True) compiled_output = compiled_model(*run_arg) regular_out = layer_block(*run_arg) @@ -343,7 +343,7 @@ def test_full_escn_exports(self): # print(explained_output) # TODO: add dynamic shapes exported_prog = export(exportable_model, args=(export_data,)) - export_output = exported_prog(export_data) + export_output = exported_prog.module()(export_data) expected_output = escn_model(regular_data) assert torch.allclose(export_output["energy"], expected_output["energy"]) assert torch.allclose(export_output["forces"].mean(0), expected_output["forces"].mean(0))