From 34250c29b5f979ff87de87fc79150ad66655ffa5 Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Wed, 29 Dec 2021 20:16:13 +0800 Subject: [PATCH] Fix jvp parameter issue: Specialize the inner J parameter. --- .../ccsrc/frontend/operator/composite/composite.cc | 5 +++-- .../ccsrc/frontend/operator/composite/composite.h | 2 +- .../frontend/optimizer/irpass/gradient_eliminate.cc | 4 ++++ .../jit/static_analysis/program_specialize.cc | 13 ++++++++++--- .../jit/static_analysis/program_specialize.h | 2 +- mindspore/core/ir/func_graph.h | 1 + 6 files changed, 20 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index eba446de4c4..159ffcd67af 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -663,11 +663,12 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_ } } -FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position, +FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position, const std::vector &forward_graph_params, bool enable_tuple_grad, const std::vector &weight_args) { FuncGraphPtr k_child = std::make_shared(); k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true); + k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true); AnfNodePtr weights_node = nullptr; AnfNodePtr position_node = nullptr; @@ -681,7 +682,7 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh } std::vector inputs; - inputs.push_back(k); + inputs.push_back(j); for (size_t i = 0; i < forward_graph_params.size(); ++i) { inputs.push_back(k_child->add_parameter()); } diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index 8281666af6a..3623b08eaaf 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -150,7 +150,7 @@ class GradOperation : public MetaFuncGraph { ~GradOperation() override = default; MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) - FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position, + FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position, const std::vector &forward_graph_params, bool enable_tuple_grad, const std::vector &weight_args = {}); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc index 16d7a6e0b27..eb1ce24141c 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc @@ -127,6 +127,10 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr for (auto &j_node : todo) { auto expanded_j = internal::ExpandJ(j_node->input(1)->cast(), optimizer); manager->Replace(j_node, expanded_j); + if (j_node->func_graph()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) { + MS_LOG(DEBUG) << j_node->func_graph()->ToString() << " has FUNC_GRAPH_FLAG_K_GRAPH flag."; + j_node->func_graph()->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, false); + } change = true; } return change; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index cf7cf6b0306..3fe253a825e 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -310,7 +310,9 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { partial_abstract->set_node(new_node); } } - MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); + MS_LOG(DEBUG) << "Set new_node: " << new_node->DebugString() << ", abstract as: " << new_node->abstract()->ToString() + << ", func_graph_: " << func_graph_->ToString() + << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); if (node->isa()) { auto attrs = conf->ObtainEvalResult()->attribute(); @@ -325,7 +327,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { AbstractBasePtr ival = GetEvaluatedValue(iconf); // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. - AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); + AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs, node); if (replace_node == nullptr) { replace_node = BuildReplacedNode(iconf); replace_node->set_abstract(ival); @@ -833,7 +835,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c } AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, - const AttrValueMapPtr &attrs) { + const AttrValueMapPtr &attrs, const AnfNodePtr &cnode) { MS_EXCEPTION_IF_NULL(origin_node); MS_EXCEPTION_IF_NULL(ival); @@ -865,6 +867,11 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin if (!value->isa() || value->cast()->parent() == nullptr || (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { return BuildValueNode(value, ival); + } else if (IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa() && + !value->cast()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) { + // Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph. + MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString(); + return BuildValueNode(value, ival); } else { return nullptr; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h index dcbdce9d9ce..7cd2768236b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -120,7 +120,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_thisnode; it may be a replicated forwarded CNode in static analysis or just a // replicated node. AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 7ea6e09183c..1fff40dc419 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -82,6 +82,7 @@ const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block"; const char FUNC_GRAPH_FLAG_CORE[] = "core"; +const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph"; const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";