-
Notifications
You must be signed in to change notification settings - Fork 11
/
pretrain_glm.py
194 lines (155 loc) · 6.89 KB
/
pretrain_glm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GLM"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt_dataset import build_dataset_group
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
from megatron.utils import average_losses_across_data_parallel_group
from glm import build_train_valid_test_datasets, build_mask_matrix
import deepspeed
from deepspeed.runtime.utils import see_memory_usage
import os
try:
from torch.distributed.elastic.multiprocessing.errors import record
except ImportError:
# noop
def record(fn):
return fn
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
see_memory_usage(f"Before Building Model", force=True)
args = get_args()
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed and args.pipeline_model_parallel_size > 1:
# GLM has dynamic attn_mask, so don't set args.attn_mask
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
see_memory_usage(f"After Building Model", force=True)
return model
def process_data(data):
args = get_args()
# Items and their type.
keys = ['text', 'loss_mask', 'target', 'attention_mask', 'position_id']
if not args.finetune and args.deepspeed:
keys.append('task_type')
datatype = torch.int64
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
labels = data_b['target'].long()
attention_mask = data_b['attention_mask'].long()
loss_mask = data_b['loss_mask'].float()
position_ids = data_b['position_id'].long()
task_type = data_b['task_type'].long() if 'task_type' in data_b else None
attention_mask = build_mask_matrix(attention_mask, tokens.size(0),
tokens.size(1))
attention_mask = attention_mask.to(torch.bool)
if task_type is None:
return tokens, labels, loss_mask, attention_mask, position_ids
return tokens, labels, loss_mask, attention_mask, position_ids, task_type
def get_batch(data_iterator):
"""Generate a batch"""
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
return process_data(data)
def get_batch_pipe(data):
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
args = get_args()
tokens, labels, loss_mask, attention_mask, position_ids, task_type = process_data(data)
if args.curriculum_learning and args.curriculum_seqlen < tokens.size()[1]:
# seqlen-based curriculum learning
# tokens, position_ids, labels, loss_mask have size [batch size, seqlen]
tokens = tokens[:, :args.curriculum_seqlen].contiguous()
position_ids = position_ids[:, :args.curriculum_seqlen].contiguous()
labels = labels[:, :args.curriculum_seqlen].contiguous()
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return (tokens, position_ids, attention_mask), (labels, loss_mask, task_type)
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids, task_type = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
if args.curriculum_learning and args.curriculum_seqlen < args.seq_length:
loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None
print_rank_0('> building train, validation, and test datasets for GPT ...')
# Option 1 of data loading using --data-path
if args.data_path or args.multitask_data_path:
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
length_per_sample=args.length_per_sample,
aggregated_samples_per_sequence=args.aggregated_samples_per_sequence,
args=args)
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
elif args.train_weighted_split_paths:
raise NotImplementedError()
else:
raise NotImplementedError("No dataloading argument passed")
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
@record
def main():
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'IceTokenizer'})
if __name__ == "__main__":
main()