forked from mindspore-Ecosystem/mindspore
lift fv before grad except weight, then convert
switch(cond, partial(g1, xs), partial(g2, ys))(Zs) to switch(cond, g1, g2)(Xs, Ys, Zs) switch_layer(index, make_tuple(partial(g1, xs), partial(g2, ys)))(Zs) to switch_layer(index, make_tuple(g1, g2))(Xs, Ys, Zs) put Zs at last when unifyparameter as it may have u-monad or io-monad use joined args other than broadened one as some extra parameter which is not a parameter of while_header can be add to while_body inline fprop_switch forcely reorder the parameter if one of the parameter is Monad when incorporate call incorporate switch tuple_getitem if item 0 of tuple is EnvInstance or item 1 of tuple is bprop function addn with shape() and shape(1) remove context_ from FuncGraphEvaluator to make it re-entry able to resolve evaluator stuck issue because of re-entry of the same FuncGraphEvaluator
This commit is contained in:
parent
af8c15265d
commit
a5b4e7bbf8
|
@ -101,12 +101,15 @@ void DumpInferStack(std::ostringstream &oss) {
|
|||
infer_vec.clear();
|
||||
break;
|
||||
}
|
||||
auto graph_context = graph_infer->context();
|
||||
auto graph_context = graph_infer->parent_context();
|
||||
if (graph_context == nullptr) {
|
||||
MS_LOG(INFO) << "Null context continue";
|
||||
continue;
|
||||
}
|
||||
auto graph = graph_context->func_graph();
|
||||
if (graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto args_spec_list = graph_context->args_spec_list();
|
||||
oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list);
|
||||
}
|
||||
|
@ -264,7 +267,7 @@ std::vector<AnalysisContextPtr> AnalyzedFuncGraphExporter::ProcessFuncGraphCall(
|
|||
}
|
||||
|
||||
auto base_fg_evaluator = dyn_cast<abstract::BaseFuncGraphEvaluator>(evaluator);
|
||||
auto ctx = base_fg_evaluator->context();
|
||||
auto ctx = base_fg_evaluator->parent_context();
|
||||
if (ctx != nullptr && context_map_.insert({ctx, false}).second) {
|
||||
MS_LOG(DEBUG) << "Add new context, ctx.addr = " << ctx.get() << "ctx = " << ctx->ToString();
|
||||
context_vec_.push_back(ctx);
|
||||
|
@ -506,7 +509,7 @@ void GetEvalStackInfo(std::ostringstream &oss) {
|
|||
return;
|
||||
}
|
||||
static int fileNumber = 0;
|
||||
string file_name = "analyze_fail" + std::to_string(fileNumber++) + ".dat";
|
||||
string file_name = "analyze_fail_" + std::to_string(fileNumber++) + ".dat";
|
||||
auto ms_om_path = common::GetEnv("MS_OM_PATH");
|
||||
if (!ms_om_path.empty()) {
|
||||
auto path = ms_om_path + "/" + file_name;
|
||||
|
|
|
@ -37,6 +37,8 @@ namespace ad {
|
|||
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
|
||||
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
|
||||
|
||||
bool lift_fv_before_grad = true;
|
||||
|
||||
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
|
||||
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
|
||||
{
|
||||
|
@ -73,6 +75,10 @@ void DFunctor::Clear() {
|
|||
}
|
||||
|
||||
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
||||
if (lift_fv_before_grad) {
|
||||
MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv "
|
||||
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
|
||||
}
|
||||
auto fv_adjoint = anfnode_to_adjoin_.find(fv);
|
||||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
|
||||
|
@ -437,6 +443,13 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
|
|||
AnfNodePtr new_grad_fv = grad_fv;
|
||||
// Add grads wrt fv.
|
||||
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
|
||||
if (!is_top_ && free_variables_nodes.size() != 0) {
|
||||
if (lift_fv_before_grad) {
|
||||
MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &fv : free_variables_nodes) {
|
||||
auto fv_adjoint = anfnode_to_adjoin_.find(fv);
|
||||
if (fv_adjoint == anfnode_to_adjoin_.end()) {
|
||||
|
@ -460,6 +473,10 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
|
|||
}
|
||||
|
||||
AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
|
||||
if (lift_fv_before_grad) {
|
||||
MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
|
||||
<< grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
|
||||
}
|
||||
AnfNodePtr new_grad_fv = grad_fv;
|
||||
// Add indirect fv bprop.
|
||||
for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
|
||||
|
@ -497,7 +514,12 @@ void DFunctor::MapMorphism() {
|
|||
output_adjoint->second->AccumulateDout(dout_);
|
||||
|
||||
// Set output for tape closure.
|
||||
auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
|
||||
AnfNodePtr grad_fv;
|
||||
if (lift_fv_before_grad) {
|
||||
grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
|
||||
} else {
|
||||
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
|
||||
// Add grads wrt inputs.
|
||||
|
|
|
@ -43,6 +43,9 @@ extern KPrim g_k_prims;
|
|||
class DFunctor;
|
||||
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
||||
|
||||
// Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
|
||||
extern bool lift_fv_before_grad;
|
||||
|
||||
// D Functor's rules to map closure object and morphisms.
|
||||
class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
||||
public:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,12 +16,52 @@
|
|||
|
||||
#include "frontend/optimizer/ad/grad.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
namespace {
|
||||
FuncGraphPtr PartialEliminateOptPass(const ResourcePtr &resource, const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
|
||||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
opt::OptPassConfig partial_eliminate_opt_ = opt::OptPassConfig(
|
||||
{irpass.partial_eliminate_, irpass.switch_partial_eliminater_, irpass.switch_layer_partial_eliminater_});
|
||||
opt::OptPassGroupMap map({{"partial_eliminate_", partial_eliminate_opt_}});
|
||||
|
||||
auto after_lift_opt = opt::Optimizer::MakeOptimizer("partial_eliminate", resource, map);
|
||||
|
||||
FuncGraphPtr opt_fg = nullptr;
|
||||
WITH(MsProfile::GetProfile()->Step("partial_eliminate_before_grad"))[&after_lift_opt, func_graph, &opt_fg]() {
|
||||
opt_fg = after_lift_opt->step(func_graph, true);
|
||||
};
|
||||
return opt_fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPtr &func_graph) {
|
||||
bool save_graphs_flag = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (save_graphs_flag) {
|
||||
DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph);
|
||||
}
|
||||
FuncGraphPtr new_fg = LiftingClone(func_graph);
|
||||
if (save_graphs_flag) {
|
||||
DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg);
|
||||
}
|
||||
auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource);
|
||||
if (new_res == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parameter resources is not a pipeline::Resource";
|
||||
}
|
||||
auto opt_fg = PartialEliminateOptPass(new_res, new_fg);
|
||||
if (save_graphs_flag) {
|
||||
DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg);
|
||||
}
|
||||
return opt_fg;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto gradkv = func_graph->transforms().find("grad");
|
||||
|
@ -33,6 +73,11 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
MS_EXCEPTION_IF_NULL(manager_ptr);
|
||||
manager_ptr->AddFuncGraph(func_graph);
|
||||
|
||||
FuncGraphPtr grad_fg = func_graph;
|
||||
lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1");
|
||||
if (func_graph->func_graphs_used().size() != 0 && lift_fv_before_grad) {
|
||||
grad_fg = LiftFv(resources, func_graph);
|
||||
}
|
||||
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
|
||||
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
|
@ -41,8 +86,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
}
|
||||
};
|
||||
|
||||
auto f = std::make_shared<DFunctor>(func_graph, resources);
|
||||
auto user_defined = f->KUserDefined(func_graph);
|
||||
auto f = std::make_shared<DFunctor>(grad_fg, resources);
|
||||
auto user_defined = f->KUserDefined(grad_fg);
|
||||
if (user_defined != nullptr) {
|
||||
multi_graph_sink(user_defined);
|
||||
if (is_top) {
|
||||
|
@ -62,6 +107,9 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
}
|
||||
|
||||
multi_graph_sink(res);
|
||||
if (func_graph != grad_fg) {
|
||||
(void)func_graph->transforms().insert(std::make_pair("grad", FuncGraphTransform(res)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -327,6 +327,11 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
<< " prim bprop function to J expanded func graph. NodeInfo: "
|
||||
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
||||
}
|
||||
if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
|
||||
// Inline fprop_switch before renormalize;
|
||||
expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
|
||||
MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
|
||||
}
|
||||
|
||||
return expanded_fg;
|
||||
}
|
||||
|
|
|
@ -162,6 +162,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
exchange_switch_depend_value_ =
|
||||
MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);
|
||||
|
||||
switch_partial_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<SwitchPartialEliminater>(), "eliminate_switch_partial_", IsCNodeDup);
|
||||
switch_layer_partial_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<SwitchLayerPartialEliminater>(), "eliminate_switch_layer_partial_", IsCNodeDup);
|
||||
|
||||
// Addn
|
||||
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
|
||||
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
|
||||
|
|
|
@ -86,6 +86,9 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr convert_switch_replacement_;
|
||||
SubstitutionPtr exchange_switch_depend_value_;
|
||||
|
||||
SubstitutionPtr switch_partial_eliminater_;
|
||||
SubstitutionPtr switch_layer_partial_eliminater_;
|
||||
|
||||
// AddN
|
||||
SubstitutionPtr merge_addn_;
|
||||
SubstitutionPtr addn_zero_filter_;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
|
@ -38,29 +39,36 @@ class CallOutputTransform {
|
|||
CallOutputTransform() : cache_() {}
|
||||
~CallOutputTransform() = default;
|
||||
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) {
|
||||
FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs, bool xs_first) {
|
||||
if (cache_.find(fg) == cache_.end()) {
|
||||
cache_[fg] = {};
|
||||
}
|
||||
|
||||
auto &cache = cache_[fg];
|
||||
if (cache.find(nargs) == cache.end()) {
|
||||
auto key = std::make_pair(nargs, xs_first);
|
||||
if (cache.find(key) == cache.end()) {
|
||||
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("call"));
|
||||
|
||||
std::vector<AnfNodePtr> new_items;
|
||||
new_items.push_back(new_fg->output());
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
new_items.push_back(new_fg->add_parameter());
|
||||
if (xs_first) {
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
new_items.push_back(new_fg->add_parameter());
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
new_items.push_back(new_fg->InsertFrontParameter());
|
||||
}
|
||||
}
|
||||
new_fg->set_output(new_fg->NewCNode(new_items));
|
||||
|
||||
cache[nargs] = new_fg;
|
||||
cache[key] = new_fg;
|
||||
}
|
||||
return cache[nargs];
|
||||
return cache[key];
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<FuncGraphPtr, std::unordered_map<size_t, FuncGraphPtr>> cache_;
|
||||
std::unordered_map<FuncGraphPtr, std::unordered_map<std::pair<size_t, bool>, FuncGraphPtr, PairHasher>> cache_;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
|
@ -88,20 +96,35 @@ class IncorporateCall : public AnfVisitor {
|
|||
|
||||
auto xs_size = Xs_.size();
|
||||
auto ys_size = inputs.size() - 1;
|
||||
auto new_fg = call_output_transform_(fg_, ys_size);
|
||||
bool xs_first = true;
|
||||
if ((xs_size > 0) && (Xs_[xs_size - 1]->abstract() != nullptr) &&
|
||||
(Xs_[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) {
|
||||
xs_first = false;
|
||||
}
|
||||
auto new_fg = call_output_transform_(fg_, ys_size, xs_first);
|
||||
|
||||
std::vector<AnfNodePtr> args;
|
||||
args.push_back(NewValueNode(new_fg));
|
||||
|
||||
if (xs_size > 0) {
|
||||
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
|
||||
if (xs_first) {
|
||||
if (xs_size > 0) {
|
||||
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
|
||||
}
|
||||
if (ys_size > 0) {
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
}
|
||||
} else {
|
||||
if (ys_size > 0) {
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
}
|
||||
if (xs_size > 0) {
|
||||
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
|
||||
}
|
||||
}
|
||||
|
||||
if (ys_size > 0) {
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
}
|
||||
|
||||
return node->func_graph()->NewCNode(args);
|
||||
auto new_node = node->func_graph()->NewCNode(args);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
|
@ -159,19 +182,35 @@ class IncorporateCallSwitch : public AnfVisitor {
|
|||
auto fg = node->func_graph();
|
||||
auto xs_size = inputs_x.size() - 1;
|
||||
auto ys_size = inputs.size() - 1;
|
||||
auto new_g1 = call_output_transform_(g1_, ys_size);
|
||||
auto new_g2 = call_output_transform_(g2_, ys_size);
|
||||
bool xs_first = true;
|
||||
if ((xs_size > 0) && (inputs_x[xs_size - 1]->abstract() != nullptr) &&
|
||||
(inputs_x[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) {
|
||||
xs_first = false;
|
||||
}
|
||||
auto new_g1 = call_output_transform_(g1_, ys_size, xs_first);
|
||||
auto new_g2 = call_output_transform_(g2_, ys_size, xs_first);
|
||||
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
|
||||
|
||||
std::vector<AnfNodePtr> args{sw_node};
|
||||
if (xs_size > 0) {
|
||||
(void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end());
|
||||
}
|
||||
if (ys_size > 0) {
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
if (xs_first) {
|
||||
if (xs_size > 0) {
|
||||
(void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end());
|
||||
}
|
||||
if (ys_size > 0) {
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
}
|
||||
} else {
|
||||
if (ys_size > 0) {
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
}
|
||||
if (xs_size > 0) {
|
||||
(void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end());
|
||||
}
|
||||
}
|
||||
|
||||
return fg->NewCNode(args);
|
||||
auto new_node = fg->NewCNode(args);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
|
|
|
@ -164,6 +164,7 @@ class IncorporateGetitem : public AnfVisitor {
|
|||
}
|
||||
}
|
||||
}
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
@ -228,6 +229,7 @@ class IncorporateGetitemDepend : public AnfVisitor {
|
|||
new_depend_cnode =
|
||||
node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), new_fg_cnode, depend_2nd_input_});
|
||||
}
|
||||
new_depend_cnode->set_abstract(node->abstract());
|
||||
return new_depend_cnode;
|
||||
}
|
||||
|
||||
|
@ -294,8 +296,7 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
is_in_get_ = false;
|
||||
|
||||
auto fg = node->func_graph();
|
||||
if (idx_ == -1 || switch_ == nullptr || fg == nullptr ||
|
||||
(fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) && !ExistEnvNode(fg))) {
|
||||
if (idx_ == -1 || switch_ == nullptr || fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -308,10 +309,27 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
}
|
||||
auto tuple_getitem = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
bool has_env_type = false;
|
||||
if (tuple_getitem->input(1)->abstract() && tuple_getitem->input(1)->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
const auto &abs_tuple = *(tuple_getitem->input(1)->abstract()->cast<abstract::AbstractTuplePtr>());
|
||||
// eliminate (envinstance, value1, value2, ...) built by bprop func_graph()
|
||||
if (abs_tuple.size() >= 1) {
|
||||
// Value maybe kAnyValue, so check the type track;
|
||||
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) {
|
||||
has_env_type = true;
|
||||
}
|
||||
}
|
||||
// eliminate (value, bprop_func) built by fprop func_graph
|
||||
if (abs_tuple.size() >= 2) {
|
||||
if (abs_tuple[1]->isa<abstract::AbstractFunction>()) {
|
||||
has_env_type = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If exist env_getitem/env_setitem in this funcgraph or
|
||||
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
|
||||
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
|
||||
!ExistEnvNodeInTupleItem(g2_)) {
|
||||
!ExistEnvNodeInTupleItem(g2_) && !has_env_type) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_g1 = getitem_transform_(g1_, idx_);
|
||||
|
@ -319,7 +337,9 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
|
||||
(void)args_.insert(args_.begin(), sw_node);
|
||||
|
||||
return fg->NewCNode(args_);
|
||||
auto new_node = fg->NewCNode(args_);
|
||||
new_node->set_abstract(node->abstract());
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
|
@ -398,8 +418,8 @@ class IncorporateGetitemSwitch : public AnfVisitor {
|
|||
}
|
||||
const auto &cnode = output->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &node) {
|
||||
auto sub_fg = GetValueNode<FuncGraphPtr>(node);
|
||||
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &input) {
|
||||
auto sub_fg = GetValueNode<FuncGraphPtr>(input);
|
||||
if (sub_fg != nullptr && ExistEnvNode(sub_fg)) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -85,6 +85,9 @@ bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
|
|||
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
|
||||
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
|
||||
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
|
||||
bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
|
||||
return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE);
|
||||
}
|
||||
|
||||
// {G, Xs}
|
||||
class InlinerBase : public AnfVisitor {
|
||||
|
@ -127,6 +130,14 @@ class InlinerBase : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (IsForceInline(this, fg, node)) {
|
||||
if (IsUniqueUse(nullptr, fg, nullptr)) {
|
||||
return InlineMove(node, fg, args, inputs);
|
||||
} else {
|
||||
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
||||
}
|
||||
}
|
||||
|
||||
if (IsUniqueUse(nullptr, fg, nullptr)) {
|
||||
// For the single used fg, including non-after and after not matched above,
|
||||
// we move the whole fg nodes.
|
||||
|
@ -151,15 +162,20 @@ class InlinerBase : public AnfVisitor {
|
|||
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
||||
}
|
||||
|
||||
AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
ReplaceParams(mng, args, fg);
|
||||
auto out_node = fg->output();
|
||||
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
||||
return out_node;
|
||||
}
|
||||
|
||||
AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
|
||||
const std::vector<AnfNodePtr> &inputs) {
|
||||
if (use_move_) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
ReplaceParams(mng, args, fg);
|
||||
auto out_node = fg->output();
|
||||
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
||||
return out_node;
|
||||
return InlineMove(node, fg, args, inputs);
|
||||
}
|
||||
|
||||
// The other branch calling the last after block.
|
||||
|
@ -366,6 +382,7 @@ class DirectInliner : public InlinerBase {
|
|||
: InlinerBase(
|
||||
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
|
||||
{
|
||||
{IsForceInline},
|
||||
{IsDirectParentCall},
|
||||
},
|
||||
use_move) {}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,9 +17,11 @@
|
|||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
|
@ -72,6 +74,341 @@ class PartialEliminater : public AnfVisitor {
|
|||
private:
|
||||
std::vector<AnfNodePtr> Xs_{};
|
||||
};
|
||||
|
||||
class ChoicePartialEliminater : public AnfVisitor {
|
||||
public:
|
||||
virtual ~ChoicePartialEliminater() = default;
|
||||
|
||||
protected:
|
||||
AnfNodePtrList fg_list_{};
|
||||
std::vector<AnfNodePtrList> args_list_{};
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
fg_list_.push_back(node);
|
||||
args_list_.push_back(AnfNodePtrList{});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
// {prim::kPrimPartial, G, Xs}
|
||||
if (inputs.size() < 3) {
|
||||
MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
|
||||
return;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(inputs[1])) {
|
||||
fg_list_.push_back(inputs[1]);
|
||||
AnfNodePtrList args;
|
||||
(void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(args));
|
||||
args_list_.push_back(args);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// return value: true -- continue replace; false -- return nullptr;
|
||||
bool CheckFuncGraphAndArgs() {
|
||||
// Either one should be {Partial, G, X}
|
||||
auto has_partial_args =
|
||||
std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; });
|
||||
if (!has_partial_args) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check funcgraph should be used once only.
|
||||
for (auto &fg_node : fg_list_) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(fg_node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (fg->func_graph_cnodes_index().size() != 1) {
|
||||
for (auto iter : fg->func_graph_cnodes_index()) {
|
||||
MS_LOG(ERROR) << "fg user: " << iter.first->first->DebugString(1) << ", index: " << iter.first->second;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "fg is used multiple times: " << fg->ToString();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// f(x1, x2, x3, z1, z2)
|
||||
// g(x4, x2, z1, z2)
|
||||
// h(x5, x2, x7, x8, z1, z2)
|
||||
// --> anchor_fg = h
|
||||
// h(x5, x2, x7, x8, x1, x3, x4, z1, z2)
|
||||
// f(x5, x2, x7, x8, x1, x3, x4, z1, z2)
|
||||
// g(x5, x2, x7, x8, x1, x3, x4, z1, z2)
|
||||
// as z1, z2 maybe U or IO monad.
|
||||
AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list,
|
||||
const std::vector<AnfNodePtrList> args_list) {
|
||||
std::vector<size_t> inputs_index_list[args_list.size()];
|
||||
size_t extra_input_counter = 0;
|
||||
AnfNodePtrList extra_inputs;
|
||||
const auto &anchor_args = args_list[anchor_index];
|
||||
size_t anchor_args_size = anchor_args.size();
|
||||
auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]);
|
||||
MS_EXCEPTION_IF_NULL(anchor_fg);
|
||||
// Find the new location of the old_inputs except Zs;
|
||||
for (size_t i = 0; i < args_list.size(); ++i) {
|
||||
if (i == anchor_index) {
|
||||
continue;
|
||||
}
|
||||
const auto &another_args = args_list[i];
|
||||
auto &curr_inputs_index = inputs_index_list[i];
|
||||
for (size_t j = 0; j < another_args.size(); ++j) {
|
||||
size_t k;
|
||||
for (k = 0; k < anchor_args_size; ++k) {
|
||||
if (another_args[j] == anchor_args[k]) {
|
||||
curr_inputs_index.push_back(k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == anchor_args_size) {
|
||||
// check if used by another func_graph;
|
||||
for (k = 0; k < extra_input_counter; ++k) {
|
||||
if (another_args[j] == extra_inputs[k]) {
|
||||
curr_inputs_index.push_back(anchor_args_size + k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (k == extra_input_counter) {
|
||||
extra_inputs.push_back(another_args[j]);
|
||||
curr_inputs_index.push_back(anchor_args_size + extra_input_counter);
|
||||
extra_input_counter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto manager = anchor_fg->manager();
|
||||
auto txn = manager->Transact();
|
||||
|
||||
size_t anchor_params_size = anchor_fg->parameters().size();
|
||||
const auto &anchor_fg_params = anchor_fg->parameters();
|
||||
for (size_t i = 0; i < args_list.size(); ++i) {
|
||||
if (i == anchor_index) {
|
||||
continue;
|
||||
}
|
||||
AnfNodePtrList new_params;
|
||||
new_params.resize(anchor_params_size + extra_input_counter);
|
||||
|
||||
const auto &curr_inputs_index = inputs_index_list[i];
|
||||
auto another_fg = GetValueNode<FuncGraphPtr>(fg_list[i]);
|
||||
MS_EXCEPTION_IF_NULL(another_fg);
|
||||
const auto &old_params = another_fg->parameters();
|
||||
const auto &old_args = args_list[i];
|
||||
for (size_t j = 0; j < old_args.size(); j++) {
|
||||
new_params[curr_inputs_index[j]] = old_params[j];
|
||||
}
|
||||
// Zs_
|
||||
for (size_t j = old_args.size(), k = 0; j < old_params.size(); ++j, ++k) {
|
||||
new_params[anchor_args_size + extra_input_counter + k] = old_params[j];
|
||||
}
|
||||
// unused inputs
|
||||
for (size_t j = 0; j < anchor_args_size; ++j) {
|
||||
if (new_params[j] == nullptr) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
|
||||
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
||||
new_params[j] = param;
|
||||
}
|
||||
}
|
||||
// extra inputs used by another func_graph;
|
||||
for (size_t j = 0; j < extra_inputs.size(); ++j) {
|
||||
if (new_params[anchor_args_size + j] == nullptr) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
|
||||
ParameterPtr param = std::make_shared<Parameter>(another_fg);
|
||||
new_params[anchor_args_size + j] = param;
|
||||
}
|
||||
}
|
||||
// set the parameter for another_fg and replace it's parameters;
|
||||
txn.SetParameters(another_fg, new_params);
|
||||
}
|
||||
// Reorder Zs_ and add extra parameters for anchor_fg;
|
||||
// add extra parameter for anchor_fg;
|
||||
AnfNodePtrList new_params;
|
||||
new_params.reserve(anchor_params_size + extra_input_counter);
|
||||
// reuse parameters for anchor_args;
|
||||
std::copy(anchor_fg_params.cbegin(), anchor_fg_params.cbegin() + anchor_args_size, std::back_inserter(new_params));
|
||||
// Extra parameters;
|
||||
for (size_t i = 0; i < extra_inputs.size(); ++i) {
|
||||
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
|
||||
ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
|
||||
new_params.push_back(param);
|
||||
}
|
||||
// Reorder Zs_ to last;
|
||||
for (size_t i = anchor_args_size; i < anchor_params_size; ++i) {
|
||||
new_params.push_back(anchor_fg_params[i]);
|
||||
}
|
||||
txn.SetParameters(anchor_fg, new_params);
|
||||
txn.Commit();
|
||||
|
||||
return extra_inputs;
|
||||
}
|
||||
};
|
||||
|
||||
// {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} ->
|
||||
// {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs}
|
||||
// {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union
|
||||
// Zs}
|
||||
// {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union
|
||||
// Zs}
|
||||
class SwitchPartialEliminater : public ChoicePartialEliminater {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
|
||||
if (input0_cnode->size() != 4) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
fg_list_.clear();
|
||||
args_list_.clear();
|
||||
auto &maybe_partial_1 = input0_cnode->input(2);
|
||||
Visit(maybe_partial_1);
|
||||
auto &maybe_partial_2 = input0_cnode->input(3);
|
||||
Visit(maybe_partial_2);
|
||||
|
||||
// Either one should be {Partial, G, X}
|
||||
if (fg_list_.size() != 2 && args_list_.size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
// Should not continue;
|
||||
if (!CheckFuncGraphAndArgs()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (args_list_[0] == args_list_[1]) {
|
||||
auto new_node =
|
||||
BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[0], AnfNodePtrList{});
|
||||
return new_node;
|
||||
} else {
|
||||
// find partial funcgraph with the longest args as anchor;
|
||||
size_t max_args_pos = 0;
|
||||
if (args_list_[0].size() > args_list_[1].size()) {
|
||||
max_args_pos = 0;
|
||||
} else {
|
||||
max_args_pos = 1;
|
||||
}
|
||||
|
||||
auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_);
|
||||
auto new_node =
|
||||
BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[max_args_pos], extra_inputs);
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr BuildNewSwitchNode(const CNodePtr &old_cnode, const CNodePtr input0_cnode, const AnfNodePtr &G1,
|
||||
const AnfNodePtr &G2, const AnfNodePtrList &partial_args,
|
||||
const AnfNodePtrList &extra_args) {
|
||||
TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
|
||||
// {Switch, cond, G1, G2}
|
||||
auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
|
||||
AnfNodePtrList args{switch_cnode};
|
||||
(void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
|
||||
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
||||
// Zs
|
||||
if (old_cnode->size() >= 2) {
|
||||
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
||||
}
|
||||
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
||||
auto new_node = old_cnode->func_graph()->NewCNode(args);
|
||||
return new_node;
|
||||
}
|
||||
};
|
||||
|
||||
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
|
||||
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs}
|
||||
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
|
||||
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs}
|
||||
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} ->
|
||||
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs}
|
||||
class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
// {SwitchLayer{}, Zs}
|
||||
if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
|
||||
// {SwitchLayer, cond, MakeTuple{}}
|
||||
if (switch_layer_cnode->size() != 3) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
|
||||
if (make_tuple_cnode->size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
fg_list_.clear();
|
||||
args_list_.clear();
|
||||
// Build funcgraph list and args list;
|
||||
for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
|
||||
Visit(make_tuple_cnode->input(i));
|
||||
}
|
||||
|
||||
if (!CheckFuncGraphAndArgs()) {
|
||||
return nullptr;
|
||||
}
|
||||
// All have the same args;
|
||||
auto args_equal =
|
||||
std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; });
|
||||
if (args_equal) {
|
||||
auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[0], AnfNodePtrList{});
|
||||
return new_node;
|
||||
} else {
|
||||
// find partial funcgraph with the longest args as anchor;
|
||||
size_t max_args_pos = 0, max_args_len = 0;
|
||||
for (size_t i = 0; i < args_list_.size(); ++i) {
|
||||
if (max_args_len < args_list_[i].size()) {
|
||||
max_args_len = args_list_[i].size();
|
||||
max_args_pos = i;
|
||||
}
|
||||
}
|
||||
auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_);
|
||||
auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[max_args_pos], extra_inputs);
|
||||
return new_node;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
|
||||
const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
|
||||
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
|
||||
AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
|
||||
make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
|
||||
TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
|
||||
// {MakeTuple, G1, G2, ...}
|
||||
auto new_make_tuple_cnode = old_cnode->func_graph()->NewCNode(make_tuple_args);
|
||||
|
||||
TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer_cnode->debug_info()));
|
||||
// {SwitchLayer, cond, MakeTuple{}}
|
||||
auto new_switch_layer_cnode = old_cnode->func_graph()->NewCNode(
|
||||
{switch_layer_cnode->input(0), switch_layer_cnode->input(1), new_make_tuple_cnode});
|
||||
AnfNodePtrList args{new_switch_layer_cnode};
|
||||
(void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
|
||||
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
|
||||
// Zs
|
||||
if (old_cnode->size() >= 2) {
|
||||
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
|
||||
}
|
||||
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
|
||||
auto new_node = old_cnode->func_graph()->NewCNode(args);
|
||||
return new_node;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -109,11 +109,12 @@ void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine,
|
|||
}
|
||||
|
||||
// Start running stack frames in a Evaluator.
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
|
||||
const AnalysisContextPtr &context) {
|
||||
EvalResultPtr eval_result = nullptr;
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
std::stack<StackFramePtr> stack_frames;
|
||||
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_);
|
||||
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_);
|
||||
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
|
||||
stack_frames.push(current_stack_frame);
|
||||
while (true) {
|
||||
|
@ -155,7 +156,8 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
|
|||
return res_base;
|
||||
}
|
||||
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
|
||||
const AnalysisContextPtr &context) {
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
|
||||
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
|
||||
|
@ -165,7 +167,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
|
|||
});
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
for (const auto &node : all_nodes) {
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context_);
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString();
|
||||
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
|
||||
|
@ -214,24 +216,30 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
|
||||
<< fg->ToString() << "();";
|
||||
}
|
||||
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
||||
auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list);
|
||||
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
|
||||
if (func_graph_evaluator != nullptr) {
|
||||
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
|
||||
engine->set_root_context(context);
|
||||
}
|
||||
}
|
||||
const auto ¶meters = fg->parameters();
|
||||
for (size_t i = 0; i < nargs; i++) {
|
||||
const auto &arg = args_abs_list[i];
|
||||
const auto &node = parameters[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, context_);
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, context);
|
||||
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
|
||||
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
|
||||
<< ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
|
||||
<< ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
|
||||
<< ", current function call depth: " << engine->function_call_depth();
|
||||
AbstractBasePtr res_base = nullptr;
|
||||
if (engine->enable_recursive_eval()) {
|
||||
res_base = LaunchRecursiveEval(engine, fg);
|
||||
res_base = LaunchRecursiveEval(engine, fg, context);
|
||||
} else {
|
||||
res_base = LaunchStackFrame(engine, fg);
|
||||
res_base = LaunchStackFrame(engine, fg, context);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(res_base);
|
||||
|
@ -250,28 +258,27 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
return res;
|
||||
}
|
||||
|
||||
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
// Only broaden scalar that data type is number, such as float16,int32 and so on.
|
||||
auto type = arg->BuildType()->type_id();
|
||||
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
|
||||
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
|
||||
return arg->Broaden(config);
|
||||
} else if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
}
|
||||
return arg;
|
||||
});
|
||||
}
|
||||
|
||||
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
|
||||
AbstractBasePtrList broaded_list;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
// Only broaden scalar that data type is number, such as float16,int32 and so on.
|
||||
auto type = arg->BuildType()->type_id();
|
||||
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
|
||||
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
|
||||
return arg->Broaden(config);
|
||||
} else if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
}
|
||||
return arg;
|
||||
});
|
||||
if (func_graph_->joined_shapes_.size() == broaded_list.size()) {
|
||||
for (size_t i = 0; i < broaded_list.size(); ++i) {
|
||||
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
|
||||
}
|
||||
}
|
||||
BroadenArgs(args_spec_list, &broaded_list);
|
||||
|
||||
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
|
||||
<< ", broaded: " << mindspore::ToString(broaded_list);
|
||||
|
@ -287,55 +294,11 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
}
|
||||
|
||||
if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
|
||||
if (parent_context_) {
|
||||
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
|
||||
<< ", context: " << parent_context_->ToString();
|
||||
auto last_context = parent_context_->FindParentContext(func_graph_);
|
||||
if (last_context && last_context->func_graph() == func_graph_) {
|
||||
MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString();
|
||||
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
||||
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list());
|
||||
// Join the last eval arguments and current arguments to check if there are loop variant.
|
||||
auto joined_args_spec_list_1 = AbstractJoin(args_spec_list, last_context->args_spec_list());
|
||||
MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list_1);
|
||||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list_1 == args_spec_list)) {
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->joined_shapes_.clear();
|
||||
std::transform(joined_args_spec_list_1.begin(), joined_args_spec_list_1.end(),
|
||||
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
|
||||
MS_EXCEPTION_IF_NULL(arg_spec);
|
||||
return arg_spec->GetShapeTrack();
|
||||
});
|
||||
joined_args_spec_list_1 = NormalizeArgs(joined_args_spec_list_1);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
return joined_args_spec_list_1;
|
||||
}
|
||||
}
|
||||
if (!trace_.empty()) {
|
||||
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
|
||||
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back());
|
||||
// Join the last eval arguments and current arguments to check if there are loop variant.
|
||||
auto joined_args_spec_list_2 = AbstractJoin(args_spec_list, trace_.back());
|
||||
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
|
||||
if (!(joined_args_spec_list_2 == args_spec_list)) {
|
||||
trace_.push_back(joined_args_spec_list_2);
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->joined_shapes_.clear();
|
||||
std::transform(joined_args_spec_list_2.begin(), joined_args_spec_list_2.end(),
|
||||
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
|
||||
MS_EXCEPTION_IF_NULL(arg_spec);
|
||||
return arg_spec->GetShapeTrack();
|
||||
});
|
||||
joined_args_spec_list_2 = NormalizeArgs(joined_args_spec_list_2);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list_2);
|
||||
return joined_args_spec_list_2;
|
||||
} else {
|
||||
trace_.push_back(args_spec_list);
|
||||
}
|
||||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
auto normalized_args_spec_list = NormalizeArgs(args_spec_list);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_spec_list);
|
||||
return normalized_args_spec_list;
|
||||
}
|
||||
return args_spec_list;
|
||||
}
|
||||
|
@ -388,16 +351,6 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
|
|||
|
||||
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
// The evaluator can't reenter at sametime.
|
||||
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_lock_, std::try_to_lock);
|
||||
if (!eval_lock.owns_lock()) {
|
||||
// Release GIL
|
||||
pybind11::gil_scoped_release infer_gil_release;
|
||||
// Check if enter endless loop
|
||||
HealthPointScopedDrop health_point_check;
|
||||
eval_lock.lock();
|
||||
}
|
||||
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
|
|
|
@ -210,8 +210,6 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
|
||||
AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AnalysisContextPtr context() const { return context_; }
|
||||
void set_context(const AnalysisContextPtr &context) { context_ = context; }
|
||||
AnalysisContextPtr parent_context() const { return parent_context_; }
|
||||
void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; }
|
||||
|
||||
|
@ -219,14 +217,14 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
AnalysisContextPtr parent_context_;
|
||||
|
||||
private:
|
||||
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
|
||||
const AnalysisContextPtr &context);
|
||||
// Add functions for stack frame routine.
|
||||
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
|
||||
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
|
||||
const AnalysisContextPtr &context);
|
||||
static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
|
||||
const StackFramePtr &new_stack_frame);
|
||||
static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame);
|
||||
|
||||
AnalysisContextPtr context_;
|
||||
};
|
||||
|
||||
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
|
||||
|
@ -353,6 +351,8 @@ class JEvaluator : public Evaluator {
|
|||
EvaluatorPtr evaluator_;
|
||||
AbstractFunctionPtr orig_func_;
|
||||
};
|
||||
|
||||
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/do_signature.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "debug/trace.h"
|
||||
|
@ -545,32 +546,37 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
|
||||
EvalResultPtr ret = nullptr;
|
||||
AbstractBasePtrList broaded_argvals;
|
||||
EvalResultCache &cache = evalcaches_[eval]->GetCache();
|
||||
for (auto &argvals_map : cache) {
|
||||
std::vector<AbstractBasePtrList> args_vector;
|
||||
auto &origin_eval_cache = evalcaches_[eval]->GetCache();
|
||||
for (auto &argvals_map : origin_eval_cache) {
|
||||
auto argvals = argvals_map.first;
|
||||
args_vector.push_back(argvals);
|
||||
broaded_argvals.clear();
|
||||
|
||||
(void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
|
||||
BroadenArgs(argvals, &broaded_argvals);
|
||||
(void)choices.insert(broaded_argvals);
|
||||
MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
|
||||
}
|
||||
|
||||
if (choices.size() == 1) {
|
||||
ConfigPtrList args_conf_list;
|
||||
(void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
|
||||
[](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
|
||||
|
||||
// If broaden return null
|
||||
ret = eval->SingleRun(engine_, args_conf_list, nullptr);
|
||||
if (args_vector.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Should have 2 more choices, but: " << args_vector.size();
|
||||
}
|
||||
AbstractBasePtrList joined_argvals = args_vector[0];
|
||||
for (size_t i = 1; i < args_vector.size(); ++i) {
|
||||
joined_argvals = abstract::AbstractJoin(joined_argvals, args_vector[i]);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
|
||||
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
|
||||
real->SetValue(broaded_argvals, ret);
|
||||
evalcaches_[eval] = real;
|
||||
return std::make_pair(broaded_argvals, ret->abstract());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
|
||||
return std::make_pair(AbstractBasePtrList(), nullptr);
|
||||
auto joined_eval_result = origin_eval_cache.get(joined_argvals);
|
||||
if (joined_eval_result != nullptr) {
|
||||
MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
|
||||
|
||||
real->SetValue(joined_argvals, joined_eval_result);
|
||||
evalcaches_[eval] = real;
|
||||
return std::make_pair(joined_argvals, joined_eval_result->abstract());
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
|
||||
return std::make_pair(AbstractBasePtrList(), nullptr);
|
||||
}
|
||||
|
||||
void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
||||
|
|
|
@ -96,7 +96,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
|
|||
}
|
||||
|
||||
// Create a new stack frame and set arguments for it.
|
||||
fg_evaluator->set_context(new_context);
|
||||
auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, parent_context);
|
||||
new_stack_frame->set_args_abs_list(std::move(args_abs_list));
|
||||
return new_stack_frame;
|
||||
|
|
|
@ -80,6 +80,8 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
|
|||
MS_EXCEPTION_IF_NULL(func_graph_manager_);
|
||||
func_graph_manager_->AddFuncGraph(func_graph);
|
||||
|
||||
root_func_graph_ = func_graph;
|
||||
|
||||
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
|
||||
|
||||
// Running the analyzer.
|
||||
|
@ -105,7 +107,7 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
|
|||
const ConfigPtrList &args_conf_list) {
|
||||
std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
|
||||
(void)eval->Run(shared_from_this(), args_conf_list, nullptr);
|
||||
return eval->context();
|
||||
return root_context_;
|
||||
}
|
||||
|
||||
void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||
|
@ -213,11 +215,12 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
|
|||
return;
|
||||
}
|
||||
auto top_evaluator = infer_stack.top().first;
|
||||
if (!top_evaluator->isa<BaseFuncGraphEvaluator>()) {
|
||||
// Top or root func_graph must be FuncGraph other than MetaFuncGraph;
|
||||
if (!top_evaluator->isa<FuncGraphEvaluator>()) {
|
||||
MS_LOG(EXCEPTION) << "Top evaluator is " << top_evaluator->ToString();
|
||||
}
|
||||
auto top_fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(top_evaluator);
|
||||
auto top_context_fg = top_fg_evaluator->context()->func_graph();
|
||||
auto top_fg_evaluator = dyn_cast<FuncGraphEvaluator>(top_evaluator);
|
||||
auto top_context_fg = top_fg_evaluator->func_graph();
|
||||
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
|
||||
return;
|
||||
}
|
||||
|
@ -339,6 +342,8 @@ void AnalysisEngine::Clear() {
|
|||
evaluators_.clear();
|
||||
constructors_app_.clear();
|
||||
continued_evals_.clear();
|
||||
root_func_graph_ = nullptr;
|
||||
root_context_ = nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -578,7 +583,7 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
#endif
|
||||
}
|
||||
|
||||
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
||||
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
|
||||
static std::mutex fg_lock;
|
||||
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
|
@ -591,10 +596,17 @@ bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|||
auto undetermined_fgs = fg->recursive();
|
||||
if (undetermined_fgs) {
|
||||
auto fg_parent = fg->parent();
|
||||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
|
||||
return true;
|
||||
if (fg_parent != nullptr) {
|
||||
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
|
||||
return true;
|
||||
} else if (possible_parent_fg != nullptr) {
|
||||
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString() << " for fg: " << fg->ToString();
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "cannot find parent for fg: " << fg->ToString();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -810,8 +822,11 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>();
|
||||
AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>();
|
||||
|
||||
bool firstRun = !SetUndeterminedFlag(evaluators[0]);
|
||||
(void)SetUndeterminedFlag(evaluators[1]);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
bool firstRun = !SetUndeterminedFlag(evaluators[0], possible_parent_fg);
|
||||
(void)SetUndeterminedFlag(evaluators[1], possible_parent_fg);
|
||||
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
|
||||
|
@ -891,8 +906,12 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
for (const auto &eval : evaluators) {
|
||||
(void)SetUndeterminedFlag(eval);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
for (auto eval : evaluators) {
|
||||
(void)SetUndeterminedFlag(eval, possible_parent_fg);
|
||||
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
|
|
|
@ -248,6 +248,10 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
|
||||
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
|
||||
|
||||
FuncGraphPtr root_func_graph() const { return root_func_graph_; }
|
||||
AnalysisContextPtr root_context() const { return root_context_; }
|
||||
void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; }
|
||||
|
||||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
void ResetFunctionCallDepth() {
|
||||
|
@ -292,7 +296,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator);
|
||||
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
bool *continue_flag);
|
||||
|
@ -308,6 +312,9 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
std::list<EvaluatorArgs> eval_trace_;
|
||||
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
|
||||
std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_;
|
||||
// root or top func_graph for static analysis;
|
||||
FuncGraphPtr root_func_graph_{nullptr};
|
||||
AnalysisContextPtr root_context_{nullptr};
|
||||
|
||||
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
|
|
|
@ -98,6 +98,21 @@ void FuncGraph::add_parameter(const ParameterPtr &p) {
|
|||
}
|
||||
}
|
||||
|
||||
ParameterPtr FuncGraph::InsertFrontParameter() {
|
||||
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
|
||||
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
|
||||
InsertFrontParameter(p);
|
||||
return p;
|
||||
}
|
||||
|
||||
void FuncGraph::InsertFrontParameter(const ParameterPtr &p) {
|
||||
if (manager_.lock()) {
|
||||
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p);
|
||||
} else {
|
||||
PrependParameter(p);
|
||||
}
|
||||
}
|
||||
|
||||
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
|
||||
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
|
||||
ParameterPtr p = std::make_shared<Parameter>(this_graph);
|
||||
|
|
|
@ -83,6 +83,7 @@ const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
|||
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
||||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
|
||||
const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
|
||||
|
||||
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
||||
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
|
||||
|
@ -169,9 +170,14 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
|
||||
|
||||
const std::vector<AnfNodePtr> ¶meters() const { return parameters_; }
|
||||
// Append
|
||||
virtual ParameterPtr add_parameter();
|
||||
void add_parameter(const ParameterPtr &p);
|
||||
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
|
||||
// Prepend
|
||||
virtual ParameterPtr InsertFrontParameter();
|
||||
void InsertFrontParameter(const ParameterPtr &p);
|
||||
void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
|
||||
void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; }
|
||||
// Add a weight parameter with specific name.
|
||||
ParameterPtr AddWeightParameter(const std::string &name);
|
||||
|
@ -354,7 +360,6 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
|
||||
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::vector<BaseShapePtr> joined_shapes_;
|
||||
std::unordered_map<std::string, FuncGraphTransform> transforms_;
|
||||
// Parameter default value.
|
||||
std::map<std::string, AnfNodePtr> parameter_default_value_;
|
||||
|
|
|
@ -224,7 +224,6 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
|||
TraceGuard trace_guard(func_graph->debug_info(), target_relation_);
|
||||
*target_func_graph = std::make_shared<FuncGraph>();
|
||||
(*target_func_graph)->set_attrs(func_graph->attrs());
|
||||
(*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_;
|
||||
(*target_func_graph)->set_transforms(func_graph->transforms());
|
||||
(*target_func_graph)->set_has_vararg(func_graph->has_vararg());
|
||||
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
|
||||
|
@ -254,9 +253,24 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
|
|||
return;
|
||||
}
|
||||
|
||||
CloneInfo item = todo_.back();
|
||||
auto lift_top_func_graph = item.origin;
|
||||
for (auto &fv_map : iter->second) {
|
||||
auto &free_var = fv_map.first;
|
||||
if (utils::isa<AnfNodePtr>(free_var)) {
|
||||
auto free_var_node = utils::cast<AnfNodePtr>(free_var);
|
||||
// Don't lift weight parameter to top func_graph.
|
||||
if (func_graph == lift_top_func_graph) {
|
||||
if (free_var_node->isa<Parameter>()) {
|
||||
auto free_var_param = free_var_node->cast<ParameterPtr>();
|
||||
if (free_var_param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->ToString()
|
||||
<< " for top_func_graph: " << lift_top_func_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
|
||||
repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
|
||||
}
|
||||
}
|
||||
|
@ -301,6 +315,8 @@ void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList
|
|||
}
|
||||
}
|
||||
AnfNodePtr new_param = nullptr;
|
||||
CloneInfo item = todo_.back();
|
||||
auto lift_top_func_graph = item.origin;
|
||||
for (auto ¶m : params) {
|
||||
auto old_param = repl_node_[param];
|
||||
if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
|
||||
|
@ -314,6 +330,14 @@ void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList
|
|||
input_params->push_back(new_param);
|
||||
continue;
|
||||
}
|
||||
if (lift_top_func_graph == func_graph) {
|
||||
// Don't lift parameter from used_graphs to my parameter if I am the top;
|
||||
repl_node_[old_param] = old_param;
|
||||
input_params->push_back(old_param);
|
||||
MS_LOG(DEBUG) << "Bypass param: " << old_param->ToString()
|
||||
<< " for top_func_graph: " << lift_top_func_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
new_param = AddParameter(func_graph, old_param, false);
|
||||
parameters.push_back(new_param);
|
||||
lift_params->push_back(new_param);
|
||||
|
|
|
@ -445,6 +445,12 @@ void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &pa
|
|||
tr.Commit();
|
||||
}
|
||||
|
||||
void FuncGraphManager::InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) {
|
||||
auto tr = Transact();
|
||||
tr.InsertFrontParameter(fg, parameter);
|
||||
tr.Commit();
|
||||
}
|
||||
|
||||
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
auto func_graph = old_node->func_graph();
|
||||
auto tr = Transact();
|
||||
|
@ -599,6 +605,13 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl
|
|||
auto param_node = param.param->cast<ParameterPtr>();
|
||||
param.func_graph->append_parameter(param_node);
|
||||
} break;
|
||||
case Change::kTxInsertFrontParam: {
|
||||
auto param = args.cast<ArgsOfInsertFrontParam>();
|
||||
MS_EXCEPTION_IF_NULL(param.func_graph);
|
||||
(*adds)[param.param] += 1;
|
||||
auto param_node = param.param->cast<ParameterPtr>();
|
||||
param.func_graph->PrependParameter(param_node);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -665,6 +678,10 @@ void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m
|
|||
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param});
|
||||
}
|
||||
|
||||
void FuncGraphTransaction::InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) {
|
||||
changes_.emplace_back(Change::kTxInsertFrontParam, ArgsOfInsertFrontParam{fg, param});
|
||||
}
|
||||
|
||||
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
|
|
@ -312,6 +312,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
|||
void RemoveRoots();
|
||||
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters);
|
||||
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
|
||||
void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
|
||||
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
|
||||
|
@ -406,6 +407,7 @@ class FuncGraphTransaction {
|
|||
// set parameters of a func graph
|
||||
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms);
|
||||
void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
|
||||
void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m);
|
||||
|
||||
// replace old_node with new_node
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
|
@ -447,6 +449,18 @@ struct ArgsOfAddParam {
|
|||
}
|
||||
};
|
||||
|
||||
// args for InsertFront param
|
||||
struct ArgsOfInsertFrontParam {
|
||||
FuncGraphPtr func_graph;
|
||||
AnfNodePtr param;
|
||||
bool operator==(const ArgsOfInsertFrontParam &other) const { return &other == this; }
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os, const ArgsOfInsertFrontParam &) {
|
||||
os << "[ArgsOfInsertFrontParam]";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
// args for set edge
|
||||
struct ArgsOfSetEdge {
|
||||
CNodePtr root_node;
|
||||
|
@ -473,7 +487,7 @@ struct ArgsOfAddEdge {
|
|||
};
|
||||
|
||||
struct Change {
|
||||
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxAddEdge };
|
||||
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxInsertFrontParam, kTxAddEdge };
|
||||
OpName op;
|
||||
Any args;
|
||||
Change(OpName name, const Any ¶) : op(name), args(para) {}
|
||||
|
|
|
@ -40,6 +40,16 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
|
|||
auto element0_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_0);
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
auto shape = elements[i]->BuildShape();
|
||||
if (shape->isa<abstract::Shape>() && shape_0->isa<abstract::Shape>()) {
|
||||
const auto &shape_vec = shape->cast<abstract::ShapePtr>()->shape();
|
||||
const auto &shape_0_vec = shape_0->cast<abstract::ShapePtr>()->shape();
|
||||
if ((shape_vec == ShapeVector({1}) && shape_0_vec == ShapeVector()) ||
|
||||
(shape_vec == ShapeVector() && shape_0_vec == ShapeVector({1}))) {
|
||||
MS_LOG(DEBUG) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
|
||||
<< " are consistent with the shape of input[0]" << shape_0->ToString();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (*shape != *shape_0) {
|
||||
MS_EXCEPTION(ValueError) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
|
||||
<< " are not consistent with the shape of input[0]" << shape_0->ToString();
|
||||
|
|
|
@ -676,7 +676,7 @@ class SideEffectControlFlowAssignDependWhileNet(Cell):
|
|||
return grad_out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
|
|
|
@ -37,11 +37,13 @@ class TestAD : public UT::Common {
|
|||
|
||||
public:
|
||||
UT::PyFuncGraphFetcher getPyFun;
|
||||
pipeline::ResourceBasePtr resourcePtr = std::make_shared<pipeline::ResourceBase>();
|
||||
pipeline::ResourcePtr resourcePtr = std::make_shared<pipeline::Resource>();
|
||||
|
||||
protected:
|
||||
void AssertExpect(const std::string& testCase) {
|
||||
FuncGraphPtr g = getPyFun(testCase);
|
||||
resourcePtr->manager()->RemoveRoots();
|
||||
resourcePtr->manager()->AddFuncGraph(g, true);
|
||||
FuncGraphPtr dg = Grad(g, resourcePtr);
|
||||
AssertExpect(testCase, dg);
|
||||
}
|
||||
|
|
|
@ -20,9 +20,10 @@ from mindspore import context
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
|
||||
grad_by_list = C.GradOperation(get_by_list=True)
|
||||
|
||||
class CropAndResizeNet(nn.Cell):
|
||||
def __init__(self, crop_size):
|
||||
|
@ -138,3 +139,38 @@ def test_ad_fv_cnode_order():
|
|||
net.add_flags_recursive(defer_inline=True)
|
||||
grad_net = grad_all(net)
|
||||
grad_net(input_x, input_y)
|
||||
|
||||
|
||||
# True and False branch of switch have different number of parameters.
|
||||
def test_if_branch_with_different_params():
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
|
||||
self.weight2 = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="weight2")
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = x
|
||||
if idx < end:
|
||||
out = out + self.weight1 * self.weight2
|
||||
else:
|
||||
out = out + self.weight1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
return grad_by_list(self.net, self.weights)(idx, end, x)
|
||||
|
||||
idx = Tensor(np.array((0), dtype=np.int32))
|
||||
end = Tensor(np.array((3), dtype=np.int32))
|
||||
x = Tensor(np.array([2.0], dtype=np.float32))
|
||||
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
grad_net(idx, end, x)
|
||||
|
|
Loading…
Reference in New Issue