diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index cad54acc2d185..3242893bc89da 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -82,6 +82,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetSinGradient) { return result; } +IMPLEMENT_GRADIENT_BUILDER(GetCosGradient) { + std::vector result; + NodeDef zero_constant_node = ZeroConstantNode(IElemType(0)); + ArgDef zero = zero_constant_node.output_args[0]; + result.push_back(zero_constant_node); + result.push_back(NodeDef("Sin", {I(0)}, {IA("Sin_O0")})); + result.push_back(NodeDef("Sub", {zero, IA("Sin_O0")}, {IA("NegSin_O0")})); + result.push_back(NodeDef("Mul", {GO(0), IA("NegSin_O0")}, {GI(0)})); + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetLogGradient) { return std::vector{ NodeDef("Div", diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 9edccb02cbb5a..174c650201361 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -14,6 +14,7 @@ namespace training { DECLARE_GRADIENT_BUILDER(GetCastGradient) DECLARE_GRADIENT_BUILDER(GetSinGradient) +DECLARE_GRADIENT_BUILDER(GetCosGradient) DECLARE_GRADIENT_BUILDER(GetLogGradient) DECLARE_GRADIENT_BUILDER(GetTanhGradient) DECLARE_GRADIENT_BUILDER(GetSqrtGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index f67d9cf7fdd43..35da30c0be048 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -45,6 +45,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { // Register gradient builders here. REGISTER_GRADIENT_BUILDER("Cast", GetCastGradient); REGISTER_GRADIENT_BUILDER("Sin", GetSinGradient); + REGISTER_GRADIENT_BUILDER("Cos", GetCosGradient); REGISTER_GRADIENT_BUILDER("Log", GetLogGradient); REGISTER_GRADIENT_BUILDER("Tanh", GetTanhGradient); REGISTER_GRADIENT_BUILDER("Sqrt", GetSqrtGradient); diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index aef434cf569dc..6f1b50fd43cf9 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -424,6 +424,8 @@ TEST(GradientCheckerTest, MatMulGrad) { TEST(GradientCheckerTest, SinGrad) { UnaryOpGradientTest("Sin"); } +TEST(GradientCheckerTest, CosGrad) { UnaryOpGradientTest("Cos"); } + TEST(GradientCheckerTest, NegGrad) { UnaryOpGradientTest("Neg"); } TEST(GradientCheckerTest, AbsGrad) {