optimiz no need grad net
This commit is contained in:
parent
4aa9e4043c
commit
3d09b8fbc2
|
@ -256,6 +256,8 @@ class PynativeAdjoint {
|
|||
AnfNodePtr k_node() const { return k_node_; }
|
||||
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
|
||||
|
||||
AnfNodePtr get_dout() const { return dout_; }
|
||||
|
||||
private:
|
||||
const FuncGraphPtr tape_;
|
||||
AnfNodePtr dout_{nullptr};
|
||||
|
@ -320,6 +322,8 @@ class KPynativeCellImpl : public KPynativeCell {
|
|||
bool BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out, const FuncGraphPtr &fg,
|
||||
const PynativeAdjoint::FuncGraphType fg_type = PynativeAdjoint::FuncGraphType::kBackwardPropagate);
|
||||
void BuildAdjointForInput(const CNodePtr &cnode, const ValuePtrList &op_args);
|
||||
bool IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const;
|
||||
std::vector<bool> GetNeedGradFlags(const CNodePtr &cnode);
|
||||
void PropagateStopGradient();
|
||||
bool AllReferencesStopped(const CNodePtr &curr_cnode);
|
||||
OrderedMap<AnfNodePtr, PynativeAdjointPtr>::reverse_iterator GetLastNodeReverseIter();
|
||||
|
@ -652,8 +656,51 @@ void KPynativeCellImpl::BuildAdjointForInput(const CNodePtr &cnode, const ValueP
|
|||
}
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::IsCNodeNeedGrad(const AnfNodePtr &node_ptr) const {
|
||||
if (node_ptr->isa<CNode>()) {
|
||||
const auto &cnode = node_ptr->cast<CNodePtr>();
|
||||
if (cnode == nullptr || !cnode->HasAttr(kAttrIsCNodeNeedGrad)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return GetValue<bool>(cnode->GetAttr(kAttrIsCNodeNeedGrad));
|
||||
}
|
||||
|
||||
auto param_ptr = node_ptr->cast<ParameterPtr>();
|
||||
if (param_ptr == nullptr) {
|
||||
// Value node will return here.
|
||||
return false;
|
||||
}
|
||||
auto param_value = param_ptr->param_info();
|
||||
if (param_value == nullptr) {
|
||||
// If node is a parameter, but param_info is null, node need to grad.
|
||||
return true;
|
||||
}
|
||||
return param_value->requires_grad();
|
||||
}
|
||||
|
||||
std::vector<bool> KPynativeCellImpl::GetNeedGradFlags(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<bool> need_grad_flag_of_inputs;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
need_grad_flag_of_inputs.emplace_back(IsCNodeNeedGrad(cnode->input(i)));
|
||||
}
|
||||
return need_grad_flag_of_inputs;
|
||||
}
|
||||
|
||||
bool KPynativeCellImpl::BuildAdjoint(const CNodePtr &cnode, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const FuncGraphPtr &fg, const PynativeAdjoint::FuncGraphType fg_type) {
|
||||
auto need_grad_flag_of_inputs = GetNeedGradFlags(cnode);
|
||||
size_t need_grad_input_num = std::count(need_grad_flag_of_inputs.begin(), need_grad_flag_of_inputs.end(), true);
|
||||
cnode->AddAttr(kAttrIsCNodeNeedGrad, MakeValue(need_grad_input_num != 0));
|
||||
if (need_grad_input_num != need_grad_flag_of_inputs.size()) {
|
||||
cnode->AddAttr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flag_of_inputs));
|
||||
} else if (cnode->HasAttr(kAttrNeedGradFlagOfInputs)) {
|
||||
cnode->EraseAttr(kAttrNeedGradFlagOfInputs);
|
||||
}
|
||||
|
||||
fg->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flag_of_inputs));
|
||||
|
||||
// Optimize the bprop_fg based on value.
|
||||
// Clone op_args and out, so the address of tensor data can be reset to nullptr if the value of tensor
|
||||
// is not used in bprop_fg;
|
||||
|
@ -765,6 +812,12 @@ const AnfNodePtrList KPynativeCellImpl::BuildKNodeListFromPrimalCNode(const CNod
|
|||
bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &cnode,
|
||||
const PynativeAdjointPtr &adjoint,
|
||||
const FuncGraphPtr &bprop_fg, bool by_value) {
|
||||
if (adjoint->get_dout() == nullptr) {
|
||||
// If dout is null, the node does not need to grad.
|
||||
MS_LOG(DEBUG) << "node dout is null, node:" << cnode->DebugString();
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtrList node_list;
|
||||
abstract::AbstractBasePtr bprop_output_abs;
|
||||
|
||||
|
@ -815,6 +868,12 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &
|
|||
const FuncGraphPtr &fprop_fg, bool by_value) {
|
||||
MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
|
||||
|
||||
if (adjoint->get_dout() == nullptr) {
|
||||
// If dout is null, the node does not need to grad.
|
||||
MS_LOG(DEBUG) << "node dout is null, node:" << cnode->DebugString();
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtrList node_list;
|
||||
CNodePtr bprop_cnode;
|
||||
if (by_value) {
|
||||
|
|
|
@ -200,11 +200,17 @@ FuncGraphPtr PrimBpropOptimizer::OptimizeBPropFuncGraph(const FuncGraphPtr &bpro
|
|||
return GenSpecOptBprop(bprop_fg, op_args, out, prim, hookback_flg);
|
||||
}
|
||||
|
||||
return GetOptBpropFromCache(bprop_fg, op_args, out, prim);
|
||||
std::vector<bool> need_grad_flags;
|
||||
if (c_node->HasAttr(kAttrNeedGradFlagOfInputs)) {
|
||||
need_grad_flags = GetValue<std::vector<bool>>(c_node->GetAttr(kAttrNeedGradFlagOfInputs));
|
||||
}
|
||||
|
||||
return GetOptBpropFromCache(bprop_fg, op_args, out, prim, need_grad_flags);
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args,
|
||||
const ValuePtr &out, const PrimitivePtr &prim) {
|
||||
const ValuePtr &out, const PrimitivePtr &prim,
|
||||
const std::vector<bool> &need_grad_flags) {
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
abstract::AbstractBasePtrList abs_list;
|
||||
ArgsToAbs(prim, op_args, &abs_list);
|
||||
|
@ -217,7 +223,12 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
|
|||
if (cache_res == E_LEVEL_2) {
|
||||
MS_LOG(DEBUG) << "Level 2 cache matched, prim: " << prim->ToString();
|
||||
level_2_graph_info->TryFreeArgsValue(op_args, out);
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
auto level2_graph_clone = BasicClone(level_2_graph_info->opt_func_graph());
|
||||
if (!need_grad_flags.empty()) {
|
||||
return GetBpropGrahWithNoGradInput(level_2_graph_info, AddOutToAbsList(out, abs_list), need_grad_flags, op_args,
|
||||
out);
|
||||
}
|
||||
return level2_graph_clone;
|
||||
}
|
||||
|
||||
// do step1 opt
|
||||
|
@ -235,6 +246,12 @@ FuncGraphPtr PrimBpropOptimizer::GetOptBpropFromCache(const FuncGraphPtr &bprop_
|
|||
auto enable_grad_cache = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
|
||||
if (enable_grad_cache) {
|
||||
level_1_graph_info->graph_level_2_cache_[abs_list] = level_2_graph_info;
|
||||
}
|
||||
|
||||
if (!need_grad_flags.empty()) {
|
||||
return GetBpropGrahWithNoGradInput(level_2_graph_info, new_abs_list, need_grad_flags, op_args, out);
|
||||
}
|
||||
if (enable_grad_cache) {
|
||||
return BasicClone(level_2_graph_info->opt_func_graph());
|
||||
}
|
||||
return level_2_graph_info->opt_func_graph();
|
||||
|
@ -294,7 +311,8 @@ void PrimBpropOptimizer::BindAbsToParameters(const FuncGraphPtr &bprop_fg,
|
|||
}
|
||||
|
||||
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
|
||||
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) const {
|
||||
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
|
||||
const std::vector<bool> &need_grad_flags) const {
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
BindAbsToParameters(bprop_fg, abs_list_input);
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
|
@ -307,12 +325,23 @@ PrimBpropOptGraphLevel2InfoPtr PrimBpropOptimizer::PrimBpropOptStep2(
|
|||
}
|
||||
}
|
||||
manager->AddFuncGraph(bprop_fg);
|
||||
auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource);
|
||||
auto opt_bprop_fg = PrimBpOptPassStep2(irpass, resource, need_grad_flags);
|
||||
auto level_2_graph_info = std::make_shared<PrimBpropOptGraphLevel2Info>(opt_bprop_fg);
|
||||
level_2_graph_info->AnalysisArgUsingInfo(manager);
|
||||
return level_2_graph_info;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::GetBpropGrahWithNoGradInput(const PrimBpropOptGraphLevel2InfoPtr &level_2_graph_info,
|
||||
const abstract::AbstractBasePtrList &abs_list,
|
||||
const std::vector<bool> &need_grad_flags,
|
||||
const ValuePtrList &op_args, const ValuePtr &out) {
|
||||
auto level2_graph_clone = BasicClone(level_2_graph_info->opt_func_graph());
|
||||
level2_graph_clone->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flags));
|
||||
auto no_grad_graph_info = PrimBpropOptStep2(level2_graph_clone, abs_list, need_grad_flags);
|
||||
no_grad_graph_info->TryFreeArgsValue(op_args, out);
|
||||
return no_grad_graph_info->opt_func_graph();
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpropOptimizer::BpropGraphFinalOpt(const pipeline::ResourcePtr &res) const {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
auto after_opt_bg = BpropGraphFinalOptPass(res);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <string>
|
||||
#include "utils/hash_map.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -164,13 +164,19 @@ class PrimBpropOptimizer {
|
|||
PrimBpropOptGraphInfoPtr PrimBpropOptStep1(const FuncGraphPtr &bprop_fg) const;
|
||||
|
||||
// do opt with input info
|
||||
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(const FuncGraphPtr &bprop_fg,
|
||||
const abstract::AbstractBasePtrList &abs_list_input) const;
|
||||
PrimBpropOptGraphLevel2InfoPtr PrimBpropOptStep2(
|
||||
const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input,
|
||||
const std::vector<bool> &need_grad_flags = std::vector<bool>()) const;
|
||||
|
||||
FuncGraphPtr GetBpropGrahWithNoGradInput(const PrimBpropOptGraphLevel2InfoPtr &level_2_graph_info,
|
||||
const abstract::AbstractBasePtrList &abs_list,
|
||||
const std::vector<bool> &need_grad_flags, const ValuePtrList &op_args,
|
||||
const ValuePtr &out);
|
||||
|
||||
void BindAbsToParameters(const FuncGraphPtr &bprop_fg, const abstract::AbstractBasePtrList &abs_list_input) const;
|
||||
|
||||
FuncGraphPtr GetOptBpropFromCache(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const PrimitivePtr &prim);
|
||||
const PrimitivePtr &prim, const std::vector<bool> &need_grad_flags);
|
||||
|
||||
FuncGraphPtr GenSpecOptBprop(const FuncGraphPtr &bprop_fg, const ValuePtrList &op_args, const ValuePtr &out,
|
||||
const PrimitivePtr &prim, bool hook_flg);
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "frontend/optimizer/irpass/tile_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/transpose_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/value_based_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/pynative_no_grad_eliminate.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
|
@ -75,6 +76,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "ad_related_special_op_eliminate",
|
||||
{prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
|
||||
pynative_no_grad_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<PynativeNoGradEliminater>(), "pynative_no_grad_eliminate", prim::kPrimMakeTuple);
|
||||
zero_like_fill_zero_ =
|
||||
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||
adjust_all_reduce_mul_add_ =
|
||||
|
|
|
@ -168,6 +168,9 @@ class OptimizeIRPassLib {
|
|||
// Pynative Eliminate
|
||||
SubstitutionPtr pynative_eliminate_;
|
||||
|
||||
// Pynative no need grad eliminate
|
||||
SubstitutionPtr pynative_no_grad_eliminate_;
|
||||
|
||||
// Recompute
|
||||
SubstitutionPtr set_cell_output_no_recompute_;
|
||||
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PYNATIVE_NO_GRAD_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PYNATIVE_NO_GRAD_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
class PynativeNoGradEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
if (!IsNeedOptimiz(optimizer, node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &node_inputs = node->cast_ptr<CNode>()->inputs();
|
||||
if (need_grad_flag_of_inputs_.size() != node_inputs.size() - 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &graph_inputs = func_graph_->get_inputs();
|
||||
if (graph_inputs.size() < node_inputs.size() - 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto manager = optimizer->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
if (!need_grad_flag_of_inputs_[i - 1] && node_inputs[i]->isa<CNode>() &&
|
||||
!IsPrimitiveCNode(node_inputs[i], prim::kPrimZerosLike)) {
|
||||
const auto &graph_input_type = graph_inputs[i - 1]->Type();
|
||||
if (graph_input_type == nullptr || !graph_input_type->isa<TensorType>()) {
|
||||
// If input is not tensor, it can not be input for kPrimZerosLike.
|
||||
continue;
|
||||
}
|
||||
AnfNodePtrList new_inputs = {NewValueNode(prim::kPrimZerosLike), graph_inputs[i - 1]};
|
||||
auto zeros_like_node = node->func_graph()->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(zeros_like_node);
|
||||
|
||||
zeros_like_node->set_abstract(graph_inputs[i - 1]->abstract());
|
||||
if (!manager->Replace(node_inputs[i], zeros_like_node)) {
|
||||
MS_LOG(EXCEPTION) << node_inputs[i]->DebugString() << ", replace node failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsNeedOptimiz(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto &resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
|
||||
func_graph_ = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
if (!func_graph_->has_attr(kAttrNeedGradFlagOfInputs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t ret_input_size = 2;
|
||||
const auto &return_node = func_graph_->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
if (return_node->size() != ret_input_size) {
|
||||
// ret node has two input 1 ret op + 1 value
|
||||
return false;
|
||||
}
|
||||
|
||||
if (return_node->input(1) != node) {
|
||||
// Only optimiz return maketuple node.
|
||||
return false;
|
||||
}
|
||||
|
||||
need_grad_flag_of_inputs_ = GetValue<std::vector<bool>>(func_graph_->get_attr(kAttrNeedGradFlagOfInputs));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<bool> need_grad_flag_of_inputs_;
|
||||
FuncGraphPtr func_graph_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_
|
|
@ -637,6 +637,8 @@ constexpr auto kAttrZeroInfinity = "zero_infinity";
|
|||
constexpr auto kAttrBlank = "blank";
|
||||
constexpr auto kAttrUpdateSlots = "update_slots";
|
||||
constexpr auto kAttrLr = "lr";
|
||||
constexpr auto kAttrNeedGradFlagOfInputs = "need_grad_flag_of_inputs";
|
||||
constexpr auto kAttrIsCNodeNeedGrad = "is_cnode_need_grad";
|
||||
|
||||
// FuncGraph Flags
|
||||
constexpr auto kFlagsIsCutGraph = "is_cut_graph";
|
||||
|
|
|
@ -164,28 +164,40 @@ FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, co
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource) {
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
|
||||
const std::vector<bool> &need_grad_flags) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
MS_EXCEPTION_IF_NULL(resource->func_graph());
|
||||
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
|
||||
irpass.switch_simplify_,
|
||||
irpass.reduce_eliminate_,
|
||||
irpass.tile_eliminate_,
|
||||
irpass.arithmetic_simplify_,
|
||||
irpass.make_sparse_tensor_to_make_tuple_,
|
||||
});
|
||||
OptPassGroupMap map;
|
||||
if (need_grad_flags.empty()) {
|
||||
opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
|
||||
irpass.switch_simplify_,
|
||||
irpass.reduce_eliminate_,
|
||||
irpass.tile_eliminate_,
|
||||
irpass.arithmetic_simplify_,
|
||||
irpass.make_sparse_tensor_to_make_tuple_,
|
||||
});
|
||||
|
||||
opt::OptPassConfig inline_opt = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
});
|
||||
opt::OptPassConfig inline_opt = opt::OptPassConfig({
|
||||
irpass.inline_,
|
||||
});
|
||||
|
||||
auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
|
||||
return ReAutoMonad(root);
|
||||
};
|
||||
OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"ad_inline", inline_opt},
|
||||
{"ad_special_op_simplify", special_op_simplify},
|
||||
{"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}});
|
||||
auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
|
||||
return ReAutoMonad(root);
|
||||
};
|
||||
|
||||
map.push_back({"ad_renormalize", opt::OptPassConfig::Renormalize()});
|
||||
map.push_back({"ad_inline", inline_opt});
|
||||
map.push_back({"ad_special_op_simplify", special_op_simplify});
|
||||
map.push_back({"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)});
|
||||
} else {
|
||||
// If func graph has not need_grad_flag_of_inputs attr, this graph has no need do this pass.
|
||||
opt::OptPassConfig pynative_no_grad_eliminate = opt::OptPassConfig({
|
||||
irpass.pynative_no_grad_eliminate_,
|
||||
});
|
||||
|
||||
map.push_back({"pynative_no_grad_eliminate", pynative_no_grad_eliminate});
|
||||
}
|
||||
|
||||
auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", resource, map);
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
|
|
|
@ -52,7 +52,8 @@ bool PynativeOptPass(const ResourcePtr &resource);
|
|||
bool EliminateAdRelatedSpecialOpOptPass(const ResourcePtr &resource);
|
||||
bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph);
|
||||
FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource);
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource);
|
||||
FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
|
||||
const std::vector<bool> &need_grad_flags);
|
||||
FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &resource);
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
|
||||
#include "frontend/optimizer/irpass/pynative_no_grad_eliminate.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
|
||||
#include "include/common/debug/draw.h"
|
||||
|
@ -85,17 +86,14 @@ class TestOptOpt : public UT::Common {
|
|||
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
|
||||
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
|
||||
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
|
||||
pynative_no_grad_elim = MakeSubstitution(std::make_shared<irpass::PynativeNoGradEliminater>(),
|
||||
"pynative_no_grad_eliminate", prim::kPrimMakeTuple);
|
||||
}
|
||||
|
||||
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
|
||||
equiv_node.clear();
|
||||
equiv_graph.clear();
|
||||
FuncGraphPtr graph_after_trans = TransformGraph(gbefore, transform);
|
||||
|
||||
FuncGraphPtr gbefore_clone = BasicClone(gbefore);
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", std::make_shared<pipeline::Resource>());
|
||||
transform(gbefore_clone, optimizer);
|
||||
|
||||
return Isomorphic(gbefore_clone, gafter, &equiv_graph, &equiv_node);
|
||||
return Isomorphic(graph_after_trans, gafter, &equiv_graph, &equiv_node);
|
||||
}
|
||||
|
||||
bool CheckOpt(FuncGraphPtr before, FuncGraphPtr after, std::vector<SubstitutionPtr> opts = {}) {
|
||||
|
@ -103,6 +101,24 @@ class TestOptOpt : public UT::Common {
|
|||
return CheckTransform(before, after, eq);
|
||||
}
|
||||
|
||||
FuncGraphPtr TransformGraph(FuncGraphPtr gbefore, const SubstitutionList &transform) {
|
||||
equiv_node.clear();
|
||||
equiv_graph.clear();
|
||||
|
||||
FuncGraphPtr gbefore_clone = BasicClone(gbefore);
|
||||
pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
resource->set_func_graph(gbefore_clone);
|
||||
auto manager = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(gbefore_clone, true);
|
||||
|
||||
OptimizerPtr optimizer = std::make_shared<Optimizer>("ut_test", resource);
|
||||
transform(gbefore_clone, optimizer);
|
||||
|
||||
return gbefore_clone;
|
||||
}
|
||||
|
||||
public:
|
||||
UT::PyFuncGraphFetcher getPyFun;
|
||||
|
||||
|
@ -119,6 +135,7 @@ class TestOptOpt : public UT::Common {
|
|||
SubstitutionPtr elim_R;
|
||||
SubstitutionPtr idempotent_P;
|
||||
SubstitutionPtr Qct_to_P;
|
||||
SubstitutionPtr pynative_no_grad_elim;
|
||||
SubstitutionPtr tuple_flatten = irpass_lib.call_graph_tuple_transform_;
|
||||
};
|
||||
|
||||
|
@ -213,6 +230,46 @@ TEST_F(TestOptOpt, CSE) {
|
|||
ASSERT_EQ(manager2->all_nodes().size(), 12);
|
||||
}
|
||||
|
||||
/// Feature: test no grad input net.
|
||||
/// Description: test no grad input net.
|
||||
/// Expectation: No exception.
|
||||
TEST_F(TestOptOpt, PynativeNoGradElim) {
|
||||
FuncGraphPtr test_graph1 = getPyFun.CallAndParseRet("test_no_grad", "test_f1");
|
||||
|
||||
ASSERT_TRUE(nullptr != test_graph1);
|
||||
|
||||
auto all_nodes1 = TopoSort(test_graph1->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||
auto mul_node_num1 = std::count_if(all_nodes1.begin(), all_nodes1.end(),
|
||||
[](AnfNodePtr node) { return IsPrimitiveCNode(node, prim::kPrimMul); });
|
||||
|
||||
ASSERT_EQ(mul_node_num1, 2);
|
||||
|
||||
FuncGraphPtr test_graph2 = getPyFun.CallAndParseRet("test_no_grad", "test_f1");
|
||||
|
||||
ASSERT_TRUE(nullptr != test_graph2);
|
||||
|
||||
std::vector<bool> need_grad_flags{true, false};
|
||||
test_graph2->set_attr(kAttrNeedGradFlagOfInputs, MakeValue(need_grad_flags));
|
||||
|
||||
auto tmp_substitution = std::vector<SubstitutionPtr>({pynative_no_grad_elim});
|
||||
SubstitutionList substitution_list(tmp_substitution);
|
||||
|
||||
std::vector<int64_t> shape_vec = {1};
|
||||
AbstractBasePtr abs = std::make_shared<abstract::AbstractTensor>(kTensorType, shape_vec);
|
||||
auto graph_params = test_graph2->parameters();
|
||||
for (auto graph_input : graph_params) {
|
||||
graph_input->set_abstract(abs);
|
||||
}
|
||||
|
||||
auto test_graph2_after_optmiz = TransformGraph(test_graph2, substitution_list);
|
||||
ASSERT_TRUE(nullptr != test_graph2_after_optmiz);
|
||||
auto all_nodes2 = TopoSort(test_graph2_after_optmiz->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||
auto mul_node_num2 = std::count_if(all_nodes2.begin(), all_nodes2.end(),
|
||||
[](AnfNodePtr node) { return IsPrimitiveCNode(node, prim::kPrimMul); });
|
||||
|
||||
ASSERT_EQ(mul_node_num2, 1);
|
||||
}
|
||||
|
||||
size_t TupleArgAndParamSum(const FuncGraphPtr &func_graph) {
|
||||
// Check tuple params and tuple args.
|
||||
auto all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||
|
|
|
@ -374,6 +374,25 @@ def test_cse(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def test_no_grad(tag):
|
||||
"""
|
||||
Feature: test no grad input net.
|
||||
Description: test no grad input net.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
fns = FnDict()
|
||||
mul = Primitive('Mul')
|
||||
make_tuple = Primitive('MakeTuple')
|
||||
|
||||
@fns
|
||||
def test_f1(x, y):
|
||||
x1 = mul(x, 2)
|
||||
y1 = mul(y, 2)
|
||||
return make_tuple(x1, y1)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_arithmetic(tag):
|
||||
""" test_arithmetic """
|
||||
fns = FnDict()
|
||||
|
|
Loading…
Reference in New Issue