forked from mindspore-Ecosystem/mindspore
Add AdamApplyOne fusion pass
This commit is contained in:
parent
80333e9f55
commit
5eb5379889
|
@ -42,17 +42,69 @@ AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_g
|
|||
|
||||
const BaseRef AdamApplyOneFusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})});
|
||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
}
|
||||
|
||||
const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]});
|
||||
VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})});
|
||||
VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]});
|
||||
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})});
|
||||
return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})});
|
||||
}
|
||||
|
||||
const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -18,21 +18,23 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kAdamApplyOneInputNum = 5;
|
||||
constexpr size_t kAdamApplyOneMulInputNum = 4;
|
||||
constexpr size_t kAdamApplyOneInputVarNum = 5;
|
||||
constexpr size_t kAdamApplyOneMulInputVarNum = 4;
|
||||
|
||||
class AdamApplyOneFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AdamApplyOneFusion(bool multigraph = true) : PatternProcessPass("adam_apply_one_fusion", multigraph) {
|
||||
for (size_t i = 0; i < kAdamApplyOneInputNum; ++i) {
|
||||
explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph) {
|
||||
for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) {
|
||||
input_vars_.push_back(std::make_shared<Var>());
|
||||
}
|
||||
for (size_t i = 0; i < kAdamApplyOneMulInputNum; ++i) {
|
||||
for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) {
|
||||
mul_x_input_vars_.push_back(std::make_shared<Var>());
|
||||
}
|
||||
add2_y_ = std::make_shared<Var>();
|
||||
|
@ -44,7 +46,7 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
|||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
|
||||
std::vector<VarPtr> input_vars_;
|
||||
std::vector<VarPtr> mul_x_input_vars_;
|
||||
|
@ -52,6 +54,42 @@ class AdamApplyOneFusion : public PatternProcessPass {
|
|||
VarPtr add0_var_;
|
||||
VarPtr add1_var_;
|
||||
};
|
||||
|
||||
class AdamApplyOneCond1Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneCond1Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneCond1Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneCond2Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneCond2Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneCond2Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneCond3Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneCond3Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneCond3Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
|
||||
class AdamApplyOneCond4Fusion : public AdamApplyOneFusion {
|
||||
public:
|
||||
explicit AdamApplyOneCond4Fusion(bool multigraph = true)
|
||||
: AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {}
|
||||
|
||||
~AdamApplyOneCond4Fusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_
|
||||
|
|
|
@ -66,5 +66,156 @@ TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_fusion) {
|
|||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond1_fusion) {
|
||||
/*
|
||||
* def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
* square0 = Square(input0)
|
||||
* mul1 = Mul(mul1_x, input0)
|
||||
* mul0 = Mul(mul0_x, input2)
|
||||
* mul2 = Mul(mul2_x, input1)
|
||||
* mul3 = Mul(mul3_x, square0)
|
||||
* add0 = Add(mul0, mul1)
|
||||
* add1 = Add(mul2, mul3)
|
||||
* sqrt0 = Sqrt(add1)
|
||||
* add2 = Add(add2_y, sqrt0)
|
||||
* true_div0 = RealDiv(add0, add2)
|
||||
* mul4 = Mul(input4, true_div0)
|
||||
* sub0 = Sub(input3, mul4)
|
||||
* outputs = make_tuple(add1, add0, sub0)
|
||||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond1");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneCond1Fusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond2_fusion) {
|
||||
/*
|
||||
* def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
* square0 = Square(input0)
|
||||
* mul1 = Mul(mul1_x, input0)
|
||||
* mul0 = Mul(mul0_x, input2)
|
||||
* mul2 = Mul(mul2_x, input1)
|
||||
* mul3 = Mul(square0, mul3_x)
|
||||
* add0 = Add(mul0, mul1)
|
||||
* add1 = Add(mul2, mul3)
|
||||
* sqrt0 = Sqrt(add1)
|
||||
* add2 = Add(sqrt0, add2_y)
|
||||
* true_div0 = RealDiv(add0, add2)
|
||||
* mul4 = Mul(true_div0, input4)
|
||||
* sub0 = Sub(input3, mul4)
|
||||
* outputs = make_tuple(add1, add0, sub0)
|
||||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond2");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneCond2Fusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond3_fusion) {
|
||||
/*
|
||||
* def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
* square0 = Square(input0)
|
||||
* mul1 = Mul(mul1_x, input0)
|
||||
* mul0 = Mul(mul0_x, input2)
|
||||
* mul2 = Mul(mul2_x, input1)
|
||||
* mul3 = Mul(mul3_x, square0)
|
||||
* add0 = Add(mul0, mul1)
|
||||
* add1 = Add(mul2, mul3)
|
||||
* sqrt0 = Sqrt(add1)
|
||||
* add2 = Add(sqrt0, add2_y)
|
||||
* true_div0 = RealDiv(add0, add2)
|
||||
* mul4 = Mul(true_div0, input4)
|
||||
* sub0 = Sub(input3, mul4)
|
||||
* outputs = make_tuple(add1, add0, sub0)
|
||||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond3");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneCond3Fusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWAdamApplyOneFusion, test_adam_apply_one_cond4_fusion) {
|
||||
/*
|
||||
* def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
* square0 = Square(input0)
|
||||
* mul1 = Mul(mul1_x, input0)
|
||||
* mul0 = Mul(mul0_x, input2)
|
||||
* mul2 = Mul(mul2_x, input1)
|
||||
* mul3 = Mul(mul3_x, square0)
|
||||
* add0 = Add(mul0, mul1)
|
||||
* add1 = Add(mul2, mul3)
|
||||
* sqrt0 = Sqrt(add1)
|
||||
* add2 = Add(add2_y, sqrt0)
|
||||
* true_div0 = RealDiv(add0, add2)
|
||||
* mul4 = Mul(true_div0, input4)
|
||||
* sub0 = Sub(input3, mul4)
|
||||
* outputs = make_tuple(add1, add0, sub0)
|
||||
* output = tuple_getitem(outputs, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "before_cond4");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 10; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamApplyOneCond4Fusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,6 +58,78 @@ def test_adam_apply_one_fusion(tag):
|
|||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_cond1(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
square0 = Square(input0)
|
||||
mul1 = Mul(mul1_x, input0)
|
||||
mul0 = Mul(mul0_x, input2)
|
||||
mul2 = Mul(mul2_x, input1)
|
||||
mul3 = Mul(mul3_x, square0)
|
||||
add0 = Add(mul0, mul1)
|
||||
add1 = Add(mul2, mul3)
|
||||
sqrt0 = Sqrt(add1)
|
||||
add2 = Add(add2_y, sqrt0)
|
||||
true_div0 = RealDiv(add0, add2)
|
||||
mul4 = Mul(input4, true_div0)
|
||||
sub0 = Sub(input3, mul4)
|
||||
outputs = make_tuple(add1, add0, sub0)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_cond2(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
square0 = Square(input0)
|
||||
mul1 = Mul(mul1_x, input0)
|
||||
mul0 = Mul(mul0_x, input2)
|
||||
mul2 = Mul(mul2_x, input1)
|
||||
mul3 = Mul(square0, mul3_x)
|
||||
add0 = Add(mul0, mul1)
|
||||
add1 = Add(mul2, mul3)
|
||||
sqrt0 = Sqrt(add1)
|
||||
add2 = Add(sqrt0, add2_y)
|
||||
true_div0 = RealDiv(add0, add2)
|
||||
mul4 = Mul(true_div0, input4)
|
||||
sub0 = Sub(input3, mul4)
|
||||
outputs = make_tuple(add1, add0, sub0)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_cond3(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
square0 = Square(input0)
|
||||
mul1 = Mul(mul1_x, input0)
|
||||
mul0 = Mul(mul0_x, input2)
|
||||
mul2 = Mul(mul2_x, input1)
|
||||
mul3 = Mul(mul3_x, square0)
|
||||
add0 = Add(mul0, mul1)
|
||||
add1 = Add(mul2, mul3)
|
||||
sqrt0 = Sqrt(add1)
|
||||
add2 = Add(sqrt0, add2_y)
|
||||
true_div0 = RealDiv(add0, add2)
|
||||
mul4 = Mul(true_div0, input4)
|
||||
sub0 = Sub(input3, mul4)
|
||||
outputs = make_tuple(add1, add0, sub0)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_cond4(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
square0 = Square(input0)
|
||||
mul1 = Mul(mul1_x, input0)
|
||||
mul0 = Mul(mul0_x, input2)
|
||||
mul2 = Mul(mul2_x, input1)
|
||||
mul3 = Mul(mul3_x, square0)
|
||||
add0 = Add(mul0, mul1)
|
||||
add1 = Add(mul2, mul3)
|
||||
sqrt0 = Sqrt(add1)
|
||||
add2 = Add(add2_y, sqrt0)
|
||||
true_div0 = RealDiv(add0, add2)
|
||||
mul4 = Mul(true_div0, input4)
|
||||
sub0 = Sub(input3, mul4)
|
||||
outputs = make_tuple(add1, add0, sub0)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y):
|
||||
adam_apply_one = AdamApplyOne(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, add2_y)
|
||||
|
|
Loading…
Reference in New Issue