From 353b96dcf7973efec843f9330d26f1eb6b7279e4 Mon Sep 17 00:00:00 2001 From: fangzehua Date: Wed, 16 Nov 2022 15:25:40 +0800 Subject: [PATCH] support GatherD export onnx op --- .../transform/express_ir/onnx_exporter.cc | 22 +++++++++++++ requirements.txt | 1 + tests/st/ops/cpu/test_gather_d_op.py | 33 +++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc index d1cf0e99578..ada5e2932cd 100644 --- a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc @@ -1150,6 +1150,8 @@ class OnnxExporter { std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimExpandDims(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimGatherD(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimPad(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportPrimBatchMatMul(const FuncGraphPtr &func_graph, const CNodePtr &node, @@ -1888,6 +1890,25 @@ void OnnxExporter::ExportPrimExpandDims(const FuncGraphPtr &, const CNodePtr &no node_proto->add_input(name_shape); } +// MindSpore GatherD -> ONNX GatherElements +void OnnxExporter::ExportPrimGatherD(const FuncGraphPtr &, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto); + auto axis = GetInt64Value(node->input(kTwoNum)); + auto input_indices = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto); + auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr); + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("GatherElements"); + node_proto->add_output(node_name); + node_proto->add_input(input_x); + node_proto->add_input(input_indices); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("axis"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto->set_i(static_cast<::google::protobuf::int64>(axis)); +} + // MindSpore Pad -> ONNX Pad void OnnxExporter::ExportPrimPad(const FuncGraphPtr &, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { @@ -3248,6 +3269,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n {prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual}, {prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze}, {prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims}, + {prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD}, {prim::kPrimPad, &OnnxExporter::ExportPrimPad}, {prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul}, {prim::kPrimBroadcastTo, &OnnxExporter::ExportPrimBroadcastTo}, diff --git a/requirements.txt b/requirements.txt index 96443293d36..d15962dd4b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,5 +11,6 @@ packaging >= 20.0 pycocotools >= 2.0.2 # for st test tables >= 3.6.1 # for st test easydict >= 1.9 # for st test +onnxruntime >= 1.6.0 # for st test psutil >= 5.7.0 astunparse >= 1.6.3 diff --git a/tests/st/ops/cpu/test_gather_d_op.py b/tests/st/ops/cpu/test_gather_d_op.py index a31f6de55bd..569801a4581 100644 --- a/tests/st/ops/cpu/test_gather_d_op.py +++ b/tests/st/ops/cpu/test_gather_d_op.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ +import os +import stat import numpy as np import pytest @@ -21,6 +23,7 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.train.serialization import export context.set_context(mode=context.GRAPH_MODE, device_target="CPU") @@ -138,3 +141,33 @@ def test_gatherd_cpu_dynamic_shape(): output = gatherd(Tensor(x, mindspore.float32), Tensor(y, mindspore.int32)) expect_shape = (5, 5, 8) assert output.asnumpy().shape == expect_shape + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gatherd_cpu_onnx(): + """ + Feature: test GatherD op in cpu. + Description: test the ops export onnx. + Expectation: expect correct shape result. + """ + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + dim = 1 + net = NetGatherD(dim) + data = np.array([[1, 2], [3, 4]], dtype=np.float32) + indices = np.array([[0, 0], [1, 0]], dtype=np.int32) + out_ms = net(Tensor(data), Tensor(indices)).asnumpy() + file = 'gatherd.onnx' + export(net, Tensor(data), Tensor(indices), file_name=file, file_format="ONNX") + assert os.path.exists(file) + + import onnxruntime + sess = onnxruntime.InferenceSession(file) + input_x = sess.get_inputs()[0].name + input_indices = sess.get_inputs()[1].name + result = sess.run([], {input_x: data, input_indices: indices})[0] + assert np.all(out_ms == result) + + os.chmod(file, stat.S_IWRITE) + os.remove(file)