support if by if grad parameter

add join for ref

adjust env eliminate to eliminate all env ops

add partial app cache

resolve while endless

fix env eliminate

support for "for while" cases

fix join shape error
This commit is contained in:
huangdongrun 2020-07-09 08:29:07 +08:00
parent eb4749642d
commit 2a6d346d2f
27 changed files with 1259 additions and 112 deletions

View File

@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra
} }
oss << "SymInst(%para" << idx << ")"; oss << "SymInst(%para" << idx << ")";
} else { } else {
MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
oss << "SymInst(cnode_" << sym_node->ToString() << ")";
} }
return oss.str(); return oss.str();

View File

@ -189,6 +189,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
if (!morph->isa<CNode>()) { if (!morph->isa<CNode>()) {
return nullptr; return nullptr;
} }
// for free variable, which may be handled in MapValueObject, just return it
auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
if (node_adjoint_found != anfnode_to_adjoin_.end()) {
return node_adjoint_found->second;
}
ScopeGuard scope_guard(morph->scope()); ScopeGuard scope_guard(morph->scope());
auto cnode_morph = morph->cast<CNodePtr>(); auto cnode_morph = morph->cast<CNodePtr>();
@ -502,7 +507,7 @@ void DFunctor::MapFvObject() {
if (parent_adjoint != nullptr) { if (parent_adjoint != nullptr) {
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_); adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
} else { } else {
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) { if (is_top_ || node->isa<Parameter>()) {
// Out of ad scope, add adjoint for free variables. // Out of ad scope, add adjoint for free variables.
adjoint = std::make_shared<Adjoint>(node, node, tape_); adjoint = std::make_shared<Adjoint>(node, node, tape_);
UpdateAdjoint(adjoint); UpdateAdjoint(adjoint);

View File

@ -87,10 +87,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
env_get_item_eliminate_ = env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem); new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ = incorporate_env_getitem_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(), incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem); "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
// Ref eliminate // Ref eliminate
make_ref_eliminate_ = make_ref_eliminate_ =
@ -122,6 +124,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// inline // inline
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph); inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
replace_applicator_ = replace_applicator_ =
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>); MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ = specialize_transform_ =

View File

@ -55,6 +55,7 @@ class OptimizeIRPassLib {
SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr new_env_get_item_; SubstitutionPtr new_env_get_item_;
SubstitutionPtr incorporate_env_getitem_; SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_; SubstitutionPtr incorporate_env_getitem_switch_;
// Ref eliminate // Ref eliminate
@ -80,6 +81,7 @@ class OptimizeIRPassLib {
// inline // inline
SubstitutionPtr inline_; SubstitutionPtr inline_;
SubstitutionPtr inline_without_move_;
SubstitutionPtr replace_applicator_; SubstitutionPtr replace_applicator_;
SubstitutionPtr specialize_transform_; SubstitutionPtr specialize_transform_;
@ -193,6 +195,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
auto inp0 = node->cast<CNodePtr>()->input(0); auto inp0 = node->cast<CNodePtr>()->input(0);
return (inp0 != nullptr) && inp0->isa<CNode>(); return (inp0 != nullptr) && inp0->isa<CNode>();
} }
// check if the cnode is a switch cnode
inline bool IsCNodeSwitch(const AnfNodePtr &node) {
if (node != nullptr) {
if (node->isa<CNode>()) {
return IsPrimitiveCNode(node, prim::kPrimSwitch);
}
}
return false;
}
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "frontend/optimizer/anf_visitor.h" #include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
@ -59,8 +60,13 @@ class EnvGetitemTransform {
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value} // {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs(); auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) { if (inputs.size() != 4) {
MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
} }
env = inputs[1]; env = inputs[1];
@ -91,33 +97,12 @@ class EnvGetitemTransform {
class NewEnvGetItem : public AnfVisitor { class NewEnvGetItem : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); PatternNode c1, c2, y;
auto gety = [this](const AnfNodePtr &node) -> bool { MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y,
this->y_ = node; (IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) &&
return true; (GetValueNode<EnvInstancePtr>(c1.GetNode(node)))->Len() == 0));
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
if (env_ != nullptr && env_->Len() == 0) {
return y_;
}
return nullptr; return nullptr;
} }
void Visit(const ValueNodePtr &vnode) override {
if (env_ == nullptr) {
env_ = GetValueNode<EnvInstancePtr>(vnode);
}
}
void Reset() {
y_ = nullptr;
env_ = nullptr;
}
private:
AnfNodePtr y_{nullptr};
EnvInstancePtr env_{nullptr};
}; };
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor {
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value} // {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs(); auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) { if (inputs.size() != 4) {
MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
} }
env = inputs[1]; env = inputs[1];
@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller {
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} // {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor { class IncorporateEnvGetitem : public AnfVisitor {
public: public:
IncorporateEnvGetitem() : env_get_item_transform_() {} explicit IncorporateEnvGetitem(bool bypass_recursive = false)
: env_get_item_transform_(), bypass_recursive_(bypass_recursive) {}
~IncorporateEnvGetitem() override = default; ~IncorporateEnvGetitem() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor {
auto inputs = inp1->inputs(); auto inputs = inp1->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
auto new_fg = env_get_item_transform_(fg, key, default_v); auto new_fg = env_get_item_transform_(fg, key, default_v);
if (fg->recursive() && bypass_recursive_) {
MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString();
return nullptr;
}
if (new_fg == nullptr) {
return nullptr;
}
std::vector<AnfNodePtr> args; std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(new_fg)); args.push_back(NewValueNode(new_fg));
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor {
private: private:
bool is_match_{false}; bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_; internal::EnvGetitemTransform env_get_item_transform_;
bool bypass_recursive_;
}; };
// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3)); auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
auto new_g1 = env_get_item_transform_(g1, key, default_v); auto new_g1 = env_get_item_transform_(g1, key, default_v);
auto new_g2 = env_get_item_transform_(g2, key, default_v); auto new_g2 = env_get_item_transform_(g2, key, default_v);
if (new_g1 == nullptr || new_g2 == nullptr) {
return nullptr;
}
auto fg = node->func_graph(); auto fg = node->func_graph();
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});

View File

@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) {
bool unique_use = IsUniqueUse(fg, nullptr);
bool is_recursive = fg->recursive();
if (fg->parent() != nullptr && is_recursive) {
if (fg->parent() == node->func_graph() && unique_use) {
return true;
}
}
return false;
}
// {G, Xs} // {G, Xs}
class InlinerBase : public AnfVisitor { class InlinerBase : public AnfVisitor {
public: public:
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {} explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true)
: use_move_(use_move), criterions_(criterions) {}
~InlinerBase() override = default; ~InlinerBase() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr; return nullptr;
} }
// Do not inline GraphKernel to Cell. // Do not inline GraphKernel to Cell.
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If the GraphKernel only contains a return node, we make it inlined. // If the GraphKernel only contains a return node, we make it inlined.
@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor {
std::vector<AnfNodePtr> params; std::vector<AnfNodePtr> params;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
// compare size to avoid the case that the function has default value after grad.
if (IsUniqueUse(fg, nullptr)) { // for which after renormalize, the function default value will be an input
if (fg->parameters().size() != params.size()) {
return nullptr;
}
if (use_move_ && IsUniqueUse(fg, nullptr)) {
auto mng = fg->manager(); auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, params, fg); ReplaceParams(mng, params, fg);
@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor {
private: private:
bool is_checked_{false}, is_recursive_{false}; bool is_checked_{false}, is_recursive_{false};
bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_; std::vector<std::pair<CriterionFuncType, bool>> criterions_;
}; };
class Inliner : public InlinerBase { class Inliner : public InlinerBase {
public: public:
Inliner() explicit Inliner(bool use_move = true)
: InlinerBase({ : InlinerBase(
{
{IsUniqueUse, true}, {IsUniqueUse, true},
{IsTrivial, false}, {IsTrivial, false},
{IsInside, false}, {IsInside, false},
{IsCore, false}, {IsCore, false},
{IsDirectParentCall, false},
{NoCriterion, true}, {NoCriterion, true},
}) {} },
use_move) {}
~Inliner() override = default; ~Inliner() override = default;
}; };
class DirectInliner : public InlinerBase {
public:
explicit DirectInliner(bool use_move = true)
: InlinerBase(
{
{IsDirectParentCall, false},
},
use_move) {}
~DirectInliner() override = default;
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -26,6 +26,30 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
namespace internal {
class GetRefValueTransform {
public:
GetRefValueTransform() {}
~GetRefValueTransform() = default;
AnfNodePtr operator()(const AnfNodePtr &node) {
CNodePtr cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>();
if (fg->recursive()) {
MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString();
return node;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue"));
auto output = new_fg->output();
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output}));
inputs[0] = NewValueNode(new_fg);
auto ret_node = cnode->func_graph()->NewCNode(inputs);
return ret_node;
}
};
} // namespace internal
// {prim::kPrimMakeRef, X, Y, Z} -> Y // {prim::kPrimMakeRef, X, Y, Z} -> Y
class MakeRefEliminater : public OptimizerCaller { class MakeRefEliminater : public OptimizerCaller {
public: public:
@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller {
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f}
class GetMakeRefEliminater : public OptimizerCaller { class GetMakeRefEliminater : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z; PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node));
internal::GetRefValueTransform trans;
auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr {
auto rep = trans(x.GetNode(node));
if (rep != nullptr) {
return rep;
}
return nullptr;
};
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node));
return nullptr; return nullptr;
} }
}; };

View File

@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
std::string backend = MsContext::GetInstance()->backend_policy();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) { if (CompileGraphs::ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false); bc_ptr->set_is_multi_graph_sink(false);
@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
context_ptr->set_loop_sink_flag(false); context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) { } else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target(); std::string device_target = context_ptr->device_target();
if (device_target == kAscendDevice) { if (device_target == kAscendDevice && backend != kMsVm) {
bc_ptr->set_is_multi_graph_sink(true); bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true); context_ptr->set_is_multi_graph_sink(true);
} }
} }
if (IsCtrlSink()) { if (IsCtrlSink() && backend == kMsConvert) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true; return true;
} }
@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) {
if (res->results().count(kOutput) == 0) { if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error"; MS_LOG(EXCEPTION) << "Execute args error";
} }
std::string backend = MsContext::GetInstance()->backend_policy();
if (IsCtrlSink()) { if (IsCtrlSink() && backend == kMsConvert) {
if (!res->results()[kOutput].is<GraphId>()) { if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error"; MS_LOG(EXCEPTION) << "Execute args error";
} }

View File

@ -30,6 +30,7 @@
#include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/resource.h" #include "pipeline/jit/resource.h"
#include "pipeline/jit/validator.h" #include "pipeline/jit/validator.h"
#include "pipeline/jit/remove_value_node_dup.h"
#include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/cse.h" #include "frontend/optimizer/cse.h"
#include "frontend/optimizer/graph_kernel_reuse.h" #include "frontend/optimizer/graph_kernel_reuse.h"
@ -127,11 +128,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_getitem_set_, irpass.incorporate_getitem_set_,
irpass.incorporate_call_, irpass.incorporate_call_,
irpass.incorporate_call_switch_, irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_, irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_, irpass.new_env_get_item_,
irpass.depend_value_elim_, irpass.depend_value_elim_,
}); });
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({ opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_, irpass.arithmetic_simplify2_,
irpass.same_eliminate_, irpass.same_eliminate_,
@ -154,6 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"virtual_dataset", virtual_dataset}, {"virtual_dataset", virtual_dataset},
{"grad", grad}, {"grad", grad},
{"resolve", resolve_pass}, {"resolve", resolve_pass},
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()}, {"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSE(false))}, {"cse", opt::OptPassConfig(opt::CSE(false))},
{"a_3", a_3}}); {"a_3", a_3}});
@ -161,11 +166,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return map_a; return map_a;
} }
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining
irpass.inline_,
irpass.partial_eliminate_,
});
OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
return map_a;
}
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig b_1 = opt::OptPassConfig(
opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.value_based_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({ opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,
@ -244,6 +262,8 @@ void InitOpt(const ResourcePtr &res) {
opt::irpass::OptimizeIRPassLib irpass; opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_after_cconv"] =
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
g_pass_opts["opt_graph_kernel_a"] = g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] = g_pass_opts["opt_graph_kernel_b"] =
@ -288,6 +308,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
@ -311,6 +332,33 @@ bool AddControlDependPass(const ResourcePtr &res) {
return true; return true;
} }
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(res->manager());
if (res->manager()->func_graphs().size() <= 1) {
return true;
}
return MergeDuplicateGraphs(res->manager());
}
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
}
auto manager = res->manager();
HashCache hash_cache;
HashValue hashes;
// Remove duplicated value nodes across all graphs in manager
for (auto &fg : manager->func_graphs()) {
auto value_nodes = fg->value_nodes();
for (const auto &value_pair : value_nodes) {
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
}
}
return true;
}
bool CconvPass(const ResourcePtr &res) { bool CconvPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph()); MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
@ -340,6 +388,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"clean_after_opta", CleanAfterOptAPass}, {"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},
{"cconv", CconvPass}, {"cconv", CconvPass},
{"opt_after_cconv", OptPassAfterCconvGroup},
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}}; {"add_control_depend", AddControlDependPass}};

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <string>
#include "pipeline/jit/remove_value_node_dup.h" #include "pipeline/jit/remove_value_node_dup.h"
#include "ir/anf.h" #include "ir/anf.h"
@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
// Meet for the first time, append node to bucket. // Meet for the first time, append node to bucket.
bucket.emplace_back(node); bucket.emplace_back(node);
} }
size_t HashOfGraph(const FuncGraphPtr &fg) {
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
MS_LOG(DEBUG) << "TopSort for:" << fg->ToString();
std::unordered_map<AnfNodePtr, std::size_t> hashes;
auto &params = fg->parameters();
for (size_t i = 0; i < params.size(); i++) {
hashes[params[i]] = std::hash<std::string>{}("param" + std::to_string(i));
}
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
if (hashes.find(node) != hashes.end()) {
continue;
}
std::size_t h = 0;
if (node->isa<ValueNode>()) {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (IsValueNode<FuncGraph>(value_node)) {
auto v_fg = value->cast<FuncGraphPtr>();
h = value->hash();
} else if (IsValueNode<tensor::Tensor>(value_node)) {
// the tensor has same value has been replaced in duplicate value pass,
// so we use the value pointer here as an identifier
h = hash_combine(value->hash(), std::hash<Value *>{}(value.get()));
} else {
h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash()));
}
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
size_t init = 0;
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
return hash_combine(hash, hashes[node_in]);
});
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknow node type";
}
hashes[node] = h;
}
return hashes[fg->get_return()];
}
bool IsCNodeGraph(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
return IsValueNode<FuncGraph>(inp0);
}
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) {
std::unordered_map<size_t, std::vector<FuncGraphPtr>> hash_graphs;
std::unordered_map<FuncGraphPtr, size_t> graph_hash;
for (auto fg : manager->func_graphs()) {
size_t h = HashOfGraph(fg);
graph_hash[fg] = h;
if (hash_graphs.find(h) == hash_graphs.end()) {
hash_graphs[h] = {fg};
} else {
hash_graphs[h].push_back(fg);
}
}
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
for (auto &fg : manager->func_graphs()) {
MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString();
for (auto &item : fg->nodes()) {
if (!item->isa<CNode>()) {
continue;
}
auto &inputs = item->cast<CNodePtr>()->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
if (!inputs[i]->isa<ValueNode>()) {
continue;
}
auto value_ptr = GetValueNode(inputs[i]);
auto v_fg = value_ptr->cast<FuncGraphPtr>();
if (v_fg == nullptr) {
continue;
}
auto &fg_vec = hash_graphs[graph_hash[v_fg]];
if (fg_vec.size() > 1) {
if (v_fg != fg_vec[0]) {
bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node);
if (is_morphic) {
auto new_node = NewValueNode(fg_vec[0]);
MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString();
manager->Replace(inputs[i], new_node);
}
}
}
}
}
}
return true;
}
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -28,6 +28,10 @@ using HashCache = std::unordered_map<std::size_t, std::vector<AnfNodePtr>>;
using HashValue = std::unordered_map<AnfNodePtr, std::size_t>; using HashValue = std::unordered_map<AnfNodePtr, std::size_t>;
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
size_t HashOfGraph(const FuncGraphPtr &fg);
bool IsCNodeGraph(const AnfNodePtr &node);
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager);
} // namespace pipeline } // namespace pipeline
} // namespace mindspore } // namespace mindspore

View File

@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
} }
const AnfNodePtr &func_node = fg->get_return(); const AnfNodePtr &func_node = fg->get_return();
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
AbstractBasePtr ret_base = nullptr; AbstractBasePtr ret_base = nullptr;
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it; const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
<< ", abstract: " << ret_base->ToString(); << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
} }
MS_EXCEPTION_IF_NULL(ret_base); MS_EXCEPTION_IF_NULL(ret_base);
@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr { [](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden(); return arg->Broaden();
});
if (func_graph_->joined_shapes_.size() != broaded_list.size()) {
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size()
<< " does not equal to number of original buffer arguments "
<< func_graph_->joined_shapes_.size();
} }
return arg;
});
if (func_graph_->joined_shapes_.size() == broaded_list.size()) {
for (size_t i = 0; i < broaded_list.size(); ++i) { for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
} }
}
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
<< ", broaded: " << mindspore::ToString(broaded_list); << ", broaded: " << mindspore::ToString(broaded_list);
return broaded_list; return broaded_list;
@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear(); func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_), std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
} }
return joined_args_spec_list; return joined_args_spec_list;
@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear(); func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_), std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
} }
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);

View File

@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
trace::TraceEvalCNodeLeave(); trace::TraceEvalCNodeLeave();
} else { } else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
<< (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} }
@ -301,6 +302,8 @@ void AnalysisEngine::Clear() {
anfnode_config_map_.clear(); anfnode_config_map_.clear();
eval_trace_.clear(); eval_trace_.clear();
constructors_.clear(); constructors_.clear();
constructors_app_.clear();
continued_evals_.clear();
} }
namespace { namespace {
@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr
MS_EXCEPTION_IF_NULL(func); MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr func_orig = func->fn(); AbstractFunctionPtr func_orig = func->fn();
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
auto part_pair = std::make_pair(func_orig, func->args());
auto itr = constructors_app_.find(part_pair);
if (itr != constructors_app_.end()) {
return itr->second;
}
std::shared_ptr<PartialAppEvaluator> partial_evaluator = std::shared_ptr<PartialAppEvaluator> partial_evaluator =
std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args()); std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
constructors_app_[part_pair] = partial_evaluator;
return partial_evaluator; return partial_evaluator;
} }
@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
if (fg_eval == nullptr) { if (fg_eval == nullptr) {
return; return;
} }
auto fg = fg_eval->func_graph(); auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs(); auto undetermined_fgs = fg->recursive();
if (undetermined_fgs) { if (undetermined_fgs) {
auto fg_parent = fg->parent(); auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent); MS_EXCEPTION_IF_NULL(fg_parent);
@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) { for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { auto &alternate_evaluator = multi_poss_[u_eval.first];
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; auto &eval_cache = alternate_evaluator->cache();
if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) &&
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
(eval_cache->find(args_spec_list) == eval_cache->end()))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined.";
has_undetermined = true; has_undetermined = true;
break; break;
} }
} }
if (has_undetermined == false) { if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
*continue_flag = true; *continue_flag = true;
return latest_entry; return latest_entry;
} }
@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
auto current_inf = std::make_pair(eval, args_spec_list); auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
if (it == eval_trace_.rend()) { if (it == eval_trace_.rend()) {
eval_trace_.push_back(current_inf); eval_trace_.push_back(current_inf);
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(eval);
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract()); MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
out_specs.push_back(eval_result->abstract()); out_specs.push_back(eval_result->abstract());
eval_trace_.pop_back(); eval_trace_.pop_back();
if (eval_trace_.empty()) { if (eval_trace_.empty()) {
multi_poss_.clear(); multi_poss_.clear();
} }
} else if (it != eval_trace_.rbegin()) { } else {
bool continue_flag = false; bool continue_flag = false;
auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
if (continue_flag) { if (continue_flag) {
MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString();
continued_evals_.insert(current_inf);
continue; continue;
} }
// Try to travel the latest undetermined. // Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->first) { if (latest_entry != eval_trace_.rbegin()->first) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract()); MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
<< " return out_spec: " << eval_result->abstract()->ToString(); << " return out_spec: " << eval_result->abstract()->ToString();
return eval_result; return eval_result;
} }

