Skip to content

Commit

Permalink
Dev (#136)
Browse files Browse the repository at this point in the history
* comment improvements

* add __init__.py

* add disabling pin memory on fast dev run

* rename template_utils.py to utils.py

* improve README.md

* add num_classes property to mnist datamodule

* add sh to requirements.txt

* bump package versions
  • Loading branch information
Łukasz Zalewski authored May 3, 2021
1 parent d5da445 commit 023d4bd
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 28 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ The directory structure of new project looks like this:
│ ├── models <- Lightning models
│ ├── utils <- Utility scripts
│ │ ├── inference_example.py <- Example of inference with trained model
│ │ └── template_utils.py <- Extra features for the template
│ │ └── utils.py <- Extra features for the template
│ │
│ └── train.py <- Training pipeline
Expand Down Expand Up @@ -742,7 +742,7 @@ You can run DDP on mnist example with 4 GPUs like this:
python run.py trainer.gpus=4 +trainer.accelerator="ddp"
```
⚠️ When using DDP you have to be careful how you write your models - learn more [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html).
<br>
<br><br>


### Extra Features
Expand Down Expand Up @@ -1062,7 +1062,7 @@ some_parameter: ${datamodule: some_param}
```
When later accessing this field, say in your lightning model, it will get automatically resolved based on all resolvers that are registered. Remember not to access this field before datamodule is initialized. **You also need to set resolve to false in print_config() in [run.py](run.py) method or it will throw errors!**
```python
template_utils.print_config(config, resolve=False)
utils.print_config(config, resolve=False)
```

</details>
Expand Down
6 changes: 3 additions & 3 deletions configs/hparams_search/mnist_optuna.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# @package _global_

# example hyperparameter optimization of some experiment with Optuna:
# python run.py -m hparams_search=mnist_optuna experiment=exp_example_simple
# python run.py -m hparams_search=mnist_optuna experiment=exp_example_simple hydra.sweeper.n_trials=30
# python run.py -m hparams_search=mnist_optuna experiment=exp_example_simple logger=wandb
# python run.py -m hparams_search=mnist_optuna experiment=example_simple
# python run.py -m hparams_search=mnist_optuna experiment=example_simple hydra.sweeper.n_trials=30
# python run.py -m hparams_search=mnist_optuna experiment=example_simple logger=wandb

defaults:
- override /hydra/sweeper: optuna
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ pytorch-lightning>=1.2.10
torchmetrics

# --------- hydra --------- #
hydra-core==1.1.0.dev5
hydra-core==1.1.0.dev6
hydra-colorlog==1.1.0.dev1
hydra-optuna-sweeper==1.1.0.dev1
# hydra-ray-launcher==0.1.2
# hydra-submitit-launcher==1.1.0

# --------- loggers --------- #
wandb>=0.10.26
wandb>=0.10.28
# neptune-client
# comet-ml
# mlflow
Expand All @@ -21,6 +21,7 @@ wandb>=0.10.26
rich
python-dotenv
pytest
sh
pre-commit
scikit-learn
pandas
Expand Down
6 changes: 3 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ def main(config: DictConfig):
# Imports should be nested inside @hydra.main to optimize tab completion
# Read more here: https://github.com/facebookresearch/hydra/issues/934
from src.train import train
from src.utils import template_utils
from src.utils import utils

# A couple of optional utilities:
# - disabling python warnings
# - easier access to debug mode
# - forcing debug friendly configuration
# - forcing multi-gpu friendly configuration
# You can safely get rid of this line if you don't want those
template_utils.extras(config)
utils.extras(config)

# Pretty print config using Rich library
if config.get("print_config"):
template_utils.print_config(config, resolve=True)
utils.print_config(config, resolve=True)

# Train model
return train(config)
Expand Down
1 change: 1 addition & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#
2 changes: 2 additions & 0 deletions src/callbacks/wandb_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get Weights&Biases logger from Trainer."""

if isinstance(trainer.logger, WandbLogger):
return trainer.logger

Expand Down
4 changes: 4 additions & 0 deletions src/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

@property
def num_classes(self) -> int:
return 10

def prepare_data(self):
"""Download data if needed. This method is called only from a single GPU.
Do not use it to assign state (self.x = y)."""
Expand Down
10 changes: 5 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import hydra
from omegaconf import DictConfig

from src.utils import template_utils
from src.utils import utils

log = template_utils.get_logger(__name__)
log = utils.get_logger(__name__)


def train(config: DictConfig) -> Optional[float]:
Expand Down Expand Up @@ -59,7 +59,7 @@ def train(config: DictConfig) -> Optional[float]:

# Send some parameters from config to all lightning loggers
log.info("Logging hyperparameters!")
template_utils.log_hyperparameters(
utils.log_hyperparameters(
config=config,
model=model,
datamodule=datamodule,
Expand All @@ -79,7 +79,7 @@ def train(config: DictConfig) -> Optional[float]:

# Make sure everything closed properly
log.info("Finalizing!")
template_utils.finish(
utils.finish(
config=config,
model=model,
datamodule=datamodule,
Expand All @@ -91,7 +91,7 @@ def train(config: DictConfig) -> Optional[float]:
# Print path to best checkpoint
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")

# Return metric score for Optuna optimization
# Return metric score for hyperparameter optimization
optimized_metric = config.get("optimized_metric")
if optimized_metric:
return trainer.callback_metrics[optimized_metric]
15 changes: 3 additions & 12 deletions src/utils/template_utils.py → src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def extras(config: DictConfig) -> None:
# enable adding new keys to config
OmegaConf.set_struct(config, False)

# disable python warnings if <config.disable_warnings=True>
# disable python warnings if <config.ignore_warnings=True>
if config.get("ignore_warnings"):
log.info(f"Disabling python warnings! <config.ignore_warnings=True>")
warnings.filterwarnings("ignore")
Expand All @@ -58,6 +58,8 @@ def extras(config: DictConfig) -> None:
# Debuggers don't like GPUs or multiprocessing
if config.trainer.get("gpus"):
config.trainer.gpus = 0
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):
config.datamodule.num_workers = 0

Expand Down Expand Up @@ -128,7 +130,6 @@ def log_hyperparameters(
"""This method controls which parameters from Hydra config are saved by Lightning loggers.
Additionaly saves:
- sizes of train, val, test dataset
- number of trainable model parameters
"""

Expand All @@ -143,16 +144,6 @@ def log_hyperparameters(
if "callbacks" in config:
hparams["callbacks"] = config["callbacks"]

# save sizes of each dataset
# (requires calling `datamodule.setup()` first to initialize datasets)
# datamodule.setup()
# if hasattr(datamodule, "data_train") and datamodule.data_train:
# hparams["datamodule/train_size"] = len(datamodule.data_train)
# if hasattr(datamodule, "data_val") and datamodule.data_val:
# hparams["datamodule/val_size"] = len(datamodule.data_val)
# if hasattr(datamodule, "data_test") and datamodule.data_test:
# hparams["datamodule/test_size"] = len(datamodule.data_test)

# save number of model parameters
hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
hparams["model/params_trainable"] = sum(
Expand Down

0 comments on commit 023d4bd

Please sign in to comment.