forked from mindspore-Ecosystem/mindspore
!28390 Fix jvp parameter issue: Specialize the inner J parameter.
Merge pull request !28390 from 张清华/opt_jvp_parameter
This commit is contained in:
commit
17d8c5d2ef
|
@ -663,11 +663,12 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
|
||||
const std::vector<AnfNodePtr> &weight_args) {
|
||||
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
||||
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
k_child->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, true);
|
||||
|
||||
AnfNodePtr weights_node = nullptr;
|
||||
AnfNodePtr position_node = nullptr;
|
||||
|
@ -681,7 +682,7 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh
|
|||
}
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(k);
|
||||
inputs.push_back(j);
|
||||
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
|
||||
inputs.push_back(k_child->add_parameter());
|
||||
}
|
||||
|
|
|
@ -150,7 +150,7 @@ class GradOperation : public MetaFuncGraph {
|
|||
~GradOperation() override = default;
|
||||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
|
||||
const std::vector<AnfNodePtr> &weight_args = {});
|
||||
|
||||
|
|
|
@ -127,6 +127,10 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr
|
|||
for (auto &j_node : todo) {
|
||||
auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer);
|
||||
manager->Replace(j_node, expanded_j);
|
||||
if (j_node->func_graph()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
|
||||
MS_LOG(DEBUG) << j_node->func_graph()->ToString() << " has FUNC_GRAPH_FLAG_K_GRAPH flag.";
|
||||
j_node->func_graph()->set_flag(FUNC_GRAPH_FLAG_K_GRAPH, false);
|
||||
}
|
||||
change = true;
|
||||
}
|
||||
return change;
|
||||
|
|
|
@ -311,7 +311,9 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
partial_abstract->set_node(new_node);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
|
||||
MS_LOG(DEBUG) << "Set new_node: " << new_node->DebugString() << ", abstract as: " << new_node->abstract()->ToString()
|
||||
<< ", func_graph_: " << func_graph_->ToString()
|
||||
<< ", specialized_func_graph_: " << specialized_func_graph_->ToString();
|
||||
|
||||
if (node->isa<CNode>()) {
|
||||
auto attrs = conf->ObtainEvalResult()->attribute();
|
||||
|
@ -326,7 +328,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
AbstractBasePtr ival = GetEvaluatedValue(iconf);
|
||||
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
|
||||
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
|
||||
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
|
||||
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs, node);
|
||||
if (replace_node == nullptr) {
|
||||
replace_node = BuildReplacedNode(iconf);
|
||||
replace_node->set_abstract(ival);
|
||||
|
@ -834,7 +836,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c
|
|||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
|
||||
const AttrValueMapPtr &attrs) {
|
||||
const AttrValueMapPtr &attrs, const AnfNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(ival);
|
||||
|
||||
|
@ -866,6 +868,11 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
|
|||
if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
|
||||
(IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
|
||||
return BuildValueNode(value, ival);
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa<Parameter>() &&
|
||||
!value->cast<FuncGraphPtr>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
|
||||
// Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph.
|
||||
MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString();
|
||||
return BuildValueNode(value, ival);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
|
||||
// Build a value node if ival is constant and not any-value
|
||||
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
|
||||
const AttrValueMapPtr &attrs);
|
||||
const AttrValueMapPtr &attrs, const AnfNodePtr &cnode = nullptr);
|
||||
// Build a replaceable node for iconf->node; it may be a replicated forwarded CNode in static analysis or just a
|
||||
// replicated node.
|
||||
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
|
||||
|
|
|
@ -84,6 +84,7 @@ const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
|
|||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
|
||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||
const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";
|
||||
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
||||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||
const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
|
||||
|
|
Loading…
Reference in New Issue