Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch distributed #123

Open
wants to merge 2 commits into
base: pytorch-v1.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,30 +135,30 @@ python -m torch.distributed.launch --nproc_per_node=4 tools/train.py --cfg exper

For example, evaluating our model on the Cityscapes validation set with multi-scale and flip testing:
````bash
python tools/test.py --cfg experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml \
python -m torch.distributed.launch --nproc_per_node=4 tools/test.py --cfg experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml \
TEST.MODEL_FILE hrnet_w48_cityscapes_cls19_1024x2048_trainset.pth \
TEST.SCALE_LIST 0.5,0.75,1.0,1.25,1.5,1.75 \
TEST.FLIP_TEST True
````
Evaluating our model on the Cityscapes test set with multi-scale and flip testing:
````bash
python tools/test.py --cfg experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml \
python -m torch.distributed.launch --nproc_per_node=4 tools/test.py --cfg experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml \
DATASET.TEST_SET list/cityscapes/test.lst \
TEST.MODEL_FILE hrnet_w48_cityscapes_cls19_1024x2048_trainset.pth \
TEST.SCALE_LIST 0.5,0.75,1.0,1.25,1.5,1.75 \
TEST.FLIP_TEST True
````
Evaluating our model on the PASCAL-Context validation set with multi-scale and flip testing:
````bash
python tools/test.py --cfg experiments/pascal_ctx/seg_hrnet_w48_cls59_480x480_sgd_lr4e-3_wd1e-4_bs_16_epoch200.yaml \
python -m torch.distributed.launch --nproc_per_node=4 tools/test.py --cfg experiments/pascal_ctx/seg_hrnet_w48_cls59_480x480_sgd_lr4e-3_wd1e-4_bs_16_epoch200.yaml \
DATASET.TEST_SET testval \
TEST.MODEL_FILE hrnet_w48_pascal_context_cls59_480x480.pth \
TEST.SCALE_LIST 0.5,0.75,1.0,1.25,1.5,1.75,2.0 \
TEST.FLIP_TEST True
````
Evaluating our model on the LIP validation set with flip testing:
````bash
python tools/test.py --cfg experiments/lip/seg_hrnet_w48_473x473_sgd_lr7e-3_wd5e-4_bs_40_epoch150.yaml \
python -m torch.distributed.launch --nproc_per_node=4 tools/test.py --cfg experiments/lip/seg_hrnet_w48_473x473_sgd_lr7e-3_wd5e-4_bs_40_epoch150.yaml \
DATASET.TEST_SET list/lip/testvalList.txt \
TEST.MODEL_FILE hrnet_w48_lip_cls20_473x473.pth \
TEST.FLIP_TEST True \
Expand Down
14 changes: 12 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def parse_args():
help='experiment configure file name',
required=True,
type=str)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
Expand All @@ -60,6 +61,9 @@ def main():
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
gpus = list(config.GPUS)
distributed = len(gpus) > 1
device = torch.device('cuda:{}'.format(args.local_rank))

# build model
model = eval('models.'+config.MODEL.NAME +
Expand Down Expand Up @@ -87,8 +91,14 @@ def main():
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

gpus = list(config.GPUS)
model = nn.DataParallel(model, device_ids=gpus).cuda()
model.to(device)
if (True): # original code is [if distributed:]
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://",
)
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank)

# prepare data
test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
Expand Down
9 changes: 4 additions & 5 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ def main():
shutil.rmtree(models_dst_dir)
shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://",
)
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://",
)

# prepare data
crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
Expand Down