!44923 vmap cleancode

Merge pull request !44923 from Erpim/1101
This commit is contained in:
i-robot 2022-11-01 14:59:48 +00:00 committed by Gitee
commit 157179dd32
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 12 additions and 2 deletions

View File

@ -59,6 +59,7 @@ void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, in
auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs);
const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none");
auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn);
MS_EXCEPTION_IF_NULL(bind_all_none_fg);
auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode});
if (bind) {
auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode});
@ -123,6 +124,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
if (inputs_abstract_elements_end->isa<abstract::AbstractNone>()) {
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
(void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg));
(void)out_cnode_inputs.emplace_back(value_cnode);
(void)out_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
@ -135,6 +137,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
auto move_axis_fg = parse::ParsePythonCode(move_axis);
MS_EXCEPTION_IF_NULL(move_axis_fg);
(void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg));
(void)out_cnode_inputs.emplace_back(value_cnode);
(void)out_cnode_inputs.emplace_back(dim_cnode);
@ -224,6 +227,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu
if (src_abstract->isa<abstract::AbstractNone>() && !dst_abstract->isa<abstract::AbstractNone>()) {
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size});
} else if (!src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
MS_LOG(EXCEPTION) << "It is invalid that source is not None and dst is None.";
@ -232,6 +236,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu
} else {
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
auto move_axis_fg = parse::ParsePythonCode(move_axis);
MS_EXCEPTION_IF_NULL(move_axis_fg);
out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode});
}
(void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode);
@ -249,6 +254,10 @@ FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args
auto inputs_abstract = args_spec_list[kIndex0];
auto out_axes_abstract = args_spec_list[kIndex1];
auto axis_size_abstract = args_spec_list[kIndex2];
MS_EXCEPTION_IF_NULL(inputs_abstract);
MS_EXCEPTION_IF_NULL(out_axes_abstract);
MS_EXCEPTION_IF_NULL(axis_size_abstract);
if (!inputs_abstract->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got "
<< inputs_abstract->ToString() << ".";
@ -402,6 +411,7 @@ CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &tupl
fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)});
const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack");
auto unstack_fg_ = parse::ParsePythonCode(unstack_fn);
MS_EXCEPTION_IF_NULL(unstack_fg_);
auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode});
for (int64_t m = 0; m < axis_size_; ++m) {
auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)});
@ -506,6 +516,7 @@ FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_
const py::function vmap_general_output_process_fn =
python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process");
auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn);
MS_EXCEPTION_IF_NULL(vmap_general_output_process_fg_);
auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode});
fg_->set_output(vmap_general_output);
return fg_;

View File

@ -484,8 +484,7 @@ AnfNodePtr ExpandVmapPrimitive(const AnfNodePtr &vnode, const pipeline::Resource
}
AnfNodePtr prim_vmap_rule = GetVmapRule(prim, resource, axis_size);
if (prim_vmap_rule == nullptr) {
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed. NodeInfo: "
<< trace::GetDebugInfo(prim_vmap_rule->debug_info()) << ".";
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed.";
}
return prim_vmap_rule;
}