lift fv before grad except weight, then convert

switch(cond, partial(g1, xs), partial(g2, ys))(Zs)
 to
 switch(cond, g1, g2)(Xs, Ys, Zs)

 switch_layer(index, make_tuple(partial(g1, xs), partial(g2, ys)))(Zs)
 to
 switch_layer(index, make_tuple(g1, g2))(Xs, Ys, Zs)

put Zs at last when unifyparameter as it may have u-monad or io-monad

use joined args other than broadened one as some extra parameter which is not a parameter of while_header can be add to while_body

inline fprop_switch forcely

reorder the parameter if one of the parameter is Monad when incorporate call

incorporate switch tuple_getitem if item 0 of tuple is EnvInstance or
item 1 of tuple is bprop function

addn with shape() and shape(1)

remove context_ from FuncGraphEvaluator to make it re-entry able to resolve evaluator stuck issue because of re-entry of the same FuncGraphEvaluator
This commit is contained in:
zhousiyi 2021-04-30 06:55:11 +00:00 committed by Zhang Qinghua
parent af8c15265d
commit a5b4e7bbf8
26 changed files with 784 additions and 175 deletions

View File

@ -101,12 +101,15 @@ void DumpInferStack(std::ostringstream &oss) {
infer_vec.clear();
break;
}
auto graph_context = graph_infer->context();
auto graph_context = graph_infer->parent_context();
if (graph_context == nullptr) {
MS_LOG(INFO) << "Null context continue";
continue;
}
auto graph = graph_context->func_graph();
if (graph == nullptr) {
continue;
}
auto args_spec_list = graph_context->args_spec_list();
oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list);
}
@ -264,7 +267,7 @@ std::vector<AnalysisContextPtr> AnalyzedFuncGraphExporter::ProcessFuncGraphCall(
}
auto base_fg_evaluator = dyn_cast<abstract::BaseFuncGraphEvaluator>(evaluator);
auto ctx = base_fg_evaluator->context();
auto ctx = base_fg_evaluator->parent_context();
if (ctx != nullptr && context_map_.insert({ctx, false}).second) {
MS_LOG(DEBUG) << "Add new context, ctx.addr = " << ctx.get() << "ctx = " << ctx->ToString();
context_vec_.push_back(ctx);
@ -506,7 +509,7 @@ void GetEvalStackInfo(std::ostringstream &oss) {
return;
}
static int fileNumber = 0;
string file_name = "analyze_fail" + std::to_string(fileNumber++) + ".dat";
string file_name = "analyze_fail_" + std::to_string(fileNumber++) + ".dat";
auto ms_om_path = common::GetEnv("MS_OM_PATH");
if (!ms_om_path.empty()) {
auto path = ms_om_path + "/" + file_name;

View File

@ -37,6 +37,8 @@ namespace ad {
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
bool lift_fv_before_grad = true;
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
{
@ -73,6 +75,10 @@ void DFunctor::Clear() {
}
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
if (lift_fv_before_grad) {
MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv "
<< fv->func_graph()->ToString() << " " << fv->ToString() << ".";
}
auto fv_adjoint = anfnode_to_adjoin_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_.end()) {
MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
@ -437,6 +443,13 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
AnfNodePtr new_grad_fv = grad_fv;
// Add grads wrt fv.
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
if (!is_top_ && free_variables_nodes.size() != 0) {
if (lift_fv_before_grad) {
MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
<< ".";
}
}
for (auto &fv : free_variables_nodes) {
auto fv_adjoint = anfnode_to_adjoin_.find(fv);
if (fv_adjoint == anfnode_to_adjoin_.end()) {
@ -460,6 +473,10 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
}
AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
if (lift_fv_before_grad) {
MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
<< grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
}
AnfNodePtr new_grad_fv = grad_fv;
// Add indirect fv bprop.
for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
@ -497,7 +514,12 @@ void DFunctor::MapMorphism() {
output_adjoint->second->AccumulateDout(dout_);
// Set output for tape closure.
auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
AnfNodePtr grad_fv;
if (lift_fv_before_grad) {
grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
} else {
grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
}
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
// Add grads wrt inputs.

View File

@ -43,6 +43,9 @@ extern KPrim g_k_prims;
class DFunctor;
using DFunctorPtr = std::shared_ptr<DFunctor>;
// Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed.
extern bool lift_fv_before_grad;
// D Functor's rules to map closure object and morphisms.
class DFunctor : public std::enable_shared_from_this<DFunctor> {
public:

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,12 +16,52 @@
#include "frontend/optimizer/ad/grad.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/optimizer/irpass.h"
#include "ir/func_graph_cloner.h"
#include "utils/ms_context.h"
#include "utils/symbolic.h"
namespace mindspore {
namespace ad {
namespace {
FuncGraphPtr PartialEliminateOptPass(const ResourcePtr &resource, const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(resource);
opt::irpass::OptimizeIRPassLib irpass;
opt::OptPassConfig partial_eliminate_opt_ = opt::OptPassConfig(
{irpass.partial_eliminate_, irpass.switch_partial_eliminater_, irpass.switch_layer_partial_eliminater_});
opt::OptPassGroupMap map({{"partial_eliminate_", partial_eliminate_opt_}});
auto after_lift_opt = opt::Optimizer::MakeOptimizer("partial_eliminate", resource, map);
FuncGraphPtr opt_fg = nullptr;
WITH(MsProfile::GetProfile()->Step("partial_eliminate_before_grad"))[&after_lift_opt, func_graph, &opt_fg]() {
opt_fg = after_lift_opt->step(func_graph, true);
};
return opt_fg;
}
FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPtr &func_graph) {
bool save_graphs_flag = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (save_graphs_flag) {
DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph);
}
FuncGraphPtr new_fg = LiftingClone(func_graph);
if (save_graphs_flag) {
DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg);
}
auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource);
if (new_res == nullptr) {
MS_LOG(EXCEPTION) << "Parameter resources is not a pipeline::Resource";
}
auto opt_fg = PartialEliminateOptPass(new_res, new_fg);
if (save_graphs_flag) {
DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg);
}
return opt_fg;
}
} // namespace
FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) {
MS_EXCEPTION_IF_NULL(func_graph);
auto gradkv = func_graph->transforms().find("grad");
@ -33,6 +73,11 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
MS_EXCEPTION_IF_NULL(manager_ptr);
manager_ptr->AddFuncGraph(func_graph);
FuncGraphPtr grad_fg = func_graph;
lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1");
if (func_graph->func_graphs_used().size() != 0 && lift_fv_before_grad) {
grad_fg = LiftFv(resources, func_graph);
}
auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
@ -41,8 +86,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
}
};
auto f = std::make_shared<DFunctor>(func_graph, resources);
auto user_defined = f->KUserDefined(func_graph);
auto f = std::make_shared<DFunctor>(grad_fg, resources);
auto user_defined = f->KUserDefined(grad_fg);
if (user_defined != nullptr) {
multi_graph_sink(user_defined);
if (is_top) {
@ -62,6 +107,9 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
}
multi_graph_sink(res);
if (func_graph != grad_fg) {
(void)func_graph->transforms().insert(std::make_pair("grad", FuncGraphTransform(res)));
}
return res;
}

View File

@ -327,6 +327,11 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
<< " prim bprop function to J expanded func graph. NodeInfo: "
<< trace::GetDebugInfo(bprop_fg->debug_info());
}
if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
// Inline fprop_switch before renormalize;
expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
}
return expanded_fg;
}

