Add primal_attr to link between forward and backward node attr

This commit is contained in:
l00591931 2021-04-28 14:28:58 +08:00
parent 2de5bbadb5
commit befc7a9dea
9 changed files with 95 additions and 24 deletions

View File

@ -341,7 +341,6 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
return;
}
if (op->attrs().empty()) {
gsub->buffer << std::endl;
return;
}
@ -360,6 +359,32 @@ void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &g
}
}
gsub->buffer << "}";
}
void DumpCNodePrimalAttrs(const CNodePtr &op, const std::shared_ptr<SubGraphIRInfo> &gsub) {
if (op == nullptr || gsub == nullptr) {
return;
}
if (op->primal_attrs().empty()) {
gsub->buffer << std::endl;
return;
}
auto primal_attrs = op->primal_attrs();
gsub->buffer << " cnode_primal_attrs: {";
int i = 0;
for (const auto &attr : primal_attrs) {
if (i++ != 0) {
gsub->buffer << ", ";
}
gsub->buffer << attr.first << ": ";
if (attr.second == nullptr) {
gsub->buffer << "null";
} else {
gsub->buffer << attr.second->ToString();
}
}
gsub->buffer << "}";
gsub->buffer << std::endl;
}
@ -415,6 +440,9 @@ void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap<Anf
// print cnode attrs
DumpCNodeAttrs(nd, gsub);
// print cnode primal attrs
DumpCNodePrimalAttrs(nd, gsub);
// print parallel info
DumpParallelInfo(nd, gsub);

View File

@ -965,7 +965,6 @@ void DFunctor::EliminatePrimalGraph() {
// Replace primal graph with k graph.
auto k_vnode = NewValueNode(k_graph_);
auto primal_abs = primal_user->abstract();
primal_user->set_input(0, k_vnode);
primal_user->set_abstract(j_user->abstract());
@ -984,7 +983,7 @@ void DFunctor::EliminatePrimalGraph() {
auto idx0 = NewValueNode(SizeToLong(0));
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
getitem0->set_abstract(primal_abs);
getitem0->CloneCNodeInfo(primal_user);
(void)manager->Replace(primal_user, getitem0);
}
}

View File

@ -230,7 +230,7 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, cons
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
out_value = outer->NewCNode(transf_args);
if constexpr (std::is_same<T, PrimitivePtr>::value) {
out_value->CloneUserData(cnode);
out_value->CloneCNodeInfo(cnode);
}
} else {
out_value = outer->NewCNode(transf_args);

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -29,6 +29,7 @@
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "utils/utils.h"
#include "utils/symbolic.h"
#include "utils/primitive_utils.h"
#include "utils/ms_context.h"
@ -178,6 +179,17 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
}
}
}
if (cnode != nullptr) {
Manage({cnode->func_graph(), bprop_fg}, false);
const auto &bprop_fg_cnodes = bprop_fg->GetOrderedCnodes();
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
for (auto &bprop_fg_cnode : bprop_fg_cnodes) {
bprop_fg_cnode->set_primal_attrs(cnode->primal_attrs());
bprop_fg_cnode->AddPrimalAttr(kPrimalAttrForwardNodeName, MakeValue(forward_node_primal_attr));
}
}
AdjustForAutoMonad(prim, bprop_fg);
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode);
if (expanded_fg == nullptr) {

View File

@ -41,6 +41,7 @@
#include "abstract/param_validator.h"
#include "utils/ms_utils.h"
#include "utils/shape_utils.h"
#include "utils/parallel_node_check.h"
namespace mindspore {
namespace abstract {
@ -86,16 +87,22 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
}
ScopeGuard scope_guard(scope);
AnfNodePtr new_cnode = nullptr;
AnfNodePtr new_node = nullptr;
if (bound_node() != nullptr) {
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
} else {
new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list,
args_inputs);
}
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
if (out_node->isa<CNode>()) {
auto out_cnode = out_node->cast<CNodePtr>();
auto new_cnode = new_node->cast<CNodePtr>();
new_cnode->CloneCNodeInfo(out_cnode);
}
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context());
return engine->ForwardConfig(out_conf, fn_conf);
}
@ -256,6 +263,10 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context());
if (new_node->isa<CNode>()) {
auto new_cnode = new_node->cast<CNodePtr>();
new_cnode->CloneCNodeInfo(out_node);
}
return engine->ForwardConfig(out_conf, fn_conf);
}

