Skip to content

Commit

Permalink
implement cos gradient as a function op (#13227)
Browse files Browse the repository at this point in the history
### Description
Implemented gradient of cos as per the function below.

![image](https://user-images.githubusercontent.com/31260940/193900310-b62a3e77-06d5-45af-ad28-a1d41920bad0.png)

### Motivation and Context
Cos gradient required for [huggingface's diffusers
library](https://github.com/huggingface/diffusers)

### Testing
built ORT from source: `./build.sh --config RelWithDebInfo
--enable_training --use_cuda --cuda_home /usr/local/cuda --cudnn_home
/usr/local/cuda --build_wheel --parallel --skip_tests`
tested CosGrad implementation: `cd build/Linux/RelWithDebInfo/ &&
./onnxruntime_test_all --gtest_filter=GradientCheckerTest.CosGrad`

Co-authored-by: Prathik Rao <[email protected]>
  • Loading branch information
prathikr and Prathik Rao authored Oct 11, 2022
1 parent 05acd20 commit 93e0a15
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 0 deletions.
11 changes: 11 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetSinGradient) {
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetCosGradient) {
std::vector<NodeDef> result;
NodeDef zero_constant_node = ZeroConstantNode(IElemType(0));
ArgDef zero = zero_constant_node.output_args[0];
result.push_back(zero_constant_node);
result.push_back(NodeDef("Sin", {I(0)}, {IA("Sin_O0")}));
result.push_back(NodeDef("Sub", {zero, IA("Sin_O0")}, {IA("NegSin_O0")}));
result.push_back(NodeDef("Mul", {GO(0), IA("NegSin_O0")}, {GI(0)}));
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetLogGradient) {
return std::vector<NodeDef>{
NodeDef("Div",
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace training {

DECLARE_GRADIENT_BUILDER(GetCastGradient)
DECLARE_GRADIENT_BUILDER(GetSinGradient)
DECLARE_GRADIENT_BUILDER(GetCosGradient)
DECLARE_GRADIENT_BUILDER(GetLogGradient)
DECLARE_GRADIENT_BUILDER(GetTanhGradient)
DECLARE_GRADIENT_BUILDER(GetSqrtGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
// Register gradient builders here.
REGISTER_GRADIENT_BUILDER("Cast", GetCastGradient);
REGISTER_GRADIENT_BUILDER("Sin", GetSinGradient);
REGISTER_GRADIENT_BUILDER("Cos", GetCosGradient);
REGISTER_GRADIENT_BUILDER("Log", GetLogGradient);
REGISTER_GRADIENT_BUILDER("Tanh", GetTanhGradient);
REGISTER_GRADIENT_BUILDER("Sqrt", GetSqrtGradient);
Expand Down
2 changes: 2 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ TEST(GradientCheckerTest, MatMulGrad) {

TEST(GradientCheckerTest, SinGrad) { UnaryOpGradientTest("Sin"); }

TEST(GradientCheckerTest, CosGrad) { UnaryOpGradientTest("Cos"); }

TEST(GradientCheckerTest, NegGrad) { UnaryOpGradientTest("Neg"); }

TEST(GradientCheckerTest, AbsGrad) {
Expand Down

0 comments on commit 93e0a15

Please sign in to comment.