!17395 Change location to add primal information for back-propagation node.

From: @liangzhibo
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-06-01 14:38:11 +08:00 committed by Gitee
commit ad165deb15
3 changed files with 34 additions and 28 deletions

View File

@ -33,6 +33,7 @@
#include "frontend/optimizer/ad/adjoint.h"
#include "frontend/operator/ops.h"
#include "debug/trace.h"
#include "utils/utils.h"
namespace mindspore {
namespace ad {
@ -142,8 +143,7 @@ class KPrim {
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
private:
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
@ -152,7 +152,8 @@ class KPrim {
// Refer the comment in KUserDefinedCellBprop.
template <typename T>
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const FuncGraphPtr &current_primal_fg,
const CNodePtr &cnode);
const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg);
void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
const PrimitivePtr &primitive, const FuncGraphPtr &outer,
@ -169,15 +170,20 @@ class KPrim {
template <typename T>
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg,
const CNodePtr &cnode) {
const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
MS_EXCEPTION_IF_NULL(primal);
MS_EXCEPTION_IF_NULL(bprop_fg);
CheckBprop(bprop_fg, primal->ToString());
auto debug_info = std::make_shared<GraphDebugInfo>();
debug_info->set_name(primal->ToString());
auto cloned_bprop_fg = BasicClone(bprop_fg);
FuncGraphPtr cloned_bprop_fg;
{
PrimalAttrGuard primal_attr_guard(primal_attrs);
PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
cloned_bprop_fg = BasicClone(bprop_fg);
}
MS_EXCEPTION_IF_NULL(cloned_bprop_fg);
cloned_bprop_fg->debug_info()->set_name("");

View File

@ -40,8 +40,7 @@ namespace mindspore {
namespace ad {
KPrim g_k_prims;
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
// Set a child scope named "grad'PrimitiveName'" for the bprop function,
// and add "Gradients" to the front.
static const std::string gradients_scope = "Gradients/";
@ -50,8 +49,6 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const std::unordered_map<
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
PrimalAttrGuard primal_attr_guard(primal_attrs);
PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
py::function fn;
if (prim->is_base()) {
@ -87,7 +84,7 @@ FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
}
if (bprop_fg == nullptr) {
bprop_fg = GetBprop(prim, {}, {});
bprop_fg = GetBprop(prim);
if (bprop_fg != nullptr) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
@ -190,20 +187,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
}
if (bprop_fg == nullptr) {
std::unordered_map<std::string, ValuePtr> primal_attrs;
std::vector<NodeDebugInfoPtr> primal_debug_infos;
if (resources != nullptr) {
auto manager = resources->manager();
auto &users = manager->node_users()[value_node];
for (auto user_iter = users.begin(); user_iter != users.end(); user_iter++) {
primal_debug_infos.push_back(user_iter->first->debug_info());
}
}
if (cnode != nullptr) {
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
}
bprop_fg = GetBprop(prim, primal_attrs, primal_debug_infos);
bprop_fg = GetBprop(prim);
if (bprop_fg != nullptr) {
// Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg;
@ -214,7 +198,21 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
}
AdjustForAutoMonad(prim, bprop_fg);
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode);
std::unordered_map<std::string, ValuePtr> primal_attrs;
std::vector<NodeDebugInfoPtr> primal_debug_infos;
if (resources != nullptr) {
auto manager = resources->manager();
auto &users = manager->node_users()[value_node];
for (auto user_iter = users.begin(); user_iter != users.end(); user_iter++) {
primal_debug_infos.push_back(user_iter->first->debug_info());
}
}
if (cnode != nullptr) {
primal_attrs = cnode->primal_attrs();
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
}
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
if (expanded_fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
<< " prim bprop function to J expanded func graph. NodeInfo: "
@ -376,7 +374,7 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const Fu
// primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
// current_primal_fg is specalized and AutoMoaded primal_fg;
auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr);
auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
if (expanded_fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString()
<< " Cell bprop function to K expanded func graph. NodeInfo: "

View File

@ -312,7 +312,9 @@ class CNode : public AnfNode, public EffectInfoHolder {
std::vector<NodeDebugInfoPtr> primal_debug_infos() { return primal_debug_infos_; }
void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) { primal_debug_infos_ = debug_infos; }
void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) {
primal_debug_infos_.insert(primal_debug_infos_.end(), debug_infos.begin(), debug_infos.end());
}
void AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info) {
if (std::find(primal_debug_infos_.begin(), primal_debug_infos_.end(), debug_info) != primal_debug_infos_.end()) {