View File

@ -162,6 +162,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
exchange_switch_depend_value_ =
MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);
switch_partial_eliminater_ =
MakeSubstitution(std::make_shared<SwitchPartialEliminater>(), "eliminate_switch_partial_", IsCNodeDup);
switch_layer_partial_eliminater_ =
MakeSubstitution(std::make_shared<SwitchLayerPartialEliminater>(), "eliminate_switch_layer_partial_", IsCNodeDup);
// Addn
merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);

View File

@ -86,6 +86,9 @@ class OptimizeIRPassLib {
SubstitutionPtr convert_switch_replacement_;
SubstitutionPtr exchange_switch_depend_value_;
SubstitutionPtr switch_partial_eliminater_;
SubstitutionPtr switch_layer_partial_eliminater_;
// AddN
SubstitutionPtr merge_addn_;
SubstitutionPtr addn_zero_filter_;

View File

@ -20,6 +20,7 @@
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <memory>
#include "frontend/optimizer/irpass.h"
@ -38,29 +39,36 @@ class CallOutputTransform {
CallOutputTransform() : cache_() {}
~CallOutputTransform() = default;
FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) {
FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs, bool xs_first) {
if (cache_.find(fg) == cache_.end()) {
cache_[fg] = {};
}
auto &cache = cache_[fg];
if (cache.find(nargs) == cache.end()) {
auto key = std::make_pair(nargs, xs_first);
if (cache.find(key) == cache.end()) {
FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("call"));
std::vector<AnfNodePtr> new_items;
new_items.push_back(new_fg->output());
for (size_t i = 0; i < nargs; i++) {
new_items.push_back(new_fg->add_parameter());
if (xs_first) {
for (size_t i = 0; i < nargs; i++) {
new_items.push_back(new_fg->add_parameter());
}
} else {
for (size_t i = 0; i < nargs; i++) {
new_items.push_back(new_fg->InsertFrontParameter());
}
}
new_fg->set_output(new_fg->NewCNode(new_items));
cache[nargs] = new_fg;
cache[key] = new_fg;
}
return cache[nargs];
return cache[key];
}
private:
std::unordered_map<FuncGraphPtr, std::unordered_map<size_t, FuncGraphPtr>> cache_;
std::unordered_map<FuncGraphPtr, std::unordered_map<std::pair<size_t, bool>, FuncGraphPtr, PairHasher>> cache_;
};
} // namespace internal
@ -88,20 +96,35 @@ class IncorporateCall : public AnfVisitor {
auto xs_size = Xs_.size();
auto ys_size = inputs.size() - 1;
auto new_fg = call_output_transform_(fg_, ys_size);
bool xs_first = true;
if ((xs_size > 0) && (Xs_[xs_size - 1]->abstract() != nullptr) &&
(Xs_[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) {
xs_first = false;
}
auto new_fg = call_output_transform_(fg_, ys_size, xs_first);
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(new_fg));
if (xs_size > 0) {
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
if (xs_first) {
if (xs_size > 0) {
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
}
if (ys_size > 0) {
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
}
} else {
if (ys_size > 0) {
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
}
if (xs_size > 0) {
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
}
}
if (ys_size > 0) {
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
}
return node->func_graph()->NewCNode(args);
auto new_node = node->func_graph()->NewCNode(args);
new_node->set_abstract(node->abstract());
return new_node;
}
void Visit(const CNodePtr &cnode) override {
@ -159,19 +182,35 @@ class IncorporateCallSwitch : public AnfVisitor {
auto fg = node->func_graph();
auto xs_size = inputs_x.size() - 1;
auto ys_size = inputs.size() - 1;
auto new_g1 = call_output_transform_(g1_, ys_size);
auto new_g2 = call_output_transform_(g2_, ys_size);
bool xs_first = true;
if ((xs_size > 0) && (inputs_x[xs_size - 1]->abstract() != nullptr) &&
(inputs_x[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) {
xs_first = false;
}
auto new_g1 = call_output_transform_(g1_, ys_size, xs_first);
auto new_g2 = call_output_transform_(g2_, ys_size, xs_first);
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
std::vector<AnfNodePtr> args{sw_node};
if (xs_size > 0) {
(void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end());
}
if (ys_size > 0) {
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
if (xs_first) {
if (xs_size > 0) {
(void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end());
}
if (ys_size > 0) {
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
}
} else {
if (ys_size > 0) {
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
}
if (xs_size > 0) {
(void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end());
}
}
return fg->NewCNode(args);
auto new_node = fg->NewCNode(args);
new_node->set_abstract(node->abstract());
return new_node;
}
void Visit(const AnfNodePtr &node) override {

View File

@ -164,6 +164,7 @@ class IncorporateGetitem : public AnfVisitor {
}
}
}
new_node->set_abstract(node->abstract());
return new_node;
}
@ -228,6 +229,7 @@ class IncorporateGetitemDepend : public AnfVisitor {
new_depend_cnode =
node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), new_fg_cnode, depend_2nd_input_});
}
new_depend_cnode->set_abstract(node->abstract());
return new_depend_cnode;
}
@ -294,8 +296,7 @@ class IncorporateGetitemSwitch : public AnfVisitor {
is_in_get_ = false;
auto fg = node->func_graph();
if (idx_ == -1 || switch_ == nullptr || fg == nullptr ||
(fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) && !ExistEnvNode(fg))) {
if (idx_ == -1 || switch_ == nullptr || fg == nullptr) {
return nullptr;
}
@ -308,10 +309,27 @@ class IncorporateGetitemSwitch : public AnfVisitor {
}
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
bool has_env_type = false;
if (tuple_getitem->input(1)->abstract() && tuple_getitem->input(1)->abstract()->isa<abstract::AbstractTuple>()) {
const auto &abs_tuple = *(tuple_getitem->input(1)->abstract()->cast<abstract::AbstractTuplePtr>());
// eliminate (envinstance, value1, value2, ...) built by bprop func_graph()
if (abs_tuple.size() >= 1) {
// Value maybe kAnyValue, so check the type track;
if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) {
has_env_type = true;
}
}
// eliminate (value, bprop_func) built by fprop func_graph
if (abs_tuple.size() >= 2) {
if (abs_tuple[1]->isa<abstract::AbstractFunction>()) {
has_env_type = true;
}
}
}
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_)) {
!ExistEnvNodeInTupleItem(g2_) && !has_env_type) {
return nullptr;
}
auto new_g1 = getitem_transform_(g1_, idx_);
@ -319,7 +337,9 @@ class IncorporateGetitemSwitch : public AnfVisitor {
auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)});
(void)args_.insert(args_.begin(), sw_node);
return fg->NewCNode(args_);
auto new_node = fg->NewCNode(args_);
new_node->set_abstract(node->abstract());
return new_node;
}
void Visit(const AnfNodePtr &node) override {
@ -398,8 +418,8 @@ class IncorporateGetitemSwitch : public AnfVisitor {
}
const auto &cnode = output->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &node) {
auto sub_fg = GetValueNode<FuncGraphPtr>(node);
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &input) {
auto sub_fg = GetValueNode<FuncGraphPtr>(input);
if (sub_fg != nullptr && ExistEnvNode(sub_fg)) {
return true;
}

View File

@ -85,6 +85,9 @@ bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE);
}
// {G, Xs}
class InlinerBase : public AnfVisitor {
@ -127,6 +130,14 @@ class InlinerBase : public AnfVisitor {
return nullptr;
}
if (IsForceInline(this, fg, node)) {
if (IsUniqueUse(nullptr, fg, nullptr)) {
return InlineMove(node, fg, args, inputs);
} else {
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
}
if (IsUniqueUse(nullptr, fg, nullptr)) {
// For the single used fg, including non-after and after not matched above,
// we move the whole fg nodes.
@ -151,15 +162,20 @@ class InlinerBase : public AnfVisitor {
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
}
AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
const std::vector<AnfNodePtr> &inputs) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
}
AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
const std::vector<AnfNodePtr> &inputs) {
if (use_move_) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg);
auto out_node = fg->output();
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
return out_node;
return InlineMove(node, fg, args, inputs);
}
// The other branch calling the last after block.
@ -366,6 +382,7 @@ class DirectInliner : public InlinerBase {
: InlinerBase(
// Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
{
{IsForceInline},
{IsDirectParentCall},
},
use_move) {}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,9 +17,11 @@
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
@ -72,6 +74,341 @@ class PartialEliminater : public AnfVisitor {
private:
std::vector<AnfNodePtr> Xs_{};
};
class ChoicePartialEliminater : public AnfVisitor {
public:
virtual ~ChoicePartialEliminater() = default;
protected:
AnfNodePtrList fg_list_{};
std::vector<AnfNodePtrList> args_list_{};
void Visit(const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
if (IsValueNode<FuncGraph>(node)) {
fg_list_.push_back(node);
args_list_.push_back(AnfNodePtrList{});
}
return;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
// {prim::kPrimPartial, G, Xs}
if (inputs.size() < 3) {
MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
return;
}
if (IsValueNode<FuncGraph>(inputs[1])) {
fg_list_.push_back(inputs[1]);
AnfNodePtrList args;
(void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(args));
args_list_.push_back(args);
}
return;
}
// return value: true -- continue replace; false -- return nullptr;
bool CheckFuncGraphAndArgs() {
// Either one should be {Partial, G, X}
auto has_partial_args =
std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; });
if (!has_partial_args) {
return false;
}
// check funcgraph should be used once only.
for (auto &fg_node : fg_list_) {
auto fg = GetValueNode<FuncGraphPtr>(fg_node);
MS_EXCEPTION_IF_NULL(fg);
if (fg->func_graph_cnodes_index().size() != 1) {
for (auto iter : fg->func_graph_cnodes_index()) {
MS_LOG(ERROR) << "fg user: " << iter.first->first->DebugString(1) << ", index: " << iter.first->second;
}
MS_LOG(EXCEPTION) << "fg is used multiple times: " << fg->ToString();
}
}
return true;
}
// f(x1, x2, x3, z1, z2)
// g(x4, x2, z1, z2)
// h(x5, x2, x7, x8, z1, z2)
// --> anchor_fg = h
// h(x5, x2, x7, x8, x1, x3, x4, z1, z2)
// f(x5, x2, x7, x8, x1, x3, x4, z1, z2)
// g(x5, x2, x7, x8, x1, x3, x4, z1, z2)
// as z1, z2 maybe U or IO monad.
AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list,
const std::vector<AnfNodePtrList> args_list) {
std::vector<size_t> inputs_index_list[args_list.size()];
size_t extra_input_counter = 0;
AnfNodePtrList extra_inputs;
const auto &anchor_args = args_list[anchor_index];
size_t anchor_args_size = anchor_args.size();
auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]);
MS_EXCEPTION_IF_NULL(anchor_fg);
// Find the new location of the old_inputs except Zs;
for (size_t i = 0; i < args_list.size(); ++i) {
if (i == anchor_index) {
continue;
}
const auto &another_args = args_list[i];
auto &curr_inputs_index = inputs_index_list[i];
for (size_t j = 0; j < another_args.size(); ++j) {
size_t k;
for (k = 0; k < anchor_args_size; ++k) {
if (another_args[j] == anchor_args[k]) {
curr_inputs_index.push_back(k);
break;
}
}
if (k == anchor_args_size) {
// check if used by another func_graph;
for (k = 0; k < extra_input_counter; ++k) {
if (another_args[j] == extra_inputs[k]) {
curr_inputs_index.push_back(anchor_args_size + k);
break;
}
}
if (k == extra_input_counter) {
extra_inputs.push_back(another_args[j]);
curr_inputs_index.push_back(anchor_args_size + extra_input_counter);
extra_input_counter++;
}
}
}
}
auto manager = anchor_fg->manager();
auto txn = manager->Transact();
size_t anchor_params_size = anchor_fg->parameters().size();
const auto &anchor_fg_params = anchor_fg->parameters();
for (size_t i = 0; i < args_list.size(); ++i) {
if (i == anchor_index) {
continue;
}
AnfNodePtrList new_params;
new_params.resize(anchor_params_size + extra_input_counter);
const auto &curr_inputs_index = inputs_index_list[i];
auto another_fg = GetValueNode<FuncGraphPtr>(fg_list[i]);
MS_EXCEPTION_IF_NULL(another_fg);
const auto &old_params = another_fg->parameters();
const auto &old_args = args_list[i];
for (size_t j = 0; j < old_args.size(); j++) {
new_params[curr_inputs_index[j]] = old_params[j];
}
// Zs_
for (size_t j = old_args.size(), k = 0; j < old_params.size(); ++j, ++k) {
new_params[anchor_args_size + extra_input_counter + k] = old_params[j];
}
// unused inputs
for (size_t j = 0; j < anchor_args_size; ++j) {
if (new_params[j] == nullptr) {
TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(another_fg);
new_params[j] = param;
}
}
// extra inputs used by another func_graph;
for (size_t j = 0; j < extra_inputs.size(); ++j) {
if (new_params[anchor_args_size + j] == nullptr) {
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(another_fg);
new_params[anchor_args_size + j] = param;
}
}
// set the parameter for another_fg and replace it's parameters;
txn.SetParameters(another_fg, new_params);
}
// Reorder Zs_ and add extra parameters for anchor_fg;
// add extra parameter for anchor_fg;
AnfNodePtrList new_params;
new_params.reserve(anchor_params_size + extra_input_counter);
// reuse parameters for anchor_args;
std::copy(anchor_fg_params.cbegin(), anchor_fg_params.cbegin() + anchor_args_size, std::back_inserter(new_params));
// Extra parameters;
for (size_t i = 0; i < extra_inputs.size(); ++i) {
TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
new_params.push_back(param);
}
// Reorder Zs_ to last;
for (size_t i = anchor_args_size; i < anchor_params_size; ++i) {
new_params.push_back(anchor_fg_params[i]);
}
txn.SetParameters(anchor_fg, new_params);
txn.Commit();
return extra_inputs;
}
};
// {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} ->
// {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs}
// {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union
// Zs}
// {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union
// Zs}
class SwitchPartialEliminater : public ChoicePartialEliminater {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
return nullptr;
}
auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
if (input0_cnode->size() != 4) {
return nullptr;
}
fg_list_.clear();
args_list_.clear();
auto &maybe_partial_1 = input0_cnode->input(2);
Visit(maybe_partial_1);
auto &maybe_partial_2 = input0_cnode->input(3);
Visit(maybe_partial_2);
// Either one should be {Partial, G, X}
if (fg_list_.size() != 2 && args_list_.size() != 2) {
return nullptr;
}
// Should not continue;
if (!CheckFuncGraphAndArgs()) {
return nullptr;
}
if (args_list_[0] == args_list_[1]) {
auto new_node =
BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[0], AnfNodePtrList{});
return new_node;
} else {
// find partial funcgraph with the longest args as anchor;
size_t max_args_pos = 0;
if (args_list_[0].size() > args_list_[1].size()) {
max_args_pos = 0;
} else {
max_args_pos = 1;
}
auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_);
auto new_node =
BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[max_args_pos], extra_inputs);
return new_node;
}
}
private:
AnfNodePtr BuildNewSwitchNode(const CNodePtr &old_cnode, const CNodePtr input0_cnode, const AnfNodePtr &G1,
const AnfNodePtr &G2, const AnfNodePtrList &partial_args,
const AnfNodePtrList &extra_args) {
TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
// {Switch, cond, G1, G2}
auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
AnfNodePtrList args{switch_cnode};
(void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
// Zs
if (old_cnode->size() >= 2) {
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
}
TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
auto new_node = old_cnode->func_graph()->NewCNode(args);
return new_node;
}
};
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs}
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs}
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} ->
// {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs}
class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
// {SwitchLayer{}, Zs}
if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) {
return nullptr;
}
auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
// {SwitchLayer, cond, MakeTuple{}}
if (switch_layer_cnode->size() != 3) {
return nullptr;
}
if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) {
return nullptr;
}
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
if (make_tuple_cnode->size() < 2) {
return nullptr;
}
fg_list_.clear();
args_list_.clear();
// Build funcgraph list and args list;
for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
Visit(make_tuple_cnode->input(i));
}
if (!CheckFuncGraphAndArgs()) {
return nullptr;
}
// All have the same args;
auto args_equal =
std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; });
if (args_equal) {
auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[0], AnfNodePtrList{});
return new_node;
} else {
// find partial funcgraph with the longest args as anchor;
size_t max_args_pos = 0, max_args_len = 0;
for (size_t i = 0; i < args_list_.size(); ++i) {
if (max_args_len < args_list_[i].size()) {
max_args_len = args_list_[i].size();
max_args_pos = i;
}
}
auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_);
auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[max_args_pos], extra_inputs);
return new_node;
}
}
private:
AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
// {MakeTuple, G1, G2, ...}
auto new_make_tuple_cnode = old_cnode->func_graph()->NewCNode(make_tuple_args);
TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer_cnode->debug_info()));
// {SwitchLayer, cond, MakeTuple{}}
auto new_switch_layer_cnode = old_cnode->func_graph()->NewCNode(
{switch_layer_cnode->input(0), switch_layer_cnode->input(1), new_make_tuple_cnode});
AnfNodePtrList args{new_switch_layer_cnode};
(void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
(void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
// Zs
if (old_cnode->size() >= 2) {
(void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
}
TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
auto new_node = old_cnode->func_graph()->NewCNode(args);
return new_node;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -109,11 +109,12 @@ void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine,
}
// Start running stack frames in a Evaluator.
AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
const AnalysisContextPtr &context) {
EvalResultPtr eval_result = nullptr;
AbstractBasePtr res_base = nullptr;
std::stack<StackFramePtr> stack_frames;
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_);
auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_);
MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
stack_frames.push(current_stack_frame);
while (true) {
@ -155,7 +156,8 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr
return res_base;
}
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) {
AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
const AnalysisContextPtr &context) {
const AnfNodePtr &func_node = fg->get_return();
const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
@ -165,7 +167,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine
});
AbstractBasePtr res_base = nullptr;
for (const auto &node : all_nodes) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context_);
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf);
@ -214,24 +216,30 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":"
<< fg->ToString() << "();";
}
context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list);
auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list);
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
if (func_graph_evaluator != nullptr) {
if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
engine->set_root_context(context);
}
}
const auto &parameters = fg->parameters();
for (size_t i = 0; i < nargs; i++) {
const auto &arg = args_abs_list[i];
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, context_);
AnfNodeConfigPtr conf = engine->MakeConfig(node, context);
engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr));
MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString();
}
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
<< ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString()
<< ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
<< ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
<< ", current function call depth: " << engine->function_call_depth();
AbstractBasePtr res_base = nullptr;
if (engine->enable_recursive_eval()) {
res_base = LaunchRecursiveEval(engine, fg);
res_base = LaunchRecursiveEval(engine, fg, context);
} else {
res_base = LaunchStackFrame(engine, fg);
res_base = LaunchStackFrame(engine, fg, context);
}
MS_EXCEPTION_IF_NULL(res_base);
@ -250,28 +258,27 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
return res;
}
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
// Only broaden scalar that data type is number, such as float16,int32 and so on.
auto type = arg->BuildType()->type_id();
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
return arg->Broaden(config);
} else if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
}
return arg;
});
}
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
MS_EXCEPTION_IF_NULL(func_graph_);
if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
AbstractBasePtrList broaded_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
// Only broaden scalar that data type is number, such as float16,int32 and so on.
auto type = arg->BuildType()->type_id();
if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) {
auto config = abstract::AbstractBase::kBroadenScalarParameterOnly;
return arg->Broaden(config);
} else if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
}
return arg;
});
if (func_graph_->joined_shapes_.size() == broaded_list.size()) {
for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
}
}
BroadenArgs(args_spec_list, &broaded_list);
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
<< ", broaded: " << mindspore::ToString(broaded_list);
@ -287,55 +294,11 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
}
if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
if (parent_context_) {
MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString()
<< ", context: " << parent_context_->ToString();
auto last_context = parent_context_->FindParentContext(func_graph_);
if (last_context && last_context->func_graph() == func_graph_) {
MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString();
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list());
// Join the last eval arguments and current arguments to check if there are loop variant.
auto joined_args_spec_list_1 = AbstractJoin(args_spec_list, last_context->args_spec_list());
MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list_1);
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list_1 == args_spec_list)) {
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list_1.begin(), joined_args_spec_list_1.end(),
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
MS_EXCEPTION_IF_NULL(arg_spec);
return arg_spec->GetShapeTrack();
});
joined_args_spec_list_1 = NormalizeArgs(joined_args_spec_list_1);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
return joined_args_spec_list_1;
}
}
if (!trace_.empty()) {
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back());
// Join the last eval arguments and current arguments to check if there are loop variant.
auto joined_args_spec_list_2 = AbstractJoin(args_spec_list, trace_.back());
// If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation.
if (!(joined_args_spec_list_2 == args_spec_list)) {
trace_.push_back(joined_args_spec_list_2);
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list_2.begin(), joined_args_spec_list_2.end(),
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
MS_EXCEPTION_IF_NULL(arg_spec);
return arg_spec->GetShapeTrack();
});
joined_args_spec_list_2 = NormalizeArgs(joined_args_spec_list_2);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list_2);
return joined_args_spec_list_2;
} else {
trace_.push_back(args_spec_list);
}
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
auto normalized_args_spec_list = NormalizeArgs(args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_spec_list);
return normalized_args_spec_list;
}
return args_spec_list;
}
@ -388,16 +351,6 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
// The evaluator can't reenter at sametime.
std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_lock_, std::try_to_lock);
if (!eval_lock.owns_lock()) {
// Release GIL
pybind11::gil_scoped_release infer_gil_release;
// Check if enter endless loop
HealthPointScopedDrop health_point_check;
eval_lock.lock();
}
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {

View File

@ -210,8 +210,6 @@ class BaseFuncGraphEvaluator : public Evaluator {
AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list);
AnalysisContextPtr context() const { return context_; }
void set_context(const AnalysisContextPtr &context) { context_ = context; }
AnalysisContextPtr parent_context() const { return parent_context_; }
void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; }
@ -219,14 +217,14 @@ class BaseFuncGraphEvaluator : public Evaluator {
AnalysisContextPtr parent_context_;
private:
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
const AnalysisContextPtr &context);
// Add functions for stack frame routine.
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg);
AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
const AnalysisContextPtr &context);
static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
const StackFramePtr &new_stack_frame);
static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame);
AnalysisContextPtr context_;
};
class FuncGraphEvaluator : public BaseFuncGraphEvaluator {
@ -353,6 +351,8 @@ class JEvaluator : public Evaluator {
EvaluatorPtr evaluator_;
AbstractFunctionPtr orig_func_;
};
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args);
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_

View File

@ -23,6 +23,7 @@
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/do_signature.h"
#include "abstract/abstract_function.h"
#include "abstract/utils.h"
#include "ir/graph_utils.h"
#include "utils/log_adapter.h"
#include "debug/trace.h"
@ -545,32 +546,37 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
EvalResultPtr ret = nullptr;
AbstractBasePtrList broaded_argvals;
EvalResultCache &cache = evalcaches_[eval]->GetCache();
for (auto &argvals_map : cache) {
std::vector<AbstractBasePtrList> args_vector;
auto &origin_eval_cache = evalcaches_[eval]->GetCache();
for (auto &argvals_map : origin_eval_cache) {
auto argvals = argvals_map.first;
args_vector.push_back(argvals);
broaded_argvals.clear();
(void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals),
[](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); });
BroadenArgs(argvals, &broaded_argvals);
(void)choices.insert(broaded_argvals);
MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
}
if (choices.size() == 1) {
ConfigPtrList args_conf_list;
(void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list),
[](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
// If broaden return null
ret = eval->SingleRun(engine_, args_conf_list, nullptr);
if (args_vector.size() < 2) {
MS_LOG(EXCEPTION) << "Should have 2 more choices, but: " << args_vector.size();
}
AbstractBasePtrList joined_argvals = args_vector[0];
for (size_t i = 1; i < args_vector.size(); ++i) {
joined_argvals = abstract::AbstractJoin(joined_argvals, args_vector[i]);
}
MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
real->SetValue(broaded_argvals, ret);
evalcaches_[eval] = real;
return std::make_pair(broaded_argvals, ret->abstract());
} else {
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
return std::make_pair(AbstractBasePtrList(), nullptr);
auto joined_eval_result = origin_eval_cache.get(joined_argvals);
if (joined_eval_result != nullptr) {
MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
real->SetValue(joined_argvals, joined_eval_result);
evalcaches_[eval] = real;
return std::make_pair(joined_argvals, joined_eval_result->abstract());
}
}
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
return std::make_pair(AbstractBasePtrList(), nullptr);
}
void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {

View File

@ -96,7 +96,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr
}
// Create a new stack frame and set arguments for it.
fg_evaluator->set_context(new_context);
auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, parent_context);
new_stack_frame->set_args_abs_list(std::move(args_abs_list));
return new_stack_frame;

