optimiz no need grad net

This commit is contained in:
wangchangheng 2022-08-08 20:44:23 +08:00
parent 4aa9e4043c
commit 3d09b8fbc2
11 changed files with 340 additions and 35 deletions

View File

@ -256,6 +256,8 @@ class PynativeAdjoint {
AnfNodePtr k_node() const { return k_node_; }
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
AnfNodePtr get_dout() const { return dout_; }
private:
const FuncGraphPtr tape_;
AnfNodePtr dout_{nullptr};
@ -320,6 +322,8 @@ class KPynativeCellImpl : public KPynativeCell {
bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, const FuncGraphPtr &fg,
const PynativeAdjoint::FuncGraphType fg_type = PynativeAdjoint::FuncGraphType::kBackwardPropagate);
void BuildAdjointForInput(const CNodePtr &cnode, const ValuePtrList &op_args);
bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const;
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
void PropagateStopGradient();
bool AllReferencesStopped(const CNodePtr &curr_cnode);
OrderedMap<AnfNodePtr, PynativeAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
@ -652,8 +656,51 @@ void KPynativeCellImpl::BuildAdjointForInput(const CNodePtr &cnode, const ValueP
}
}
bool KPynativeCellImpl::IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const {
if (node_ptr->isa<CNode>()) {
const auto &cnode = node_ptr->cast<CNodePtr>();
if (cnode == nullptr || !cnode->HasAttr(kAttrIsCNodeNeedGrad)) {
return true;
}
return GetValue<bool>(cnode->GetAttr(kAttrIsCNodeNeedGrad));
}
auto param_ptr = node_ptr->cast<ParameterPtr>();
if (param_ptr == nullptr) {
// Value node will return here.
return false;
}
auto param_value = param_ptr->param_info();
if (param_value == nullptr) {
// If node is a parameter, but param_info is null, node need to grad.
return true;
}
return param_value->requires_grad();
}
std::vector<bool> KPynativeCellImpl::GetNeedGradFlags(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<bool> need_grad_flag_of_inputs;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
need_grad_flag_of_inputs.emplace_back(IsCNodeNeedGrad(cnode->input(i)));
}
return need_grad_flag_of_inputs;
}
bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
const FuncGraphPtr &fg, const PynativeAdjoint::FuncGraphType fg_type) {
auto need_grad_flag_of_inputs = GetNeedGradFlags(cnode);
size_t need_grad_input_num = std::count(need_grad_flag_of_inputs.begin(), need_grad_flag_of_inputs.end(), true);
cnode->AddAttr(kAttrIsCNodeNeedGrad, MakeValue(need_grad_input_num != 0));
if (need_grad_input_num != need_grad_flag_of_inputs.size()) {
cnode->AddAttr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flag_of_inputs));
} else if (cnode->HasAttr(kAttrNeedGradFlagOfInputs)) {
cnode->EraseAttr(kAttrNeedGradFlagOfInputs);
}
fg->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flag_of_inputs));
// Optimize the bprop_fg based on value.
// Clone op_args and out, so the address of tensor data can be reset to nullptr if the value of tensor
// is not used in bprop_fg;
@ -765,6 +812,12 @@ const AnfNodePtrList KPynativeCellImpl::BuildKNodeListFromPrimalCNode(const CNod
bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &cnode,
const PynativeAdjointPtr &adjoint,
const FuncGraphPtr &bprop_fg, bool by_value) {
if (adjoint->get_dout() == nullptr) {
// If dout is null, the node does not need to grad.
MS_LOG(DEBUG) << "node dout is null, node:" << cnode->DebugString();
return true;
}
AnfNodePtrList node_list;
abstract::AbstractBasePtr bprop_output_abs;
@ -815,6 +868,12 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &
const FuncGraphPtr &fprop_fg, bool by_value) {
MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
if (adjoint->get_dout() == nullptr) {
// If dout is null, the node does not need to grad.
MS_LOG(DEBUG) << "node dout is null, node:" << cnode->DebugString();
return true;
}
AnfNodePtrList node_list;
CNodePtr bprop_cnode;
if (by_value) {

View File

@ -200,11 +200,17 @@ FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bpro
return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
}
return GetOptBpropFromCache(bprop_fg, op_args, out, prim);
std::vector<bool> need_grad_flags;
if (c_node->HasAttr(kAttrNeedGradFlagOfInputs)) {
need_grad_flags = GetValue<std::vector<bool>>(c_node->GetAttr(kAttrNeedGradFlagOfInputs));
}
return GetOptBpropFromCache(bprop_fg, op_args, out, prim, need_grad_flags);
}
FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
const ValuePtr &out, const PrimitivePtr &prim) {
const ValuePtr &out, const PrimitivePtr &prim,
const std::vector<bool> &need_grad_flags) {
MS_EXCEPTION_IF_NULL(bprop_fg);
abstract::AbstractBasePtrList abs_list;
ArgsToAbs(prim, op_args, &abs_list);
@ -217,7 +223,12 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
if (cache_res == E_LEVEL_2) {
MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
level_2_graph_info->TryFreeArgsValue(op_args, out);
return BasicClone(level_2_graph_info->opt_func_graph());
auto level2_graph_clone = BasicClone(level_2_graph_info->opt_func_graph());
if (!need_grad_flags.empty()) {
return GetBpropGrahWithNoGradInput(level_2_graph_info, AddOutToAbsList(out, abs_list), need_grad_flags, op_args,
out);
}
return level2_graph_clone;
}
// do step1 opt
@ -235,6 +246,12 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
if (enable_grad_cache) {
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
}
if (!need_grad_flags.empty()) {
return GetBpropGrahWithNoGradInput(level_2_graph_info, new_abs_list, need_grad_flags, op_args, out);
}
if (enable_grad_cache) {
return BasicClone(level_2_graph_info->opt_func_graph());
}
return level_2_graph_info->opt_func_graph();
@ -294,7 +311,8 @@ void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
}
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) const {
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
const std::vector<bool> &need_grad_flags) const {
opt::irpass::OptimizeIRPassLib irpass;
BindAbsToParameters(bprop_fg, abs_list_input);
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
@ -307,12 +325,23 @@ PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
}
}
manager->AddFuncGraph(bprop_fg);
auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource);
auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource, need_grad_flags);
auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
level_2_graph_info->AnalysisArgUsingInfo(manager);
return level_2_graph_info;
}
FuncGraphPtr PrimBpropOptimizer::GetBpropGrahWithNoGradInput(const PrimBpropOptGraphLevel2InfoPtr &level_2_graph_info,
const abstract::AbstractBasePtrList &abs_list,
const std::vector<bool> &need_grad_flags,
const ValuePtrList &op_args, const ValuePtr &out) {
auto level2_graph_clone = BasicClone(level_2_graph_info->opt_func_graph());
level2_graph_clone->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flags));
auto no_grad_graph_info = PrimBpropOptStep2(level2_graph_clone, abs_list, need_grad_flags);
no_grad_graph_info->TryFreeArgsValue(op_args, out);
return no_grad_graph_info->opt_func_graph();
}
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
MS_EXCEPTION_IF_NULL(res);
auto after_opt_bg = BpropGraphFinalOptPass(res);

