support onnx scatternd

This commit is contained in:
fangzehua 2022-11-18 16:50:41 +08:00
parent 57d2a1a1e5
commit d19f82788f
2 changed files with 84 additions and 0 deletions

View File

@ -1216,6 +1216,8 @@ class OnnxExporter {
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimOnesLike(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimScatterNd(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimArgMaxWithValue(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node,
@ -2797,6 +2799,50 @@ void OnnxExporter::ExportPrimOnesLike(const FuncGraphPtr &, const CNodePtr &node
}
}
void OnnxExporter::ExportPrimScatterNd(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
auto input_indices_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
auto input_update_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
auto input_shape_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
auto node_zero_tensor_name = node_name + "_zero";
auto dtype = node->input(kTwoNum)->Type();
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
onnx::TensorProto *zero_proto = AddConstantOfShapeOp(input_shape_name, node_zero_tensor_name, graph_proto);
switch (elem_type) {
case kNumberTypeInt32:
zero_proto->set_data_type(onnx::TensorProto_DataType_INT32);
zero_proto->add_int32_data(0);
break;
case kNumberTypeInt64:
zero_proto->set_data_type(onnx::TensorProto_DataType_INT64);
zero_proto->add_int64_data(0);
break;
case kNumberTypeFloat32:
zero_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
zero_proto->add_float_data(0.0f);
break;
case kNumberTypeFloat64:
zero_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
zero_proto->add_double_data(0.0);
break;
default:
MS_LOG(EXCEPTION) << "Unsupported dtype: " << elem_type;
}
auto int64_indices_name = input_indices_name + "_int64";
AddCastOp(input_indices_name, int64_indices_name, onnx::TensorProto_DataType_INT64, graph_proto);
// Create ScatterND node
onnx::NodeProto *scatternd_proto = graph_proto->add_node();
scatternd_proto->set_op_type("ScatterND");
scatternd_proto->add_input(node_zero_tensor_name);
scatternd_proto->add_input(int64_indices_name);
scatternd_proto->add_input(input_update_name);
scatternd_proto->add_output(node_name);
}
void OnnxExporter::ExportPrimArgMaxWithValue(const FuncGraphPtr &, const CNodePtr &node,
std::map<AnfNodePtr, std::string> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
@ -3436,6 +3482,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
{prim::kPrimROIAlign, &OnnxExporter::ExportPrimROIAlign},
{prim::kPrimSlice, &OnnxExporter::ExportPrimSlice},
{prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike},
{prim::kPrimScatterNd, &OnnxExporter::ExportPrimScatterNd},
{prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue},
{prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot},
{prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose},

View File

@ -342,3 +342,40 @@ def test_scatternd_functional_pynative():
diff = output - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_scatternd_cpu_onnx():
"""
Feature: test ScatterNd op in cpu.
Description: test the ops export onnx.
Expectation: expect correct shape result.
"""
import os
import stat
import onnxruntime
from mindspore.train.serialization import export
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
shape = (4, 4, 4)
net = Net(shape)
indices = np.array([[0], [2]], dtype=np.int32)
updates = np.array([[[1, 1, 1, 1], [2, 2, 2, 2],
[3, 3, 3, 3], [4, 4, 4, 4]],
[[1, 1, 1, 1], [2, 2, 2, 2],
[3, 3, 3, 3], [4, 4, 4, 4]]], dtype=np.float32)
out_ms = net(Tensor(indices), Tensor(updates)).asnumpy()
file = 'scatternd.onnx'
export(net, Tensor(indices), Tensor(updates), file_name=file, file_format="ONNX")
assert os.path.exists(file)
sess = onnxruntime.InferenceSession(file)
input_indices = sess.get_inputs()[0].name
input_updates = sess.get_inputs()[1].name
result = sess.run([], {input_indices: indices, input_updates: updates})[0]
assert np.all(out_ms == result)
os.chmod(file, stat.S_IWRITE)
os.remove(file)