View File

@ -80,6 +80,8 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac
MS_EXCEPTION_IF_NULL(func_graph_manager_);
func_graph_manager_->AddFuncGraph(func_graph);
root_func_graph_ = func_graph;
AnalysisContextPtr empty_context = AnalysisContext::DummyContext();
// Running the analyzer.
@ -105,7 +107,7 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
const ConfigPtrList &args_conf_list) {
std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context);
(void)eval->Run(shared_from_this(), args_conf_list, nullptr);
return eval->context();
return root_context_;
}
void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
@ -213,11 +215,12 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) {
return;
}
auto top_evaluator = infer_stack.top().first;
if (!top_evaluator->isa<BaseFuncGraphEvaluator>()) {
// Top or root func_graph must be FuncGraph other than MetaFuncGraph;
if (!top_evaluator->isa<FuncGraphEvaluator>()) {
MS_LOG(EXCEPTION) << "Top evaluator is " << top_evaluator->ToString();
}
auto top_fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(top_evaluator);
auto top_context_fg = top_fg_evaluator->context()->func_graph();
auto top_fg_evaluator = dyn_cast<FuncGraphEvaluator>(top_evaluator);
auto top_context_fg = top_fg_evaluator->func_graph();
if (current_cnode_fg != top_context_fg) { // Ignore FV call.
return;
}
@ -339,6 +342,8 @@ void AnalysisEngine::Clear() {
evaluators_.clear();
constructors_app_.clear();
continued_evals_.clear();
root_func_graph_ = nullptr;
root_context_ = nullptr;
}
namespace {
@ -578,7 +583,7 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
#endif
}
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
static std::mutex fg_lock;
std::lock_guard<std::mutex> infer_lock(fg_lock);
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
@ -591,10 +596,17 @@ bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
auto undetermined_fgs = fg->recursive();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString();
return true;
if (fg_parent != nullptr) {
fg_parent->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString();
return true;
} else if (possible_parent_fg != nullptr) {
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString() << " for fg: " << fg->ToString();
return true;
} else {
MS_LOG(EXCEPTION) << "cannot find parent for fg: " << fg->ToString();
}
}
return false;
}
@ -810,8 +822,11 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>();
AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>();
bool firstRun = !SetUndeterminedFlag(evaluators[0]);
(void)SetUndeterminedFlag(evaluators[1]);
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto possible_parent_fg = out_conf->node()->func_graph();
bool firstRun = !SetUndeterminedFlag(evaluators[0], possible_parent_fg);
(void)SetUndeterminedFlag(evaluators[1], possible_parent_fg);
std::string threadId = AnalysisResultCacheMgr::GetThreadid();
MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString();
@ -891,8 +906,12 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
for (const auto &eval : evaluators) {
(void)SetUndeterminedFlag(eval);
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto possible_parent_fg = out_conf->node()->func_graph();
for (auto eval : evaluators) {
(void)SetUndeterminedFlag(eval, possible_parent_fg);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.

View File

@ -248,6 +248,10 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf);
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
FuncGraphPtr root_func_graph() const { return root_func_graph_; }
AnalysisContextPtr root_context() const { return root_context_; }
void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; }
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
void ResetFunctionCallDepth() {
@ -292,7 +296,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
private:
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator);
bool SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
bool *continue_flag);
@ -308,6 +312,9 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
std::list<EvaluatorArgs> eval_trace_;
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_;
// root or top func_graph for static analysis;
FuncGraphPtr root_func_graph_{nullptr};
AnalysisContextPtr root_context_{nullptr};
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list);

