-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【AutoParallel】Add c_embedding pass in PIR #68389
Conversation
Sorry to inform you that 7d43f21's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add reasonable comments to pass
super().__init__() | ||
|
||
def _check_self(self): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returned in advance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, the code has been modified.
# update weight dims mapping | ||
mp_axis = update_weight(op, concrete_program) | ||
|
||
# update startup_program |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meaningless comments. delete line 334, 347, 350, 353...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, the code has been modified.
SpmdInfo CEmbeddingGradInferSpmdBase(const DistMetaTensor& weight, | ||
const DistMetaTensor& x, | ||
const DistMetaTensor& out_grad, | ||
int64_t start_index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant functions CEmbeddingGradInferSpmdBase
and CEmbeddingInferSpmdBase
<< str_join(out_grad_dst.dist_attr().dims_mapping()) << "]\n" | ||
<< "Output w_grad shape: [" << str_join(phi::vectorize(w_grad.dims())) | ||
<< "], dims_mapping: [" << str_join(w_grad.dist_attr().dims_mapping()) | ||
<< "]\n\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace line222-240 with functions LOG_SPMD_INPUT
and LOG_SPMD_OUTPUT
<< str_join(weight_dims_mapping) << "], dst_dims_mapping: [" | ||
<< str_join(weight_dist_attr_dst.dims_mapping()) | ||
<< "]\n Out dims_mapping: [" << str_join(out_dims_mapping) | ||
<< "], partial_on_dims: [" << str_join(partial_on_dims) << "]\n\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace line121-130 with functions LogInputDistAttr
and LogOutputDistAttr
// w_grad = einsum('...j, ...k->jk', onehot(x, j), out_grad) | ||
|
||
// TODO(cxxly): Simplifies the code logic of sharding propagation using | ||
// primitive operators. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why TODO**(cxxly)**
) | ||
np.testing.assert_allclose( | ||
dy2static_losses_use_pass, dy_losses, atol=1e-7 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Require equal md5 values
- Check the program contains
c_embedding
op when using pass
{"dtype": "float32", "seed": "2024"}, {"backend": ["gpu"]} | ||
) | ||
for envs in envs_list: | ||
# self._log_dir.name = "./log" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete redundant code
x_tensor_dist_attr = TensorDistAttr() | ||
x_tensor_dist_attr.process_mesh = ( | ||
process_mesh # not set the dims mapping is ok. | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L104-L106 -> x_tensor_dist_attr.process_mesh = process_mesh
table_tensor_dist_attr = TensorDistAttr() | ||
table_tensor_dist_attr.process_mesh = ( | ||
process_mesh # not set the dims mapping is ok. | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
table_tensor_dist_attr.process_mesh = process_mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, the code has been modified according to the suggestions above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
New features
Description
add c_embedding pass in pir
Pcard-70448