From bf33c04ff76b7f61bbff895e369beec3e8f76584 Mon Sep 17 00:00:00 2001 From: liuyang_655 Date: Tue, 1 Jun 2021 20:01:17 +0800 Subject: [PATCH] Add onnx export --- .../transform/express_ir/onnx_exporter.cc | 272 +++++++++++++++++- mindspore/core/base/core_ops.h | 1 + 2 files changed, 267 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc index e5100cad09e..b1b8c927844 100644 --- a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc @@ -262,10 +262,16 @@ OPERATOR_ONNX_CONVERT_DEFINE( OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(StridedSlice, Slice, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ResizeNearestNeighbor, Resize, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo()) #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name @@ -289,10 +295,14 @@ void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)()); - fn(OP_CONVERT_FUNCTION_NAME(Concat)()); fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); fn(OP_CONVERT_FUNCTION_NAME(Sub)()); + fn(OP_CONVERT_FUNCTION_NAME(Maximum)()); + fn(OP_CONVERT_FUNCTION_NAME(Exp)()); + fn(OP_CONVERT_FUNCTION_NAME(ResizeNearestNeighbor)()); + fn(OP_CONVERT_FUNCTION_NAME(Softplus)()); + fn(OP_CONVERT_FUNCTION_NAME(Tanh)()); } class OpConvertRegistry { @@ -351,6 +361,14 @@ class OnnxExporter { std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, @@ -592,12 +610,25 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map op_merged_infos; MatchAndMark(func_graph, nodes, &op_merged_infos); - + int count = -1; for (const AnfNodePtr &node : nodes) { + count++; if (!node->isa()) { continue; } auto cnode = node->cast(); + if (cnode->IsApply(prim::kPrimMakeTuple)) { + int i = count + 1; + while (!nodes[i]->isa()) { + i++; + } + auto nextCNode = nodes[i]->cast(); + if (nextCNode->IsApply(prim::kPrimUpdateState) && + IsPrimitiveCNode(nextCNode->input(2), std::make_shared("MakeTuple"))) { + continue; + } + } + auto iter = op_merged_infos.find(cnode); // the node is not referenced by any other nodes, skip it if (iter == op_merged_infos.end()) { @@ -700,6 +731,224 @@ void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node, } } +void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_perm = node->input(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + auto name = prim::kPrimTranspose->name(); + node_proto->set_op_type(name); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + + if (input_perm->isa()) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("perm"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); + auto perm_value = dyn_cast(input_perm)->value(); + auto int_ptr = dyn_cast(perm_value); + if (int_ptr == nullptr) { + auto tuple_ptr = dyn_cast(perm_value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + } else { + attr_proto->add_ints(int_ptr->value()); + } + } else { + MS_LOG(EXCEPTION) << "The input input_perm of Transpose is not a ValueNode! " + << "Need to insert op convert variable from tuple to attributes for " << name; + } +} + +void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto begin = node->input(2); + auto name = prim::kPrimStridedSlice->name(); + std::string name_begin; + if (begin->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[begin] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_begin = std::to_string(const_node_idx); + node_proto->add_output(name_begin); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("starts"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(begin)->value(), attr_proto->mutable_t()); + } else { + MS_LOG(EXCEPTION) << "The input begin of StridedSlice is not a ValueNode! " + << "Need to insert op convert variable from tuple to tensor for " << name; + } + + auto end = node->input(3); + std::string name_end; + if (end->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[end] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_end = std::to_string(const_node_idx); + node_proto->add_output(name_end); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("ends"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(end)->value(), attr_proto->mutable_t()); + } else { + MS_LOG(EXCEPTION) << "The input end of StridedSlice is not a ValueNode! " + << "Need to insert op convert variable from tuple to tensor for " << name; + } + + auto x_shape = dyn_cast(node->input(1)->Shape()); + int size = x_shape->shape().size(); + std::vector axes_value; + ValuePtr axes_value_ptr = nullptr; + for (int i = 0; i < size; ++i) { + axes_value.push_back(i); + } + axes_value_ptr = MakeValue>(axes_value); + auto axes = NewValueNode(axes_value_ptr)->cast(); + std::string name_axes; + auto const_node_idx_axes = AllocateNodeIndex(); + (*node_map_ptr)[axes] = const_node_idx_axes; + onnx::NodeProto *node_proto_axes = graph_proto->add_node(); + name_axes = std::to_string(const_node_idx_axes); + node_proto_axes->add_output(name_axes); + node_proto_axes->set_op_type("Constant"); + onnx::AttributeProto *attr_proto_axes = node_proto_axes->add_attribute(); + attr_proto_axes->set_name("axes"); + attr_proto_axes->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(axes)->value(), attr_proto_axes->mutable_t()); + + auto strides = node->input(4); + std::string name_strides; + if (strides->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[strides] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_strides = std::to_string(const_node_idx); + node_proto->add_output(name_strides); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto_steps = node_proto->add_attribute(); + attr_proto_steps->set_name("steps"); + attr_proto_steps->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(strides)->value(), attr_proto_steps->mutable_t()); + } else { + MS_LOG(EXCEPTION) << "The input strides of StridedSlice is not a ValueNode! " + << "Need to insert op convert variable from tuple to tensor for " << name; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Slice"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + node_proto->add_input(name_begin); + node_proto->add_input(name_end); + node_proto->add_input(name_axes); + node_proto->add_input(name_strides); +} + +void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto x_shape = dyn_cast(node->input(1)->Shape()); + + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + std::vector resize_size; + + auto tuple_ptr = dyn_cast(prim->GetAttr("size")); + + for (size_t i = 0; i < x_shape->shape().size() - 2; i++) { + resize_size.push_back(x_shape->shape()[i]); + } + for (size_t i = 0; i < tuple_ptr->size(); i++) { + ValuePtr elem = (*tuple_ptr)[i]; + resize_size.push_back(dyn_cast(elem)->value()); + } + auto resize_size_ptr = MakeValue>(resize_size); + auto size = NewValueNode(resize_size_ptr)->cast(); + std::string name_size; + + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[size] = const_node_idx; + onnx::NodeProto *node_proto_size = graph_proto->add_node(); + name_size = std::to_string(const_node_idx); + node_proto_size->add_output(name_size); + node_proto_size->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto_size->add_attribute(); + attr_proto->set_name("sizes"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t()); + + auto node_idx = AllocateNodeIndex(); + + onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer(); + auto roi_name = std::to_string(node_idx) + "roi_initializer"; + roi_initializer_proto->set_name(roi_name); + roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32)); + roi_initializer_proto->add_dims(0); + + onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer(); + auto scales_name = std::to_string(node_idx) + "scales_initializer"; + scales_initializer_proto->set_name(scales_name); + scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32)); + scales_initializer_proto->add_dims(0); + + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + + node_proto->set_op_type("Resize"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + node_proto->add_input(roi_name); + node_proto->add_input(scales_name); + node_proto->add_input(name_size); +} + +void OnnxExporter::ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + auto input_node = node->input(1)->cast(); + + if (input_node->IsApply(prim::kPrimMakeTuple)) { + node_proto->set_op_type("ConcatFromSequence"); + } else { + node_proto->set_op_type("Concat"); + } + + // set attr axis + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("axis"); + SetAttrValueToProto(prim->GetAttr("axis"), onnx::AttributeProto_AttributeType_INT, onnx_attr_proto, prim); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); +} + void OnnxExporter::ExportPrimCast(const FuncGraphPtr &, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); @@ -959,6 +1208,19 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); } + if (node->IsApply(prim::kPrimTranspose)) { + return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto); + } + if (node->IsApply(prim::kPrimStridedSlice)) { + return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto); + } + if (node->IsApply(prim::kPrimResizeNearestNeighbor)) { + return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto); + } + if (node->IsApply(prim::kPrimConcat)) { + return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto); + } + // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) if (node->IsApply(prim::kPrimCast)) { return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); @@ -1016,9 +1278,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); } - if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); - } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); } size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr &, std::map *node_map_ptr, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index e580471ae20..4c45d03f4f0 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -212,6 +212,7 @@ inline const PrimitivePtr kPrimReverseSequence = std::make_shared("Re inline const PrimitivePtr kPrimRank = std::make_shared("Rank"); inline const PrimitivePtr kPrimResizeBilinear = std::make_shared("ResizeBilinear"); inline const PrimitivePtr kPrimResizeGrad = std::make_shared("ResizeGrad"); +inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared("ResizeNearestNeighbor"); inline const PrimitivePtr kPrimSort = std::make_shared("Sort"); // NN