Skip to content

Commit

Permalink
Merge pull request #14 from airaria/v0.1.10dev
Browse files Browse the repository at this point in the history
V0.1.10dev
  • Loading branch information
airaria committed Jul 16, 2020
2 parents 153be24 + 2c50bbc commit d69a96e
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 31 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ Paper: [https://arxiv.org/abs/2002.12620](https://arxiv.org/abs/2002.12620)

## Update

**Jul 14, 2020**
* Updated to 0.1.10:
* Now supports mixed precision training with Apex! Just set `fp16` to `True` in `TrainingConfig`. See the documentation of `TrainingConfig` for detail.
* Added `data_parallel` option in `TrainingConfig` to enable data parallel training and mixed precision training work together.

**Apr 26, 2020**

* Added Chinese NER task (MSRA NER) results.
Expand Down Expand Up @@ -116,6 +121,7 @@ See [Full Documentation](https://textbrewer.readthedocs.io/) for detailed usages
* NumPy
* tqdm
* Transformers >= 2.0 (optional, used by some examples)
* Apex == 0.1.0 (optional, mixed precision training)

* Install from PyPI

Expand Down Expand Up @@ -392,7 +398,6 @@ We recommend that users use pre-trained student models whenever possible to full

## Known Issues

* Compatibility with FP16 training has not been tested.
* Multi-GPU training support is only available through `DataParallel` currently.

## Citation
Expand Down
7 changes: 6 additions & 1 deletion README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ Paper: [https://arxiv.org/abs/2002.12620](https://arxiv.org/abs/2002.12620)

## 更新

**Jul 14, 2020**
* **版本更新至0.1.10**:
* 支持apex混合精度训练功能:可通过在`TrainingConfig`中设置`fp16=True`启用。详细设置参见`TraningConfig`的说明。
*`TrainingConfig`中增加了`data_parallel`选项,使得数据并行与混合精度训练可同时启用。

**Apr 26, 2020**

* 增加了中文NER任务(MSRA NER)上的实验结果。
Expand Down Expand Up @@ -115,6 +120,7 @@ Paper: [https://arxiv.org/abs/2002.12620](https://arxiv.org/abs/2002.12620)
* NumPy
* tqdm
* Transformers >= 2.0 (可选, Transformer相关示例需要用到)
* Apex == 0.1.0 (可选,用于混合精度训练)

### 安装方式

Expand Down Expand Up @@ -381,7 +387,6 @@ Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

## 已知问题

* FP16精度训练的兼容性尚未测试。
* 尚不支持DataParallel以外的多卡训练策略。

## 引用
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

setup(
name="textbrewer",
version="0.1.9",
version="0.1.10",
author="ziqingyang",
author_email="[email protected]",
description="PyTorch-based knowledge distillation toolkit for natural language processing",
Expand Down
2 changes: 1 addition & 1 deletion src/textbrewer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.9"
__version__ = "0.1.10"

from .distillers import BasicTrainer
from .distillers import BasicDistiller
Expand Down
10 changes: 9 additions & 1 deletion src/textbrewer/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,12 @@
if torch.__version__ < '1.2':
mask_dtype = torch.uint8
else:
mask_dtype = torch.bool
mask_dtype = torch.bool

def is_apex_available():
try:
from apex import amp
_has_apex = True
except ImportError:
_has_apex = False
return _has_apex
24 changes: 18 additions & 6 deletions src/textbrewer/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,24 @@ class TrainingConfig(Config):
ckpt_steps (int): if *num_steps* is passes to ``distiller.train()``, saves the model every **ckpt_steps**, meanwhile ignore `ckpt_frequency` and `ckpt_epoch_frequency` .
log_dir (str): directory to save the tensorboard log file. Set it to ``None`` to disable tensorboard.
output_dir (str): directory to save model weights.
device (str or torch.device) : training on CPU or GPU.
device (str or torch.device): training on CPU or GPU.
fp16 (bool): if ``True``, enables mixed precision training using Apex.
fp16_opt_level(str): Pure or mixed precision optimization level. Accepted values are "O0", "O1", "O2", and "O3". See Apex documenation for details.
data_parallel (bool): If ``True``, wraps the models with ``torch.nn.DataParallel``.
Note:
* To perform data parallel training, you could either wrap the models with ``torch.nn.DataParallel`` outside TextBrewer by yourself, or leave the work for TextBrewer by setting **data_parallel** to ``True``.
* To enable both data parallel training and mixed precision training, you should set **data_parallel** to ``True``, and DO NOT wrap the models by yourself.
* In some experiments, we have observed the slowing down in the speed with ``torch.nn.DataParallel``. In the future we will move to DistributedDataParallel.
Example::
# Usually just need to set log_dir and output_dir and leave others default
train_config = TrainingConfig(log_dir=my_log_dir, output_dir=my_output_dir)
# Stores model at the end of each epoch
# Stores the model at the end of each epoch
train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=1)
# Stores model twice (at the middle and at the end) in each epoch
# Stores the model twice (at the middle and at the end) in each epoch
train_config = TrainingConfig(ckpt_frequency=2, ckpt_epoch_frequency=1)
# Stores model once every two epochs
# Stores the model once every two epochs
train_config = TrainingConfig(ckpt_frequency=1, ckpt_epoch_frequency=2)
"""
Expand All @@ -64,7 +70,10 @@ def __init__(self,gradient_accumulation_steps = 1,
ckpt_steps = None,
log_dir = None,
output_dir = './saved_models',
device = 'cuda'
device = 'cuda',
fp16 = False,
fp16_opt_level = 'O1',
data_parallel = False
):
super(TrainingConfig, self).__init__()

Expand All @@ -75,6 +84,9 @@ def __init__(self,gradient_accumulation_steps = 1,
self.log_dir = log_dir
self.output_dir = output_dir
self.device = device
self.fp16 = fp16
self.fp16_opt_level = fp16_opt_level
self.data_parallel = data_parallel

if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
Expand Down
3 changes: 1 addition & 2 deletions src/textbrewer/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def __init__(self, train_config,

def save_and_callback(self,global_step, step, epoch, callback):
logger.info(f"Saving at global step {global_step}, epoch step {step + 1} epoch {epoch+1}")
coreModel = self.model_S.module if \
'DataParallel' in self.model_S.__class__.__name__ else self.model_S
coreModel = self.model.module if hasattr(self.model, "module") else self
state_dict = coreModel.state_dict()
torch.save(state_dict, os.path.join(self.t_config.output_dir, f"gs{global_step}.pkl"))
if callback is not None:
Expand Down
42 changes: 36 additions & 6 deletions src/textbrewer/distiller_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def __init__(self, train_config,

def save_and_callback(self,global_step, step, epoch, callback):
logger.info(f"Saving at global step {global_step}, epoch step {step + 1} epoch {epoch+1}")
coreModel = self.model_S.module if \
'DataParallel' in self.model_S.__class__.__name__ else self.model_S
coreModel = self.model_S.module if hasattr(self.model_S, "module") else self.model_S
state_dict = coreModel.state_dict()
torch.save(state_dict, os.path.join(self.t_config.output_dir, f"gs{global_step}.pkl"))
if callback is not None:
Expand Down Expand Up @@ -77,6 +76,23 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
# overwrite scheduler
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)

if self.t_config.fp16:
if not has_apex:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
if isinstance(self.model_T,(list,tuple)):
models = [self.model_S] + list(self.model_T)
models, optimizer = amp.initialize(models, optimizer, opt_level=self.t_config.fp16_opt_level)
self.model_S = models[0]
self.model_T =models[1:]
else:
(self.model_S, self.model_T), optimizer = amp.initialize([self.model_S, self.model_T], optimizer, opt_level=self.t_config.fp16_opt_level)
if self.t_config.data_parallel:
self.model_S = torch.nn.DataParallel(self.model_S)
if isinstance(self.model_T,(list,tuple)):
self.model_T = [torch.nn.DataParallel(model_t) for model_t in self.model_T]
else:
self.model_T = torch.nn.DataParallel(self.model_T)

if num_steps is not None:
if self.d_config.is_caching_logits is True:
logger.warning("is_caching_logits is True, but num_steps is not None!")
Expand All @@ -96,14 +112,21 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
batch = batch_postprocessor(batch)
total_loss = self.train_on_batch(batch,args)
total_loss /= self.t_config.gradient_accumulation_steps
total_loss.backward()
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()

self.write_loss(total_loss, writer_step)
writer_step += 1

if (step+1)%self.t_config.gradient_accumulation_steps == 0:
if max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
Expand Down Expand Up @@ -153,14 +176,21 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
batch = batch_postprocessor(batch)
total_loss = self.train_on_batch(batch,args)
total_loss /= self.t_config.gradient_accumulation_steps
total_loss.backward()
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()

self.write_loss(total_loss, writer_step)
writer_step += 1

if (step+1)%self.t_config.gradient_accumulation_steps == 0:
if max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
Expand Down
24 changes: 21 additions & 3 deletions src/textbrewer/distiller_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def train(self, optimizer, dataloaders, num_steps, scheduler_class=None, schedul
# overwrite scheduler
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)

if self.t_config.fp16:
if not has_apex:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
tasknames, model_Ts = zip(*self.model_T.items())
models = [self.model_S] + list(model_Ts)
models, optimizer = amp.initialize(models, optimizer, opt_level=self.t_config.fp16_opt_level)
self.model_S = models[0]
self.model_T = dict(zip(tasknames,models[1:]))
if self.t_config.data_parallel:
self.model_S = torch.nn.DataParallel(self.model_S)
self.model_T = {k:torch.nn.DataParallel(v) for k,v in self.model_T.items()}

total_global_steps = num_steps
ckpt_steps =self.t_config.ckpt_steps
print_every = ckpt_steps // self.print_freq
Expand Down Expand Up @@ -93,13 +105,19 @@ def train(self, optimizer, dataloaders, num_steps, scheduler_class=None, schedul
batch_taskname = (batch, taskname)
total_loss = self.train_on_batch(batch_taskname, args)
total_loss /= self.t_config.gradient_accumulation_steps
total_loss.backward()

if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
writer_step += 1
if max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model_S.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
Expand Down
37 changes: 29 additions & 8 deletions src/textbrewer/distiller_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
# overwrite scheduler
scheduler = scheduler_class(**{'optimizer':optimizer},**scheduler_args)

if self.t_config.fp16:
if not has_apex:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=self.t_config.fp16_opt_level)

#dataparallel multi-gpu training
if self.t_config.data_parallel:
self.model = torch.nn.DataParallel(self.model)

if num_steps is not None:
total_global_steps = num_steps
ckpt_steps =self.t_config.ckpt_steps
Expand All @@ -58,15 +67,22 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
batch = batch_postprocessor(batch)
total_loss = self.train_on_batch(batch,args)
total_loss /= self.t_config.gradient_accumulation_steps
total_loss.backward()
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()

scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
writer_step += 1

if (step+1)%self.t_config.gradient_accumulation_steps == 0:
if max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
Expand All @@ -76,8 +92,7 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
logger.info(f"Global step: {global_step}, epoch step:{step+1}")
if (global_step%ckpt_steps==0) or global_step==total_global_steps:
logger.info(f"Saving at global step {global_step}")
coreModel = self.model.module if \
'DataParallel' in self.model.__class__.__name__ else self.model
coreModel = self.model.module if hasattr(self.model, "module") else self.model
state_dict = coreModel.state_dict()
torch.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pkl"))
if callback is not None:
Expand Down Expand Up @@ -106,15 +121,22 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
batch = batch_postprocessor(batch)
total_loss = self.train_on_batch(batch,args)
total_loss /= self.t_config.gradient_accumulation_steps
total_loss.backward()
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()

scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
writer_step += 1

if (step+1)%self.t_config.gradient_accumulation_steps == 0:
if max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
optimizer.step()
if scheduler is not None:
scheduler.step()
Expand All @@ -125,8 +147,7 @@ def train(self, optimizer, dataloader, num_epochs, scheduler_class=None, schedul
if (global_step%train_steps_per_epoch in checkpoints) \
and ((current_epoch+1)%self.t_config.ckpt_epoch_frequency==0 or current_epoch+1==num_epochs):
logger.info(f"Saving at global step {global_step}, epoch step {step+1} epoch {current_epoch+1}")
coreModel = self.model.module if \
'DataParallel' in self.model.__class__.__name__ else self.model
coreModel = self.model.module if hasattr(self.model, "module") else self.model
state_dict = coreModel.state_dict()
torch.save(state_dict, os.path.join(self.t_config.output_dir,f"gs{global_step}.pkl"))
if callback is not None:
Expand Down
7 changes: 6 additions & 1 deletion src/textbrewer/distiller_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from .presets import *
from .configurations import TrainingConfig, DistillationConfig
import random
from .compatibility import mask_dtype
from .compatibility import mask_dtype, is_apex_available

has_apex = is_apex_available()
if has_apex:
from apex import amp


logger = logging.getLogger("Distillation")
#logger.setLevel(logging.INFO)
Expand Down

0 comments on commit d69a96e

Please sign in to comment.