View File

@ -427,6 +427,9 @@ constexpr auto kAttrRecursiveEnd = "recursive_end";
constexpr auto kAttrRecursive = "recursive";
constexpr auto kAttrMultiCallEnd = "multicall_end";
// primal attr key name
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";
constexpr auto kValueTargetOther = "target_other";

View File

@ -281,9 +281,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;
}
attrs_.insert(attrs.cbegin(), attrs.cend());
}
void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
@ -294,6 +292,31 @@ class CNode : public AnfNode, public EffectInfoHolder {
}
bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }
ssize_t input_tensor_num() const { return input_tensor_num_; }
const std::unordered_map<std::string, ValuePtr> &primal_attrs() const { return primal_attrs_; }
void set_primal_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
primal_attrs_.insert(attrs.cbegin(), attrs.cend());
}
void AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; }
void ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); }
ValuePtr GetPrimalAttr(const std::string &name) const {
auto iter = primal_attrs_.find(name);
return iter == primal_attrs_.cend() ? nullptr : iter->second;
}
bool HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != attrs_.cend(); }
void CloneCNodeInfo(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
set_abstract(node->abstract());
set_forward(node->forward().first, node->forward().second);
set_inputs_value(node->inputs_value());
set_attrs(node->attrs());
set_primal_attrs(node->primal_attrs());
set_load_flag(node->get_load_flag());
CloneUserData(node);
set_kernel_info(node->kernel_info_ptr());
}
void set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; }
// Is effect have been handled.
@ -314,6 +337,7 @@ class CNode : public AnfNode, public EffectInfoHolder {
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
std::pair<ValuePtr, std::string> output_value_;
std::unordered_map<std::string, ValuePtr> attrs_;
std::unordered_map<std::string, ValuePtr> primal_attrs_;
ssize_t input_tensor_num_ = -1;
};

View File

@ -87,18 +87,12 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
TraceGuard trace_guard(node->debug_info(), relation_);
CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
auto old_node = node->cast<CNodePtr>();
new_node->set_abstract(old_node->abstract());
new_node->set_forward(old_node->forward().first, old_node->forward().second);
new_node->set_inputs_value(old_node->inputs_value());
new_node->set_attrs(old_node->attrs());
new_node->set_load_flag(old_node->get_load_flag());
new_node->CloneCNodeInfo(old_node);
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_node->set_scope(scope);
new_node->CloneUserData(old_node);
if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) {
new_node->set_fullname_with_scope(old_node->fullname_with_scope());
}
new_node->set_kernel_info(old_node->kernel_info_ptr());
repl_node_[old_node] = new_node;
nodes_.emplace_back(old_node, new_node);
}

View File

@ -127,12 +127,12 @@ def test_sit_auto_mix_precision_model_o0():
model = Model(net, loss, opt, amp_level="O0")
model.train(1, dataset1, dataset_sink_mode=False)
contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend)
castnum = re.findall(r"Cast\(", contend)
assert len(castnum) == 5
clean_all_ir_files('./test_amp_o0')
model.predict(Tensor(input_data))
contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend)
castnum = re.findall(r"Cast\(", contend)
assert len(castnum) == 11
clean_all_ir_files('./test_amp_o0')
@ -163,7 +163,7 @@ def test_sit_auto_mix_precision_model_o2():
model = Model(net, loss, opt, amp_level="O2")
model.train(1, dataset1, dataset_sink_mode=False)
contend = read_validateir_file('./test_amp_o2')
castnum = re.findall("Cast", contend)
castnum = re.findall(r"Cast\(", contend)
assert len(castnum) == 14
clean_all_ir_files('./test_amp_o2')
out_graph = model.predict(Tensor(input_data))