support onnx scatternd
This commit is contained in:
parent
57d2a1a1e5
commit
d19f82788f
|
@ -1216,6 +1216,8 @@ class OnnxExporter {
|
|||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimOnesLike(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimScatterNd(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimArgMaxWithValue(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
|
@ -2797,6 +2799,50 @@ void OnnxExporter::ExportPrimOnesLike(const FuncGraphPtr &, const CNodePtr &node
|
|||
}
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimScatterNd(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
|
||||
auto input_indices_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
|
||||
auto input_update_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
|
||||
auto input_shape_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
|
||||
auto node_zero_tensor_name = node_name + "_zero";
|
||||
auto dtype = node->input(kTwoNum)->Type();
|
||||
auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
|
||||
|
||||
onnx::TensorProto *zero_proto = AddConstantOfShapeOp(input_shape_name, node_zero_tensor_name, graph_proto);
|
||||
switch (elem_type) {
|
||||
case kNumberTypeInt32:
|
||||
zero_proto->set_data_type(onnx::TensorProto_DataType_INT32);
|
||||
zero_proto->add_int32_data(0);
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
zero_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||
zero_proto->add_int64_data(0);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
zero_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
|
||||
zero_proto->add_float_data(0.0f);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
zero_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
|
||||
zero_proto->add_double_data(0.0);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unsupported dtype: " << elem_type;
|
||||
}
|
||||
auto int64_indices_name = input_indices_name + "_int64";
|
||||
AddCastOp(input_indices_name, int64_indices_name, onnx::TensorProto_DataType_INT64, graph_proto);
|
||||
|
||||
// Create ScatterND node
|
||||
onnx::NodeProto *scatternd_proto = graph_proto->add_node();
|
||||
scatternd_proto->set_op_type("ScatterND");
|
||||
scatternd_proto->add_input(node_zero_tensor_name);
|
||||
scatternd_proto->add_input(int64_indices_name);
|
||||
scatternd_proto->add_input(input_update_name);
|
||||
scatternd_proto->add_output(node_name);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimArgMaxWithValue(const FuncGraphPtr &, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, std::string> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
|
@ -3436,6 +3482,7 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
{prim::kPrimROIAlign, &OnnxExporter::ExportPrimROIAlign},
|
||||
{prim::kPrimSlice, &OnnxExporter::ExportPrimSlice},
|
||||
{prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike},
|
||||
{prim::kPrimScatterNd, &OnnxExporter::ExportPrimScatterNd},
|
||||
{prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue},
|
||||
{prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot},
|
||||
{prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose},
|
||||
|
|
|
@ -342,3 +342,40 @@ def test_scatternd_functional_pynative():
|
|||
diff = output - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatternd_cpu_onnx():
|
||||
"""
|
||||
Feature: test ScatterNd op in cpu.
|
||||
Description: test the ops export onnx.
|
||||
Expectation: expect correct shape result.
|
||||
"""
|
||||
import os
|
||||
import stat
|
||||
import onnxruntime
|
||||
from mindspore.train.serialization import export
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
shape = (4, 4, 4)
|
||||
net = Net(shape)
|
||||
indices = np.array([[0], [2]], dtype=np.int32)
|
||||
updates = np.array([[[1, 1, 1, 1], [2, 2, 2, 2],
|
||||
[3, 3, 3, 3], [4, 4, 4, 4]],
|
||||
[[1, 1, 1, 1], [2, 2, 2, 2],
|
||||
[3, 3, 3, 3], [4, 4, 4, 4]]], dtype=np.float32)
|
||||
out_ms = net(Tensor(indices), Tensor(updates)).asnumpy()
|
||||
file = 'scatternd.onnx'
|
||||
export(net, Tensor(indices), Tensor(updates), file_name=file, file_format="ONNX")
|
||||
assert os.path.exists(file)
|
||||
|
||||
sess = onnxruntime.InferenceSession(file)
|
||||
input_indices = sess.get_inputs()[0].name
|
||||
input_updates = sess.get_inputs()[1].name
|
||||
result = sess.run([], {input_indices: indices, input_updates: updates})[0]
|
||||
assert np.all(out_ms == result)
|
||||
|
||||
os.chmod(file, stat.S_IWRITE)
|
||||
os.remove(file)
|
||||
|
|
Loading…
Reference in New Issue