Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add balanced sampling #2

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cfg/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ training:
batch_size: 64
amp_level: "O1"
precision: 16
balanced: False

ckpt:
resume_from: null
Expand Down
18 changes: 16 additions & 2 deletions dataloading/LitDataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import random
import pandas as pd
import numpy as np
from typing import Optional, Generator, Union, IO, Dict, Callable
from pathlib import Path
from braceexpand import braceexpand
Expand All @@ -13,6 +14,8 @@
from dataloading.decoders import decoder2in_chans
from dataloading.augmentations import get_train_transforms, get_valid_transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from dataloading.samplers import DistributedProxySampler
import torch

def cat_collate_fn(data):
Expand Down Expand Up @@ -93,6 +96,16 @@ def setup(self, stage: Optional[str] = None):

# Shuffle
self.dataset = self.dataset.sample(frac=1).reset_index(drop=True)

# make sampler for balanced training
if self.args.training.balanced:
assert self.args.dataset.pair_constraint == False, 'pair constraint is always balanced, set it to False or null'
count = self.dataset.loc[self.dataset.fold == 1].groupby(['label']).count()['image_name'].to_dict()
balance_weight = np.array([1./count[k] for k in self.dataset.loc[self.dataset.fold == 1, 'label']])
balance_weight = torch.from_numpy(balance_weight)
self.sampler = WeightedRandomSampler(balance_weight.type('torch.DoubleTensor'), len(balance_weight))
if len(self.args.training.gpus or '') > 1:
self.sampler = DistributedProxySampler(self.sampler)

self.train_dataset = retriever(
data_path=self.args.dataset.data_path,
Expand Down Expand Up @@ -145,8 +158,9 @@ def train_dataloader(self):
drop_last=True,
batch_size=self.args.training.batch_size,
num_workers=self.args.dataset.num_workers,
sampler=self.sampler if self.args.training.balanced else None,
collate_fn=cat_collate_fn if self.args.dataset.pair_constraint else None,
shuffle=True)
shuffle=False if self.args.training.balanced else True)
return loader

def val_dataloader(self):
Expand All @@ -163,4 +177,4 @@ def test_dataloader(self):
num_workers=self.args.dataset.num_workers,
collate_fn=cat_collate_fn if self.args.dataset.pair_constraint else None,
shuffle=False)
return loader
return loader
1 change: 1 addition & 0 deletions dataloading/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def rjca_decode(path):
def gray_spatial_decode(path):
image = load_or_pass(path, np.ndarray, cv2.imread)
image = image[:,:,:1].astype(np.float32)
image /= 255.0
return image

def cost_map_decode(path, cover_path, payload):
Expand Down
41 changes: 41 additions & 0 deletions dataloading/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
From https://github.com/pytorch/pytorch/issues/23430#issuecomment-562350407
"""
import torch
from torch.utils.data.distributed import DistributedSampler

class DistributedProxySampler(DistributedSampler):
"""Sampler that restricts data loading to a subset of input sampler indices.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Input sampler is assumed to be of constant size.
Arguments:
sampler: Input data sampler.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""

def __init__(self, sampler, num_replicas=None, rank=None):
super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
self.sampler = sampler

def __iter__(self):
# deterministically shuffle based on epoch
torch.manual_seed(self.epoch)
indices = list(self.sampler)

# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
if len(indices) != self.total_size:
raise RuntimeError("{} vs {}".format(len(indices), self.total_size))

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
if len(indices) != self.num_samples:
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))

return iter(indices)
1 change: 1 addition & 0 deletions tests/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ training:
batch_size: 8
amp_level: "O1"
precision: 32
balanced: False

ckpt:
resume_from: null
Expand Down
1 change: 1 addition & 0 deletions train_lit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def main(args):
accelerator='ddp' if len(args.training.gpus or '') > 1 else None,
benchmark=True,
sync_batchnorm=len(args.training.gpus or '') > 1,
replace_sampler_ddp=False if args.training.balanced else True,
resume_from_checkpoint=args.ckpt.resume_from)

trainer.logger.log_hyperparams(model.hparams)
Expand Down