!28390 Fix jvp parameter issue: Specialize the inner J parameter.

Merge pull request !28390 from 张清华/opt_jvp_parameter
This commit is contained in:
i-robot 2021-12-31 06:24:47 +00:00 committed by Gitee
commit 17d8c5d2ef
6 changed files with 20 additions and 7 deletions

View File

@ -663,11 +663,12 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
}
}
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position,
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args) {
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
AnfNodePtr weights_node = nullptr;
AnfNodePtr position_node = nullptr;
@ -681,7 +682,7 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh
}
std::vector<AnfNodePtr> inputs;
inputs.push_back(k);
inputs.push_back(j);
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
inputs.push_back(k_child->add_parameter());
}

View File

@ -150,7 +150,7 @@ class GradOperation : public MetaFuncGraph {
~GradOperation() override = default;
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position,
FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &weight_args = {});

View File

@ -127,6 +127,10 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
for (auto &j_node : todo) {
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer);
manager->Replace(j_node, expanded_j);
if (j_node->func_graph()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
MS_LOG(DEBUG) << j_node->func_graph()->ToString() << " has FUNC_GRAPH_FLAG_K_GRAPH flag.";
j_node->func_graph()->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, false);
}
change = true;
}
return change;

View File

@ -311,7 +311,9 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
partial_abstract->set_node(new_node);
}
}
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
MS_LOG(DEBUG) << "Set new_node: " << new_node->DebugString() << ", abstract as: " << new_node->abstract()->ToString()
<< ", func_graph_: " << func_graph_->ToString()
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();
if (node->isa<CNode>()) {
auto attrs = conf->ObtainEvalResult()->attribute();
@ -326,7 +328,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AbstractBasePtr ival = GetEvaluatedValue(iconf);
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs, node);
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(iconf);
replace_node->set_abstract(ival);
@ -834,7 +836,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c
}
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs) {
const AttrValueMapPtr &attrs, const AnfNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(ival);
@ -866,6 +868,11 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
return BuildValueNode(value, ival);
} else if (IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa<Parameter>() &&
!value->cast<FuncGraphPtr>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
// Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph.
MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString();
return BuildValueNode(value, ival);
} else {
return nullptr;
}

View File

@ -120,7 +120,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// Build a value node if ival is constant and not any-value
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs);
const AttrValueMapPtr &attrs, const AnfNodePtr &cnode = nullptr);
// Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a
// replicated node.
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);

View File

@ -84,6 +84,7 @@ const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";