View File

@ -98,6 +98,21 @@ void FuncGraph::add_parameter(const ParameterPtr &p) {
}
}
ParameterPtr FuncGraph::InsertFrontParameter() {
FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_func_graph);
InsertFrontParameter(p);
return p;
}
void FuncGraph::InsertFrontParameter(const ParameterPtr &p) {
if (manager_.lock()) {
manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p);
} else {
PrependParameter(p);
}
}
ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
FuncGraphPtr this_graph = shared_from_base<FuncGraph>();
ParameterPtr p = std::make_shared<Parameter>(this_graph);

View File

@ -83,6 +83,7 @@ const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
@ -169,9 +170,14 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
// Append
virtual ParameterPtr add_parameter();
void add_parameter(const ParameterPtr &p);
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
// Prepend
virtual ParameterPtr InsertFrontParameter();
void InsertFrontParameter(const ParameterPtr &p);
void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
// Add a weight parameter with specific name.
ParameterPtr AddWeightParameter(const std::string &name);
@ -354,7 +360,6 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
std::unordered_map<std::string, ValuePtr> attrs_;
std::vector<BaseShapePtr> joined_shapes_;
std::unordered_map<std::string, FuncGraphTransform> transforms_;
// Parameter default value.
std::map<std::string, AnfNodePtr> parameter_default_value_;

