forked from OSSInnovation/mindspore
support if by if grad parameter
add join for ref adjust env eliminate to eliminate all env ops add partial app cache resolve while endless fix env eliminate support for "for while" cases fix join shape error
This commit is contained in:
parent
eb4749642d
commit
2a6d346d2f
|
@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra
|
|||
}
|
||||
oss << "SymInst(%para" << idx << ")";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
|
||||
MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
|
||||
oss << "SymInst(cnode_" << sym_node->ToString() << ")";
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
|
|
|
@ -189,6 +189,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
if (!morph->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
// for free variable, which may be handled in MapValueObject, just return it
|
||||
auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
|
||||
if (node_adjoint_found != anfnode_to_adjoin_.end()) {
|
||||
return node_adjoint_found->second;
|
||||
}
|
||||
ScopeGuard scope_guard(morph->scope());
|
||||
auto cnode_morph = morph->cast<CNodePtr>();
|
||||
|
||||
|
@ -502,7 +507,7 @@ void DFunctor::MapFvObject() {
|
|||
if (parent_adjoint != nullptr) {
|
||||
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
|
||||
} else {
|
||||
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
|
||||
if (is_top_ || node->isa<Parameter>()) {
|
||||
// Out of ad scope, add adjoint for free variables.
|
||||
adjoint = std::make_shared<Adjoint>(node, node, tape_);
|
||||
UpdateAdjoint(adjoint);
|
||||
|
|
|
@ -87,10 +87,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
env_get_item_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
|
||||
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_bypass_recursive_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
|
||||
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
|
||||
incorporate_env_getitem_ =
|
||||
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
|
||||
|
||||
// Ref eliminate
|
||||
make_ref_eliminate_ =
|
||||
|
@ -122,6 +124,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
|
||||
// inline
|
||||
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
|
||||
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
|
||||
replace_applicator_ =
|
||||
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
|
||||
specialize_transform_ =
|
||||
|
|
|
@ -55,6 +55,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr env_get_item_eliminate_;
|
||||
SubstitutionPtr new_env_get_item_;
|
||||
SubstitutionPtr incorporate_env_getitem_;
|
||||
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
|
||||
SubstitutionPtr incorporate_env_getitem_switch_;
|
||||
|
||||
// Ref eliminate
|
||||
|
@ -80,6 +81,7 @@ class OptimizeIRPassLib {
|
|||
|
||||
// inline
|
||||
SubstitutionPtr inline_;
|
||||
SubstitutionPtr inline_without_move_;
|
||||
SubstitutionPtr replace_applicator_;
|
||||
SubstitutionPtr specialize_transform_;
|
||||
|
||||
|
@ -193,6 +195,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
|
|||
auto inp0 = node->cast<CNodePtr>()->input(0);
|
||||
return (inp0 != nullptr) && inp0->isa<CNode>();
|
||||
}
|
||||
|
||||
// check if the cnode is a switch cnode
|
||||
inline bool IsCNodeSwitch(const AnfNodePtr &node) {
|
||||
if (node != nullptr) {
|
||||
if (node->isa<CNode>()) {
|
||||
return IsPrimitiveCNode(node, prim::kPrimSwitch);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/irpass/inline.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
|
@ -59,8 +60,13 @@ class EnvGetitemTransform {
|
|||
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
|
||||
// {prim::kPrimEnvSetItem, env, symbolickey, value}
|
||||
auto &inputs = env->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance.";
|
||||
if (inputs.size() != 4) {
|
||||
MS_LOG(WARNING) << "Input size should be 4";
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
env = inputs[1];
|
||||
|
@ -91,33 +97,12 @@ class EnvGetitemTransform {
|
|||
class NewEnvGetItem : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
auto gety = [this](const AnfNodePtr &node) -> bool {
|
||||
this->y_ = node;
|
||||
return true;
|
||||
};
|
||||
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
|
||||
if (env_ != nullptr && env_->Len() == 0) {
|
||||
return y_;
|
||||
}
|
||||
PatternNode c1, c2, y;
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y,
|
||||
(IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) &&
|
||||
(GetValueNode<EnvInstancePtr>(c1.GetNode(node)))->Len() == 0));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const ValueNodePtr &vnode) override {
|
||||
if (env_ == nullptr) {
|
||||
env_ = GetValueNode<EnvInstancePtr>(vnode);
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
y_ = nullptr;
|
||||
env_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr y_{nullptr};
|
||||
EnvInstancePtr env_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
|
||||
|
@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor {
|
|||
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
|
||||
// {prim::kPrimEnvSetItem, env, symbolickey, value}
|
||||
auto &inputs = env->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance.";
|
||||
if (inputs.size() != 4) {
|
||||
MS_LOG(WARNING) << "Input size should be 4";
|
||||
return nullptr;
|
||||
}
|
||||
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
|
||||
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
env = inputs[1];
|
||||
|
@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller {
|
|||
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
|
||||
class IncorporateEnvGetitem : public AnfVisitor {
|
||||
public:
|
||||
IncorporateEnvGetitem() : env_get_item_transform_() {}
|
||||
explicit IncorporateEnvGetitem(bool bypass_recursive = false)
|
||||
: env_get_item_transform_(), bypass_recursive_(bypass_recursive) {}
|
||||
~IncorporateEnvGetitem() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor {
|
|||
auto inputs = inp1->inputs();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
|
||||
auto new_fg = env_get_item_transform_(fg, key, default_v);
|
||||
|
||||
if (fg->recursive() && bypass_recursive_) {
|
||||
MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
if (new_fg == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> args;
|
||||
args.push_back(NewValueNode(new_fg));
|
||||
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
|
||||
|
@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor {
|
|||
private:
|
||||
bool is_match_{false};
|
||||
internal::EnvGetitemTransform env_get_item_transform_;
|
||||
bool bypass_recursive_;
|
||||
};
|
||||
|
||||
// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
|
||||
|
@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
|||
auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
|
||||
auto new_g1 = env_get_item_transform_(g1, key, default_v);
|
||||
auto new_g2 = env_get_item_transform_(g2, key, default_v);
|
||||
|
||||
if (new_g1 == nullptr || new_g2 == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto fg = node->func_graph();
|
||||
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});
|
||||
|
||||
|
|
|
@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
|
|||
|
||||
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
|
||||
|
||||
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) {
|
||||
bool unique_use = IsUniqueUse(fg, nullptr);
|
||||
bool is_recursive = fg->recursive();
|
||||
if (fg->parent() != nullptr && is_recursive) {
|
||||
if (fg->parent() == node->func_graph() && unique_use) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// {G, Xs}
|
||||
class InlinerBase : public AnfVisitor {
|
||||
public:
|
||||
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {}
|
||||
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true)
|
||||
: use_move_(use_move), criterions_(criterions) {}
|
||||
~InlinerBase() override = default;
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor {
|
|||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Do not inline GraphKernel to Cell.
|
||||
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
// If the GraphKernel only contains a return node, we make it inlined.
|
||||
|
@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor {
|
|||
|
||||
std::vector<AnfNodePtr> params;
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
|
||||
|
||||
if (IsUniqueUse(fg, nullptr)) {
|
||||
// compare size to avoid the case that the function has default value after grad.
|
||||
// for which after renormalize, the function default value will be an input
|
||||
if (fg->parameters().size() != params.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (use_move_ && IsUniqueUse(fg, nullptr)) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
ReplaceParams(mng, params, fg);
|
||||
|
@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor {
|
|||
|
||||
private:
|
||||
bool is_checked_{false}, is_recursive_{false};
|
||||
bool use_move_;
|
||||
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
|
||||
};
|
||||
|
||||
class Inliner : public InlinerBase {
|
||||
public:
|
||||
Inliner()
|
||||
: InlinerBase({
|
||||
{IsUniqueUse, true},
|
||||
{IsTrivial, false},
|
||||
{IsInside, false},
|
||||
{IsCore, false},
|
||||
{NoCriterion, true},
|
||||
}) {}
|
||||
explicit Inliner(bool use_move = true)
|
||||
: InlinerBase(
|
||||
{
|
||||
{IsUniqueUse, true},
|
||||
{IsTrivial, false},
|
||||
{IsInside, false},
|
||||
{IsCore, false},
|
||||
{IsDirectParentCall, false},
|
||||
{NoCriterion, true},
|
||||
},
|
||||
use_move) {}
|
||||
~Inliner() override = default;
|
||||
};
|
||||
|
||||
class DirectInliner : public InlinerBase {
|
||||
public:
|
||||
explicit DirectInliner(bool use_move = true)
|
||||
: InlinerBase(
|
||||
{
|
||||
{IsDirectParentCall, false},
|
||||
},
|
||||
use_move) {}
|
||||
~DirectInliner() override = default;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,6 +26,30 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace internal {
|
||||
class GetRefValueTransform {
|
||||
public:
|
||||
GetRefValueTransform() {}
|
||||
~GetRefValueTransform() = default;
|
||||
|
||||
AnfNodePtr operator()(const AnfNodePtr &node) {
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto inputs = cnode->inputs();
|
||||
auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>();
|
||||
if (fg->recursive()) {
|
||||
MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString();
|
||||
return node;
|
||||
}
|
||||
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue"));
|
||||
auto output = new_fg->output();
|
||||
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output}));
|
||||
inputs[0] = NewValueNode(new_fg);
|
||||
auto ret_node = cnode->func_graph()->NewCNode(inputs);
|
||||
return ret_node;
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
// {prim::kPrimMakeRef, X, Y, Z} -> Y
|
||||
class MakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
|
@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller {
|
|||
|
||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||
// {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f}
|
||||
class GetMakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x, y, z;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
||||
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node));
|
||||
internal::GetRefValueTransform trans;
|
||||
auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr {
|
||||
auto rep = trans(x.GetNode(node));
|
||||
if (rep != nullptr) {
|
||||
return rep;
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
FuncGraphPtr func_graph = res->func_graph();
|
||||
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (CompileGraphs::ContainMixedTarget(func_graph)) {
|
||||
bc_ptr->set_is_multi_graph_sink(false);
|
||||
|
@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
|
|||
context_ptr->set_loop_sink_flag(false);
|
||||
} else if (context_ptr->execution_mode() != kPynativeMode) {
|
||||
std::string device_target = context_ptr->device_target();
|
||||
if (device_target == kAscendDevice) {
|
||||
if (device_target == kAscendDevice && backend != kMsVm) {
|
||||
bc_ptr->set_is_multi_graph_sink(true);
|
||||
context_ptr->set_is_multi_graph_sink(true);
|
||||
}
|
||||
}
|
||||
|
||||
if (IsCtrlSink()) {
|
||||
if (IsCtrlSink() && backend == kMsConvert) {
|
||||
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
|
||||
return true;
|
||||
}
|
||||
|
@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) {
|
|||
if (res->results().count(kOutput) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
|
||||
if (IsCtrlSink()) {
|
||||
std::string backend = MsContext::GetInstance()->backend_policy();
|
||||
if (IsCtrlSink() && backend == kMsConvert) {
|
||||
if (!res->results()[kOutput].is<GraphId>()) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/validator.h"
|
||||
#include "pipeline/jit/remove_value_node_dup.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/cse.h"
|
||||
#include "frontend/optimizer/graph_kernel_reuse.h"
|
||||
|
@ -127,11 +128,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.incorporate_getitem_set_,
|
||||
irpass.incorporate_call_,
|
||||
irpass.incorporate_call_switch_,
|
||||
irpass.incorporate_env_getitem_,
|
||||
irpass.incorporate_env_getitem_bypass_recursive_,
|
||||
irpass.incorporate_env_getitem_switch_,
|
||||
irpass.new_env_get_item_,
|
||||
irpass.depend_value_elim_,
|
||||
});
|
||||
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
|
||||
irpass.inline_without_move_,
|
||||
});
|
||||
opt::OptPassConfig a_3 = opt::OptPassConfig({
|
||||
irpass.arithmetic_simplify2_,
|
||||
irpass.same_eliminate_,
|
||||
|
@ -154,6 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
{"virtual_dataset", virtual_dataset},
|
||||
{"grad", grad},
|
||||
{"resolve", resolve_pass},
|
||||
{"a_after_grad", a_after_grad},
|
||||
{"renormalize", opt::OptPassConfig::Renormalize()},
|
||||
{"cse", opt::OptPassConfig(opt::CSE(false))},
|
||||
{"a_3", a_3}});
|
||||
|
@ -161,11 +166,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
return map_a;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig c_1 = opt::OptPassConfig({
|
||||
// Safe inlining
|
||||
irpass.inline_,
|
||||
irpass.partial_eliminate_,
|
||||
});
|
||||
|
||||
OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
|
||||
return map_a;
|
||||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig b_1 =
|
||||
opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
|
||||
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_,
|
||||
irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_});
|
||||
opt::OptPassConfig b_1 = opt::OptPassConfig(
|
||||
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
|
||||
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
|
||||
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
|
||||
irpass.value_based_eliminate_});
|
||||
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
||||
irpass.replace_refkey_by_param_,
|
||||
irpass.make_ref_eliminate_,
|
||||
|
@ -244,6 +262,8 @@ void InitOpt(const ResourcePtr &res) {
|
|||
opt::irpass::OptimizeIRPassLib irpass;
|
||||
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
|
||||
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
|
||||
g_pass_opts["opt_after_cconv"] =
|
||||
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
|
||||
g_pass_opts["opt_graph_kernel_a"] =
|
||||
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
|
||||
g_pass_opts["opt_graph_kernel_b"] =
|
||||
|
@ -288,6 +308,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
|
|||
|
||||
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
|
||||
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
|
||||
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
|
||||
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
|
||||
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
|
||||
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
|
||||
|
@ -311,6 +332,33 @@ bool AddControlDependPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MergeDupGraphPass(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(res->manager());
|
||||
if (res->manager()->func_graphs().size() <= 1) {
|
||||
return true;
|
||||
}
|
||||
return MergeDuplicateGraphs(res->manager());
|
||||
}
|
||||
|
||||
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
|
||||
if (res->func_graph() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
|
||||
}
|
||||
auto manager = res->manager();
|
||||
HashCache hash_cache;
|
||||
HashValue hashes;
|
||||
// Remove duplicated value nodes across all graphs in manager
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
auto value_nodes = fg->value_nodes();
|
||||
for (const auto &value_pair : value_nodes) {
|
||||
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CconvPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
|
@ -340,6 +388,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
|
|||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"cconv", CconvPass},
|
||||
{"opt_after_cconv", OptPassAfterCconvGroup},
|
||||
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_control_depend", AddControlDependPass}};
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
|
||||
#include "pipeline/jit/remove_value_node_dup.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
|
|||
// Meet for the first time, append node to bucket.
|
||||
bucket.emplace_back(node);
|
||||
}
|
||||
|
||||
size_t HashOfGraph(const FuncGraphPtr &fg) {
|
||||
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
|
||||
MS_LOG(DEBUG) << "TopSort for:" << fg->ToString();
|
||||
std::unordered_map<AnfNodePtr, std::size_t> hashes;
|
||||
auto ¶ms = fg->parameters();
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
hashes[params[i]] = std::hash<std::string>{}("param" + std::to_string(i));
|
||||
}
|
||||
for (auto node : toposet) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (hashes.find(node) != hashes.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::size_t h = 0;
|
||||
if (node->isa<ValueNode>()) {
|
||||
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (IsValueNode<FuncGraph>(value_node)) {
|
||||
auto v_fg = value->cast<FuncGraphPtr>();
|
||||
h = value->hash();
|
||||
} else if (IsValueNode<tensor::Tensor>(value_node)) {
|
||||
// the tensor has same value has been replaced in duplicate value pass,
|
||||
// so we use the value pointer here as an identifier
|
||||
h = hash_combine(value->hash(), std::hash<Value *>{}(value.get()));
|
||||
} else {
|
||||
h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash()));
|
||||
}
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
size_t init = 0;
|
||||
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
|
||||
return hash_combine(hash, hashes[node_in]);
|
||||
});
|
||||
} else if (node->isa<Parameter>()) {
|
||||
h = node->hash();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknow node type";
|
||||
}
|
||||
hashes[node] = h;
|
||||
}
|
||||
return hashes[fg->get_return()];
|
||||
}
|
||||
|
||||
bool IsCNodeGraph(const AnfNodePtr &node) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inp0 = node->cast<CNodePtr>()->input(0);
|
||||
return IsValueNode<FuncGraph>(inp0);
|
||||
}
|
||||
|
||||
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) {
|
||||
std::unordered_map<size_t, std::vector<FuncGraphPtr>> hash_graphs;
|
||||
std::unordered_map<FuncGraphPtr, size_t> graph_hash;
|
||||
for (auto fg : manager->func_graphs()) {
|
||||
size_t h = HashOfGraph(fg);
|
||||
graph_hash[fg] = h;
|
||||
if (hash_graphs.find(h) == hash_graphs.end()) {
|
||||
hash_graphs[h] = {fg};
|
||||
} else {
|
||||
hash_graphs[h].push_back(fg);
|
||||
}
|
||||
}
|
||||
FuncGraphPairMapEquiv equiv_graph;
|
||||
NodeMapEquiv equiv_node;
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString();
|
||||
for (auto &item : fg->nodes()) {
|
||||
if (!item->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto &inputs = item->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
if (!inputs[i]->isa<ValueNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto value_ptr = GetValueNode(inputs[i]);
|
||||
auto v_fg = value_ptr->cast<FuncGraphPtr>();
|
||||
if (v_fg == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto &fg_vec = hash_graphs[graph_hash[v_fg]];
|
||||
if (fg_vec.size() > 1) {
|
||||
if (v_fg != fg_vec[0]) {
|
||||
bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node);
|
||||
if (is_morphic) {
|
||||
auto new_node = NewValueNode(fg_vec[0]);
|
||||
MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString();
|
||||
manager->Replace(inputs[i], new_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,10 @@ using HashCache = std::unordered_map<std::size_t, std::vector<AnfNodePtr>>;
|
|||
using HashValue = std::unordered_map<AnfNodePtr, std::size_t>;
|
||||
|
||||
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
|
||||
size_t HashOfGraph(const FuncGraphPtr &fg);
|
||||
bool IsCNodeGraph(const AnfNodePtr &node);
|
||||
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager);
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
}
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
|
||||
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
|
||||
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
|
||||
AbstractBasePtr ret_base = nullptr;
|
||||
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
|
||||
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
|
||||
const auto &node = *it;
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString();
|
||||
ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
|
||||
<< ", abstract: " << ret_base->ToString();
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
|
||||
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ret_base);
|
||||
|
@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
|
|||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
|
||||
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
return arg->Broaden();
|
||||
if (arg->GetValueTrack() != kAnyValue) {
|
||||
return arg->Broaden();
|
||||
}
|
||||
return arg;
|
||||
});
|
||||
if (func_graph_->joined_shapes_.size() != broaded_list.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size()
|
||||
<< " does not equal to number of original buffer arguments "
|
||||
<< func_graph_->joined_shapes_.size();
|
||||
}
|
||||
for (size_t i = 0; i < broaded_list.size(); ++i) {
|
||||
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
|
||||
if (func_graph_->joined_shapes_.size() == broaded_list.size()) {
|
||||
for (size_t i = 0; i < broaded_list.size(); ++i) {
|
||||
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
|
||||
<< ", broaded: " << mindspore::ToString(broaded_list);
|
||||
return broaded_list;
|
||||
|
@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->joined_shapes_.clear();
|
||||
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
|
||||
std::back_inserter(func_graph_->joined_shapes_),
|
||||
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
|
||||
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
|
||||
if (arg_spec->isa<AbstractRef>()) {
|
||||
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
|
||||
}
|
||||
return arg_spec->GetShapeTrack();
|
||||
});
|
||||
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
return joined_args_spec_list;
|
||||
|
@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
|
|||
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
|
||||
func_graph_->joined_shapes_.clear();
|
||||
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
|
||||
std::back_inserter(func_graph_->joined_shapes_),
|
||||
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
|
||||
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
|
||||
if (arg_spec->isa<AbstractRef>()) {
|
||||
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
|
||||
}
|
||||
return arg_spec->GetShapeTrack();
|
||||
});
|
||||
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
|
||||
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);
|
||||
|
|
|
@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
trace::TraceEvalCNodeLeave();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
|
||||
<< (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
|
||||
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
|
||||
|
@ -301,6 +302,8 @@ void AnalysisEngine::Clear() {
|
|||
anfnode_config_map_.clear();
|
||||
eval_trace_.clear();
|
||||
constructors_.clear();
|
||||
constructors_app_.clear();
|
||||
continued_evals_.clear();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr
|
|||
MS_EXCEPTION_IF_NULL(func);
|
||||
AbstractFunctionPtr func_orig = func->fn();
|
||||
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
|
||||
auto part_pair = std::make_pair(func_orig, func->args());
|
||||
auto itr = constructors_app_.find(part_pair);
|
||||
if (itr != constructors_app_.end()) {
|
||||
return itr->second;
|
||||
}
|
||||
std::shared_ptr<PartialAppEvaluator> partial_evaluator =
|
||||
std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
|
||||
constructors_app_[part_pair] = partial_evaluator;
|
||||
return partial_evaluator;
|
||||
}
|
||||
|
||||
|
@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
|
|||
if (fg_eval == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto fg = fg_eval->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto undetermined_fgs = fg->recursive_graphs();
|
||||
auto undetermined_fgs = fg->recursive();
|
||||
if (undetermined_fgs) {
|
||||
auto fg_parent = fg->parent();
|
||||
MS_EXCEPTION_IF_NULL(fg_parent);
|
||||
|
@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
|
|||
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
|
||||
|
||||
for (auto u_eval : undetermined_evals) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
|
||||
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined.";
|
||||
auto &alternate_evaluator = multi_poss_[u_eval.first];
|
||||
auto &eval_cache = alternate_evaluator->cache();
|
||||
if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) &&
|
||||
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
|
||||
(eval_cache->find(args_spec_list) == eval_cache->end()))) {
|
||||
MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined.";
|
||||
has_undetermined = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_undetermined == false) {
|
||||
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
|
||||
MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
|
||||
*continue_flag = true;
|
||||
return latest_entry;
|
||||
}
|
||||
|
@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
|
||||
auto current_inf = std::make_pair(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
|
||||
if (it == eval_trace_.rend()) {
|
||||
eval_trace_.push_back(current_inf);
|
||||
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
|
||||
out_specs.push_back(eval_result->abstract());
|
||||
eval_trace_.pop_back();
|
||||
if (eval_trace_.empty()) {
|
||||
multi_poss_.clear();
|
||||
}
|
||||
} else if (it != eval_trace_.rbegin()) {
|
||||
} else {
|
||||
bool continue_flag = false;
|
||||
auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
|
||||
if (continue_flag) {
|
||||
MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString();
|
||||
continued_evals_.insert(current_inf);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to travel the latest undetermined.
|
||||
if (latest_entry != eval_trace_.rbegin()->first) {
|
||||
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
|
||||
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
|
||||
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
|
||||
MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
|
||||
<< " return out_spec: " << eval_result->abstract()->ToString();
|
||||
return eval_result;
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
#ifdef DEBUG
|
||||
#include <stack>
|
||||
|
@ -113,7 +114,8 @@ class AnfNodeConfig : public Config {
|
|||
|
||||
std::string ToString() const override {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString();
|
||||
buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId()
|
||||
<< "), Context: " << context_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
@ -173,7 +175,13 @@ struct AnalysisResult {
|
|||
};
|
||||
|
||||
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
|
||||
|
||||
struct PartialAppHasher {
|
||||
std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const {
|
||||
auto h1 = std::hash<AbstractFunctionPtr>{}(p.first);
|
||||
auto h2 = AbstractBasePtrListHash(p.second);
|
||||
return h1 ^ h2;
|
||||
}
|
||||
};
|
||||
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||
public:
|
||||
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
|
||||
|
@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
const PrimEvaluatorMap &prim_constructors_;
|
||||
FuncGraphManagerPtr func_graph_manager_;
|
||||
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
|
||||
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
|
||||
constructors_app_;
|
||||
AnfNodeConfigMap anfnode_config_map_;
|
||||
// Use a list to trace multiple evaluators.
|
||||
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
|
||||
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
|
||||
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> continued_evals_;
|
||||
|
||||
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
|
|
|
@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError;
|
|||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractJTagged;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractRef;
|
||||
using mindspore::abstract::AbstractRowTensor;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSparseTensor;
|
||||
|
@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
// only send string in external
|
||||
if (!IsValueNode<StringImm>(node)) {
|
||||
// Validate a type.
|
||||
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString();
|
||||
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString()
|
||||
<< " for node=" << node->DebugString();
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
|
||||
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
|
||||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -481,8 +481,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
|
|||
}
|
||||
|
||||
// Isomorphism
|
||||
static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
|
||||
NodeMapEquiv *const equiv_node) {
|
||||
static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
|
||||
NodeMapEquiv *const equiv_node);
|
||||
bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
|
||||
NodeMapEquiv *const equiv_node) {
|
||||
if (equiv_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid equiv_node";
|
||||
return false;
|
||||
|
@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu
|
|||
MS_LOG(DEBUG) << "two parameters are not equal.";
|
||||
return false;
|
||||
}
|
||||
if (node1->isa<CNode>() && node2->isa<CNode>()) {
|
||||
return SameNode(node1, node2, equiv_func_graph, equiv_node);
|
||||
}
|
||||
MS_LOG(ERROR) << "type error";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
|
|||
} // namespace
|
||||
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrToAnfNodePtrMap eqv;
|
||||
if (lst.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Input anf node list is empty";
|
||||
}
|
||||
TraceManager::DebugTrace(
|
||||
std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
TraceManager::EndTrace();
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrToAnfNodePtrMap eqv;
|
||||
// Merge CNodes into a AnfGraph that represents a linear instruction segment
|
||||
for (auto n : lst) {
|
||||
if (!n->isa<CNode>()) {
|
||||
|
@ -154,7 +157,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
|
||||
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
|
||||
}
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info()));
|
||||
eqv[n] = fg->NewCNode(args);
|
||||
TraceManager::EndTrace();
|
||||
eqv[n]->set_abstract(n->abstract());
|
||||
eqv[n]->set_kernel_info(n->kernel_info_ptr());
|
||||
}
|
||||
|
|
|
@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
|
|||
}
|
||||
auto other_tensor = dyn_cast<AbstractTensor>(other);
|
||||
if (other_tensor == nullptr) {
|
||||
auto ref_tensor = dyn_cast<AbstractRef>(other);
|
||||
if (ref_tensor != nullptr) {
|
||||
return this->Join(ref_tensor->ref());
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
|
||||
}
|
||||
if (*this == *other) {
|
||||
|
|
|
@ -48,7 +48,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
|
|||
continue;
|
||||
}
|
||||
if (rank.find(node) != rank.end() && rank[node] != todo.size()) {
|
||||
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
|
||||
}
|
||||
rank[node] = todo.size();
|
||||
bool cont = false;
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "base/base.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/dtype/number.h"
|
||||
#include "utils/hashing.h"
|
||||
|
||||
using std::fabs;
|
||||
|
||||
|
@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr<Scalar>;
|
|||
|
||||
class BoolImm : public Scalar {
|
||||
public:
|
||||
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash<bool>{}(v_); }
|
||||
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
|
||||
~BoolImm() override = default;
|
||||
MS_DECLARE_PARENT(BoolImm, Scalar)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -91,7 +92,7 @@ class IntergerImm : public Scalar {
|
|||
class Int8Imm : public IntergerImm {
|
||||
public:
|
||||
Int8Imm() : IntergerImm(kInt8), v_(0) {}
|
||||
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash<int>{}(v_); }
|
||||
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
|
||||
~Int8Imm() override = default;
|
||||
MS_DECLARE_PARENT(Int8Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t)
|
|||
class Int16Imm : public IntergerImm {
|
||||
public:
|
||||
Int16Imm() : IntergerImm(kInt16), v_(0) {}
|
||||
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash<int>{}(v_); }
|
||||
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
|
||||
~Int16Imm() override = default;
|
||||
MS_DECLARE_PARENT(Int16Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t)
|
|||
class Int32Imm : public IntergerImm {
|
||||
public:
|
||||
Int32Imm() : IntergerImm(kInt32), v_(0) {}
|
||||
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash<int>{}(v_); }
|
||||
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
|
||||
~Int32Imm() override = default;
|
||||
MS_DECLARE_PARENT(Int32Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t)
|
|||
class Int64Imm : public IntergerImm {
|
||||
public:
|
||||
Int64Imm() : IntergerImm(kInt64), v_(0) {}
|
||||
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash<int64_t>{}(v_); }
|
||||
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
|
||||
~Int64Imm() override = default;
|
||||
MS_DECLARE_PARENT(Int64Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t)
|
|||
class UInt8Imm : public IntergerImm {
|
||||
public:
|
||||
UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
|
||||
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
|
||||
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
|
||||
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
|
||||
}
|
||||
~UInt8Imm() override = default;
|
||||
MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t);
|
|||
class UInt16Imm : public IntergerImm {
|
||||
public:
|
||||
UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
|
||||
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
|
||||
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
|
||||
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
|
||||
}
|
||||
~UInt16Imm() override = default;
|
||||
MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t);
|
|||
class UInt32Imm : public IntergerImm {
|
||||
public:
|
||||
UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
|
||||
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
|
||||
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
|
||||
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
|
||||
}
|
||||
~UInt32Imm() override = default;
|
||||
MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t);
|
|||
class UInt64Imm : public IntergerImm {
|
||||
public:
|
||||
UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
|
||||
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash<uint64_t>{}(v); }
|
||||
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
|
||||
hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
|
||||
}
|
||||
~UInt64Imm() override = default;
|
||||
MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr<FloatImm>;
|
|||
class FP32Imm : public FloatImm {
|
||||
public:
|
||||
FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
|
||||
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash<float>{}(v_); }
|
||||
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
|
||||
~FP32Imm() override = default;
|
||||
MS_DECLARE_PARENT(FP32Imm, FloatImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float)
|
|||
class FP64Imm : public FloatImm {
|
||||
public:
|
||||
FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
|
||||
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash<double>{}(v_); }
|
||||
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
|
||||
~FP64Imm() override = default;
|
||||
MS_DECLARE_PARENT(FP64Imm, FloatImm)
|
||||
std::size_t hash() const override { return hash_; }
|
||||
|
|
|
@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo {
|
|||
return std::make_shared<TraceCombileLikeGraphs>(*shared_from_base<TraceCombileLikeGraphs>());
|
||||
}
|
||||
};
|
||||
|
||||
class TraceSegmentTransform : public TraceInfo {
|
||||
public:
|
||||
explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {}
|
||||
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
|
||||
~TraceSegmentTransform() override = default;
|
||||
TraceInfoPtr clone() override {
|
||||
return std::make_shared<TraceSegmentTransform>(*shared_from_base<TraceSegmentTransform>());
|
||||
}
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_
|
||||
|
|
|
@ -0,0 +1,816 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test control ops """
|
||||
import numpy as np
|
||||
|
||||
from mindspore import dtype as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
# from tests.vm_impl.math_ops_vm_impl import *
|
||||
# from tests.vm_impl.vm_interface import *
|
||||
# from tests.vm_impl import *
|
||||
# context.set_context(save_graphs=True)
|
||||
|
||||
|
||||
def test_while_forward():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
max_num = self.max(part)
|
||||
x[idx, :, 0:2] = max_num
|
||||
idx = idx + 1
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
max_num = self.max(part)
|
||||
x[idx, :, 0:2] = max_num
|
||||
idx = idx + 1
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_forward():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
max_num = self.max(part)
|
||||
x[idx, :, 0:2] = max_num
|
||||
out = out + x + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_endless_case():
|
||||
"""endless case when optmization"""
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
out = out + part
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
part = x[idx, :, :]
|
||||
max_num = self.max(part)
|
||||
x[idx, :, 0:2] = max_num
|
||||
out = out + x + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_forward_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
if 2 > 1:
|
||||
out = out + self.param
|
||||
else:
|
||||
out = out + idx + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = while_net
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_opt_endless():
|
||||
"""endless during optimization case"""
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
self.addn = P.AddN()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
addn1 = self.addn((x, x, x))
|
||||
out = addn1
|
||||
while idx < end:
|
||||
out = self.addn((out, addn1))
|
||||
idx = idx + 1
|
||||
out = self.addn((out, x))
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_no_while_call():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
if 2 > 1:
|
||||
out = out + self.param
|
||||
else:
|
||||
out = out + idx + self.param
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = while_net
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_grad_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
if 2 > 1:
|
||||
out = out + self.param
|
||||
else:
|
||||
out = out + idx + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_for_while_with_param_grad_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
self.start = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
for _ in range(0, 2):
|
||||
idx = self.start
|
||||
while idx < end:
|
||||
if 2 > 1:
|
||||
out = out + self.param
|
||||
else:
|
||||
out = out + idx + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_for_while_with_param_grad_basic():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
self.start = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
for _ in range(0, 2):
|
||||
idx = self.start
|
||||
while idx < end:
|
||||
out = out + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_for_while_with_param_grad_normal():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.reduce = P.ReduceSum()
|
||||
self.start = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = x
|
||||
for _ in range(0, 2):
|
||||
idx = self.start
|
||||
while idx < end:
|
||||
out = out + self.param
|
||||
idx = idx + 1
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_basic_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.t2 = Tensor(np.array(2), dtype=ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
out = out + self.param
|
||||
idx = idx + 1
|
||||
return out + self.param
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(3), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_basic_grad_mul():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
|
||||
self.t2 = Tensor(np.array(2), dtype=ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
out = out * self.param
|
||||
idx = idx + 1
|
||||
return out + self.param
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(3), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_basic_grad_two():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.t2 = Tensor(np.array(2), dtype=ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
out = out + self.param + self.weight
|
||||
idx = idx + 1
|
||||
return out + self.param
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(3), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_basic_grad_three():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
|
||||
self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.t2 = Tensor(np.array(2), dtype=ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
out = out + self.param + self.weight + self.key
|
||||
idx = idx + 1
|
||||
return out + self.param
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(3), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_if_with_param_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
self.t2 = Tensor(np.array(2), dtype=ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
if self.max(out) < self.max(x):
|
||||
out = out + self.param * 2
|
||||
else:
|
||||
out = out + self.param
|
||||
idx = idx + 1
|
||||
return out + self.param
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(3), dtype=ms.int32)
|
||||
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_while_with_param_grad_not_enter_while():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, idx, end, x):
|
||||
out = self.zero
|
||||
while idx < end:
|
||||
out = out + self.param * 3
|
||||
idx = idx + 1
|
||||
return out + self.param
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(3), dtype=ms.int32)
|
||||
end = Tensor(np.array(0), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_with_param_if_by_if_forward():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, a, b, x):
|
||||
out = self.zero
|
||||
if a < b:
|
||||
out = out + x + self.param
|
||||
else:
|
||||
out = out + x
|
||||
if a == b:
|
||||
out = out + x*3 + self.param
|
||||
else:
|
||||
out = out + x*2
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(4), dtype=ms.int32)
|
||||
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_with_param_if_by_if_grad_inputs():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, a, b, x):
|
||||
out = self.zero
|
||||
if a < b:
|
||||
out = out + x + self.param * 4
|
||||
if a == b:
|
||||
out = out + x*3 + self.param * 3
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(0), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_with_param_if_by_if_grad_parameter():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, a, b, x):
|
||||
out = self.zero
|
||||
if a < b:
|
||||
out = out + x + self.param * 2
|
||||
if a == b:
|
||||
out = out + x*3 + self.param
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_with_param_if_by_if_grad_param_excute_null():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, a, b, x):
|
||||
out = self.zero
|
||||
if a < b:
|
||||
out = out + x + self.param * 2
|
||||
return out
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(4), dtype=ms.int32)
|
||||
end = Tensor(np.array(0), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_if_by_if_return_inside_grad():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max = P.ReduceMax()
|
||||
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
|
||||
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
|
||||
|
||||
def construct(self, a, b, x):
|
||||
out = self.zero
|
||||
if a < b:
|
||||
return out + x + self.param
|
||||
if a == b:
|
||||
return out + self.param * 2
|
||||
return out + self.param * 3
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(net.trainable_params())
|
||||
|
||||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(1), dtype=ms.int32)
|
||||
end = Tensor(np.array(0), dtype=ms.int32)
|
||||
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_if_by_if_forward():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if a < b:
|
||||
a = self.add(a, b)
|
||||
else:
|
||||
a = self.sub(a, b)
|
||||
if a == x:
|
||||
a = self.mul(a, b)
|
||||
else:
|
||||
a = self.div(a, b)
|
||||
if b == x:
|
||||
b = self.add(a, b)
|
||||
else:
|
||||
b = self.add(a, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(4), dtype=ms.float32)
|
||||
net(idx, end, x)
|
|
@ -58,6 +58,7 @@ add_subdirectory(serving)
|
|||
|
||||
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
"../../../mindspore/core/base/*.cc"
|
||||
"../../../mindspore/core/gvar/*.cc"
|
||||
"../../../mindspore/core/abstract/*.cc"
|
||||
"../../../mindspore/core/ir/*.cc"
|
||||
"../../../mindspore/core/utils/*.cc"
|
||||
|
|
|
@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
|
||||
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
|
||||
|
||||
|
||||
class InputBackward(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(InputBackward, self).__init__()
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
CURRPATH=$(cd $(dirname $0); pwd)
|
||||
IGNORE_EXEC="--ignore=$CURRPATH/exec"
|
||||
PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Generate vm_impl function for array ops"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G
|
|||
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
||||
from .vm_interface import vm
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
|
||||
|
@ -181,8 +179,7 @@ def vm_impl_tile(self):
|
|||
|
||||
def vm_impl(x, multiples):
|
||||
x = x.asnumpy()
|
||||
multiples = multiples.asnumpy()
|
||||
out = vm.Tile(x, multiples)
|
||||
out = np.tile(x, multiples)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
@ -255,7 +252,10 @@ def vm_impl_sum(self):
|
|||
|
||||
def vm_impl(x, axis):
|
||||
x = x.asnumpy()
|
||||
out = vm.sum(x, axis)
|
||||
if axis == ():
|
||||
out = np.sum(x)
|
||||
else:
|
||||
out = np.sum(x, axis=axis)
|
||||
return Tensor(np.array(out))
|
||||
|
||||
return vm_impl
|
||||
|
@ -291,12 +291,14 @@ def vm_impl_square(self):
|
|||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.ZerosLike)
|
||||
def vm_impl_zeros_like(self):
|
||||
"""Generate vm_impl function for ZerosLike"""
|
||||
def vm_impl(x):
|
||||
return Tensor(np.zeros_like(x.asnumpy()))
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.Partial)
|
||||
def vm_impl_partial(self):
|
||||
"""Generate vm_impl function for Partial"""
|
||||
|
@ -307,6 +309,7 @@ def vm_impl_partial(self):
|
|||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.Depend)
|
||||
def vm_impl_depend(self):
|
||||
"""Generate vm_impl function for Depend"""
|
||||
|
|
|
@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self):
|
|||
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.ReduceMax)
|
||||
def vm_impl_reduce_max(self):
|
||||
"""Generate vm_impl function for ReduceMean."""
|
||||
|
||||
def vm_impl(x, axis):
|
||||
x = x.asnumpy()
|
||||
if axis == ():
|
||||
axis = None
|
||||
out = np.amax(x, axis)
|
||||
return Tensor(out)
|
||||
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.Equal)
|
||||
def vm_impl_equal(self):
|
||||
|
|
Loading…
Reference in New Issue