View File

@ -21,7 +21,7 @@
#include <utility>
#include <memory>
#include <unordered_map>
#include <string>
#include "utils/hash_map.h"
#include "frontend/optimizer/irpass.h"
#include "ir/func_graph.h"
@ -164,13 +164,19 @@ class PrimBpropOptimizer {
PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) const;
// do opt with input info
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg,
const abstract::AbstractBasePtrList &abs_list_input) const;
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
const std::vector<bool> &need_grad_flags = std::vector<bool>()) const;
FuncGraphPtr GetBpropGrahWithNoGradInput(const PrimBpropOptGraphLevel2InfoPtr &level_2_graph_info,
const abstract::AbstractBasePtrList &abs_list,
const std::vector<bool> &need_grad_flags, const ValuePtrList &op_args,
const ValuePtr &out);
void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) const;
FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
const PrimitivePtr &prim);
const PrimitivePtr &prim, const std::vector<bool> &need_grad_flags);
FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
const PrimitivePtr &prim, bool hook_flg);

View File

@ -43,6 +43,7 @@
#include "frontend/optimizer/irpass/tile_eliminate.h"
#include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/irpass/value_based_eliminate.h"
#include "frontend/optimizer/irpass/pynative_no_grad_eliminate.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
@ -75,6 +76,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "ad_related_special_op_eliminate",
{prim::kPrimMirror, prim::kPrimVirtualDiv});
pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
pynative_no_grad_eliminate_ =
MakeSubstitution(std::make_shared<PynativeNoGradEliminater>(), "pynative_no_grad_eliminate", prim::kPrimMakeTuple);
zero_like_fill_zero_ =
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ =

View File

