forked from mindspore-Ecosystem/mindspore
!45647 support GatherD onnx
Merge pull request !45647 from fangzehua/onnx_gatherd
This commit is contained in:
commit
26aef12ff9
|
@ -1172,6 +1172,8 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimExpandDims(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimGatherD(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimPad(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimBatchMatMul(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
|
@ -1914,6 +1916,25 @@ void OnnxExporter::ExportPrimExpandDims(const FuncGraphPtr &, const CNodePtr &no
|
|||
node_proto->add_input(name_shape);
|
||||
}
|
||||
|
||||
// MindSpore GatherD -> ONNX GatherElements
|
||||
void OnnxExporter::ExportPrimGatherD(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
|
||||
auto axis = GetInt64Value(node->input(kTwoNum));
|
||||
auto input_indices = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
|
||||
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("GatherElements");
|
||||
node_proto->add_output(node_name);
|
||||
node_proto->add_input(input_x);
|
||||
node_proto->add_input(input_indices);
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("axis");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
attr_proto->set_i(static_cast<::google::protobuf::int64>(axis));
|
||||
}
|
||||
|
||||
// MindSpore Pad -> ONNX Pad
|
||||
void OnnxExporter::ExportPrimPad(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
|
@ -3422,6 +3443,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual},
|
||||
{prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
|
||||
{prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
|
||||
{prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD},
|
||||
{prim::kPrimPad, &OnnxExporter::ExportPrimPad},
|
||||
{prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
|
||||
{prim::kPrimBroadcastTo, &OnnxExporter::ExportPrimBroadcastTo},
|
||||
|
|
|
@ -12,5 +12,6 @@ packaging >= 20.0
|
|||
pycocotools >= 2.0.2 # for st test
|
||||
tables >= 3.6.1 # for st test
|
||||
easydict >= 1.9 # for st test
|
||||
onnxruntime >= 1.6.0 # for st test
|
||||
psutil >= 5.7.0
|
||||
astunparse >= 1.6.3
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import stat
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -21,6 +23,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.serialization import export
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
@ -138,3 +141,33 @@ def test_gatherd_cpu_dynamic_shape():
|
|||
output = gatherd(Tensor(x, mindspore.float32), Tensor(y, mindspore.int32))
|
||||
expect_shape = (5, 5, 8)
|
||||
assert output.asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_gatherd_cpu_onnx():
|
||||
"""
|
||||
Feature: test GatherD op in cpu.
|
||||
Description: test the ops export onnx.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
dim = 1
|
||||
net = NetGatherD(dim)
|
||||
data = np.array([[1, 2], [3, 4]], dtype=np.float32)
|
||||
indices = np.array([[0, 0], [1, 0]], dtype=np.int32)
|
||||
out_ms = net(Tensor(data), Tensor(indices)).asnumpy()
|
||||
file = 'gatherd.onnx'
|
||||
export(net, Tensor(data), Tensor(indices), file_name=file, file_format="ONNX")
|
||||
assert os.path.exists(file)
|
||||
|
||||
import onnxruntime
|
||||
sess = onnxruntime.InferenceSession(file)
|
||||
input_x = sess.get_inputs()[0].name
|
||||
input_indices = sess.get_inputs()[1].name
|
||||
result = sess.run([], {input_x: data, input_indices: indices})[0]
|
||||
assert np.all(out_ms == result)
|
||||
|
||||
os.chmod(file, stat.S_IWRITE)
|
||||
os.remove(file)
|
||||
|
|
Loading…
Reference in New Issue