diff --git a/direct/config/defaults.py b/direct/config/defaults.py index e276d019..17f280b9 100644 --- a/direct/config/defaults.py +++ b/direct/config/defaults.py @@ -7,7 +7,7 @@ from direct.config import BaseConfig from direct.data.datasets_config import DatasetConfig -from typing import Optional, List, Callable +from typing import Optional, List @dataclass @@ -67,11 +67,12 @@ class TrainingConfig(BaseConfig): # Metrics metrics: List[str] = field(default_factory=lambda: []) + @dataclass class ValidationConfig(BaseConfig): datasets: List[DatasetConfig] = field(default_factory=lambda: [DatasetConfig()]) batch_size: int = 8 - metrics: Optional[List[str]] = None + metrics: List[str] = field(default_factory=lambda: []) @dataclass diff --git a/direct/engine.py b/direct/engine.py index 6ae49519..a52dd4dc 100644 --- a/direct/engine.py +++ b/direct/engine.py @@ -9,14 +9,14 @@ import numpy as np import warnings -from typing import Optional, Dict, Tuple, List, Callable -from apex import amp +from typing import Optional, Dict, Tuple, List from abc import abstractmethod, ABC from torch import nn from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler +from torch.cuda.amp import GradScaler from direct.data.mri_transforms import AddNames from direct.data import sampler @@ -50,6 +50,7 @@ def __init__(self, cfg: BaseConfig, # TODO: This might not be needed, if these objects are changed in-place self.__optimizer = None self.__lr_scheduler = None + self._scaler = GradScaler(enabled=self.mixed_precision) self.__writers = None self.__bind_sigint_signal() @@ -71,6 +72,11 @@ def build_metrics(metrics_list) -> Dict: @abstractmethod def _do_iteration(self, *args, **kwargs) -> Tuple[torch.Tensor, Dict]: + """ + This is a placeholder for the iteration function. This needs to perform the backward pass. + If using mixed-precision you need to implement `autocast` as well in this function. + It is recommended you raise an error if `self.mixed_precision` is true but mixed precision is not available. + """ pass @torch.no_grad() @@ -172,30 +178,28 @@ def training_loop(self, # Gradient accumulation if (iter_idx + 1) % self.cfg.training.gradient_steps == 0: # type: ignore - # TODO: Is this slow? This is a generator, so should be cheap. - # parameter_list = self.model.parameters() if not self.mixed_precision else \ - # amp.master_params(self.__optimizer) if self.cfg.training.gradient_steps > 1: # type: ignore - warnings.warn('Gradient accumulation set. Currently not implemented. ' - 'This message will only be displayed once.') - # for parameter in parameter_list: - # if parameter.grad is not None: - # # In-place division - # parameter.grad.div_(self.cfg.training.gradient_steps) # type: ignore + for parameter in self.model.parameters(): + if parameter.grad is not None: + # In-place division + parameter.grad.div_(self.cfg.training.gradient_steps) # type: ignore if self.cfg.training.gradient_clipping > 0.0: # type: ignore - warnings.warn('Gradient clipping set. Currently not implemented. ' - 'This message will only be displayed once.') - # torch.nn.utils.clip_grad_norm_(parameter_list, self.cfg.training.gradient_clipping) # type: ignore - # + self._scaler.unscale_(self.__optimizer) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.training.gradient_clipping) + # Gradient norm if self.cfg.training.gradient_debug: # type: ignore - warnings.warn(f'gradient debug set. Currently not implemented. ' + warnings.warn(f'Gradient debug set. This will affect training performance. Only use for debugging.' f'This message will only be displayed once.') - # parameters = list(filter(lambda p: p.grad is not None, parameter_list)) - # gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt() # typing: ignore - # storage.add_scalar('gradient_norm', gradient_norm) + parameters = list(filter(lambda p: p.grad is not None, self.model.parameters())) + gradient_norm = sum([parameter.grad.data ** 2 for parameter in parameters]).sqrt() # typing: ignore + storage.add_scalar('train/gradient_norm', gradient_norm) + + # Same as self.__optimizer.step() for mixed precision. + self._scaler.step(self.__optimizer) + # Updates the scale for next iteration. + self._scaler.update() - self.__optimizer.step() # type: ignore # Incorrect inference by mypy and pyflake self.__lr_scheduler.step() # type: ignore # noqa storage.add_scalar('lr', self.__optimizer.param_groups[0]['lr'], smoothing_hint=False) @@ -321,18 +325,21 @@ def train(self, # Mixed precision setup. This requires the model to be on the gpu. git_hash = direct.utils.git_hash() - extra_checkpointing = {'__author__': git_hash if git_hash else 'N/A'} - if self.mixed_precision > 0: - opt_level = f'O{self.mixed_precision}' - self.logger.info(f'Using apex level {opt_level}.') - self.model, self.__optimizer = amp.initialize(self.model, self.__optimizer, opt_level=opt_level) - extra_checkpointing['amp'] = amp - extra_checkpointing['opt_level'] = opt_level + extra_checkpointing = {'__author__': git_hash if git_hash else 'N/A', + '__mixed_precision__': self.mixed_precision, + } + if self.mixed_precision: + # TODO(jt): Check if on GPU + self.logger.info(f'Using mixed precision training.') self.checkpointer = Checkpointer( self.model, experiment_directory, save_to_disk=communication.is_main_process(), - optimizer=optimizer, lr_scheduler=lr_scheduler, **self.models, **extra_checkpointing) + optimizer=optimizer, + lr_scheduler=lr_scheduler, + scaler=self._scaler, + **self.models, + **extra_checkpointing) # Load checkpoint start_iter = 0 @@ -340,8 +347,7 @@ def train(self, if resume: self.logger.info('Attempting to resume...') # This changes the model inplace - checkpoint = self.checkpointer.load( - iteration='latest', checkpointable_objects=['amp'] if self.mixed_precision > 0 else []) + checkpoint = self.checkpointer.load(iteration='latest') if not checkpoint: self.logger.info('No checkpoint found. Starting from scratch.') else: @@ -364,17 +370,21 @@ def train(self, if '__datetime__' in checkpoint: self.logger.info(f"Checkpoint created at: {checkpoint['__datetime__']}.") - if 'opt_level' in checkpoint: - if checkpoint['opt_level'] != opt_level: - self.logger.warning(f"Mixed precision opt-levels do not match. " - f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. " + if '__mixed_precision__' in checkpoint: + if (not self.mixed_precision) and checkpoint['__mixed_precision__']: + self.logger.warning(f'Mixed precision training is not enabled, yet saved checkpoint requests this' + f'Will now enable mixed precision.') + self.mixed_precision = True + elif not checkpoint['__mixed_precision__'] and self.mixed_precision: + self.logger.warning(f"Mixed precision levels of training and loading checkpoint do not match. " + f"Requested mixed precision but checkpoint is saved without. " f"This will almost surely lead to performance degradation.") self.logger.info(f'World size: {communication.get_world_size()}.') self.logger.info(f'Device count: {torch.cuda.device_count()}.') if communication.get_world_size() > 1: self.model = DistributedDataParallel( - self.model, device_ids=[communication.get_rank()], broadcast_buffers=False) + self.model, device_ids=[communication.get_local_rank()], broadcast_buffers=False) # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging. elif torch.cuda.device_count() > 1 and communication.get_world_size() == 1: diff --git a/direct/environment.py b/direct/environment.py index 599fdd87..82b26a70 100644 --- a/direct/environment.py +++ b/direct/environment.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -def setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, debug=False): +def setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=False): experiment_dir = base_directory / run_name if communication.get_local_rank() == 0: @@ -86,7 +86,7 @@ def setup_environment(run_name, base_directory, cfg_filename, device, machine_ra logger.error(f'Engine does not exist for {cfg_from_file.model_name} (err = {e}).') sys.exit(-1) - engine = engine_class(cfg, model, device=device) + engine = engine_class(cfg, model, device=device, mixed_precision=mixed_precision) return cfg, experiment_dir, forward_operator, backward_operator, engine diff --git a/direct/nn/rim/rim_engine.py b/direct/nn/rim/rim_engine.py index 3d8272a6..d0524edf 100644 --- a/direct/nn/rim/rim_engine.py +++ b/direct/nn/rim/rim_engine.py @@ -5,10 +5,10 @@ import torch -from apex import amp from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader +from torch.cuda.amp import autocast from direct.config import BaseConfig from direct.data.mri_transforms import AddNames @@ -20,7 +20,6 @@ from direct.functionals import SSIM - class RIMEngine(Engine): def __init__(self, cfg: BaseConfig, model: nn.Module, @@ -54,30 +53,27 @@ def _do_iteration(self, sensitivity_map_norm = modulus(sensitivity_map).sum('coil') data['sensitivity_map'] = safe_divide(sensitivity_map, sensitivity_map_norm) - reconstruction_iter, hidden_state = self.model( - **data, - input_image=input_image, - hidden_state=hidden_state, - ) - # TODO: Unclear why this refining is needed. - output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width') - - loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()} - for output_image_iter in reconstruction_iter: - for k, v in loss_dict.items(): - loss_dict[k] = v + loss_fns[k]( - output_image_iter.rename(None), target.rename(None), reduction='mean' - ) - - loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()} - loss = sum(loss_dict.values()) + with autocast(enabled=self.mixed_precision): + reconstruction_iter, hidden_state = self.model( + **data, + input_image=input_image, + hidden_state=hidden_state, + ) + # TODO: Unclear why this refining is needed. + output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width') + + loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()} + for output_image_iter in reconstruction_iter: + for k, v in loss_dict.items(): + loss_dict[k] = v + loss_fns[k]( + output_image_iter.rename(None), target.rename(None), reduction='mean' + ) + + loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()} + loss = sum(loss_dict.values()) if self.model.training: - if self.mixed_precision: - with amp.scale_loss(loss, self.__optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() # type: ignore + self._scaler.scale(loss).backward() # Detach hidden state from computation graph, to ensure loss is only computed per RIM block. hidden_state = hidden_state.detach() diff --git a/direct/nn/unet/config.py b/direct/nn/unet/config.py new file mode 100644 index 00000000..81f5769d --- /dev/null +++ b/direct/nn/unet/config.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright (c) DIRECT Contributors +from dataclasses import dataclass +from typing import Tuple, List, Any + +from direct.config.defaults import ModelConfig + + +@dataclass +class Unet2DConfig(ModelConfig): + in_channels: int = 1 + out_channels: int = 1 + num_filters: int = 16 + num_pool_layers: int = 4 + dropout_probability: float = 0.0 diff --git a/docker/Dockerfile b/docker/Dockerfile index bda5cb33..9eb5db3f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -49,15 +49,13 @@ RUN python -m pip install opencv-python simpleitk h5py -q RUN python -m pip install runstats -q RUN python -m pip install tb-nightly -q RUN python -m pip install --pre omegaconf -q -RUN python -m pip pyxb +RUN python -m pip install pyxb RUN python -m pip install git+https://github.com/ismrmrd/ismrmrd-python.git USER root # Create directories for input and output RUN mkdir /data RUN mkdir /direct && chmod 777 /direct -# Directory to copy or symlink data into. Needs to be 777 as DIRECT runs typically without root permissions. -RUN mkdir /input && chmod 777 /input USER direct diff --git a/requirements_dev.txt b/requirements_dev.txt index 8e091a1d..86ee4852 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -3,7 +3,6 @@ mock>=4.0.2 recommonmark>=0.6.0 numpy>=1.18.1 omegaconf>=2.0.0rc24 -apex>=0.1 h5py>=2.10.0 simpleitk>=1.2.4 ipython>=7.11.1 diff --git a/tools/train_rim.py b/tools/train_rim.py index 189a97b6..46c92cfd 100644 --- a/tools/train_rim.py +++ b/tools/train_rim.py @@ -21,10 +21,10 @@ def setup_train(run_name, training_root, validation_root, base_directory, - cfg_filename, checkpoint, device, num_workers, resume, machine_rank, debug): + cfg_filename, checkpoint, device, num_workers, resume, machine_rank, mixed_precision, debug): cfg, experiment_directory, forward_operator, backward_operator, engine = \ - setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, debug=debug) + setup_environment(run_name, base_directory, cfg_filename, device, machine_rank, mixed_precision, debug=debug) # Create training and validation data # Transforms configuration @@ -109,6 +109,8 @@ def setup_train(run_name, training_root, validation_root, base_directory, 'this flag is ignored.' ) parser.add_argument('--resume', help='Resume training if possible.', action='store_true') + parser.add_argument('--mixed-precision', help='Use mixed precision training.', action='store_true') + args = parser.parse_args() @@ -123,5 +125,5 @@ def setup_train(run_name, training_root, validation_root, base_directory, run_name, args.training_root, args.validation_root, args.experiment_directory, args.cfg_file, args.initialization_checkpoint, args.device, args.num_workers, args.resume, - args.machine_rank, args.debug) + args.machine_rank, args.mixed_precision, args.debug)