-
Notifications
You must be signed in to change notification settings - Fork 6
/
main_training.py
711 lines (552 loc) · 21.3 KB
/
main_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
# main hq file for t5 training and prediction
import os
import argparse
from datasets_grammar.grammar_dataset import grammar
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# for grammar correction
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# for generation
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import DataCollatorForSeq2Seq
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
CPUOffload,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
from policies import mixed_precision
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.data import DataLoader
import performance
from ChildTuningOptimizer import ChildTuningAdamW
from sklearn.model_selection import train_test_split
import time
from datetime import datetime
# local imports
import verify
import policies
import datasets_grammar as dg
import tqdm
import numpy as np
from statistics import stdev
# config
import config
import model_checkpoints
from collections import deque
from madgrad import MirrorMADGRAD as mirror
from utils.calculations_utils import calc_flop
# some globals
g_gigabyte = 1024**3
def _is_rank_0():
return 0 == os.getenv("RANK")
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def parse_args():
parser = argparse.ArgumentParser(description="PyTorch fsdp T5.11 Example")
"""parser.add_argument("--save-dir", default="/model_chkpt", type=str)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
"""
args = parser.parse_args()
return args
# ---------------- Main functions --------------------
def get_policies(cfg, fsdp_unit_params=1000000):
"""establish current policies for mixed precision and fsdp wrapping"""
mixed_precision_policy = None
wrapping_policy = None
# mixed precision -----
if cfg.use_mixed_precision:
bf16_ready = verify.bf16_ready
if bf16_ready:
mixed_precision_policy = policies.bfSixteen
print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
else:
# mixed_precision_policy = policies.fpSixteen
print(f"bFloat16 support not present. Not using for mixed precision")
wrapping_policy = policies.get_t5_wrapper()
# wrapping_policy = policies. get_size_policy(10e8)
return mixed_precision_policy, wrapping_policy
def setup(rank, world_size, cfg):
# os.environ["MASTER_ADDR"] = g_addr
# os.environ["MASTER_PORT"] = cfg.host_port
# initialize the process group
dist.init_process_group("nccl") # , rank=rank, world_size=world_size)
def setup_environ_flags(cfg, rank):
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
if cfg.nccl_debug_handler:
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
if cfg.distributed_debug:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
if rank == 0:
print(f"--> running with torch dist debug set to detail")
def cleanup():
dist.destroy_process_group()
def clear_gpu_cache(rank=None):
print(f"clearing cache for rank {rank}")
torch.cuda.empty_cache()
def setup_tasks(rank, world_size, cfg):
"""keep the basic setup list here"""
setup(rank, world_size, cfg)
# clear_gpu_cache() - need to call torch set device first?
# set_printing()
setup_environ_flags(cfg, rank)
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
def format_stats(item, rounding=8):
return round(item, ndigits=rounding)
# ---------- Training ----------------------------------------------------------
def train(
args,
model,
local_rank,
rank,
world_size,
train_loader,
optimizer,
epoch,
sampler=None,
profiler=None,
):
model.train()
ddp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
labels=batch["target_ids"],
)
loss = output["loss"]
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(batch)
if rank == 0:
inner_pbar.update(1)
if profiler:
profiler.step()
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
train_accuracy = ddp_loss[0] / ddp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}")
return train_accuracy
# ---- Validation ---------------
def validation(model, local_rank, rank, world_size, test_loader):
model.eval()
correct = 0
ddp_loss = torch.zeros(3).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(test_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in test_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
labels=batch["target_ids"],
)
ddp_loss[0] += output["loss"].item() # sum up batch loss
ddp_loss[1] += len(batch)
if rank == 0:
inner_pbar.update(1)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
val_loss = ddp_loss[0] / ddp_loss[1]
if rank == 0:
# test_loss = ddp_loss[0] / ddp_loss[1]
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
def sync_all_device():
# setup() has already configured CUDA_VISIBLE_DEVICES such that each
# process exclusively works on its own set of devices. So it's safe to
# do device sync here
for d in range(torch.cuda.device_count()):
torch.cuda.synchronize(d)
# ---- fsdp main ------------------------------------------------------------
def fsdp_main(args):
"""main process, within each rank process"""
cfg = config.train_config() # loads from defaults
torch.cuda.manual_seed(cfg.seed)
torch.manual_seed(cfg.seed)
# torchrun specific
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if rank == 0:
print(f"--> running with these defaults {cfg}")
time_of_run = get_date_of_run()
print(f"\n**--> Torch Version = {torch.__version__}\n")
setup_tasks(rank, world_size, cfg)
fsdp_unit_params = cfg.fsdp_unit_size
batch_size = cfg.batch_size
if rank == 0:
print(f"\n BatchSize = {batch_size}\n")
val_batch_size = cfg.val_batch_size
mp_policy, wrapping_policy = get_policies(cfg, fsdp_unit_params)
model_name = cfg.model_name # "google/t5-v1_1-small" # #
if rank == 0:
print(f"--> training for model {model_name}")
printable_model_name = str.replace(model_name, "/", "=")
file_save_name = "ModelCheckpoint-" # printable_model_name + "-"
# t5-base
# google/t5-v1_1-small
# google/t5-v1_1-base
# google/t5-v1_1-large
# google/t5-v1_1-xl #3b
# google/t5-v1_1-xxl #11b
# grammar correction
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer, model_max_length=512)
hf_cache = True
if cfg.hf_activation_checkpointing:
hf_cache = False
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_cache=hf_cache)
# summarization
# model = T5ForConditionalGeneration.from_pretrained(model_name)
# tokenizer = T5Tokenizer.from_pretrained(model_name)
# dataset_name = "jfleg_train.csv"
if rank == 0:
print(f"--> Training for {model_name}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> {model_name} has {total_params/1e6} Million params\n")
# ____________ create batch dataset
train_name = None
if cfg.dataset_train:
train_name = cfg.dataset_train
train_dataset = dg.get_dataset(tokenizer, train_name, 512, 512, True)
if 0 == os.getenv("RANK"):
print(f"--> Training Set Len = {len(train_dataset)}")
print(f"using dataset {train_name}")
# print("bailing")
val_dataset = dg.get_dataset(tokenizer, cfg.dataset_test, 512, 512, True)
if 0 == os.getenv("RANK"):
print(f"--> Validation set len = {len(val_dataset)}")
print(f"using dataset {cfg.dataset_test}")
sampler1 = DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=True
)
sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)
print(f"batch size = {batch_size}")
train_kwargs = {"batch_size": batch_size, "sampler": sampler1}
test_kwargs = {"batch_size": val_batch_size, "sampler": sampler2}
cuda_kwargs = {
"num_workers": cfg.num_workers_dataloader,
"pin_memory": False,
"shuffle": False,
}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
if cfg.hf_activation_checkpointing and not cfg.fsdp_activation_checkpointing:
model.gradient_checkpointing_enable()
if rank == 0:
print(f"HF Activation checkpointing enabled\n")
model_config = model.config
embedding_size = (model.state_dict()["shared.weight"].shape)[1]
FLOP = calc_flop(cfg, model_config, cfg.model_max_length, embedding_size)
# --- sharding policy
model_sharding_strategy = (
cfg.sharding_strategy or ShardingStrategy.FULL_SHARD
) # use config, but default to normal if not available
if rank == 0:
print(f"Sharding strategy = {model_sharding_strategy}")
backward_policy = cfg.backward_policy
if rank == 0:
print(f"additional settings: ")
print(f"Backward Policy = {backward_policy}")
print(f"Using Rate Limiter = {cfg.use_rate_limiter}")
if cfg.model_in_bf16:
model.to(torch.bfloat16)
mp_policy = None
if rank == 0:
print(f"Model in BF16, all training in BF16")
sub_group = None
if model_sharding_strategy == ShardingStrategy.HYBRID_SHARD:
subgroup, _ = dist.new_subgroups()
if rank == 0:
print(f"--> HSDP active - subgroup created - {subgroup=}")
model = FSDP(
model,
auto_wrap_policy=wrapping_policy,
mixed_precision=mp_policy,
sharding_strategy=model_sharding_strategy,
backward_prefetch=backward_policy,
device_id=torch.cuda.current_device(), # streaming init
limit_all_gathers=cfg.use_rate_limiter,
)
# initializaing memory stat tracker
if rank == 0:
memmax = performance.Memory_Maximizer()
# fsdp must do the checkpointing after sharding...
if cfg.fsdp_activation_checkpointing:
policies.apply_fsdp_checkpointing(model)
if rank == 0:
print(f"--> fsdp activation checkpointing enabled...")
if cfg.fsdp_activation_checkpointing and cfg.hf_activation_checkpointing:
print(
f"*** Bad config - both hf and fsdp checkpointing enabled. Must be mutually exclusive...aborting"
)
return
if rank == 0 and cfg.print_sharding_plan:
print(f"model ")
fn = printable_model_name + "-sharded_layout.txt"
with open(fn, "w") as external_file:
header_text = (
f"model = {model_name}, sharded with {fsdp_unit_params} parameters\n"
)
print(header_text, file=external_file)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
milli_params = total_params * 4 / 1e6
print(
f"\n--> {model_name} has {milli_params} Million params\n",
file=external_file,
)
print(f"model wrapping = \n{model}\n\n", file=external_file)
external_file.close()
lr = 0.0008
gamma = 0.85
weight_decay = 0.005
if cfg.optimizer_type == "int8":
from anyprecision_quant import AnyPrecisionAdamW
optimizer = AnyPrecisionAdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
momentum_dtype=torch.float32,
variance_dtype=torch.uint8,
use_kahan_summation=False,
)
if rank == 0:
print(f"--> AnyPrecision INT8 optimizer running, variance_type = UINT8 bq")
elif cfg.optimizer_type == "anyprecision" or cfg.optimizer_type == "AnyPrecision":
from anyprecision_optimizer import AnyPrecisionAdamW
optimizer = AnyPrecisionAdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
momentum_dtype=cfg.momentum_dtype,
variance_dtype=cfg.variance_dtype,
use_kahan_summation=cfg.use_kahan,
)
if rank == 0:
print(
f"--> AnyPrecision optimizer running, momentum = {cfg.momentum_dtype}, variance type = {cfg.variance_dtype}, kahan = {cfg.use_kahan} "
)
elif cfg.optimizer_type == "childtuning":
if cfg.use_task_free:
optimizer = ChildTuningAdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
reserve_p=cfg.percent_F,
mode="taskfree",
)
if rank == 0:
print(f"--> child free tuning with {cfg.percent_F} percentage ")
else:
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
if rank == 0:
print(f"--> AdamW whole model tuning with AdamW")
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
epochs = cfg.num_epochs
if rank == 0:
print(f"Training for {epochs} epochs")
best_train_accuracy = float("-inf")
best_val_loss = float("inf")
curr_val_loss = float("inf")
# --- main training loop - todo, this needs to be modularized
if rank == 0:
dur = []
train_acc_tracking = []
val_acc_tracking = []
dq = deque(maxlen=cfg.checkpoint_max_save_count + 1)
memmax.start()
training_start_time = time.time()
torch_profiler = None
"""with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
"fsdp_v100/profile_traces"
),
profile_memory=True,
with_stack=False,
record_shapes=True,
) as torch_profiler:
"""
if rank == 0 and cfg.track_memory:
fn = cfg.model_name + "memory_tracking.txt"
mem_alloc_tracker = []
mem_reserved_tracker = []
start_training_time = time.time()
for epoch in range(1, epochs + 1):
if rank == 0:
print(f"\n--> Starting Epoch {epoch}")
t0 = time.time()
train_accuracy = train(
args,
model,
local_rank,
rank,
world_size,
train_loader,
optimizer,
epoch,
sampler=sampler1,
profiler=torch_profiler,
)
if rank == 0:
memmax.update()
if cfg.run_validation:
curr_val_loss = validation(model, local_rank, rank, world_size, test_loader)
scheduler.step()
if rank == 0:
print(f"--> epoch {epoch} completed...entering save and stats zone")
total_epoch_time = time.time() - t0
print(f"epoch_time = {total_epoch_time}")
dur.append(time.time() - t0)
train_acc_tracking.append(train_accuracy.item())
if cfg.run_validation:
val_acc_tracking.append(curr_val_loss.item())
if cfg.track_memory:
mem_alloc_tracker.append(
format_metrics_to_gb(torch.cuda.memory_allocated())
)
mem_reserved_tracker.append(
format_metrics_to_gb(torch.cuda.memory_reserved())
)
net_flops = format_stats(FLOP / 10**12 / total_epoch_time, rounding=6)
print(f"TFLOP/s/GPU: {net_flops}")
if cfg.save_model and curr_val_loss < best_val_loss:
# update curr best val accuracy
# save
if rank == 0:
print(f"--> entering save model state...")
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, save_policy
):
cpu_state = model.state_dict()
# states = model.state_dict()
print(f"saving process: rank {rank} done w state_dict")
if rank == 0:
print(f"--> saving model ...")
currEpoch = (
"-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
)
save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
torch.save(cpu_state, save_name)
print(f"--> saved {save_name} to disk")
dq.append(save_name)
# only keep a rolling number of model files to avoid excessive disk space use
model_checkpoints.prune_checkpoints(rank, dq, cfg)
# announce new val loss record:
if rank == 0 and curr_val_loss < best_val_loss:
best_val_loss = curr_val_loss
print(f"-->>>> New Val Loss Record: {best_val_loss}")
sync_all_device()
end_training_time = time.time()
delays = [None for _ in range(world_size)]
torch.distributed.all_gather_object(
delays, (end_training_time - start_training_time) / epochs
)
for i, item in enumerate(delays):
delays[i] = round(item, 4)
if rank == 0:
print("Flops cnt and delays", FLOP, delays)
gflops_gpu = FLOP / 10**9 * np.reciprocal(np.array(delays))
tflops_gpu = FLOP / 10**12 * np.reciprocal(np.array(delays))
print(f"gflops per gpu={gflops_gpu}")
# init_end_event.record()
if rank == 0:
# inner_pbar.close()
total_training_time = time.time() - training_start_time
print(f"Total training time = {total_training_time:.2f}")
print("Times per epoch:")
for i, val in enumerate(dur):
print(f"epoch {i}, time {val:.2f}")
print()
# memory
if cfg.track_memory:
print(f"total memory reserved: {mem_reserved_tracker}")
print(f"total memory allocated: {mem_alloc_tracker}")
memmax.stop()
print(f"Training accuracy: {train_acc_tracking}")
if cfg.run_validation:
print(f"Validation accuracy: {val_acc_tracking}")
print(f"\n Best Val accuracy: {best_val_loss}")
print(f" Settings again ===")
print(f"Backward Policy = {backward_policy}")
print(f"Using Rate Limiter = {cfg.use_rate_limiter}")
if cfg.use_rate_limiter:
print(f"Rate Limit = {cfg.inflight_max}\n")
print(f"Batch size = {cfg.batch_size}")
if rank == 0:
# print("LEN Tflops",len(tflops_gpu), sum(tflops_gpu), tflops_gpu)
print(
f"gflops/gpu = {sum(gflops_gpu) / len(gflops_gpu):.2f} ({stdev(gflops_gpu):.2f})\n"
f"Tflops/gpu = {sum(tflops_gpu) / len(tflops_gpu):.2f} ({stdev(tflops_gpu):.2f})\n"
)
# memory summary
if cfg.memory_report and rank == 0:
print(
f"CUDA Memory Summary After Last training:\n {torch.cuda.memory_summary()}"
)
dist.barrier()
cleanup()
# ------------------ Main functions above ------------
if __name__ == "__main__":
args = parse_args()
gpus_per_node = torch.cuda.device_count()
# torch run start
fsdp_main(args)
# cache workaround