@ -168,6 +168,9 @@ class OptimizeIRPassLib {
// Pynative Eliminate
SubstitutionPtr pynative_eliminate_;
// Pynative no need grad eliminate
SubstitutionPtr pynative_no_grad_eliminate_;
// Recompute
SubstitutionPtr set_cell_output_no_recompute_;

View File

@ -0,0 +1,114 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PYNATIVE_NO_GRAD_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PYNATIVE_NO_GRAD_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <string>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
class PynativeNoGradEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!IsNeedOptimiz(optimizer, node)) {
return nullptr;
}
if (!node->isa<CNode>()) {
return nullptr;
}
auto &node_inputs = node->cast_ptr<CNode>()->inputs();
if (need_grad_flag_of_inputs_.size() != node_inputs.size() - 1) {
return nullptr;
}
const auto &graph_inputs = func_graph_->get_inputs();
if (graph_inputs.size() < node_inputs.size() - 1) {
return nullptr;
}
auto manager = optimizer->manager();
MS_EXCEPTION_IF_NULL(manager);
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!need_grad_flag_of_inputs_[i - 1] && node_inputs[i]->isa<CNode>() &&
!IsPrimitiveCNode(node_inputs[i], prim::kPrimZerosLike)) {
const auto &graph_input_type = graph_inputs[i - 1]->Type();
if (graph_input_type == nullptr || !graph_input_type->isa<TensorType>()) {
// If input is not tensor, it can not be input for kPrimZerosLike.
continue;
}
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimZerosLike), graph_inputs[i - 1]};
auto zeros_like_node = node->func_graph()->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(zeros_like_node);
zeros_like_node->set_abstract(graph_inputs[i - 1]->abstract());
if (!manager->Replace(node_inputs[i], zeros_like_node)) {
MS_LOG(EXCEPTION) << node_inputs[i]->DebugString() << ", replace node failed.";
}
}
}
return node;
}
private:
bool IsNeedOptimiz(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return false;
}
const auto &resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
MS_EXCEPTION_IF_NULL(resource);
func_graph_ = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph_);
if (!func_graph_->has_attr(kAttrNeedGradFlagOfInputs)) {
return false;
}
const size_t ret_input_size = 2;
const auto &return_node = func_graph_->get_return();
MS_EXCEPTION_IF_NULL(return_node);
if (return_node->size() != ret_input_size) {
// ret node has two input 1 ret op + 1 value
return false;
}
if (return_node->input(1) != node) {
// Only optimiz return maketuple node.
return false;
}
need_grad_flag_of_inputs_ = GetValue<std::vector<bool>>(func_graph_->get_attr(kAttrNeedGradFlagOfInputs));
return true;
}
std::vector<bool> need_grad_flag_of_inputs_;
FuncGraphPtr func_graph_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_

View File

@ -637,6 +637,8 @@ constexpr auto kAttrZeroInfinity = "zero_infinity";
constexpr auto kAttrBlank = "blank";
constexpr auto kAttrUpdateSlots = "update_slots";
constexpr auto kAttrLr = "lr";
constexpr auto kAttrNeedGradFlagOfInputs = "need_grad_flag_of_inputs";
constexpr auto kAttrIsCNodeNeedGrad = "is_cnode_need_grad";
// FuncGraph Flags
constexpr auto kFlagsIsCutGraph = "is_cut_graph";

View File

@ -164,28 +164,40 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
return func_graph;
}
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource) {
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
const std::vector<bool> &need_grad_flags) {
MS_EXCEPTION_IF_NULL(resource);
MS_EXCEPTION_IF_NULL(resource->func_graph());
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
irpass.switch_simplify_,
irpass.reduce_eliminate_,
irpass.tile_eliminate_,
irpass.arithmetic_simplify_,
irpass.make_sparse_tensor_to_make_tuple_,
});
OptPassGroupMap map;
if (need_grad_flags.empty()) {
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
irpass.switch_simplify_,
irpass.reduce_eliminate_,
irpass.tile_eliminate_,
irpass.arithmetic_simplify_,
irpass.make_sparse_tensor_to_make_tuple_,
});
opt::OptPassConfig inline_opt = opt::OptPassConfig({
irpass.inline_,
});
opt::OptPassConfig inline_opt = opt::OptPassConfig({
irpass.inline_,
});
auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
return ReAutoMonad(root);
};
OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()},
{"ad_inline", inline_opt},
{"ad_special_op_simplify", special_op_simplify},
{"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}});
auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
return ReAutoMonad(root);
};
map.push_back({"ad_renormalize", opt::OptPassConfig::Renormalize()});
map.push_back({"ad_inline", inline_opt});
map.push_back({"ad_special_op_simplify", special_op_simplify});
map.push_back({"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)});
} else {
// If func graph has not need_grad_flag_of_inputs attr, this graph has no need do this pass.
opt::OptPassConfig pynative_no_grad_eliminate = opt::OptPassConfig({
irpass.pynative_no_grad_eliminate_,
});
map.push_back({"pynative_no_grad_eliminate", pynative_no_grad_eliminate});
}
auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", resource, map);
FuncGraphPtr func_graph = resource->func_graph();

View File

@ -52,7 +52,8 @@ bool PynativeOptPass(const ResourcePtr &resource);
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &resource);
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph);
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource);
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource);
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
const std::vector<bool> &need_grad_flags);
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource);
} // namespace pipeline
} // namespace mindspore

