Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Balanced batch sampler+base dataset #753

Merged
merged 67 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
ae4add3
Update BalancedBatchSampler to use datasets' `data_sizes` method
nimashoghi Aug 24, 2023
01fe2b4
Remove python 3.10 syntax
nimashoghi Aug 24, 2023
2bf8213
Documentation
nimashoghi Aug 24, 2023
7ba5b8a
Added set_epoch method
Aug 28, 2023
a367d1e
Format
Aug 30, 2023
46e3c57
Changed "resolved dataset" message to be a debug log to reduce log spam
Aug 30, 2023
845bce3
update with main
lbluque Apr 30, 2024
86da069
clean up batchsampler and tests
lbluque Apr 30, 2024
ff628dd
base dataset class
lbluque May 1, 2024
122197f
move lin_ref to base dataset
lbluque May 1, 2024
fb4ce16
inherit basedataset for ase dataset
lbluque May 1, 2024
c9e1759
filter indices prop
lbluque May 3, 2024
95d3e6f
added create_dataset fn
wood-b May 6, 2024
b6c640e
yaml load fix
lbluque May 7, 2024
7fa1904
create dataset function instead of filtering in base
lbluque May 7, 2024
04c96bf
remove filtered_indices
lbluque May 7, 2024
ea35b57
make create_dataset and LMDBDatabase importable from datasets
lbluque May 8, 2024
dc98285
create_dataset cleanup
lbluque May 8, 2024
2339916
test create_dataset
lbluque May 8, 2024
9b58cc7
use metadata.natoms directly and add it to subset
lbluque May 10, 2024
63c03fc
use self.indices to handle shard
lbluque May 10, 2024
76322aa
rename _data_sizes
lbluque May 14, 2024
bb41b13
merge with main-legacy + no more data_sizes
lbluque May 15, 2024
b4e22bc
fix Subset of metadata
lbluque May 15, 2024
29b6e68
minor change to metadata, added full path option
wood-b May 17, 2024
f9b15cd
Merge branch 'main' into balanced-batch-sampler+base-dataset
wood-b May 18, 2024
dc59f96
import updates
wood-b May 18, 2024
64b8df2
implement get_metadata for datasets; add tests for max_atoms and bala…
misko May 20, 2024
80fea27
a[:len(a)+1] does not throw error, change to check for this
misko May 21, 2024
7447bdb
off by one fix
misko May 22, 2024
b4f272d
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Jul 11, 2024
9a4af0c
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Jul 11, 2024
48e1d4f
fixing tests
misko Jul 11, 2024
70508c4
plug create_dataset into trainer
misko Jul 12, 2024
6b3c012
remove datasetwithsizes; fix base dataset integration; replace close_…
misko Jul 15, 2024
5f89b7d
merge main
misko Jul 15, 2024
c50f8a6
lint
misko Jul 15, 2024
a641e9d
add/fix test;
misko Jul 15, 2024
5743a59
adding new notebook for using fairchem models with NEBs without CatTS…
brookwander Jul 16, 2024
ad6a20b
Add extra test case for local batch size = 1
misko Jul 17, 2024
81f448b
fix example
misko Jul 17, 2024
6ce72fc
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Jul 17, 2024
bf0e9d8
fix test case
misko Jul 17, 2024
30f468d
reorg changes
misko Jul 17, 2024
9e6de95
remove metadata_has_sizes in favor of basedataset function metadata_h…
misko Jul 17, 2024
aa710bd
fix data_parallel typo
misko Jul 17, 2024
78c03cc
fix up some tests
misko Jul 17, 2024
fc74fa8
Merge branch 'main' into balanced-batch-sampler+base-dataset
mshuaibii Jul 19, 2024
c032641
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Jul 19, 2024
c73a885
more general evaluator
mshuaibii Jul 11, 2023
c534968
rename get_metadata to sample_property_metadata
misko Jul 23, 2024
65b84ec
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Jul 23, 2024
0219e36
add slow get_metadata for ase; add tests for get_metadata (ase+lmdb);…
misko Jul 23, 2024
2014801
add support for different backends and ddp in pytest
misko Jul 23, 2024
fbfe627
fix tests and balanced batch sampler
misko Jul 23, 2024
69b5492
make default dataset lmdb
misko Jul 23, 2024
2bb2a19
lint
misko Jul 23, 2024
1a8fb90
fix tests
misko Jul 23, 2024
43da6fc
test with world_size=0 by default
misko Jul 23, 2024
8508e2f
fix tests
misko Jul 23, 2024
4426870
fix tests..
misko Jul 24, 2024
1e2e4aa
remove subsample from oc22 dataset
misko Jul 29, 2024
795689f
remove old datasets; add test for noddp
misko Jul 29, 2024
57d86ab
remove load balancing from docs
misko Jul 30, 2024
e57d701
fix docs; add train_split_settings and test for this
misko Aug 1, 2024
79f4132
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Aug 1, 2024
57d668d
Merge branch 'main' into balanced-batch-sampler+base-dataset
misko Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/core/fine-tuning/fine-tuning-oxides.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ from fairchem.core.common.tutorial_utils import generate_yml_config
yml = generate_yml_config(checkpoint_path, 'config.yml',
delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes',
'optim.loss_force', # the checkpoint setting causes an error
'optim.load_balancing',
'dataset', 'test_dataset', 'val_dataset'],
update={'gpus': 1,
'optim.eval_every': 10,
Expand Down
236 changes: 119 additions & 117 deletions src/fairchem/core/common/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@

import heapq
import logging
from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Literal

import numba
import numpy as np
import numpy.typing as npt
import torch
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
import torch.distributed
from torch.utils.data import BatchSampler, Dataset, DistributedSampler
from typing_extensions import override

from fairchem.core.common import distutils, gp_utils
from fairchem.core.datasets import data_list_collater
from fairchem.core.datasets.base_dataset import (
UnsupportedDatasetError,
)

if TYPE_CHECKING:
from pathlib import Path

from numpy.typing import NDArray
from torch_geometric.data import Batch, Data


Expand All @@ -35,30 +38,24 @@ def __call__(self, data_list: list[Data]) -> Batch:


@numba.njit
def balanced_partition(sizes: npt.NDArray[np.int_], num_parts: int):
def _balanced_partition(sizes: NDArray[np.int_], num_parts: int):
"""
Greedily partition the given set by always inserting
the largest element into the smallest partition.
"""
sort_idx = np.argsort(-sizes) # Sort in descending order
heap: list[tuple[list[int], list[int]]] = [
(sizes[idx], [idx]) for idx in sort_idx[:num_parts]
]
heap = [(sizes[idx], [idx]) for idx in sort_idx[:num_parts]]
heapq.heapify(heap)
for idx in sort_idx[num_parts:]:
smallest_part = heapq.heappop(heap)
new_size = smallest_part[0] + sizes[idx]
new_idx = smallest_part[1] + [idx]
new_idx = smallest_part[1] + [
idx
] # TODO should this be append to save time/space
heapq.heappush(heap, (new_size, new_idx))
return [part[1] for part in heap]


@runtime_checkable
class _HasMetadata(Protocol):
@property
def metadata_path(self) -> Path: ...


class StatefulDistributedSampler(DistributedSampler):
"""
More fine-grained state DataSampler that uses training iteration and epoch
Expand Down Expand Up @@ -105,149 +102,154 @@ def set_epoch_and_start_iteration(self, epoch, start_iter):
self.start_iter = start_iter


class BalancedBatchSampler(Sampler):
def _load_dataset(self, dataset, mode: Literal["atoms", "neighbors"]):
errors: list[str] = []
if not isinstance(dataset, _HasMetadata):
errors.append(f"Dataset {dataset} does not have a metadata_path attribute.")
return None, errors
if not dataset.metadata_path.exists():
errors.append(f"Metadata file {dataset.metadata_path} does not exist.")
return None, errors
def _ensure_supported(dataset: Any):
if not isinstance(dataset, Dataset):
raise UnsupportedDatasetError("BalancedBatchSampler requires a dataset.")

if not dataset.metadata_hasattr("natoms"):
raise UnsupportedDatasetError(
"BalancedBatchSampler requires a dataset that has a metadata attributed with number of atoms."
)

key = {"atoms": "natoms", "neighbors": "neighbors"}[mode]
sizes = np.load(dataset.metadata_path)[key]
logging.debug(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
return dataset

return sizes, errors

class BalancedBatchSampler(BatchSampler):
def __init__(
self,
dataset,
dataset: Dataset,
*,
batch_size: int,
num_replicas: int,
rank: int,
device: torch.device,
seed: int,
mode: str | bool = "atoms",
mode: bool | Literal["atoms"] = "atoms",
shuffle: bool = True,
on_error: Literal["warn_and_balance", "warn_and_no_balance", "raise"] = "raise",
drop_last: bool = False,
force_balancing: bool = False,
throw_on_error: bool = False,
) -> None:
if mode is True:
mode = "atoms"

if isinstance(mode, str):
mode = mode.lower()
if mode not in ("atoms", "neighbors"):
raise ValueError(
f"Invalid mode {mode}. Must be one of 'atoms', 'neighbors', or a boolean."
)
):
"""
Initializes a BalancedBatchSampler object.

self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas
self.rank = rank
self.device = device
self.mode = mode
self.shuffle = shuffle
self.drop_last = drop_last
Args:
dataset (Dataset): The dataset to sample from.
batch_size (int): The size of each batch.
num_replicas (int): The number of processes participating in distributed training.
rank (int): The rank of the current process in distributed training.
device (torch.device): The device to use for the batches.
mode (str or bool, optional): The mode to use for balancing the batches. Defaults to "atoms".
shuffle (bool, optional): Whether to shuffle the samples. Defaults to True.
on_error (Literal["warn_and_balance", "warn_and_no_balance", "raise"], optional): The action to take when an error occurs (i.e., when we have an invalid dataset). Defaults to "raise".
- "warn_and_balance": Raise a warning and balance the batch by manually loading the data samples and counting the number of nodes (this is slow).
- "warn_and_no_balance": Raise a warning and do not do any balancing.
- "raise": Raise an error.
drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to False.
"""
self.disabled = False
self.on_error = on_error

if mode is False:
logging.warning(f"Disabled BalancedBatchSampler because {mode=}.")
self.disabled = True
elif mode.lower() != "atoms":
raise ValueError(
f"Only mode='atoms' or mode=True is supported, got {mode=}."
)
elif num_replicas == 1:
logging.warning(f"Disabled BalancedBatchSampler because {num_replicas=}.")
self.disabled = True

try:
dataset = _ensure_supported(dataset)
except UnsupportedDatasetError as error:
if self.on_error == "raise":
raise error
if self.on_error == "warn_and_balance":
logging.warning(
f"Failed to get data sizes from metadata, loading data to get sizes (THIS IS SLOW). {error}"
)
elif self.on_error == "warn_and_no_balance":
logging.warning(
f"Failed to get data sizes, falling back to uniform partitioning. {error}"
)
else:
raise ValueError(f"Unknown on_error={self.on_error}") from error

self.single_sampler = StatefulDistributedSampler(
self.dataset,
sampler = StatefulDistributedSampler(
dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
drop_last=drop_last,
batch_size=batch_size,
seed=seed,
)
self.batch_sampler = BatchSampler(
self.single_sampler,
batch_size,
drop_last=drop_last,
)

self.sizes = None
self.balance_batches = False

if self.num_replicas <= 1:
logging.info("Batch balancing is disabled for single GPU training.")
return

if self.mode is False:
logging.info(
"Batch balancing is disabled because `optim.load_balancing` is `False`"
)
return

self.sizes, errors = self._load_dataset(dataset, self.mode)
if self.sizes is None:
self.balance_batches = force_balancing
if force_balancing:
errors.append(
"BalancedBatchSampler has to load the data to determine batch sizes, which incurs significant overhead! "
"You can disable balancing by setting `optim.load_balancing` to `False`."
)
else:
errors.append(
"Batches will not be balanced, which can incur significant overhead!"
)
else:
self.balance_batches = True

if errors:
msg = "BalancedBatchSampler: " + " ".join(errors)
if throw_on_error:
raise RuntimeError(msg)
super().__init__(sampler, batch_size=batch_size, drop_last=drop_last)
self.device = device

logging.warning(msg)
logging.info(
f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
)

def __len__(self) -> int:
return len(self.batch_sampler)
def _get_natoms(self, batch_idx: list[int]):
if self.sampler.dataset.metadata_hasattr("natoms"):
return np.array(
self.sampler.dataset.get_metadata("natoms", batch_idx)
).reshape(-1)
if self.on_error == "warn_and_balance":
return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx])
return None

def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None:
if not hasattr(self.single_sampler, "set_epoch_and_start_iteration"):
if not isinstance(self.sampler, StatefulDistributedSampler):
if start_iteration != 0:
raise NotImplementedError(
f"{type(self.single_sampler)} does not support resuming from a nonzero step."
)
self.single_sampler.set_epoch(epoch)
self.sampler.set_epoch(epoch)
else:
self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration)
self.sampler.set_epoch_and_start_iteration(epoch, start_iteration)

def set_epoch(self, epoch: int) -> None:
if isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)

@staticmethod
def _dist_enabled():
return torch.distributed.is_available() and torch.distributed.is_initialized()

@override
def __iter__(self):
if not self.balance_batches:
yield from self.batch_sampler
if self.disabled or not self._dist_enabled():
yield from super().__iter__()
return

for batch_idx in self.batch_sampler:
if self.sizes is None:
# Unfortunately, we need to load the data to know the image sizes
data_list = [self.dataset[idx] for idx in batch_idx]

if self.mode == "atoms":
sizes = [data.num_nodes for data in data_list]
elif self.mode == "neighbors":
sizes = [data.edge_index.shape[1] for data in data_list]
else:
raise NotImplementedError(
f"Unknown load balancing mode: {self.mode}"
)
else:
sizes = [self.sizes[idx] for idx in batch_idx]

idx_sizes = torch.stack([torch.tensor(batch_idx), torch.tensor(sizes)])
for batch_idx in super().__iter__():
sizes = self._get_natoms(batch_idx)
if sizes is None: # on_error == "warn_and_no_balance" is set
yield batch_idx
continue

idx_sizes = torch.stack(
mshuaibii marked this conversation as resolved.
Show resolved Hide resolved
[
torch.tensor(batch_idx, device=self.device),
torch.tensor(sizes, device=self.device),
]
)
idx_sizes_all = distutils.all_gather(idx_sizes, device=self.device)
idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu()
if gp_utils.initialized():
idx_sizes_all = torch.unique(input=idx_sizes_all, dim=1)
idx_all = idx_sizes_all[0]
sizes_all = idx_sizes_all[1]

local_idx_balanced = balanced_partition(
sizes_all.numpy(), num_parts=self.num_replicas
local_idx_balanced = _balanced_partition(
sizes_all.numpy(),
num_parts=self.sampler.num_replicas,
)
# Since DistributedSampler pads the last batch
# this should always have an entry for each replica.
yield idx_all[local_idx_balanced[self.rank]]
yield idx_all[local_idx_balanced[self.sampler.rank]]
6 changes: 3 additions & 3 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def setup(config) -> None:
)
else:
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend="nccl")
dist.init_process_group(backend=config.get("backend", "nccl"))


def cleanup() -> None:
Expand Down Expand Up @@ -144,7 +144,7 @@ def all_reduce(
if not isinstance(data, torch.Tensor):
tensor = torch.tensor(data)
if device is not None:
tensor = tensor.cuda(device)
tensor = tensor.to(device)
dist.all_reduce(tensor, group=group)
if average:
tensor /= get_world_size()
Expand All @@ -162,7 +162,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]:
if not isinstance(data, torch.Tensor):
tensor = torch.tensor(data)
if device is not None:
tensor = tensor.cuda(device)
tensor = tensor.to(device)
tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())]
dist.all_gather(tensor_list, tensor, group=group)
if not isinstance(data, torch.Tensor):
Expand Down
Loading
Loading