Skip to content

Commit

Permalink
Updating to pytorch 1.6 (Docker, mixed precision), update of omegaconf.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Aug 1, 2020
1 parent d654b33 commit 9efde83
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 70 deletions.
5 changes: 3 additions & 2 deletions direct/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from direct.config import BaseConfig
from direct.data.datasets_config import DatasetConfig

from typing import Optional, List, Callable
from typing import Optional, List


@dataclass
Expand Down Expand Up @@ -67,11 +67,12 @@ class TrainingConfig(BaseConfig):
# Metrics
metrics: List[str] = field(default_factory=lambda: [])


@dataclass
class ValidationConfig(BaseConfig):
datasets: List[DatasetConfig] = field(default_factory=lambda: [DatasetConfig()])
batch_size: int = 8
metrics: Optional[List[str]] = None
metrics: List[str] = field(default_factory=lambda: [])


@dataclass
Expand Down
80 changes: 45 additions & 35 deletions direct/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import numpy as np
import warnings

from typing import Optional, Dict, Tuple, List, Callable
from apex import amp
from typing import Optional, Dict, Tuple, List
from abc import abstractmethod, ABC

from torch import nn
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler
from torch.cuda.amp import GradScaler

from direct.data.mri_transforms import AddNames
from direct.data import sampler
Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(self, cfg: BaseConfig,
# TODO: This might not be needed, if these objects are changed in-place
self.__optimizer = None
self.__lr_scheduler = None
self._scaler = GradScaler(enabled=self.mixed_precision)
self.__writers = None
self.__bind_sigint_signal()

Expand All @@ -71,6 +72,11 @@ def build_metrics(metrics_list) -> Dict:

@abstractmethod
def _do_iteration(self, *args, **kwargs) -> Tuple[torch.Tensor, Dict]:
"""
This is a placeholder for the iteration function. This needs to perform the backward pass.
If using mixed-precision you need to implement `autocast` as well in this function.
It is recommended you raise an error if `self.mixed_precision` is true but mixed precision is not available.
"""
pass

@torch.no_grad()
Expand Down Expand Up @@ -172,30 +178,28 @@ def training_loop(self,

# Gradient accumulation
if (iter_idx + 1) % self.cfg.training.gradient_steps == 0: # type: ignore
# TODO: Is this slow? This is a generator, so should be cheap.
# parameter_list = self.model.parameters() if not self.mixed_precision else \
# amp.master_params(self.__optimizer)
if self.cfg.training.gradient_steps > 1: # type: ignore
warnings.warn('Gradient accumulation set. Currently not implemented. '
'This message will only be displayed once.')
# for parameter in parameter_list:
# if parameter.grad is not None:
# # In-place division
# parameter.grad.div_(self.cfg.training.gradient_steps) # type: ignore
for parameter in self.model.parameters():
if parameter.grad is not None:
# In-place division
parameter.grad.div_(self.cfg.training.gradient_steps) # type: ignore
if self.cfg.training.gradient_clipping > 0.0: # type: ignore
warnings.warn('Gradient clipping set. Currently not implemented. '
'This message will only be displayed once.')
# torch.nn.utils.clip_grad_norm_(parameter_list, self.cfg.training.gradient_clipping) # type: ignore
#
self._scaler.unscale_(self.__optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.training.gradient_clipping)

# Gradient norm
if self.cfg.training.gradient_debug: # type: ignore
warnings.warn(f'gradient debug set. Currently not implemented. '
warnings.warn(f'Gradient debug set. This will affect training performance. Only use for debugging.'
f'This message will only be displayed once.')
# parameters = list(filter(lambda p: p.grad is not None, parameter_list))
# gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt() # typing: ignore
# storage.add_scalar('gradient_norm', gradient_norm)
parameters = list(filter(lambda p: p.grad is not None, self.model.parameters()))
gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt() # typing: ignore
storage.add_scalar('train/gradient_norm', gradient_norm)

# Same as self.__optimizer.step() for mixed precision.
self._scaler.step(self.__optimizer)
# Updates the scale for next iteration.
self._scaler.update()

self.__optimizer.step() # type: ignore
# Incorrect inference by mypy and pyflake
self.__lr_scheduler.step() # type: ignore # noqa
storage.add_scalar('lr', self.__optimizer.param_groups[0]['lr'], smoothing_hint=False)
Expand Down Expand Up @@ -321,27 +325,29 @@ def train(self,

# Mixed precision setup. This requires the model to be on the gpu.
git_hash = direct.utils.git_hash()
extra_checkpointing = {'__author__': git_hash if git_hash else 'N/A'}
if self.mixed_precision > 0:
opt_level = f'O{self.mixed_precision}'
self.logger.info(f'Using apex level {opt_level}.')
self.model, self.__optimizer = amp.initialize(self.model, self.__optimizer, opt_level=opt_level)
extra_checkpointing['amp'] = amp
extra_checkpointing['opt_level'] = opt_level
extra_checkpointing = {'__author__': git_hash if git_hash else 'N/A',
'__mixed_precision__': self.mixed_precision,
}
if self.mixed_precision:
# TODO(jt): Check if on GPU
self.logger.info(f'Using mixed precision training.')

self.checkpointer = Checkpointer(
self.model, experiment_directory,
save_to_disk=communication.is_main_process(),
optimizer=optimizer, lr_scheduler=lr_scheduler, **self.models, **extra_checkpointing)
optimizer=optimizer,
lr_scheduler=lr_scheduler,
scaler=self._scaler,
**self.models,
**extra_checkpointing)

# Load checkpoint
start_iter = 0
checkpoint = {}
if resume:
self.logger.info('Attempting to resume...')
# This changes the model inplace
checkpoint = self.checkpointer.load(
iteration='latest', checkpointable_objects=['amp'] if self.mixed_precision > 0 else [])
checkpoint = self.checkpointer.load(iteration='latest')
if not checkpoint:
self.logger.info('No checkpoint found. Starting from scratch.')
else:
Expand All @@ -364,17 +370,21 @@ def train(self,

if '__datetime__' in checkpoint:
self.logger.info(f"Checkpoint created at: {checkpoint['__datetime__']}.")
if 'opt_level' in checkpoint:
if checkpoint['opt_level'] != opt_level:
self.logger.warning(f"Mixed precision opt-levels do not match. "
f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. "
if '__mixed_precision__' in checkpoint:
if (not self.mixed_precision) and checkpoint['__mixed_precision__']:
self.logger.warning(f'Mixed precision training is not enabled, yet saved checkpoint requests this'
f'Will now enable mixed precision.')
self.mixed_precision = True
elif not checkpoint['__mixed_precision__'] and self.mixed_precision:
self.logger.warning(f"Mixed precision levels of training and loading checkpoint do not match. "
f"Requested mixed precision but checkpoint is saved without. "
f"This will almost surely lead to performance degradation.")

self.logger.info(f'World size: {communication.get_world_size()}.')
self.logger.info(f'Device count: {torch.cuda.device_count()}.')
if communication.get_world_size() > 1:
self.model = DistributedDataParallel(
self.model, device_ids=[communication.get_rank()], broadcast_buffers=False)
self.model, device_ids=[communication.get_local_rank()], broadcast_buffers=False)

# World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
elif torch.cuda.device_count() > 1 and communication.get_world_size() == 1:
Expand Down
4 changes: 2 additions & 2 deletions direct/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
logger = logging.getLogger(__name__)


def setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, debug=False):
def setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False):
experiment_dir = base_directory / run_name

if communication.get_local_rank() == 0:
Expand Down Expand Up @@ -86,7 +86,7 @@ def setup_environment(run_name, base_directory, cfg_filename, device, machine_ra
logger.error(f'Engine does not exist for {cfg_from_file.model_name} (err = {e}).')
sys.exit(-1)

engine = engine_class(cfg, model, device=device)
engine = engine_class(cfg, model, device=device, mixed_precision=mixed_precision)

return cfg, experiment_dir, forward_operator, backward_operator, engine

Expand Down
44 changes: 20 additions & 24 deletions direct/nn/rim/rim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import torch

from apex import amp
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast

from direct.config import BaseConfig
from direct.data.mri_transforms import AddNames
Expand All @@ -20,7 +20,6 @@
from direct.functionals import SSIM



class RIMEngine(Engine):
def __init__(self, cfg: BaseConfig,
model: nn.Module,
Expand Down Expand Up @@ -54,30 +53,27 @@ def _do_iteration(self,
sensitivity_map_norm = modulus(sensitivity_map).sum('coil')
data['sensitivity_map'] = safe_divide(sensitivity_map, sensitivity_map_norm)

reconstruction_iter, hidden_state = self.model(
**data,
input_image=input_image,
hidden_state=hidden_state,
)
# TODO: Unclear why this refining is needed.
output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width')

loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()}
for output_image_iter in reconstruction_iter:
for k, v in loss_dict.items():
loss_dict[k] = v + loss_fns[k](
output_image_iter.rename(None), target.rename(None), reduction='mean'
)

loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()}
loss = sum(loss_dict.values())
with autocast(enabled=self.mixed_precision):
reconstruction_iter, hidden_state = self.model(
**data,
input_image=input_image,
hidden_state=hidden_state,
)
# TODO: Unclear why this refining is needed.
output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width')

loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()}
for output_image_iter in reconstruction_iter:
for k, v in loss_dict.items():
loss_dict[k] = v + loss_fns[k](
output_image_iter.rename(None), target.rename(None), reduction='mean'
)

loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()}
loss = sum(loss_dict.values())

if self.model.training:
if self.mixed_precision:
with amp.scale_loss(loss, self.__optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward() # type: ignore
self._scaler.scale(loss).backward()

# Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
hidden_state = hidden_state.detach()
Expand Down
15 changes: 15 additions & 0 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors
from dataclasses import dataclass
from typing import Tuple, List, Any

from direct.config.defaults import ModelConfig


@dataclass
class Unet2DConfig(ModelConfig):
in_channels: int = 1
out_channels: int = 1
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
4 changes: 1 addition & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,13 @@ RUN python -m pip install opencv-python simpleitk h5py -q
RUN python -m pip install runstats -q
RUN python -m pip install tb-nightly -q
RUN python -m pip install --pre omegaconf -q
RUN python -m pip pyxb
RUN python -m pip install pyxb
RUN python -m pip install git+https://github.com/ismrmrd/ismrmrd-python.git

USER root
# Create directories for input and output
RUN mkdir /data
RUN mkdir /direct && chmod 777 /direct
# Directory to copy or symlink data into. Needs to be 777 as DIRECT runs typically without root permissions.
RUN mkdir /input && chmod 777 /input

USER direct

Expand Down
1 change: 0 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mock>=4.0.2
recommonmark>=0.6.0
numpy>=1.18.1
omegaconf>=2.0.0rc24
apex>=0.1
h5py>=2.10.0
simpleitk>=1.2.4
ipython>=7.11.1
Expand Down
8 changes: 5 additions & 3 deletions tools/train_rim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@


def setup_train(run_name, training_root, validation_root, base_directory,
cfg_filename, checkpoint, device, num_workers, resume, machine_rank, debug):
cfg_filename, checkpoint, device, num_workers, resume, machine_rank, mixed_precision, debug):

cfg, experiment_directory, forward_operator, backward_operator, engine = \
setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, debug=debug)
setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=debug)

# Create training and validation data
# Transforms configuration
Expand Down Expand Up @@ -109,6 +109,8 @@ def setup_train(run_name, training_root, validation_root, base_directory,
'this flag is ignored.'
)
parser.add_argument('--resume', help='Resume training if possible.', action='store_true')
parser.add_argument('--mixed-precision', help='Use mixed precision training.', action='store_true')


args = parser.parse_args()

Expand All @@ -123,5 +125,5 @@ def setup_train(run_name, training_root, validation_root, base_directory,
run_name, args.training_root, args.validation_root, args.experiment_directory,
args.cfg_file, args.initialization_checkpoint,
args.device, args.num_workers, args.resume,
args.machine_rank, args.debug)
args.machine_rank, args.mixed_precision, args.debug)

0 comments on commit 9efde83

Please sign in to comment.