support if by if grad parameter

add join for ref

adjust env eliminate to eliminate all env ops

add partial app cache

resolve while endless

fix env eliminate

support for "for while" cases

fix join shape error
This commit is contained in:
huangdongrun 2020-07-09 08:29:07 +08:00
parent eb4749642d
commit 2a6d346d2f
27 changed files with 1259 additions and 112 deletions

View File

@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra
}
oss << "SymInst(%para" << idx << ")";
} else {
MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString();
oss << "SymInst(cnode_" << sym_node->ToString() << ")";
}
return oss.str();

View File

@ -189,6 +189,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
if (!morph->isa<CNode>()) {
return nullptr;
}
// for free variable, which may be handled in MapValueObject, just return it
auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
if (node_adjoint_found != anfnode_to_adjoin_.end()) {
return node_adjoint_found->second;
}
ScopeGuard scope_guard(morph->scope());
auto cnode_morph = morph->cast<CNodePtr>();
@ -502,7 +507,7 @@ void DFunctor::MapFvObject() {
if (parent_adjoint != nullptr) {
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
} else {
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
if (is_top_ || node->isa<Parameter>()) {
// Out of ad scope, add adjoint for free variables.
adjoint = std::make_shared<Adjoint>(node, node, tape_);
UpdateAdjoint(adjoint);

View File

@ -87,10 +87,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
"incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
// Ref eliminate
make_ref_eliminate_ =
@ -122,6 +124,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// inline
inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
replace_applicator_ =
MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =

View File

@ -55,6 +55,7 @@ class OptimizeIRPassLib {
SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr new_env_get_item_;
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_;
// Ref eliminate
@ -80,6 +81,7 @@ class OptimizeIRPassLib {
// inline
SubstitutionPtr inline_;
SubstitutionPtr inline_without_move_;
SubstitutionPtr replace_applicator_;
SubstitutionPtr specialize_transform_;
@ -193,6 +195,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
auto inp0 = node->cast<CNodePtr>()->input(0);
return (inp0 != nullptr) && inp0->isa<CNode>();
}
// check if the cnode is a switch cnode
inline bool IsCNodeSwitch(const AnfNodePtr &node) {
if (node != nullptr) {
if (node->isa<CNode>()) {
return IsPrimitiveCNode(node, prim::kPrimSwitch);
}
}
return false;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/inline.h"
#include "frontend/optimizer/optimizer.h"
#include "utils/symbolic.h"
@ -59,8 +60,13 @@ class EnvGetitemTransform {
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance.";
if (inputs.size() != 4) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
@ -91,33 +97,12 @@ class EnvGetitemTransform {
class NewEnvGetItem : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
auto gety = [this](const AnfNodePtr &node) -> bool {
this->y_ = node;
return true;
};
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
if (env_ != nullptr && env_->Len() == 0) {
return y_;
}
PatternNode c1, c2, y;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y,
(IsValueNode<EnvInstance>(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) &&
(GetValueNode<EnvInstancePtr>(c1.GetNode(node)))->Len() == 0));
return nullptr;
}
void Visit(const ValueNodePtr &vnode) override {
if (env_ == nullptr) {
env_ = GetValueNode<EnvInstancePtr>(vnode);
}
}
void Reset() {
y_ = nullptr;
env_ = nullptr;
}
private:
AnfNodePtr y_{nullptr};
EnvInstancePtr env_{nullptr};
};
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor {
while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
// {prim::kPrimEnvSetItem, env, symbolickey, value}
auto &inputs = env->cast<CNodePtr>()->inputs();
if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance.";
if (inputs.size() != 4) {
MS_LOG(WARNING) << "Input size should be 4";
return nullptr;
}
if (!IsValueNode<SymbolicKeyInstance>(inputs[2])) {
MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?";
return nullptr;
}
env = inputs[1];
@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller {
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor {
public:
IncorporateEnvGetitem() : env_get_item_transform_() {}
explicit IncorporateEnvGetitem(bool bypass_recursive = false)
: env_get_item_transform_(), bypass_recursive_(bypass_recursive) {}
~IncorporateEnvGetitem() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor {
auto inputs = inp1->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
auto new_fg = env_get_item_transform_(fg, key, default_v);
if (fg->recursive() && bypass_recursive_) {
MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString();
return nullptr;
}
if (new_fg == nullptr) {
return nullptr;
}
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(new_fg));
(void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor {
private:
bool is_match_{false};
internal::EnvGetitemTransform env_get_item_transform_;
bool bypass_recursive_;
};
// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
auto new_g1 = env_get_item_transform_(g1, key, default_v);
auto new_g2 = env_get_item_transform_(g2, key, default_v);
if (new_g1 == nullptr || new_g2 == nullptr) {
return nullptr;
}
auto fg = node->func_graph();
auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});

View File

@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); }
bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; }
bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) {
bool unique_use = IsUniqueUse(fg, nullptr);
bool is_recursive = fg->recursive();
if (fg->parent() != nullptr && is_recursive) {
if (fg->parent() == node->func_graph() && unique_use) {
return true;
}
}
return false;
}
// {G, Xs}
class InlinerBase : public AnfVisitor {
public:
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions) : criterions_(criterions) {}
explicit InlinerBase(std::vector<std::pair<CriterionFuncType, bool>> criterions, bool use_move = true)
: use_move_(use_move), criterions_(criterions) {}
~InlinerBase() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!node->isa<CNode>()) {
@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr;
}
// Do not inline GraphKernel to Cell.
if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// If the GraphKernel only contains a return node, we make it inlined.
@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor {
std::vector<AnfNodePtr> params;
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
if (IsUniqueUse(fg, nullptr)) {
// compare size to avoid the case that the function has default value after grad.
// for which after renormalize, the function default value will be an input
if (fg->parameters().size() != params.size()) {
return nullptr;
}
if (use_move_ && IsUniqueUse(fg, nullptr)) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, params, fg);
@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor {
private:
bool is_checked_{false}, is_recursive_{false};
bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
};
class Inliner : public InlinerBase {
public:
Inliner()
: InlinerBase({
{IsUniqueUse, true},
{IsTrivial, false},
{IsInside, false},
{IsCore, false},
{NoCriterion, true},
}) {}
explicit Inliner(bool use_move = true)
: InlinerBase(
{
{IsUniqueUse, true},
{IsTrivial, false},
{IsInside, false},
{IsCore, false},
{IsDirectParentCall, false},
{NoCriterion, true},
},
use_move) {}
~Inliner() override = default;
};
class DirectInliner : public InlinerBase {
public:
explicit DirectInliner(bool use_move = true)
: InlinerBase(
{
{IsDirectParentCall, false},
},
use_move) {}
~DirectInliner() override = default;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -26,6 +26,30 @@
namespace mindspore {
namespace opt {
namespace irpass {
namespace internal {
class GetRefValueTransform {
public:
GetRefValueTransform() {}
~GetRefValueTransform() = default;
AnfNodePtr operator()(const AnfNodePtr &node) {
CNodePtr cnode = node->cast<CNodePtr>();
auto inputs = cnode->inputs();
auto fg = GetValueNode(inputs[0])->cast<FuncGraphPtr>();
if (fg->recursive()) {
MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString();
return node;
}
auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("GetRefValue"));
auto output = new_fg->output();
new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output}));
inputs[0] = NewValueNode(new_fg);
auto ret_node = cnode->func_graph()->NewCNode(inputs);
return ret_node;
}
};
} // namespace internal
// {prim::kPrimMakeRef, X, Y, Z} -> Y
class MakeRefEliminater : public OptimizerCaller {
public:
@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller {
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f}
class GetMakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node));
internal::GetRefValueTransform trans;
auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr {
auto rep = trans(x.GetNode(node));
if (rep != nullptr) {
return rep;
}
return nullptr;
};
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node));
return nullptr;
}
};

View File

@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
auto context_ptr = MsContext::GetInstance();
std::string backend = MsContext::GetInstance()->backend_policy();
MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false);
@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target();
if (device_target == kAscendDevice) {
if (device_target == kAscendDevice && backend != kMsVm) {
bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
}
}
if (IsCtrlSink()) {
if (IsCtrlSink() && backend == kMsConvert) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}
@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) {
if (res->results().count(kOutput) == 0) {
MS_LOG(EXCEPTION) << "Execute args error";
}
if (IsCtrlSink()) {
std::string backend = MsContext::GetInstance()->backend_policy();
if (IsCtrlSink() && backend == kMsConvert) {
if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}

View File

@ -30,6 +30,7 @@
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/validator.h"
#include "pipeline/jit/remove_value_node_dup.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/cse.h"
#include "frontend/optimizer/graph_kernel_reuse.h"
@ -127,11 +128,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
});
opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.arithmetic_simplify2_,
irpass.same_eliminate_,
@ -154,6 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
{"virtual_dataset", virtual_dataset},
{"grad", grad},
{"resolve", resolve_pass},
{"a_after_grad", a_after_grad},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSE(false))},
{"a_3", a_3}});
@ -161,11 +166,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return map_a;
}
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining
irpass.inline_,
irpass.partial_eliminate_,
});
OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
return map_a;
}
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 =
opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_});
opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.value_based_eliminate_});
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,
@ -244,6 +262,8 @@ void InitOpt(const ResourcePtr &res) {
opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_after_cconv"] =
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] =
@ -288,6 +308,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
@ -311,6 +332,33 @@ bool AddControlDependPass(const ResourcePtr &res) {
return true;
}
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(res->manager());
if (res->manager()->func_graphs().size() <= 1) {
return true;
}
return MergeDuplicateGraphs(res->manager());
}
bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Remove value node duplications error.";
}
auto manager = res->manager();
HashCache hash_cache;
HashValue hashes;
// Remove duplicated value nodes across all graphs in manager
for (auto &fg : manager->func_graphs()) {
auto value_nodes = fg->value_nodes();
for (const auto &value_pair : value_nodes) {
TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
}
}
return true;
}
bool CconvPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
@ -340,6 +388,8 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"opt_after_cconv", OptPassAfterCconvGroup},
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "pipeline/jit/remove_value_node_dup.h"
#include "ir/anf.h"
@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has
// Meet for the first time, append node to bucket.
bucket.emplace_back(node);
}
size_t HashOfGraph(const FuncGraphPtr &fg) {
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
MS_LOG(DEBUG) << "TopSort for:" << fg->ToString();
std::unordered_map<AnfNodePtr, std::size_t> hashes;
auto &params = fg->parameters();
for (size_t i = 0; i < params.size(); i++) {
hashes[params[i]] = std::hash<std::string>{}("param" + std::to_string(i));
}
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
if (hashes.find(node) != hashes.end()) {
continue;
}
std::size_t h = 0;
if (node->isa<ValueNode>()) {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (IsValueNode<FuncGraph>(value_node)) {
auto v_fg = value->cast<FuncGraphPtr>();
h = value->hash();
} else if (IsValueNode<tensor::Tensor>(value_node)) {
// the tensor has same value has been replaced in duplicate value pass,
// so we use the value pointer here as an identifier
h = hash_combine(value->hash(), std::hash<Value *>{}(value.get()));
} else {
h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash()));
}
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
size_t init = 0;
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
return hash_combine(hash, hashes[node_in]);
});
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknow node type";
}
hashes[node] = h;
}
return hashes[fg->get_return()];
}
bool IsCNodeGraph(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
return IsValueNode<FuncGraph>(inp0);
}
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) {
std::unordered_map<size_t, std::vector<FuncGraphPtr>> hash_graphs;
std::unordered_map<FuncGraphPtr, size_t> graph_hash;
for (auto fg : manager->func_graphs()) {
size_t h = HashOfGraph(fg);
graph_hash[fg] = h;
if (hash_graphs.find(h) == hash_graphs.end()) {
hash_graphs[h] = {fg};
} else {
hash_graphs[h].push_back(fg);
}
}
FuncGraphPairMapEquiv equiv_graph;
NodeMapEquiv equiv_node;
for (auto &fg : manager->func_graphs()) {
MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString();
for (auto &item : fg->nodes()) {
if (!item->isa<CNode>()) {
continue;
}
auto &inputs = item->cast<CNodePtr>()->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
if (!inputs[i]->isa<ValueNode>()) {
continue;
}
auto value_ptr = GetValueNode(inputs[i]);
auto v_fg = value_ptr->cast<FuncGraphPtr>();
if (v_fg == nullptr) {
continue;
}
auto &fg_vec = hash_graphs[graph_hash[v_fg]];
if (fg_vec.size() > 1) {
if (v_fg != fg_vec[0]) {
bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node);
if (is_morphic) {
auto new_node = NewValueNode(fg_vec[0]);
MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString();
manager->Replace(inputs[i], new_node);
}
}
}
}
}
}
return true;
}
} // namespace pipeline
} // namespace mindspore

View File

@ -28,6 +28,10 @@ using HashCache = std::unordered_map<std::size_t, std::vector<AnfNodePtr>>;
using HashValue = std::unordered_map<AnfNodePtr, std::size_t>;
void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value);
size_t HashOfGraph(const FuncGraphPtr &fg);
bool IsCNodeGraph(const AnfNodePtr &node);
bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager);
} // namespace pipeline
} // namespace mindspore

View File

@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
}
const AnfNodePtr &func_node = fg->get_return();
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString()
MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString()
<< ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString();
AbstractBasePtr ret_base = nullptr;
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString();
ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
<< ", abstract: " << ret_base->ToString();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString()
<< ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString();
}
MS_EXCEPTION_IF_NULL(ret_base);
@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->Broaden();
if (arg->GetValueTrack() != kAnyValue) {
return arg->Broaden();
}
return arg;
});
if (func_graph_->joined_shapes_.size() != broaded_list.size()) {
MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size()
<< " does not equal to number of original buffer arguments "
<< func_graph_->joined_shapes_.size();
}
for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
if (func_graph_->joined_shapes_.size() == broaded_list.size()) {
for (size_t i = 0; i < broaded_list.size(); ++i) {
broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]);
}
}
MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list)
<< ", broaded: " << mindspore::ToString(broaded_list);
return broaded_list;
@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
return joined_args_spec_list;
@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list);

View File

@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
trace::TraceEvalCNodeLeave();
} else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
<< (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph")
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
@ -301,6 +302,8 @@ void AnalysisEngine::Clear() {
anfnode_config_map_.clear();
eval_trace_.clear();
constructors_.clear();
constructors_app_.clear();
continued_evals_.clear();
}
namespace {
@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr
MS_EXCEPTION_IF_NULL(func);
AbstractFunctionPtr func_orig = func->fn();
EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig);
auto part_pair = std::make_pair(func_orig, func->args());
auto itr = constructors_app_.find(part_pair);
if (itr != constructors_app_.end()) {
return itr->second;
}
std::shared_ptr<PartialAppEvaluator> partial_evaluator =
std::make_shared<PartialAppEvaluator>(evaluator_orig, func->args());
constructors_app_[part_pair] = partial_evaluator;
return partial_evaluator;
}
@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) {
if (fg_eval == nullptr) {
return;
}
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto undetermined_fgs = fg->recursive_graphs();
auto undetermined_fgs = fg->recursive();
if (undetermined_fgs) {
auto fg_parent = fg->parent();
MS_EXCEPTION_IF_NULL(fg_parent);
@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPt
MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size();
for (auto u_eval : undetermined_evals) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined.";
if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined.";
MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined.";
auto &alternate_evaluator = multi_poss_[u_eval.first];
auto &eval_cache = alternate_evaluator->cache();
if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) &&
(((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) ||
(eval_cache->find(args_spec_list) == eval_cache->end()))) {
MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined.";
has_undetermined = true;
break;
}
}
if (has_undetermined == false) {
MS_LOG(DEBUG) << eval->ToString() << " has no undetermined.";
MS_LOG(DEBUG) << eval->ToString() << "has no undetermined.";
*continue_flag = true;
return latest_entry;
}
@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
if (it == eval_trace_.rend()) {
eval_trace_.push_back(current_inf);
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
MS_EXCEPTION_IF_NULL(eval);
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
out_specs.push_back(eval_result->abstract());
eval_trace_.pop_back();
if (eval_trace_.empty()) {
multi_poss_.clear();
}
} else if (it != eval_trace_.rbegin()) {
} else {
bool continue_flag = false;
auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag);
if (continue_flag) {
MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString();
continued_evals_.insert(current_inf);
continue;
}
// Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->first) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString();
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString()
<< " return out_spec: " << eval_result->abstract()->ToString();
return eval_result;
}

View File

@ -26,6 +26,7 @@
#include <vector>
#include <utility>
#include <map>
#include <set>
#ifdef DEBUG
#include <stack>
@ -113,7 +114,8 @@ class AnfNodeConfig : public Config {
std::string ToString() const override {
std::ostringstream buffer;
buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString();
buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId()
<< "), Context: " << context_->ToString();
return buffer.str();
}
@ -173,7 +175,13 @@ struct AnalysisResult {
};
using EvalTraceRevIter = std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>>::reverse_iterator;
struct PartialAppHasher {
std::size_t operator()(const std::pair<AbstractFunctionPtr, AbstractBasePtrList> &p) const {
auto h1 = std::hash<AbstractFunctionPtr>{}(p.first);
auto h2 = AbstractBasePtrListHash(p.second);
return h1 ^ h2;
}
};
class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
public:
AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager)
@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
std::unordered_map<std::pair<AbstractFunctionPtr, AbstractBasePtrList>, EvaluatorPtr, PartialAppHasher>
constructors_app_;
AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators.
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_;
std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> continued_evals_;
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list);

View File

@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractRef;
using mindspore::abstract::AbstractRowTensor;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor;
@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
// only send string in external
if (!IsValueNode<StringImm>(node)) {
// Validate a type.
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString();
MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString()
<< " for node=" << node->DebugString();
}
}
return;
@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>() || ptrBase->isa<AbstractRef>()) {
return;
}

View File

@ -481,8 +481,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
}
// Isomorphism
static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node) {
static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node);
bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph,
NodeMapEquiv *const equiv_node) {
if (equiv_node == nullptr) {
MS_LOG(ERROR) << "Invalid equiv_node";
return false;
@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu
MS_LOG(DEBUG) << "two parameters are not equal.";
return false;
}
if (node1->isa<CNode>() && node2->isa<CNode>()) {
return SameNode(node1, node2, equiv_func_graph, equiv_node);
}
MS_LOG(ERROR) << "type error";
return false;
}

View File

@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
} // namespace
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
auto fg = std::make_shared<FuncGraph>();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty";
}
TraceManager::DebugTrace(
std::make_shared<TraceSegmentTransform>(lst[0]->cast<CNodePtr>()->func_graph()->debug_info()));
auto fg = std::make_shared<FuncGraph>();
TraceManager::EndTrace();
AnfNodePtrList inputs;
AnfNodePtrToAnfNodePtrMap eqv;
// Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) {
if (!n->isa<CNode>()) {
@ -154,7 +157,9 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
}
TraceManager::DebugTrace(std::make_shared<TraceGetEnv>(n->debug_info()));
eqv[n] = fg->NewCNode(args);
TraceManager::EndTrace();
eqv[n]->set_abstract(n->abstract());
eqv[n]->set_kernel_info(n->kernel_info_ptr());
}

View File

@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
auto other_tensor = dyn_cast<AbstractTensor>(other);
if (other_tensor == nullptr) {
auto ref_tensor = dyn_cast<AbstractRef>(other);
if (ref_tensor != nullptr) {
return this->Join(ref_tensor->ref());
}
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
}
if (*this == *other) {

View File

@ -48,7 +48,7 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
continue;
}
if (rank.find(node) != rank.end() && rank[node] != todo.size()) {
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString();
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
}
rank[node] = todo.size();
bool cont = false;

View File

@ -30,6 +30,7 @@
#include "base/base.h"
#include "ir/dtype.h"
#include "ir/dtype/number.h"
#include "utils/hashing.h"
using std::fabs;
@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr<Scalar>;
class BoolImm : public Scalar {
public:
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash<bool>{}(v_); }
explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash<bool>{}(v_)}); }
~BoolImm() override = default;
MS_DECLARE_PARENT(BoolImm, Scalar)
std::size_t hash() const override { return hash_; }
@ -91,7 +92,7 @@ class IntergerImm : public Scalar {
class Int8Imm : public IntergerImm {
public:
Int8Imm() : IntergerImm(kInt8), v_(0) {}
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash<int>{}(v_); }
explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int8Imm() override = default;
MS_DECLARE_PARENT(Int8Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t)
class Int16Imm : public IntergerImm {
public:
Int16Imm() : IntergerImm(kInt16), v_(0) {}
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash<int>{}(v_); }
explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int16Imm() override = default;
MS_DECLARE_PARENT(Int16Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t)
class Int32Imm : public IntergerImm {
public:
Int32Imm() : IntergerImm(kInt32), v_(0) {}
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash<int>{}(v_); }
explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash<int>{}(v_)}); }
~Int32Imm() override = default;
MS_DECLARE_PARENT(Int32Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t)
class Int64Imm : public IntergerImm {
public:
Int64Imm() : IntergerImm(kInt64), v_(0) {}
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash<int64_t>{}(v_); }
explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash<int64_t>{}(v_)}); }
~Int64Imm() override = default;
MS_DECLARE_PARENT(Int64Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t)
class UInt8Imm : public IntergerImm {
public:
UInt8Imm() : IntergerImm(kUInt8), v_(0) {}
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt8Imm() override = default;
MS_DECLARE_PARENT(UInt8Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t);
class UInt16Imm : public IntergerImm {
public:
UInt16Imm() : IntergerImm(kUInt16), v_(0) {}
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt16Imm() override = default;
MS_DECLARE_PARENT(UInt16Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t);
class UInt32Imm : public IntergerImm {
public:
UInt32Imm() : IntergerImm(kUInt32), v_(0) {}
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash<unsigned int>{}(v_); }
explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) {
hash_ = hash_combine({tid(), std::hash<unsigned int>{}(v_)});
}
~UInt32Imm() override = default;
MS_DECLARE_PARENT(UInt32Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t);
class UInt64Imm : public IntergerImm {
public:
UInt64Imm() : IntergerImm(kUInt64), v_(0) {}
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash<uint64_t>{}(v); }
explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) {
hash_ = hash_combine({tid(), std::hash<uint64_t>{}(v)});
}
~UInt64Imm() override = default;
MS_DECLARE_PARENT(UInt64Imm, IntergerImm)
std::size_t hash() const override { return hash_; }
@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr<FloatImm>;
class FP32Imm : public FloatImm {
public:
FP32Imm() : FloatImm(kFloat32), v_(0.0) {}
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash<float>{}(v_); }
explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash<float>{}(v_)}); }
~FP32Imm() override = default;
MS_DECLARE_PARENT(FP32Imm, FloatImm)
std::size_t hash() const override { return hash_; }
@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float)
class FP64Imm : public FloatImm {
public:
FP64Imm() : FloatImm(kFloat64), v_(0.0) {}
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash<double>{}(v_); }
explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash<double>{}(v_)}); }
~FP64Imm() override = default;
MS_DECLARE_PARENT(FP64Imm, FloatImm)
std::size_t hash() const override { return hash_; }

