forked from mindspore-Ecosystem/mindspore
!17395 Change location to add primal information for back-propagation node.
From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
ad165deb15
|
@ -33,6 +33,7 @@
|
|||
#include "frontend/optimizer/ad/adjoint.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
|
@ -142,8 +143,7 @@ class KPrim {
|
|||
FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim);
|
||||
|
||||
private:
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
|
||||
const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
|
||||
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
|
||||
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||
|
@ -152,7 +152,8 @@ class KPrim {
|
|||
// Refer the comment in KUserDefinedCellBprop.
|
||||
template <typename T>
|
||||
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const FuncGraphPtr ¤t_primal_fg,
|
||||
const CNodePtr &cnode);
|
||||
const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
|
||||
const std::vector<NodeDebugInfoPtr> &primal_debug_infos);
|
||||
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg);
|
||||
void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
|
||||
const PrimitivePtr &primitive, const FuncGraphPtr &outer,
|
||||
|
@ -169,15 +170,20 @@ class KPrim {
|
|||
|
||||
template <typename T>
|
||||
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg,
|
||||
const CNodePtr &cnode) {
|
||||
const CNodePtr &cnode, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
|
||||
const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
|
||||
MS_EXCEPTION_IF_NULL(primal);
|
||||
MS_EXCEPTION_IF_NULL(bprop_fg);
|
||||
CheckBprop(bprop_fg, primal->ToString());
|
||||
|
||||
auto debug_info = std::make_shared<GraphDebugInfo>();
|
||||
debug_info->set_name(primal->ToString());
|
||||
|
||||
auto cloned_bprop_fg = BasicClone(bprop_fg);
|
||||
FuncGraphPtr cloned_bprop_fg;
|
||||
{
|
||||
PrimalAttrGuard primal_attr_guard(primal_attrs);
|
||||
PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
|
||||
cloned_bprop_fg = BasicClone(bprop_fg);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(cloned_bprop_fg);
|
||||
|
||||
cloned_bprop_fg->debug_info()->set_name("");
|
||||
|
|
|
@ -40,8 +40,7 @@ namespace mindspore {
|
|||
namespace ad {
|
||||
KPrim g_k_prims;
|
||||
|
||||
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const std::unordered_map<std::string, ValuePtr> &primal_attrs,
|
||||
const std::vector<NodeDebugInfoPtr> &primal_debug_infos) {
|
||||
FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
||||
// Set a child scope named "grad'PrimitiveName'" for the bprop function,
|
||||
// and add "Gradients" to the front.
|
||||
static const std::string gradients_scope = "Gradients/";
|
||||
|
@ -50,8 +49,6 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const std::unordered_map<
|
|||
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
|
||||
grad_op_child_scope_prefix + prim->name());
|
||||
ScopeGuard scope_guard(scope);
|
||||
PrimalAttrGuard primal_attr_guard(primal_attrs);
|
||||
PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos);
|
||||
|
||||
py::function fn;
|
||||
if (prim->is_base()) {
|
||||
|
@ -87,7 +84,7 @@ FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
|
|||
}
|
||||
|
||||
if (bprop_fg == nullptr) {
|
||||
bprop_fg = GetBprop(prim, {}, {});
|
||||
bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg != nullptr) {
|
||||
// Set bprop_g graph cache
|
||||
bprop_registry_[prim] = bprop_fg;
|
||||
|
@ -190,20 +187,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
}
|
||||
|
||||
if (bprop_fg == nullptr) {
|
||||
std::unordered_map<std::string, ValuePtr> primal_attrs;
|
||||
std::vector<NodeDebugInfoPtr> primal_debug_infos;
|
||||
if (resources != nullptr) {
|
||||
auto manager = resources->manager();
|
||||
auto &users = manager->node_users()[value_node];
|
||||
for (auto user_iter = users.begin(); user_iter != users.end(); user_iter++) {
|
||||
primal_debug_infos.push_back(user_iter->first->debug_info());
|
||||
}
|
||||
}
|
||||
if (cnode != nullptr) {
|
||||
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
|
||||
primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
|
||||
}
|
||||
bprop_fg = GetBprop(prim, primal_attrs, primal_debug_infos);
|
||||
bprop_fg = GetBprop(prim);
|
||||
if (bprop_fg != nullptr) {
|
||||
// Set bprop_g graph cache
|
||||
bprop_registry_[prim] = bprop_fg;
|
||||
|
@ -214,7 +198,21 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
}
|
||||
|
||||
AdjustForAutoMonad(prim, bprop_fg);
|
||||
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode);
|
||||
std::unordered_map<std::string, ValuePtr> primal_attrs;
|
||||
std::vector<NodeDebugInfoPtr> primal_debug_infos;
|
||||
if (resources != nullptr) {
|
||||
auto manager = resources->manager();
|
||||
auto &users = manager->node_users()[value_node];
|
||||
for (auto user_iter = users.begin(); user_iter != users.end(); user_iter++) {
|
||||
primal_debug_infos.push_back(user_iter->first->debug_info());
|
||||
}
|
||||
}
|
||||
if (cnode != nullptr) {
|
||||
primal_attrs = cnode->primal_attrs();
|
||||
const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
|
||||
primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
|
||||
}
|
||||
auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
|
||||
if (expanded_fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
|
||||
<< " prim bprop function to J expanded func graph. NodeInfo: "
|
||||
|
@ -376,7 +374,7 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const Fu
|
|||
// primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
|
||||
// current_primal_fg is specalized and AutoMoaded primal_fg;
|
||||
auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
|
||||
auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr);
|
||||
auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
|
||||
if (expanded_fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString()
|
||||
<< " Cell bprop function to K expanded func graph. NodeInfo: "
|
||||
|
|
|
@ -312,7 +312,9 @@ class CNode : public AnfNode, public EffectInfoHolder {
|
|||
|
||||
std::vector<NodeDebugInfoPtr> primal_debug_infos() { return primal_debug_infos_; }
|
||||
|
||||
void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) { primal_debug_infos_ = debug_infos; }
|
||||
void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) {
|
||||
primal_debug_infos_.insert(primal_debug_infos_.end(), debug_infos.begin(), debug_infos.end());
|
||||
}
|
||||
|
||||
void AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info) {
|
||||
if (std::find(primal_debug_infos_.begin(), primal_debug_infos_.end(), debug_info) != primal_debug_infos_.end()) {
|
||||
|
|
Loading…
Reference in New Issue