diff --git a/mindspore/ccsrc/frontend/operator/composite/vmap.cc b/mindspore/ccsrc/frontend/operator/composite/vmap.cc index fcbd0d44293..fea00103563 100644 --- a/mindspore/ccsrc/frontend/operator/composite/vmap.cc +++ b/mindspore/ccsrc/frontend/operator/composite/vmap.cc @@ -59,6 +59,7 @@ void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, in auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs); const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none"); auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn); + MS_EXCEPTION_IF_NULL(bind_all_none_fg); auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode}); if (bind) { auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode}); @@ -123,6 +124,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement( if (inputs_abstract_elements_end->isa()) { const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis"); auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis); + MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg); (void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg)); (void)out_cnode_inputs.emplace_back(value_cnode); (void)out_cnode_inputs.emplace_back(NewValueNode(static_cast(0))); @@ -135,6 +137,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement( auto dim_cnode = fg_->NewCNode(dim_cnode_inputs); const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis"); auto move_axis_fg = parse::ParsePythonCode(move_axis); + MS_EXCEPTION_IF_NULL(move_axis_fg); (void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg)); (void)out_cnode_inputs.emplace_back(value_cnode); (void)out_cnode_inputs.emplace_back(dim_cnode); @@ -224,6 +227,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu if (src_abstract->isa() && !dst_abstract->isa()) { const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis"); auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis); + MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg); out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size}); } else if (!src_abstract->isa() && dst_abstract->isa()) { MS_LOG(EXCEPTION) << "It is invalid that source is not None and dst is None."; @@ -232,6 +236,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu } else { const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis"); auto move_axis_fg = parse::ParsePythonCode(move_axis); + MS_EXCEPTION_IF_NULL(move_axis_fg); out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode}); } (void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode); @@ -249,6 +254,10 @@ FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args auto inputs_abstract = args_spec_list[kIndex0]; auto out_axes_abstract = args_spec_list[kIndex1]; auto axis_size_abstract = args_spec_list[kIndex2]; + MS_EXCEPTION_IF_NULL(inputs_abstract); + MS_EXCEPTION_IF_NULL(out_axes_abstract); + MS_EXCEPTION_IF_NULL(axis_size_abstract); + if (!inputs_abstract->isa()) { MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got " << inputs_abstract->ToString() << "."; @@ -402,6 +411,7 @@ CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &tupl fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)}); const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack"); auto unstack_fg_ = parse::ParsePythonCode(unstack_fn); + MS_EXCEPTION_IF_NULL(unstack_fg_); auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode}); for (int64_t m = 0; m < axis_size_; ++m) { auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)}); @@ -506,6 +516,7 @@ FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_ const py::function vmap_general_output_process_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process"); auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn); + MS_EXCEPTION_IF_NULL(vmap_general_output_process_fg_); auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode}); fg_->set_output(vmap_general_output); return fg_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/vmap_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/vmap_eliminate.cc index 8b21f44f861..be3744d599a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/vmap_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/vmap_eliminate.cc @@ -484,8 +484,7 @@ AnfNodePtr ExpandVmapPrimitive(const AnfNodePtr &vnode, const pipeline::Resource } AnfNodePtr prim_vmap_rule = GetVmapRule(prim, resource, axis_size); if (prim_vmap_rule == nullptr) { - MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed. NodeInfo: " - << trace::GetDebugInfo(prim_vmap_rule->debug_info()) << "."; + MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed."; } return prim_vmap_rule; }