!45647 support GatherD onnx

Merge pull request !45647 from fangzehua/onnx_gatherd
This commit is contained in:
i-robot 2022-11-18 07:49:25 +00:00 committed by Gitee
commit 26aef12ff9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 56 additions and 0 deletions

View File

@ -1172,6 +1172,8 @@ class OnnxExporter {
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimExpandDims(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimGatherD(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimPad(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimBatchMatMul(const FuncGraphPtr &func_graph, const CNodePtr &node,
@ -1914,6 +1916,25 @@ void OnnxExporter::ExportPrimExpandDims(const FuncGraphPtr &, const CNodePtr &no
node_proto->add_input(name_shape);
}
// MindSpore GatherD -> ONNX GatherElements
void OnnxExporter::ExportPrimGatherD(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
auto axis = GetInt64Value(node->input(kTwoNum));
auto input_indices = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("GatherElements");
node_proto->add_output(node_name);
node_proto->add_input(input_x);
node_proto->add_input(input_indices);
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("axis");
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
attr_proto->set_i(static_cast<::google::protobuf::int64>(axis));
}
// MindSpore Pad -> ONNX Pad
void OnnxExporter::ExportPrimPad(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
@ -3422,6 +3443,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual},
{prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
{prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD},
{prim::kPrimPad, &OnnxExporter::ExportPrimPad},
{prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
{prim::kPrimBroadcastTo, &OnnxExporter::ExportPrimBroadcastTo},

View File

@ -12,5 +12,6 @@ packaging >= 20.0
pycocotools >= 2.0.2 # for st test
tables >= 3.6.1 # for st test
easydict >= 1.9 # for st test
onnxruntime >= 1.6.0 # for st test
psutil >= 5.7.0
astunparse >= 1.6.3

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
import os
import stat
import numpy as np
import pytest
@ -21,6 +23,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.train.serialization import export
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -138,3 +141,33 @@ def test_gatherd_cpu_dynamic_shape():
output = gatherd(Tensor(x, mindspore.float32), Tensor(y, mindspore.int32))
expect_shape = (5, 5, 8)
assert output.asnumpy().shape == expect_shape
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gatherd_cpu_onnx():
"""
Feature: test GatherD op in cpu.
Description: test the ops export onnx.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
dim = 1
net = NetGatherD(dim)
data = np.array([[1, 2], [3, 4]], dtype=np.float32)
indices = np.array([[0, 0], [1, 0]], dtype=np.int32)
out_ms = net(Tensor(data), Tensor(indices)).asnumpy()
file = 'gatherd.onnx'
export(net, Tensor(data), Tensor(indices), file_name=file, file_format="ONNX")
assert os.path.exists(file)
import onnxruntime
sess = onnxruntime.InferenceSession(file)
input_x = sess.get_inputs()[0].name
input_indices = sess.get_inputs()[1].name
result = sess.run([], {input_x: data, input_indices: indices})[0]
assert np.all(out_ms == result)
os.chmod(file, stat.S_IWRITE)
os.remove(file)