View File

@ -224,7 +224,6 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
TraceGuard trace_guard(func_graph->debug_info(), target_relation_);
*target_func_graph = std::make_shared<FuncGraph>();
(*target_func_graph)->set_attrs(func_graph->attrs());
(*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_;
(*target_func_graph)->set_transforms(func_graph->transforms());
(*target_func_graph)->set_has_vararg(func_graph->has_vararg());
(*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
@ -254,9 +253,24 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
return;
}
CloneInfo item = todo_.back();
auto lift_top_func_graph = item.origin;
for (auto &fv_map : iter->second) {
auto &free_var = fv_map.first;
if (utils::isa<AnfNodePtr>(free_var)) {
auto free_var_node = utils::cast<AnfNodePtr>(free_var);
// Don't lift weight parameter to top func_graph.
if (func_graph == lift_top_func_graph) {
if (free_var_node->isa<Parameter>()) {
auto free_var_param = free_var_node->cast<ParameterPtr>();
if (free_var_param->has_default()) {
MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->ToString()
<< " for top_func_graph: " << lift_top_func_graph->ToString();
continue;
}
}
}
MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
}
}
@ -301,6 +315,8 @@ void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList
}
}
AnfNodePtr new_param = nullptr;
CloneInfo item = todo_.back();
auto lift_top_func_graph = item.origin;
for (auto &param : params) {
auto old_param = repl_node_[param];
if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
@ -314,6 +330,14 @@ void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList
input_params->push_back(new_param);
continue;
}
if (lift_top_func_graph == func_graph) {
// Don't lift parameter from used_graphs to my parameter if I am the top;
repl_node_[old_param] = old_param;
input_params->push_back(old_param);
MS_LOG(DEBUG) << "Bypass param: " << old_param->ToString()
<< " for top_func_graph: " << lift_top_func_graph->ToString();
continue;
}
new_param = AddParameter(func_graph, old_param, false);
parameters.push_back(new_param);
lift_params->push_back(new_param);

