update dropout Dropout Reduces Underfitting
CHECK PLZ run_train_gpu.sh
and ./models/dense_model/
Example
# 1) update dropout function
def update_dropout(self, drop_rate):
self.drop_rate = drop_rate
for module in self.modules():
if isinstance(module, nn.Dropout):
module.p = drop_rate
# 2) we have to count each step (for calculate schedulering global steps)
def on_train_start(self):
from models.dense_model.drop_scheduler import drop_scheduler
self.drop_scheduler = {}
if self.args.dropout_p > 0.0:
self.drop_scheduler["do"] = drop_scheduler(
self.args.dropout_p,
self.args.max_epochs,
self.trainer.num_training_batches,
self.args.cutoff_epoch,
self.args.drop_mode,
self.args.drop_schedule,
)
print(
"on_train_start :: Min DO = %.7f, Max DO = %.7f"
% (min(self.drop_scheduler["do"]), max(self.drop_scheduler["do"]))
)
# 3) Finally, you can scheduling dropout prob in your training_step
if "do" in self.drop_scheduler:
dropout_p = self.drop_scheduler["do"][self.trainer.global_step]
self.update_dropout(dropout_p)
self.log("dropout_p", dropout_p, sync_dist=(self.device != "cpu"))
--dropout_p=0.1 \
--cutoff_epoch=1 \
--drop_mode=standard \
--drop_schedule=constant
If you want to use, normal style dropout, input dropout_p and drop_mode=standard
(default) and drop_schedule=constant
(default)
you can check your dropout scheduling process in wandb
very simple but, write down is boring
boring boiling code rolling ⚡
If you need some function or someting, plz comment issues (plz write eng or ko). I reply and implement ASAP!!
- DataModule more detail: PyTorch-Lightning Dev Guide
- Model more detail: PyTorch-Lightning Dev Guide
- Inference more detail: PyTorch-Lightning Dev Guide
- WanDB with lightning more detail
https://docs.wandb.ai/v/ko/quickstart
- Using DDP, Not DP or CPU
Maybe want to using DP or CPU, Change some argument or python Script
See more detail: PyTorch-Lightning Dev Guide - Optimizer: AdamW
- LearningRate Scheduler: OneCycleLR
- see more detail: PyTorch Dev Guide
- Monitoring Tool: WanDB
- train.py(main) -> argparse
- using simple_parsing library looks like HFArgumentParser
- Trainer Argument placed with
pl.Trainer.add_argparse_args
(automatic define argparse)
- def] WandbLogger, set seed(os, random, np, torch, torch.cuda)
- def] CustomDataModule (
LightningDataModule
)- You Not have to using
LightningDataModule
. but, if you implement that in 'LightningModule', source code is looked mess - DataModule important
prepare_data
andsetup
prepare_data
is only run on cpu and not multi processing (Warning, if you using distributed learning, this place's variable is not share)- I recommand, It just using data download or datasets save
setup
is run on gpu or cpu and distributed. using map or dataload or something!setup
can have stage(fit (train), test, predict)
- DataModule can have each stage's dataloader
- using default or someting
- Dataset can define this section or making each python script and just import & using!
- You Not have to using
- def] CustomNet (
LightningModule
)- each step and step_end or epoch, epoch_end
- i think using just training_step, validation_step, validation_epoch_end is simple and best
- training_step -> forward -> configure_optimizers
- when count in each validation step (each batch step validation) -> validation_epoch_end (all batch result gather) -> log (on wandb)
- wandb logger additional setting
- checkpoint setting
- monitor name is same on your each step's log name
- learning_rate monitor setting
- ddp strategy modify
- if your dataset is so big to ddp, timeout parameter change like that
- huggingface is so hard to make it. but lightning is feel free
- make trainer to your arg
- training run and model save!
- cd your project root(./pytorch-lightning-template)
# Don't Script RUN in your scripts FOLDER!!!!! CHK PLZ!!!!!!!
bash scripts/run_train_~~~.sh
- inference.py(main) -> argparse
- set seed
- model load (second param is your model init param)
- simply torch inference & END!
- cd your project root(./pytorch-lightning-template)
# Don't Script RUN in your scripts FOLDER!!!!! CHK PLZ!!!!!!!
bash scripts/run_inference~~~.sh
- run pip_install_deepspeed.sh
bash pip_install_deepspeed.sh