View File

@ -27,6 +27,7 @@
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
#include "frontend/optimizer/irpass/pynative_no_grad_eliminate.h"
#include "pipeline/jit/action.h"
#include "include/common/debug/draw.h"
@ -85,17 +86,14 @@ class TestOptOpt : public UT::Common {
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
pynative_no_grad_elim = MakeSubstitution(std::make_shared<irpass::PynativeNoGradEliminater>(),
"pynative_no_grad_eliminate", prim::kPrimMakeTuple);
}
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
equiv_node.clear();
equiv_graph.clear();
FuncGraphPtr graph_after_trans = TransformGraph(gbefore, transform);
FuncGraphPtr gbefore_clone = BasicClone(gbefore);
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
transform(gbefore_clone, optimizer);
return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
return Isomorphic(graph_after_trans, gafter, &equiv_graph, &equiv_node);
}
bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) {
@ -103,6 +101,24 @@ class TestOptOpt : public UT::Common {
return CheckTransform(before, after, eq);
}
FuncGraphPtr TransformGraph(FuncGraphPtr gbefore, const SubstitutionList &transform) {
equiv_node.clear();
equiv_graph.clear();
FuncGraphPtr gbefore_clone = BasicClone(gbefore);
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
MS_EXCEPTION_IF_NULL(resource);
resource->set_func_graph(gbefore_clone);
auto manager = resource->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(gbefore_clone, true);
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", resource);
transform(gbefore_clone, optimizer);
return gbefore_clone;
}
public:
UT::PyFuncGraphFetcher getPyFun;
@ -119,6 +135,7 @@ class TestOptOpt : public UT::Common {
SubstitutionPtr elim_R;
SubstitutionPtr idempotent_P;
SubstitutionPtr Qct_to_P;
SubstitutionPtr pynative_no_grad_elim;
SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
};
@ -213,6 +230,46 @@ TEST_F(TestOptOpt, CSE) {
ASSERT_EQ(manager2->all_nodes().size(), 12);
}
/// Feature: test no grad input net.
/// Description: test no grad input net.
/// Expectation: No exception.
TEST_F(TestOptOpt, PynativeNoGradElim) {
FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_no_grad", "test_f1");
ASSERT_TRUE(nullptr != test_graph1);
auto all_nodes1 = TopoSort(test_graph1->return_node(), SuccDeeperSimple, AlwaysInclude);
auto mul_node_num1 = std::count_if(all_nodes1.begin(), all_nodes1.end(),
[](AnfNodePtr node) { return IsPrimitiveCNode(node, prim::kPrimMul); });
ASSERT_EQ(mul_node_num1, 2);
FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_no_grad", "test_f1");
ASSERT_TRUE(nullptr != test_graph2);
std::vector<bool> need_grad_flags{true, false};
test_graph2->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flags));
auto tmp_substitution = std::vector<SubstitutionPtr>({pynative_no_grad_elim});
SubstitutionList substitution_list(tmp_substitution);
std::vector<int64_t> shape_vec = {1};
AbstractBasePtr abs = std::make_shared<abstract::AbstractTensor>(kTensorType, shape_vec);
auto graph_params = test_graph2->parameters();
for (auto graph_input : graph_params) {
graph_input->set_abstract(abs);
}
auto test_graph2_after_optmiz = TransformGraph(test_graph2, substitution_list);
ASSERT_TRUE(nullptr != test_graph2_after_optmiz);
auto all_nodes2 = TopoSort(test_graph2_after_optmiz->return_node(), SuccDeeperSimple, AlwaysInclude);
auto mul_node_num2 = std::count_if(all_nodes2.begin(), all_nodes2.end(),
[](AnfNodePtr node) { return IsPrimitiveCNode(node, prim::kPrimMul); });
ASSERT_EQ(mul_node_num2, 1);
}
size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
// Check tuple params and tuple args.
auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);

View File

@ -374,6 +374,25 @@ def test_cse(tag):
return fns[tag]
def test_no_grad(tag):
"""
Feature: test no grad input net.
Description: test no grad input net.
Expectation: No exception.
"""
fns = FnDict()
mul = Primitive('Mul')
make_tuple = Primitive('MakeTuple')
@fns
def test_f1(x, y):
x1 = mul(x, 2)
y1 = mul(y, 2)
return make_tuple(x1, y1)
return fns[tag]
def test_arithmetic(tag):
""" test_arithmetic """
fns = FnDict()