View File

@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo {
return std::make_shared<TraceCombileLikeGraphs>(*shared_from_base<TraceCombileLikeGraphs>());
}
};
class TraceSegmentTransform : public TraceInfo {
public:
explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceSegmentTransform() override = default;
TraceInfoPtr clone() override {
return std::make_shared<TraceSegmentTransform>(*shared_from_base<TraceSegmentTransform>());
}
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_

View File

@ -0,0 +1,816 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test control ops """
import numpy as np
from mindspore import dtype as ms
from mindspore import Tensor
from mindspore import context
from mindspore import nn
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops import composite as C
from mindspore.ops import operations as P
# from tests.vm_impl.math_ops_vm_impl import *
# from tests.vm_impl.vm_interface import *
# from tests.vm_impl import *
# context.set_context(save_graphs=True)
def test_while_forward():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
def construct(self, idx, end, x):
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
idx = idx + 1
return x
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
def construct(self, idx, end, x):
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
idx = idx + 1
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_forward():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
out = out + x + self.param
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_endless_case():
"""endless case when optmization"""
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
part = x[idx, :, :]
out = out + part
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
net = MyWhileNet()
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
part = x[idx, :, :]
max_num = self.max(part)
x[idx, :, 0:2] = max_num
out = out + x + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_forward_with_const_branch():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
def construct(self, idx, end, x):
out = self.zero
while idx < end:
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
idx = idx + 1
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = while_net
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_opt_endless():
"""endless during optimization case"""
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.addn = P.AddN()
def construct(self, idx, end, x):
addn1 = self.addn((x, x, x))
out = addn1
while idx < end:
out = self.addn((out, addn1))
idx = idx + 1
out = self.addn((out, x))
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
net(idx, end, x)
def test_no_while_call():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
def construct(self, idx, end, x):
out = self.zero
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = while_net
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_grad_with_const_branch():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
def construct(self, idx, end, x):
out = self.zero
while idx < end:
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_for_while_with_param_grad_with_const_branch():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.start = Tensor(np.array(0), dtype=ms.int32)
def construct(self, idx, end, x):
out = self.zero
for _ in range(0, 2):
idx = self.start
while idx < end:
if 2 > 1:
out = out + self.param
else:
out = out + idx + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_for_while_with_param_grad_basic():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.start = Tensor(np.array(0), dtype=ms.int32)
def construct(self, idx, end, x):
out = self.zero
for _ in range(0, 2):
idx = self.start
while idx < end:
out = out + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_for_while_with_param_grad_normal():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.reduce = P.ReduceSum()
self.start = Tensor(np.array(0), dtype=ms.int32)
def construct(self, idx, end, x):
out = x
for _ in range(0, 2):
idx = self.start
while idx < end:
out = out + self.param
idx = idx + 1
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad_mul():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out * self.param
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad_two():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param + self.weight
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_basic_grad_three():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param + self.weight + self.key
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_if_with_param_grad():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
self.t2 = Tensor(np.array(2), dtype=ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
if self.max(out) < self.max(x):
out = out + self.param * 2
else:
out = out + self.param
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(3), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_while_with_param_grad_not_enter_while():
class MyWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, idx, end, x):
out = self.zero
while idx < end:
out = out + self.param * 3
idx = idx + 1
return out + self.param
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet()
net = GradNet(while_net)
idx = Tensor(np.array(3), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_forward():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param
else:
out = out + x
if a == b:
out = out + x*3 + self.param
else:
out = out + x*2
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(4), dtype=ms.int32)
x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_grad_inputs():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param * 4
if a == b:
out = out + x*3 + self.param * 3
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return C.grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_grad_parameter():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param * 2
if a == b:
out = out + x*3 + self.param
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(0), dtype=ms.int32)
end = Tensor(np.array(2), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_with_param_if_by_if_grad_param_excute_null():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
out = out + x + self.param * 2
return out
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(4), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_return_inside_grad():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.max = P.ReduceMax()
self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
def construct(self, a, b, x):
out = self.zero
if a < b:
return out + x + self.param
if a == b:
return out + self.param * 2
return out + self.param * 3
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = GradNet(if_net)
idx = Tensor(np.array(1), dtype=ms.int32)
end = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
net(idx, end, x)
def test_if_by_if_forward():
class MyIfByIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.TensorAdd()
self.sub = P.Sub()
self.mul = P.Mul()
self.div = P.RealDiv()
def construct(self, a, b, x):
if a < b:
a = self.add(a, b)
else:
a = self.sub(a, b)
if a == x:
a = self.mul(a, b)
else:
a = self.div(a, b)
if b == x:
b = self.add(a, b)
else:
b = self.add(a, x)
a = a * b
out = a + b + x
return out
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet()
net = if_net
idx = Tensor(np.array(2), dtype=ms.float32)
end = Tensor(np.array(3), dtype=ms.float32)
x = Tensor(np.array(4), dtype=ms.float32)
net(idx, end, x)

View File

@ -58,6 +58,7 @@ add_subdirectory(serving)
file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/core/base/*.cc"
"../../../mindspore/core/gvar/*.cc"
"../../../mindspore/core/abstract/*.cc"
"../../../mindspore/core/ir/*.cc"
"../../../mindspore/core/utils/*.cc"

View File

@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class InputBackward(nn.Cell):
def __init__(self, network):
super(InputBackward, self).__init__()

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
CURRPATH=$(cd $(dirname $0); pwd)
IGNORE_EXEC="--ignore=$CURRPATH/exec"
PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd)

View File

@ -14,7 +14,6 @@
# ============================================================================
"""Generate vm_impl function for array ops"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
from .vm_interface import vm
# pylint: disable=unused-argument
@ -181,8 +179,7 @@ def vm_impl_tile(self):
def vm_impl(x, multiples):
x = x.asnumpy()
multiples = multiples.asnumpy()
out = vm.Tile(x, multiples)
out = np.tile(x, multiples)
return Tensor(out)
return vm_impl
@ -255,7 +252,10 @@ def vm_impl_sum(self):
def vm_impl(x, axis):
x = x.asnumpy()
out = vm.sum(x, axis)
if axis == ():
out = np.sum(x)
else:
out = np.sum(x, axis=axis)
return Tensor(np.array(out))
return vm_impl
@ -291,12 +291,14 @@ def vm_impl_square(self):
return vm_impl
@vm_impl_getters.register(P.ZerosLike)
def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike"""
def vm_impl(x):
return Tensor(np.zeros_like(x.asnumpy()))
@vm_impl_getters.register(P.Partial)
def vm_impl_partial(self):
"""Generate vm_impl function for Partial"""
@ -307,6 +309,7 @@ def vm_impl_partial(self):
return vm_impl
@vm_impl_getters.register(P.Depend)
def vm_impl_depend(self):
"""Generate vm_impl function for Depend"""

View File

@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self):
return vm_impl
@vm_impl_getters.register(P.ReduceMax)
def vm_impl_reduce_max(self):
"""Generate vm_impl function for ReduceMean."""
def vm_impl(x, axis):
x = x.asnumpy()
if axis == ():
axis = None
out = np.amax(x, axis)
return Tensor(out)
return vm_impl
@vm_impl_getters.register(P.Equal)
def vm_impl_equal(self):