View File

@ -26,6 +26,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <map> #include <map>
#include <set>
#ifdef DEBUG #ifdef DEBUG
#include <stack> #include <stack>
@ -113,7 +114,8 @@ class AnfNodeConfig : public Config {
std::string ToString() const override { std::string ToString() const override {
std::ostringstream buffer; std::ostringstream buffer;
buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId()
<< "), Context: " << context_->ToString();
return buffer.str(); return buffer.str();
} }
@ -173,7 +175,13 @@ struct AnalysisResult {
}; };
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator; using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
struct PartialAppHasher {
std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const {
auto h1 = std::hash<AbstractFunctionPtr>{}(p.first);
auto h2 = AbstractBasePtrListHash(p.second);
return h1 ^ h2;
}
};
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public: public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const PrimEvaluatorMap &prim_constructors_; const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_; FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_; std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
constructors_app_;
AnfNodeConfigMap anfnode_config_map_; AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators. // Use a list to trace multiple evaluators.
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> continued_evals_;
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list); const ConfigPtrList &args_conf_list);

View File

@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError;
using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractRef;
using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractRowTensor;
using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractSparseTensor;
@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
// only send string in external // only send string in external
if (!IsValueNode<StringImm>(node)) { if (!IsValueNode<StringImm>(node)) {
// Validate a type. // Validate a type.
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString()
<< " for node=" << node->DebugString();
} }
} }
return; return;
@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() || ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) { ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>()) {
return; return;
} }

View File

@ -481,7 +481,9 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
} }
// Isomorphism // Isomorphism
static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node);
bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node) { NodeMapEquiv *const equiv_node) {
if (equiv_node == nullptr) { if (equiv_node == nullptr) {
MS_LOG(ERROR) << "Invalid equiv_node"; MS_LOG(ERROR) << "Invalid equiv_node";
@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu
MS_LOG(DEBUG) << "two parameters are not equal."; MS_LOG(DEBUG) << "two parameters are not equal.";
return false; return false;
} }
if (node1->isa<CNode>() && node2->isa<CNode>()) {
return SameNode(node1, node2, equiv_func_graph, equiv_node);
}
MS_LOG(ERROR) << "type error"; MS_LOG(ERROR) << "type error";
return false; return false;
} }

View File

@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
} // namespace } // namespace
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
auto fg = std::make_shared<FuncGraph>();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
if (lst.empty()) { if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty"; MS_LOG(EXCEPTION) << "Input anf node list is empty";
} }
TraceManager::DebugTrace(
std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
auto fg = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
// Merge CNodes into a AnfGraph that represents a linear instruction segment // Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) { for (auto n : lst) {
if (!n->isa<CNode>()) { if (!n->isa<CNode>()) {
@ -154,7 +157,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
} }
TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info()));
eqv[n] = fg->NewCNode(args); eqv[n] = fg->NewCNode(args);
TraceManager::EndTrace();
eqv[n]->set_abstract(n->abstract()); eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr()); eqv[n]->set_kernel_info(n->kernel_info_ptr());
} }

View File

@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
} }
auto other_tensor = dyn_cast<AbstractTensor>(other); auto other_tensor = dyn_cast<AbstractTensor>(other);
if (other_tensor == nullptr) { if (other_tensor == nullptr) {
auto ref_tensor = dyn_cast<AbstractRef>(other);
if (ref_tensor != nullptr) {
return this->Join(ref_tensor->ref());
}
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
} }
if (*this == *other) { if (*this == *other) {

View File

@ -48,7 +48,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
continue; continue;
} }
if (rank.find(node) != rank.end() && rank[node] != todo.size()) { if (rank.find(node) != rank.end() && rank[node] != todo.size()) {
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(); MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
} }
rank[node] = todo.size(); rank[node] = todo.size();
bool cont = false; bool cont = false;

View File

@ -30,6 +30,7 @@
#include "base/base.h" #include "base/base.h"
#include "ir/dtype.h" #include "ir/dtype.h"
#include "ir/dtype/number.h" #include "ir/dtype/number.h"
#include "utils/hashing.h"
using std::fabs; using std::fabs;
@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr<Scalar>;
class BoolImm : public Scalar { class BoolImm : public Scalar {
public: public:
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash<bool>{}(v_); } explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
~BoolImm() override = default; ~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm, Scalar) MS_DECLARE_PARENT(BoolImm, Scalar)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -91,7 +92,7 @@ class IntergerImm : public Scalar {
class Int8Imm : public IntergerImm { class Int8Imm : public IntergerImm {
public: public:
Int8Imm() : IntergerImm(kInt8), v_(0) {} Int8Imm() : IntergerImm(kInt8), v_(0) {}
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash<int>{}(v_); } explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int8Imm() override = default; ~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm, IntergerImm) MS_DECLARE_PARENT(Int8Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t)
class Int16Imm : public IntergerImm { class Int16Imm : public IntergerImm {
public: public:
Int16Imm() : IntergerImm(kInt16), v_(0) {} Int16Imm() : IntergerImm(kInt16), v_(0) {}
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash<int>{}(v_); } explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int16Imm() override = default; ~Int16Imm() override = default;
MS_DECLARE_PARENT(Int16Imm, IntergerImm) MS_DECLARE_PARENT(Int16Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t)
class Int32Imm : public IntergerImm { class Int32Imm : public IntergerImm {
public: public:
Int32Imm() : IntergerImm(kInt32), v_(0) {} Int32Imm() : IntergerImm(kInt32), v_(0) {}
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash<int>{}(v_); } explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int32Imm() override = default; ~Int32Imm() override = default;
MS_DECLARE_PARENT(Int32Imm, IntergerImm) MS_DECLARE_PARENT(Int32Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t)
class Int64Imm : public IntergerImm { class Int64Imm : public IntergerImm {
public: public:
Int64Imm() : IntergerImm(kInt64), v_(0) {} Int64Imm() : IntergerImm(kInt64), v_(0) {}
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash<int64_t>{}(v_); } explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
~Int64Imm() override = default; ~Int64Imm() override = default;
MS_DECLARE_PARENT(Int64Imm, IntergerImm) MS_DECLARE_PARENT(Int64Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t)
class UInt8Imm : public IntergerImm { class UInt8Imm : public IntergerImm {
public: public:
UInt8Imm() : IntergerImm(kUInt8), v_(0) {} UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash<unsigned int>{}(v_); } explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt8Imm() override = default; ~UInt8Imm() override = default;
MS_DECLARE_PARENT(UInt8Imm, IntergerImm) MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t);
class UInt16Imm : public IntergerImm { class UInt16Imm : public IntergerImm {
public: public:
UInt16Imm() : IntergerImm(kUInt16), v_(0) {} UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash<unsigned int>{}(v_); } explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt16Imm() override = default; ~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm, IntergerImm) MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t);
class UInt32Imm : public IntergerImm { class UInt32Imm : public IntergerImm {
public: public:
UInt32Imm() : IntergerImm(kUInt32), v_(0) {} UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash<unsigned int>{}(v_); } explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt32Imm() override = default; ~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm, IntergerImm) MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t);
class UInt64Imm : public IntergerImm { class UInt64Imm : public IntergerImm {
public: public:
UInt64Imm() : IntergerImm(kUInt64), v_(0) {} UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash<uint64_t>{}(v); } explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
}
~UInt64Imm() override = default; ~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm, IntergerImm) MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr<FloatImm>;
class FP32Imm : public FloatImm { class FP32Imm : public FloatImm {
public: public:
FP32Imm() : FloatImm(kFloat32), v_(0.0) {} FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash<float>{}(v_); } explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
~FP32Imm() override = default; ~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm, FloatImm) MS_DECLARE_PARENT(FP32Imm, FloatImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }
@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float)
class FP64Imm : public FloatImm { class FP64Imm : public FloatImm {
public: public:
FP64Imm() : FloatImm(kFloat64), v_(0.0) {} FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash<double>{}(v_); } explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
~FP64Imm() override = default; ~FP64Imm() override = default;
MS_DECLARE_PARENT(FP64Imm, FloatImm) MS_DECLARE_PARENT(FP64Imm, FloatImm)
std::size_t hash() const override { return hash_; } std::size_t hash() const override { return hash_; }

View File

@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo {
return std::make_shared<TraceCombileLikeGraphs>(*shared_from_base<TraceCombileLikeGraphs>()); return std::make_shared<TraceCombileLikeGraphs>(*shared_from_base<TraceCombileLikeGraphs>());
} }
}; };
class TraceSegmentTransform : public TraceInfo {
public:
explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceSegmentTransform() override = default;
TraceInfoPtr clone() override {
return std::make_shared<TraceSegmentTransform>(*shared_from_base<TraceSegmentTransform>());
}
};
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_ #endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_

View File

@ -0,0 +1,816 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test control ops """
import numpy as np
from mindspore import dtype as ms
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import operations as P
# from tests.vm_impl.math_ops_vm_impl import *
# from tests.vm_impl.vm_interface import *
# from tests.vm_impl import *
# context.set_context(save_graphs=True)
def test_while_forward():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
def construct(self, idx, end, x):
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
idx = idx + 1
return x
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
def construct(self, idx, end, x):
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
idx = idx + 1
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_forward():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
out = out + x + self.param
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_endless_case():
"""endless case when optmization"""
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
part = x[idx, :, :]
out = out + part
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
out = out + x + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_forward_with_const_branch():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
def construct(self, idx, end, x):
out = self.zero
while idx < end:
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = while_net
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_opt_endless():
"""endless during optimization case"""
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.addn = P.AddN()
def construct(self, idx, end, x):
addn1 = self.addn((x, x, x))
out = addn1
while idx < end:
out = self.addn((out, addn1))
idx = idx + 1
out = self.addn((out, x))
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
net(idx, end, x)
def test_no_while_call():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
def construct(self, idx, end, x):
out = self.zero
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = while_net
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_grad_with_const_branch():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
def construct(self, idx, end, x):
out = self.zero
while idx < end:
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_for_while_with_param_grad_with_const_branch():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.start = Tensor(np.array(0), dtype=ms.int32)
def construct(self, idx, end, x):
out = self.zero
for _ in range(0, 2):
idx = self.start
while idx < end:
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_for_while_with_param_grad_basic():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.start = Tensor(np.array(0), dtype=ms.int32)
def construct(self, idx, end, x):
out = self.zero
for _ in range(0, 2):
idx = self.start
while idx < end:
out = out + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_for_while_with_param_grad_normal():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.start = Tensor(np.array(0), dtype=ms.int32)
def construct(self, idx, end, x):
out = x
for _ in range(0, 2):
idx = self.start
while idx < end:
out = out + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad_mul():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out * self.param
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad_two():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param + self.weight
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad_three():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param + self.weight + self.key
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_if_with_param_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
if self.max(out) < self.max(x):
out = out + self.param * 2
else:
out = out + self.param
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_grad_not_enter_while():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param * 3
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(3), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_forward():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param
else:
out = out + x
if a == b:
out = out + x*3 + self.param
else:
out = out + x*2
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_grad_inputs():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param * 4
if a == b:
out = out + x*3 + self.param * 3
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_grad_parameter():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param * 2
if a == b:
out = out + x*3 + self.param
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_grad_param_excute_null():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param * 2
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(4), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_return_inside_grad():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
return out + x + self.param
if a == b:
return out + self.param * 2
return out + self.param * 3
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(1), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
if a < b:
a = self.add(a, b)
else:
a = self.sub(a, b)
if a == x:
a = self.mul(a, b)
else:
a = self.div(a, b)
if b == x:
b = self.add(a, b)
else:
b = self.add(a, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(4), dtype=ms.float32)
net(idx, end, x)

View File

@ -58,6 +58,7 @@ add_subdirectory(serving)
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/core/base/*.cc" "../../../mindspore/core/base/*.cc"
"../../../mindspore/core/gvar/*.cc"
"../../../mindspore/core/abstract/*.cc" "../../../mindspore/core/abstract/*.cc"
"../../../mindspore/core/ir/*.cc" "../../../mindspore/core/ir/*.cc"
"../../../mindspore/core/utils/*.cc" "../../../mindspore/core/utils/*.cc"

View File

@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class InputBackward(nn.Cell): class InputBackward(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(InputBackward, self).__init__() super(InputBackward, self).__init__()

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
CURRPATH=$(cd $(dirname $0); pwd) CURRPATH=$(cd $(dirname $0); pwd)
IGNORE_EXEC="--ignore=$CURRPATH/exec" IGNORE_EXEC="--ignore=$CURRPATH/exec"
PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd) PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd)

View File

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""Generate vm_impl function for array ops""" """Generate vm_impl function for array ops"""
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm from .vm_interface import vm
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -181,8 +179,7 @@ def vm_impl_tile(self):
def vm_impl(x, multiples): def vm_impl(x, multiples):
x = x.asnumpy() x = x.asnumpy()
multiples = multiples.asnumpy() out = np.tile(x, multiples)
out = vm.Tile(x, multiples)
return Tensor(out) return Tensor(out)
return vm_impl return vm_impl
@ -255,7 +252,10 @@ def vm_impl_sum(self):
def vm_impl(x, axis): def vm_impl(x, axis):
x = x.asnumpy() x = x.asnumpy()
out = vm.sum(x, axis) if axis == ():
out = np.sum(x)
else:
out = np.sum(x, axis=axis)
return Tensor(np.array(out)) return Tensor(np.array(out))
return vm_impl return vm_impl
@ -291,12 +291,14 @@ def vm_impl_square(self):
return vm_impl return vm_impl
@vm_impl_getters.register(P.ZerosLike) @vm_impl_getters.register(P.ZerosLike)
def vm_impl_zeros_like(self): def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike""" """Generate vm_impl function for ZerosLike"""
def vm_impl(x): def vm_impl(x):
return Tensor(np.zeros_like(x.asnumpy())) return Tensor(np.zeros_like(x.asnumpy()))
@vm_impl_getters.register(P.Partial) @vm_impl_getters.register(P.Partial)
def vm_impl_partial(self): def vm_impl_partial(self):
"""Generate vm_impl function for Partial""" """Generate vm_impl function for Partial"""
@ -307,6 +309,7 @@ def vm_impl_partial(self):
return vm_impl return vm_impl
@vm_impl_getters.register(P.Depend) @vm_impl_getters.register(P.Depend)
def vm_impl_depend(self): def vm_impl_depend(self):
"""Generate vm_impl function for Depend""" """Generate vm_impl function for Depend"""

View File

@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self):
return vm_impl return vm_impl
@vm_impl_getters.register(P.ReduceMax)
def vm_impl_reduce_max(self):
"""Generate vm_impl function for ReduceMean."""
def vm_impl(x, axis):
x = x.asnumpy()
if axis == ():
axis = None
out = np.amax(x, axis)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.Equal) @vm_impl_getters.register(P.Equal)
def vm_impl_equal(self): def vm_impl_equal(self):