From 93e0a151177ad8222c2c95f814342bfa27f0a64d Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 11 Oct 2022 10:11:19 -0700 Subject: [PATCH] implement cos gradient as a function op (#13227) ### Description Implemented gradient of cos as per the function below. ![image](https://user-images.githubusercontent.com/31260940/193900310-b62a3e77-06d5-45af-ad28-a1d41920bad0.png) ### Motivation and Context Cos gradient required for [huggingface's diffusers library](https://github.com/huggingface/diffusers) ### Testing built ORT from source: `./build.sh --config RelWithDebInfo --enable_training --use_cuda --cuda_home /usr/local/cuda --cudnn_home /usr/local/cuda --build_wheel --parallel --skip_tests` tested CosGrad implementation: `cd build/Linux/RelWithDebInfo/ && ./onnxruntime_test_all --gtest_filter=GradientCheckerTest.CosGrad` Co-authored-by: Prathik Rao --- .../orttraining/core/graph/gradient_builder.cc | 11 +++++++++++ orttraining/orttraining/core/graph/gradient_builder.h | 1 + .../core/graph/gradient_builder_registry.cc | 1 + .../orttraining/test/gradient/gradient_ops_test.cc | 2 ++ 4 files changed, 15 insertions(+) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index cad54acc2d185..3242893bc89da 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -82,6 +82,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetSinGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetCosGradient) { + std::vector result; + NodeDef zero_constant_node = ZeroConstantNode(IElemType(0)); + ArgDef zero = zero_constant_node.output_args[0]; + result.push_back(zero_constant_node); + result.push_back(NodeDef("Sin", {I(0)}, {IA("Sin_O0")})); + result.push_back(NodeDef("Sub", {zero, IA("Sin_O0")}, {IA("NegSin_O0")})); + result.push_back(NodeDef("Mul", {GO(0), IA("NegSin_O0")}, {GI(0)})); + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetLogGradient) { return std::vector{ NodeDef("Div", diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 9edccb02cbb5a..174c650201361 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -14,6 +14,7 @@ namespace training { DECLARE_GRADIENT_BUILDER(GetCastGradient) DECLARE_GRADIENT_BUILDER(GetSinGradient) +DECLARE_GRADIENT_BUILDER(GetCosGradient) DECLARE_GRADIENT_BUILDER(GetLogGradient) DECLARE_GRADIENT_BUILDER(GetTanhGradient) DECLARE_GRADIENT_BUILDER(GetSqrtGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index f67d9cf7fdd43..35da30c0be048 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -45,6 +45,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { // Register gradient builders here. REGISTER_GRADIENT_BUILDER("Cast", GetCastGradient); REGISTER_GRADIENT_BUILDER("Sin", GetSinGradient); + REGISTER_GRADIENT_BUILDER("Cos", GetCosGradient); REGISTER_GRADIENT_BUILDER("Log", GetLogGradient); REGISTER_GRADIENT_BUILDER("Tanh", GetTanhGradient); REGISTER_GRADIENT_BUILDER("Sqrt", GetSqrtGradient); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index aef434cf569dc..6f1b50fd43cf9 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -424,6 +424,8 @@ TEST(GradientCheckerTest, MatMulGrad) { TEST(GradientCheckerTest, SinGrad) { UnaryOpGradientTest("Sin"); } +TEST(GradientCheckerTest, CosGrad) { UnaryOpGradientTest("Cos"); } + TEST(GradientCheckerTest, NegGrad) { UnaryOpGradientTest("Neg"); } TEST(GradientCheckerTest, AbsGrad) {