Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Balanced batch sampler+base dataset #753

Merged
merged 67 commits into from
Aug 2, 2024
Merged

Conversation

misko
Copy link
Collaborator

@misko misko commented Jul 9, 2024

Add Basedataset and corresponding changes to balanced batch sampler.
Added Basedataset.metadata_hasattr() to check if the metadata associated with dataset has a specific attr. In the case of balancedbatchsampler we require 'natoms' to be present.

@misko misko requested a review from mshuaibii July 23, 2024 21:12
assert (
len(self.paths) == 1
), f"{type(self)} does not support a list of src paths."
self.path = self.paths[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a list being fed in here in the first place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I don't know, i assumed it was because some datasets might take multiple files, potentially multiple lmdb paths. But i dont think this is currently used, I will take this out.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second pass, looks like we do use multiple paths for ASE datasets,
https://github.com/FAIR-Chem/fairchem/blob/main/tests/core/datasets/test_ase_datasets.py#L89

@@ -214,7 +212,7 @@ def get_metadata(self, num_samples: int = 100):
}


class SinglePointLmdbDataset(LmdbDataset[BaseData]):
class SinglePointLmdbDataset(LmdbDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the dataset changes, this may be a good time to deprecate these....

Warning - you'll break a few tests. If you break way too many you can revert and we can do this as part of BE.

@@ -114,9 +116,6 @@ def __init__(self, config, transform=None) -> None:
if self.train_on_oc20_total_energies:
with open(config["oc20_ref"], "rb") as fp:
self.oc20_ref = pickle.load(fp)
if self.config.get("lin_ref", False):
coeff = np.load(self.config["lin_ref"], allow_pickle=True)["coeff"]
self.lin_ref = torch.nn.Parameter(torch.tensor(coeff), requires_grad=False)
self.subsample = aii(self.config.get("subsample", False), bool)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this functionality is technically now supported with ur base dataset right? if so we can remove.

@@ -455,7 +456,9 @@ def load_model(self) -> None:
self.logger.log_summary({"num_params": self.model.num_params})

if distutils.initialized() and not self.config["noddp"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this PR behave when --noddp is defined? Would be good to test that functionality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

@misko misko force-pushed the balanced-batch-sampler+base-dataset branch from 235965f to 4426870 Compare July 29, 2024 17:52
@misko misko force-pushed the balanced-batch-sampler+base-dataset branch from 65b5a6f to 1e2e4aa Compare July 29, 2024 18:20
msg = msg[:-1] + f"that are smaller than the given max_atoms {max_atoms}."
raise ValueError(msg)

indices = indices[:max_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm testing this on my end...If i have a config as follows:

dataset:
  train:
    format: lmdb
    src: data/s2ef/all/train/
    first_n: 2000000
    key_mapping:
      y: energy
      force: forces
    transforms:
      normalizer:
        energy:
          mean: -0.7554450631141663
          stdev: 2.887317180633545
        forces:
          mean: 0
          stdev: 2.887317180633545
  val:
    src: data/s2ef/all/val_id_30k

Because we have val set up to inherit the train parameters, it will also try to grab first_n for validation/testing as well. To get around this I would need to explicitly define first_n for the validation/test splits which is not ideal.

We should have it such that something like first_n does not try to get passed down to val/test. I don't know the best way we want to tackle this, open to ideas.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same is true with metadata_path, etc. All of it is trying to be inherited

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally fair, again great catch @mshuaibii ! Thank you! Looking into it...

@misko misko added this pull request to the merge queue Aug 2, 2024
Merged via the queue into main with commit 04a69b0 Aug 2, 2024
7 checks passed
@misko misko deleted the balanced-batch-sampler+base-dataset branch August 2, 2024 18:18
lbluque pushed a commit that referenced this pull request Aug 7, 2024
* Update BalancedBatchSampler to use datasets' `data_sizes` method
Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error`

* Remove python 3.10 syntax

* Documentation

* Added set_epoch method

* Format

* Changed "resolved dataset" message to be a debug log to reduce log spam

* clean up batchsampler and tests

* base dataset class

* move lin_ref to base dataset

* inherit basedataset for ase dataset

* filter indices prop

* added create_dataset fn

* yaml load fix

* create dataset function instead of filtering in base

* remove filtered_indices

* make create_dataset and LMDBDatabase importable from datasets

* create_dataset cleanup

* test create_dataset

* use metadata.natoms directly and add it to subset

* use self.indices to handle shard

* rename _data_sizes

* fix Subset of metadata

* minor change to metadata, added full path option

* import updates

* implement get_metadata for datasets; add tests for max_atoms and balanced partitioning

* a[:len(a)+1] does not throw error, change to check for this

* off by one fix

* fixing tests

* plug create_dataset into trainer

* remove datasetwithsizes; fix base dataset integration; replace close_db with __del__

* lint

* add/fix test;

* adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764)

* adding new notebook for using fairchem models with NEBs

* adding md tutorials

* blocking code cells that arent needed or take too long

* Add extra test case for local batch size = 1

* fix example

* fix test case

* reorg changes

* remove metadata_has_sizes in favor of basedataset function metadata_hasattr

* fix data_parallel typo

* fix up some tests

* rename get_metadata to sample_property_metadata

* add slow get_metadata for ase; add tests for get_metadata (ase+lmdb); add test for make lmdb metadata sizes

* add support for different backends and ddp in pytest

* fix tests and balanced batch sampler

* make default dataset lmdb

* lint

* fix tests

* test with world_size=0 by default

* fix tests

* fix tests..

* remove subsample from oc22 dataset

* remove old datasets; add test for noddp

* remove load balancing from docs

* fix docs; add train_split_settings and test for this

---------

Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: lbluque <[email protected]>
Co-authored-by: Brandon <[email protected]>
Co-authored-by: Brook Wander <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>

(cherry picked from commit 04a69b0)
@rayg1234 rayg1234 added minor Minor version release enhancement New feature or request labels Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request minor Minor version release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants