forked from mindspore-Ecosystem/mindspore
Add primal_attr to link between forward and backward node attr
This commit is contained in:
parent
2de5bbadb5
commit
befc7a9dea
|
@ -341,7 +341,6 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
|
|||
return;
|
||||
}
|
||||
if (op->attrs().empty()) {
|
||||
gsub->buffer << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -360,6 +359,32 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
|
|||
}
|
||||
}
|
||||
gsub->buffer << "}";
|
||||
}
|
||||
|
||||
void DumpCNodePrimalAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &gsub) {
|
||||
if (op == nullptr || gsub == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (op->primal_attrs().empty()) {
|
||||
gsub->buffer << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto primal_attrs = op->primal_attrs();
|
||||
gsub->buffer << " cnode_primal_attrs: {";
|
||||
int i = 0;
|
||||
for (const auto &attr : primal_attrs) {
|
||||
if (i++ != 0) {
|
||||
gsub->buffer << ", ";
|
||||
}
|
||||
gsub->buffer << attr.first << ": ";
|
||||
if (attr.second == nullptr) {
|
||||
gsub->buffer << "null";
|
||||
} else {
|
||||
gsub->buffer << attr.second->ToString();
|
||||
}
|
||||
}
|
||||
gsub->buffer << "}";
|
||||
gsub->buffer << std::endl;
|
||||
}
|
||||
|
||||
|
@ -415,6 +440,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<Anf
|
|||
// print cnode attrs
|
||||
DumpCNodeAttrs(nd, gsub);
|
||||
|
||||
// print cnode primal attrs
|
||||
DumpCNodePrimalAttrs(nd, gsub);
|
||||
|
||||
// print parallel info
|
||||
DumpParallelInfo(nd, gsub);
|
||||
|
||||
|
|
|
@ -965,7 +965,6 @@ void DFunctor::EliminatePrimalGraph() {
|
|||
|
||||
// Replace primal graph with k graph.
|
||||
auto k_vnode = NewValueNode(k_graph_);
|
||||
auto primal_abs = primal_user->abstract();
|
||||
primal_user->set_input(0, k_vnode);
|
||||
primal_user->set_abstract(j_user->abstract());
|
||||
|
||||
|
@ -984,7 +983,7 @@ void DFunctor::EliminatePrimalGraph() {
|
|||
auto idx0 = NewValueNode(SizeToLong(0));
|
||||
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
|
||||
auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
|
||||
getitem0->set_abstract(primal_abs);
|
||||
getitem0->CloneCNodeInfo(primal_user);
|
||||
(void)manager->Replace(primal_user, getitem0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -230,7 +230,7 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, cons
|
|||
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
|
||||
out_value = outer->NewCNode(transf_args);
|
||||
if constexpr (std::is_same<T, PrimitivePtr>::value) {
|
||||
out_value->CloneUserData(cnode);
|
||||
out_value->CloneCNodeInfo(cnode);
|
||||
}
|
||||
} else {
|
||||
out_value = outer->NewCNode(transf_args);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -29,6 +29,7 @@
|
|||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -178,6 +179,17 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cnode != nullptr) {
|
||||
Manage({cnode->func_graph(), bprop_fg}, false);
|
||||
const auto &bprop_fg_cnodes = bprop_fg->GetOrderedCnodes();
|
||||
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
|
||||
for (auto &bprop_fg_cnode : bprop_fg_cnodes) {
|
||||
bprop_fg_cnode->set_primal_attrs(cnode->primal_attrs());
|
||||
bprop_fg_cnode->AddPrimalAttr(kPrimalAttrForwardNodeName, MakeValue(forward_node_primal_attr));
|
||||
}
|
||||
}
|
||||
|
||||
AdjustForAutoMonad(prim, bprop_fg);
|
||||
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode);
|
||||
if (expanded_fg == nullptr) {
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
#include "abstract/param_validator.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/parallel_node_check.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -86,16 +87,22 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
}
|
||||
ScopeGuard scope_guard(scope);
|
||||
|
||||
AnfNodePtr new_cnode = nullptr;
|
||||
AnfNodePtr new_node = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
|
||||
new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
} else {
|
||||
new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
|
||||
args_inputs);
|
||||
}
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
|
||||
|
||||
if (out_node->isa<CNode>()) {
|
||||
auto out_cnode = out_node->cast<CNodePtr>();
|
||||
auto new_cnode = new_node->cast<CNodePtr>();
|
||||
new_cnode->CloneCNodeInfo(out_cnode);
|
||||
}
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
|
||||
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
@ -256,6 +263,10 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
|
||||
|
||||
if (new_node->isa<CNode>()) {
|
||||
auto new_cnode = new_node->cast<CNodePtr>();
|
||||
new_cnode->CloneCNodeInfo(out_node);
|
||||
}
|
||||
return engine->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
|
|
|
@ -427,6 +427,9 @@ constexpr auto kAttrRecursiveEnd = "recursive_end";
|
|||
constexpr auto kAttrRecursive = "recursive";
|
||||
constexpr auto kAttrMultiCallEnd = "multicall_end";
|
||||
|
||||
// primal attr key name
|
||||
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
constexpr auto kValueTargetOther = "target_other";
|
||||
|
|
|
@ -281,9 +281,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
for (auto &attr : attrs) {
|
||||
attrs_[attr.first] = attr.second;
|
||||
}
|
||||
attrs_.insert(attrs.cbegin(), attrs.cend());
|
||||
}
|
||||
|
||||
void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
|
||||
|
@ -294,6 +292,31 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
}
|
||||
bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }
|
||||
ssize_t input_tensor_num() const { return input_tensor_num_; }
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &primal_attrs() const { return primal_attrs_; }
|
||||
void set_primal_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
|
||||
primal_attrs_.insert(attrs.cbegin(), attrs.cend());
|
||||
}
|
||||
void AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; }
|
||||
void ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); }
|
||||
ValuePtr GetPrimalAttr(const std::string &name) const {
|
||||
auto iter = primal_attrs_.find(name);
|
||||
return iter == primal_attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
bool HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != attrs_.cend(); }
|
||||
|
||||
void CloneCNodeInfo(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
set_abstract(node->abstract());
|
||||
set_forward(node->forward().first, node->forward().second);
|
||||
set_inputs_value(node->inputs_value());
|
||||
set_attrs(node->attrs());
|
||||
set_primal_attrs(node->primal_attrs());
|
||||
set_load_flag(node->get_load_flag());
|
||||
CloneUserData(node);
|
||||
set_kernel_info(node->kernel_info_ptr());
|
||||
}
|
||||
|
||||
void set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; }
|
||||
|
||||
// Is effect have been handled.
|
||||
|
@ -314,6 +337,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
|
||||
std::pair<ValuePtr, std::string> output_value_;
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::unordered_map<std::string, ValuePtr> primal_attrs_;
|
||||
ssize_t input_tensor_num_ = -1;
|
||||
};
|
||||
|
||||
|
|
|
@ -87,18 +87,12 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
|
|||
TraceGuard trace_guard(node->debug_info(), relation_);
|
||||
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
|
||||
auto old_node = node->cast<CNodePtr>();
|
||||
new_node->set_abstract(old_node->abstract());
|
||||
new_node->set_forward(old_node->forward().first, old_node->forward().second);
|
||||
new_node->set_inputs_value(old_node->inputs_value());
|
||||
new_node->set_attrs(old_node->attrs());
|
||||
new_node->set_load_flag(old_node->get_load_flag());
|
||||
new_node->CloneCNodeInfo(old_node);
|
||||
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
|
||||
new_node->set_scope(scope);
|
||||
new_node->CloneUserData(old_node);
|
||||
if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) {
|
||||
new_node->set_fullname_with_scope(old_node->fullname_with_scope());
|
||||
}
|
||||
new_node->set_kernel_info(old_node->kernel_info_ptr());
|
||||
repl_node_[old_node] = new_node;
|
||||
nodes_.emplace_back(old_node, new_node);
|
||||
}
|
||||
|
|
|
@ -127,12 +127,12 @@ def test_sit_auto_mix_precision_model_o0():
|
|||
model = Model(net, loss, opt, amp_level="O0")
|
||||
model.train(1, dataset1, dataset_sink_mode=False)
|
||||
contend = read_validateir_file('./test_amp_o0')
|
||||
castnum = re.findall("Cast", contend)
|
||||
castnum = re.findall(r"Cast\(", contend)
|
||||
assert len(castnum) == 5
|
||||
clean_all_ir_files('./test_amp_o0')
|
||||
model.predict(Tensor(input_data))
|
||||
contend = read_validateir_file('./test_amp_o0')
|
||||
castnum = re.findall("Cast", contend)
|
||||
castnum = re.findall(r"Cast\(", contend)
|
||||
assert len(castnum) == 11
|
||||
clean_all_ir_files('./test_amp_o0')
|
||||
|
||||
|
@ -163,7 +163,7 @@ def test_sit_auto_mix_precision_model_o2():
|
|||
model = Model(net, loss, opt, amp_level="O2")
|
||||
model.train(1, dataset1, dataset_sink_mode=False)
|
||||
contend = read_validateir_file('./test_amp_o2')
|
||||
castnum = re.findall("Cast", contend)
|
||||
castnum = re.findall(r"Cast\(", contend)
|
||||
assert len(castnum) == 14
|
||||
clean_all_ir_files('./test_amp_o2')
|
||||
out_graph = model.predict(Tensor(input_data))
|
||||
|
|
Loading…
Reference in New Issue