commit
157179dd32
|
@ -59,6 +59,7 @@ void GenerateFuncGraphAllNone(const FuncGraphPtr &fg, const AnfNodePtr &prim, in
|
|||
auto prim_output_cnode = fg->NewCNode(prim_output_cnode_inputs);
|
||||
const py::function bind_all_none_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_bind_all_none");
|
||||
auto bind_all_none_fg = parse::ParsePythonCode(bind_all_none_fn);
|
||||
MS_EXCEPTION_IF_NULL(bind_all_none_fg);
|
||||
auto bind_all_none_cnode = fg->NewCNode({NewValueNode(bind_all_none_fg), prim_output_cnode});
|
||||
if (bind) {
|
||||
auto output_cnode = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(true), bind_all_none_cnode});
|
||||
|
@ -123,6 +124,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
|
|||
if (inputs_abstract_elements_end->isa<abstract::AbstractNone>()) {
|
||||
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
|
||||
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
|
||||
MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
|
||||
(void)out_cnode_inputs.emplace_back(NewValueNode(broadcast_by_axis_fg));
|
||||
(void)out_cnode_inputs.emplace_back(value_cnode);
|
||||
(void)out_cnode_inputs.emplace_back(NewValueNode(static_cast<int64_t>(0)));
|
||||
|
@ -135,6 +137,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerSingleElement(
|
|||
auto dim_cnode = fg_->NewCNode(dim_cnode_inputs);
|
||||
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
|
||||
auto move_axis_fg = parse::ParsePythonCode(move_axis);
|
||||
MS_EXCEPTION_IF_NULL(move_axis_fg);
|
||||
(void)out_cnode_inputs.emplace_back(NewValueNode(move_axis_fg));
|
||||
(void)out_cnode_inputs.emplace_back(value_cnode);
|
||||
(void)out_cnode_inputs.emplace_back(dim_cnode);
|
||||
|
@ -224,6 +227,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu
|
|||
if (src_abstract->isa<abstract::AbstractNone>() && !dst_abstract->isa<abstract::AbstractNone>()) {
|
||||
const py::function broadcast_by_axis = python_adapter::GetPyFn(kVmapFunctionModelName, "_broadcast_by_axis");
|
||||
auto broadcast_by_axis_fg = parse::ParsePythonCode(broadcast_by_axis);
|
||||
MS_EXCEPTION_IF_NULL(broadcast_by_axis_fg);
|
||||
out_cnode = fg_->NewCNode({NewValueNode(broadcast_by_axis_fg), val_cnode, dst_cnode, axis_size});
|
||||
} else if (!src_abstract->isa<abstract::AbstractNone>() && dst_abstract->isa<abstract::AbstractNone>()) {
|
||||
MS_LOG(EXCEPTION) << "It is invalid that source is not None and dst is None.";
|
||||
|
@ -232,6 +236,7 @@ CNodePtr VmapMatchOutAxis::GenerateFuncGraphInnerAllTuple(const AnfNodePtr &inpu
|
|||
} else {
|
||||
const py::function move_axis = python_adapter::GetPyFn(kNumpyModelName, "moveaxis");
|
||||
auto move_axis_fg = parse::ParsePythonCode(move_axis);
|
||||
MS_EXCEPTION_IF_NULL(move_axis_fg);
|
||||
out_cnode = fg_->NewCNode({NewValueNode(move_axis_fg), val_cnode, src_cnode, dst_cnode});
|
||||
}
|
||||
(void)vals_out_tuple_cnode_inputs.emplace_back(out_cnode);
|
||||
|
@ -249,6 +254,10 @@ FuncGraphPtr VmapMatchOutAxis::GenerateFuncGraph(const AbstractBasePtrList &args
|
|||
auto inputs_abstract = args_spec_list[kIndex0];
|
||||
auto out_axes_abstract = args_spec_list[kIndex1];
|
||||
auto axis_size_abstract = args_spec_list[kIndex2];
|
||||
MS_EXCEPTION_IF_NULL(inputs_abstract);
|
||||
MS_EXCEPTION_IF_NULL(out_axes_abstract);
|
||||
MS_EXCEPTION_IF_NULL(axis_size_abstract);
|
||||
|
||||
if (!inputs_abstract->isa<abstract::AbstractTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "The first input to VmapMatchOutAxis is vmap_inputs and should be a tuple but got "
|
||||
<< inputs_abstract->ToString() << ".";
|
||||
|
@ -402,6 +411,7 @@ CNodeInpusList VmapGeneralRule::ConstructMapInput(const InputsAbstractList &tupl
|
|||
fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cur_arg_node, NewValueNode(kDimIndex)});
|
||||
const py::function unstack_fn = python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_unstack");
|
||||
auto unstack_fg_ = parse::ParsePythonCode(unstack_fn);
|
||||
MS_EXCEPTION_IF_NULL(unstack_fg_);
|
||||
auto out_cnode = fg_->NewCNode({NewValueNode(unstack_fg_), dim_cnode, val_cnode});
|
||||
for (int64_t m = 0; m < axis_size_; ++m) {
|
||||
auto out_element_cnode = fg_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), out_cnode, NewValueNode(m)});
|
||||
|
@ -506,6 +516,7 @@ FuncGraphPtr VmapGeneralRule::GenerateFuncGraph(const AbstractBasePtrList &args_
|
|||
const py::function vmap_general_output_process_fn =
|
||||
python_adapter::GetPyFn(kVmapFunctionModelName, "vmap_general_output_process");
|
||||
auto vmap_general_output_process_fg_ = parse::ParsePythonCode(vmap_general_output_process_fn);
|
||||
MS_EXCEPTION_IF_NULL(vmap_general_output_process_fg_);
|
||||
auto vmap_general_output = fg_->NewCNode({NewValueNode(vmap_general_output_process_fg_), output_cnode});
|
||||
fg_->set_output(vmap_general_output);
|
||||
return fg_;
|
||||
|
|
|
@ -484,8 +484,7 @@ AnfNodePtr ExpandVmapPrimitive(const AnfNodePtr &vnode, const pipeline::Resource
|
|||
}
|
||||
AnfNodePtr prim_vmap_rule = GetVmapRule(prim, resource, axis_size);
|
||||
if (prim_vmap_rule == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed. NodeInfo: "
|
||||
<< trace::GetDebugInfo(prim_vmap_rule->debug_info()) << ".";
|
||||
MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " transform to VmapRule failed.";
|
||||
}
|
||||
return prim_vmap_rule;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue