Skip to content

Commit

Permalink
Merge pull request #192 from CarryFun/fanruikai/doc_update
Browse files Browse the repository at this point in the history
[WIP] Update doc and notes for BMTrain.
  • Loading branch information
MayDomine committed Jun 11, 2024
2 parents 22a42af + 3d7d7d9 commit 3d35a26
Show file tree
Hide file tree
Showing 56 changed files with 2,590 additions and 1,106 deletions.
315 changes: 207 additions & 108 deletions bmtrain/block_layer.py

Large diffs are not rendered by default.

32 changes: 23 additions & 9 deletions bmtrain/hook_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,70 @@
from .global_var import config
from .zero_context import ZeroContext


def zero_pre_forward(module, inputs):
"""Helper function for using ZeroContext to gather parmas before forward."""
enter = True
pipe = False
if module._mode == "PIPE":
enter = module._micro_idx == 0
pipe = True
if enter:
zero_level = module._zero_level
zero_level = module._zero_level
forward_flag = 1 if zero_level == 2 else 0
if zero_level == 2 and not module._need_release:
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: #only forward
forward_flag = 2 # repeating forward in same layer
if module.all_param_no_grad: # only forward
forward_flag = 0
module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe)
module._forward_block_ctx.enter(forward_flag)


def zero_post_forward(module, inputs, outputs):
"""Helper function for module _forwar_block_ctx weather exits after forward."""
forward_flag = 1 if module._zero_level == 2 else 0
if module.all_param_no_grad:
forward_flag = 0
exit = True
if module._mode == "PIPE":
exit = module._micro_idx == config['micros'] - 1
exit = module._micro_idx == config["micros"] - 1

if exit:
module._forward_block_ctx.exit(forward_flag)


def zero_pre_backward(module, grad_outputs):
"""Helper function for using ZeroContext to init grad buffer before backward."""
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
module._backward_block_ctx = ZeroContext(module, module._layer_dict)
module._backward_block_ctx.enter(backward_flag, True)
module.release_next_module(backward_flag)
else:
if module._micro_idx == config['micros'] - 1:
module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True)
if module._micro_idx == config["micros"] - 1:
module._backward_block_ctx = ZeroContext(
module, module._layer_dict, pipe=True
)
module._backward_block_ctx.enter(backward_flag, True)


def zero_post_backward(module, grad_inputs, grad_outputs):
"""Helper function for module weather release after backward."""
backward_flag = 2 if module._zero_level == 2 else 0
if module._mode != "PIPE":
if module._is_first_layer:
if module._is_first_layer:
module.release(backward_flag)
else:
if module._micro_idx == 0:
module.release(backward_flag)
module._micro_idx -= 1


class OneStepNoGradFunc(torch.autograd.Function):
"""
requires_grad = False for all inputs
Requires_grad = False for all inputs.
"""

@staticmethod
def forward(ctx, module, placeholder, *x):
ctx.x = x
Expand All @@ -80,7 +92,8 @@ def backward(ctx, grads):
grads = []
for _ in x:
grads.append(None)
return None, None, *grads
return None, None, *grads


class PreHookFunc(torch.autograd.Function):
@staticmethod
Expand All @@ -94,6 +107,7 @@ def backward(ctx, *grads):
zero_post_backward(ctx.module, grads, None)
return None, *grads


class PostHookFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, module, *out):
Expand Down
Loading

0 comments on commit 3d35a26

Please sign in to comment.