View File

@ -445,6 +445,12 @@ void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &pa
tr.Commit();
}
void FuncGraphManager::InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter) {
auto tr = Transact();
tr.InsertFrontParameter(fg, parameter);
tr.Commit();
}
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
auto func_graph = old_node->func_graph();
auto tr = Transact();
@ -599,6 +605,13 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl
auto param_node = param.param->cast<ParameterPtr>();
param.func_graph->append_parameter(param_node);
} break;
case Change::kTxInsertFrontParam: {
auto param = args.cast<ArgsOfInsertFrontParam>();
MS_EXCEPTION_IF_NULL(param.func_graph);
(*adds)[param.param] += 1;
auto param_node = param.param->cast<ParameterPtr>();
param.func_graph->PrependParameter(param_node);
} break;
default:
break;
}
@ -665,6 +678,10 @@ void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr &param
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param});
}
void FuncGraphTransaction::InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr &param) {
changes_.emplace_back(Change::kTxInsertFrontParam, ArgsOfInsertFrontParam{fg, param});
}
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node);

View File

@ -312,6 +312,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void RemoveRoots();
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
@ -406,6 +407,7 @@ class FuncGraphTransaction {
// set parameters of a func graph
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
void AddParameter(FuncGraphPtr fg, const AnfNodePtr &param);
void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr &param);
// replace old_node with new_node
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
@ -447,6 +449,18 @@ struct ArgsOfAddParam {
}
};
// args for InsertFront param
struct ArgsOfInsertFrontParam {
FuncGraphPtr func_graph;
AnfNodePtr param;
bool operator==(const ArgsOfInsertFrontParam &other) const { return &other == this; }
friend std::ostream &operator<<(std::ostream &os, const ArgsOfInsertFrontParam &) {
os << "[ArgsOfInsertFrontParam]";
return os;
}
};
// args for set edge
struct ArgsOfSetEdge {
CNodePtr root_node;
@ -473,7 +487,7 @@ struct ArgsOfAddEdge {
};
struct Change {
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxAddEdge };
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxInsertFrontParam, kTxAddEdge };
OpName op;
Any args;
Change(OpName name, const Any &para) : op(name), args(para) {}

View File

@ -40,6 +40,16 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect
auto element0_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_0);
for (size_t i = 0; i < elements.size(); ++i) {
auto shape = elements[i]->BuildShape();
if (shape->isa<abstract::Shape>() && shape_0->isa<abstract::Shape>()) {
const auto &shape_vec = shape->cast<abstract::ShapePtr>()->shape();
const auto &shape_0_vec = shape_0->cast<abstract::ShapePtr>()->shape();
if ((shape_vec == ShapeVector({1}) && shape_0_vec == ShapeVector()) ||
(shape_vec == ShapeVector() && shape_0_vec == ShapeVector({1}))) {
MS_LOG(DEBUG) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
<< " are consistent with the shape of input[0]" << shape_0->ToString();
continue;
}
}
if (*shape != *shape_0) {
MS_EXCEPTION(ValueError) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString()
<< " are not consistent with the shape of input[0]" << shape_0->ToString();

View File

@ -676,7 +676,7 @@ class SideEffectControlFlowAssignDependWhileNet(Cell):
return grad_out
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu

View File

@ -37,11 +37,13 @@ class TestAD : public UT::Common {
public:
UT::PyFuncGraphFetcher getPyFun;
pipeline::ResourceBasePtr resourcePtr = std::make_shared<pipeline::ResourceBase>();
pipeline::ResourcePtr resourcePtr = std::make_shared<pipeline::Resource>();
protected:
void AssertExpect(const std::string& testCase) {
FuncGraphPtr g = getPyFun(testCase);
resourcePtr->manager()->RemoveRoots();
resourcePtr->manager()->AddFuncGraph(g, true);
FuncGraphPtr dg = Grad(g, resourcePtr);
AssertExpect(testCase, dg);
}

View File

@ -20,9 +20,10 @@ from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common.parameter import Parameter, ParameterTuple
grad_all = C.GradOperation(get_all=True)
grad_by_list = C.GradOperation(get_by_list=True)
class CropAndResizeNet(nn.Cell):
def __init__(self, crop_size):
@ -138,3 +139,38 @@ def test_ad_fv_cnode_order():
net.add_flags_recursive(defer_inline=True)
grad_net = grad_all(net)
grad_net(input_x, input_y)
# True and False branch of switch have different number of parameters.
def test_if_branch_with_different_params():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
self.weight2 = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="weight2")
def construct(self, idx, end, x):
out = x
if idx < end:
out = out + self.weight1 * self.weight2
else:
out = out + self.weight1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, idx, end, x):
return grad_by_list(self.net, self.weights)(idx, end, x)
idx = Tensor(np.array((0), dtype=np.int32))
end = Tensor(np.array((3), dtype=np.int32))
x = Tensor(np.array([2.0], dtype=np.float32))
net = Net()
grad_net = GradNet(net)
grad_net(idx, end, x)