Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix full_op bug for gradient merge #68391

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions python/paddle/distributed/passes/auto_parallel_gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,18 @@ def _pir_append_gradient_merge_backward_op(
paddle.pir.set_insertion_point_after(opt_op)
allreduce_sum_out = opt_op.result(0)

scale = paddle.full([], 0.5)
scale_out = paddle._C_ops.scale_(
allreduce_sum_out, scale, 0.0, False
allreduce_sum_out, 0.5, 0.0, False
)

scale.get_defining_op().op_role = int(OpRole.Optimize)
scale_out.get_defining_op().op_role = int(OpRole.Optimize)
scale_op = scale_out.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)

# reset gradient merge var to zero after finishing optimization
paddle.pir.set_insertion_point_to_block_end(main_block)
Expand Down Expand Up @@ -680,10 +685,16 @@ def _pir_parse_program(
paddle.pir.set_insertion_point_after(op)
break
for _, new_grad in new_params_to_grads:
scale = paddle.full([], 1.0 / k_steps)
new_grad = paddle._C_ops.scale_(new_grad, scale, 0.0, False)
new_grad.get_defining_op().op_role = int(OpRole.Optimize)
scale.get_defining_op().op_role = int(OpRole.Optimize)
new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False)

scale_op = new_grad.get_defining_op()
scale_op.op_role = int(OpRole.Optimize)

full_op = scale_op.operand_source(1).get_defining_op()
assert (
full_op.name() == "pd_op.full"
), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}"
full_op.op_role = int(OpRole.Optimize)


@register_pass("auto_parallel_gradient_merge_pass")
Expand Down