From befc7a9dea8e56bee52338e37949270d183f54a1 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Wed, 28 Apr 2021 14:28:58 +0800 Subject: [PATCH] Add primal_attr to link between forward and backward node attr --- mindspore/ccsrc/debug/anf_ir_dump.cc | 30 ++++++++++++++++++- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 3 +- .../ccsrc/frontend/optimizer/ad/dfunctor.h | 2 +- .../ccsrc/frontend/optimizer/ad/kprim.cc | 14 ++++++++- .../pipeline/jit/static_analysis/prim.cc | 23 ++++++++++---- mindspore/ccsrc/utils/utils.h | 3 ++ mindspore/core/ir/anf.h | 30 +++++++++++++++++-- mindspore/core/ir/func_graph_cloner.cc | 8 +---- tests/st/mix_precision/test_mix_precision.py | 6 ++-- 9 files changed, 95 insertions(+), 24 deletions(-) diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index 734cd25a14a..6a93370a3d3 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -341,7 +341,6 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr &g return; } if (op->attrs().empty()) { - gsub->buffer << std::endl; return; } @@ -360,6 +359,32 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr &g } } gsub->buffer << "}"; +} + +void DumpCNodePrimalAttrs(const CNodePtr &op, const std::shared_ptr &gsub) { + if (op == nullptr || gsub == nullptr) { + return; + } + if (op->primal_attrs().empty()) { + gsub->buffer << std::endl; + return; + } + + auto primal_attrs = op->primal_attrs(); + gsub->buffer << " cnode_primal_attrs: {"; + int i = 0; + for (const auto &attr : primal_attrs) { + if (i++ != 0) { + gsub->buffer << ", "; + } + gsub->buffer << attr.first << ": "; + if (attr.second == nullptr) { + gsub->buffer << "null"; + } else { + gsub->buffer << attr.second->ToString(); + } + } + gsub->buffer << "}"; gsub->buffer << std::endl; } @@ -415,6 +440,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMapabstract(); primal_user->set_input(0, k_vnode); primal_user->set_abstract(j_user->abstract()); @@ -984,7 +983,7 @@ void DFunctor::EliminatePrimalGraph() { auto idx0 = NewValueNode(SizeToLong(0)); idx0->set_abstract(std::make_shared(imm0)); auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0}); - getitem0->set_abstract(primal_abs); + getitem0->CloneCNodeInfo(primal_user); (void)manager->Replace(primal_user, getitem0); } } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 47497b7bea9..6ecc44a780f 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -230,7 +230,7 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, cons TraceGuard trace_guard(std::make_shared(cnode->debug_info())); out_value = outer->NewCNode(transf_args); if constexpr (std::is_same::value) { - out_value->CloneUserData(cnode); + out_value->CloneCNodeInfo(cnode); } } else { out_value = outer->NewCNode(transf_args); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index edc5c373761..4cb75c8e788 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -1,7 +1,7 @@ /** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ #include "frontend/optimizer/ad/dfunctor.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/composite.h" +#include "utils/utils.h" #include "utils/symbolic.h" #include "utils/primitive_utils.h" #include "utils/ms_context.h" @@ -178,6 +179,17 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ } } } + + if (cnode != nullptr) { + Manage({cnode->func_graph(), bprop_fg}, false); + const auto &bprop_fg_cnodes = bprop_fg->GetOrderedCnodes(); + const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId(); + for (auto &bprop_fg_cnode : bprop_fg_cnodes) { + bprop_fg_cnode->set_primal_attrs(cnode->primal_attrs()); + bprop_fg_cnode->AddPrimalAttr(kPrimalAttrForwardNodeName, MakeValue(forward_node_primal_attr)); + } + } + AdjustForAutoMonad(prim, bprop_fg); auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode); if (expanded_fg == nullptr) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 9e89322121b..5bd49859b95 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -41,6 +41,7 @@ #include "abstract/param_validator.h" #include "utils/ms_utils.h" #include "utils/shape_utils.h" +#include "utils/parallel_node_check.h" namespace mindspore { namespace abstract { @@ -86,16 +87,22 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt } ScopeGuard scope_guard(scope); - AnfNodePtr new_cnode = nullptr; + AnfNodePtr new_node = nullptr; if (bound_node() != nullptr) { TraceGuard trace_guard(std::make_shared(bound_node()->debug_info())); - new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, - args_inputs); + new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, + args_inputs); } else { - new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, - args_inputs); + new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, + args_inputs); + } + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); + + if (out_node->isa()) { + auto out_cnode = out_node->cast(); + auto new_cnode = new_node->cast(); + new_cnode->CloneCNodeInfo(out_cnode); } - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); return engine->ForwardConfig(out_conf, fn_conf); } @@ -256,6 +263,10 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); + if (new_node->isa()) { + auto new_cnode = new_node->cast(); + new_cnode->CloneCNodeInfo(out_node); + } return engine->ForwardConfig(out_conf, fn_conf); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4f64ff6bfcc..51a50aa2104 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -427,6 +427,9 @@ constexpr auto kAttrRecursiveEnd = "recursive_end"; constexpr auto kAttrRecursive = "recursive"; constexpr auto kAttrMultiCallEnd = "multicall_end"; +// primal attr key name +constexpr auto kPrimalAttrForwardNodeName = "forward_node_name"; + // attr value constexpr auto kValueTargetSwitch = "target_switch"; constexpr auto kValueTargetOther = "target_other"; diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 888eadbddf5..d8a63e23f5f 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -281,9 +281,7 @@ class CNode : public AnfNode, public EffectInfoHolder { const std::unordered_map &attrs() const { return attrs_; } void set_attrs(const std::unordered_map &attrs) { - for (auto &attr : attrs) { - attrs_[attr.first] = attr.second; - } + attrs_.insert(attrs.cbegin(), attrs.cend()); } void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; } @@ -294,6 +292,31 @@ class CNode : public AnfNode, public EffectInfoHolder { } bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); } ssize_t input_tensor_num() const { return input_tensor_num_; } + + const std::unordered_map &primal_attrs() const { return primal_attrs_; } + void set_primal_attrs(const std::unordered_map &attrs) { + primal_attrs_.insert(attrs.cbegin(), attrs.cend()); + } + void AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; } + void ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); } + ValuePtr GetPrimalAttr(const std::string &name) const { + auto iter = primal_attrs_.find(name); + return iter == primal_attrs_.cend() ? nullptr : iter->second; + } + bool HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != attrs_.cend(); } + + void CloneCNodeInfo(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + set_abstract(node->abstract()); + set_forward(node->forward().first, node->forward().second); + set_inputs_value(node->inputs_value()); + set_attrs(node->attrs()); + set_primal_attrs(node->primal_attrs()); + set_load_flag(node->get_load_flag()); + CloneUserData(node); + set_kernel_info(node->kernel_info_ptr()); + } + void set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; } // Is effect have been handled. @@ -314,6 +337,7 @@ class CNode : public AnfNode, public EffectInfoHolder { std::vector> inputs_value_; std::pair output_value_; std::unordered_map attrs_; + std::unordered_map primal_attrs_; ssize_t input_tensor_num_ = -1; }; diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index ab6ec02646f..bd81ecb8310 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -87,18 +87,12 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { TraceGuard trace_guard(node->debug_info(), relation_); CNodePtr new_node = std::make_shared(AnfNodePtrList{}, target); auto old_node = node->cast(); - new_node->set_abstract(old_node->abstract()); - new_node->set_forward(old_node->forward().first, old_node->forward().second); - new_node->set_inputs_value(old_node->inputs_value()); - new_node->set_attrs(old_node->attrs()); - new_node->set_load_flag(old_node->get_load_flag()); + new_node->CloneCNodeInfo(old_node); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_node->set_scope(scope); - new_node->CloneUserData(old_node); if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) { new_node->set_fullname_with_scope(old_node->fullname_with_scope()); } - new_node->set_kernel_info(old_node->kernel_info_ptr()); repl_node_[old_node] = new_node; nodes_.emplace_back(old_node, new_node); } diff --git a/tests/st/mix_precision/test_mix_precision.py b/tests/st/mix_precision/test_mix_precision.py index 5796efb1b22..168cd7bf2fb 100644 --- a/tests/st/mix_precision/test_mix_precision.py +++ b/tests/st/mix_precision/test_mix_precision.py @@ -127,12 +127,12 @@ def test_sit_auto_mix_precision_model_o0(): model = Model(net, loss, opt, amp_level="O0") model.train(1, dataset1, dataset_sink_mode=False) contend = read_validateir_file('./test_amp_o0') - castnum = re.findall("Cast", contend) + castnum = re.findall(r"Cast\(", contend) assert len(castnum) == 5 clean_all_ir_files('./test_amp_o0') model.predict(Tensor(input_data)) contend = read_validateir_file('./test_amp_o0') - castnum = re.findall("Cast", contend) + castnum = re.findall(r"Cast\(", contend) assert len(castnum) == 11 clean_all_ir_files('./test_amp_o0') @@ -163,7 +163,7 @@ def test_sit_auto_mix_precision_model_o2(): model = Model(net, loss, opt, amp_level="O2") model.train(1, dataset1, dataset_sink_mode=False) contend = read_validateir_file('./test_amp_o2') - castnum = re.findall("Cast", contend) + castnum = re.findall(r"Cast\(", contend) assert len(castnum) == 14 clean_all_ir_files('./test_amp_o2') out_graph = model.predict(Tensor(input_data))