forked from mindspore-Ecosystem/mindspore
ONNX adapter for the MaxPoolWithArgmax
This commit is contained in:
parent
7a95393780
commit
cd899fba0a
|
@ -29,11 +29,12 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
enum OpMergeMode {
|
enum OpMergeMode {
|
||||||
OP_MERGE_UNDEFINED = 0, // undefined behavior
|
OP_MERGE_UNDEFINED = 0, // undefined behavior
|
||||||
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
|
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
|
||||||
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
|
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
|
||||||
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
|
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
|
||||||
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
|
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
|
||||||
|
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OpMergedInfo {
|
struct OpMergedInfo {
|
||||||
|
@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
||||||
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||||
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||||
|
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(
|
||||||
|
MaxPoolWithArgmax, MaxPool,
|
||||||
|
OpNameInfo()
|
||||||
|
.Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
|
||||||
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||||
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||||
|
|
||||||
OPERATOR_ONNX_CONVERT_DEFINE(
|
OPERATOR_ONNX_CONVERT_DEFINE(
|
||||||
AvgPool, AveragePool,
|
AvgPool, AveragePool,
|
||||||
OpNameInfo()
|
OpNameInfo()
|
||||||
|
@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||||
|
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
|
fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
|
fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
|
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
|
||||||
|
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
||||||
|
@ -328,6 +337,8 @@ class OnnxExporter {
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto
|
||||||
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
|
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
|
||||||
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
||||||
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
||||||
|
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
|
||||||
|
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) &&
|
||||||
|
GetInt32Value(cnode->input(2)) == 0) {
|
||||||
|
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX;
|
||||||
|
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
||||||
|
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
|
||||||
case OP_MERGE_BATCH_NORM:
|
case OP_MERGE_BATCH_NORM:
|
||||||
ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
|
ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
|
||||||
break;
|
break;
|
||||||
|
case OP_MERGE_MAXPOOL_WITH_ARGMAX:
|
||||||
|
ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
|
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
|
||||||
break;
|
break;
|
||||||
|
@ -811,6 +831,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
|
||||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
|
onnx::GraphProto *const graph_proto) {
|
||||||
|
auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(1));
|
||||||
|
|
||||||
|
PrimitivePtr prim_maxpool_with_argmax =
|
||||||
|
dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(0)))->value());
|
||||||
|
std::vector<AnfNodePtr> inputs;
|
||||||
|
for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) {
|
||||||
|
inputs.push_back(maxpool_with_argmax_node->input(i));
|
||||||
|
}
|
||||||
|
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
if (node->inputs().size() != 2) {
|
if (node->inputs().size() != 2) {
|
||||||
|
|
|
@ -362,6 +362,31 @@ def test_lenet5_onnx_export():
|
||||||
net = LeNet5()
|
net = LeNet5()
|
||||||
export(net, input, file_name='lenet5.onnx', file_format='ONNX')
|
export(net, input, file_name='lenet5.onnx', file_format='ONNX')
|
||||||
|
|
||||||
|
class DefinedNet(nn.Cell):
|
||||||
|
"""simple Net definition with maxpoolwithargmax."""
|
||||||
|
def __init__(self, num_classes=10):
|
||||||
|
super(DefinedNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
|
||||||
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2)
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.fc = nn.Dense(int(56*56*64), num_classes)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x, argmax = self.maxpool(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_net_onnx_maxpoolwithargmax_export():
|
||||||
|
input = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
|
||||||
|
net = DefinedNet()
|
||||||
|
export(net, input, file_name='definedNet.onnx', file_format='ONNX')
|
||||||
|
|
||||||
|
|
||||||
@run_on_onnxruntime
|
@run_on_onnxruntime
|
||||||
def test_lenet5_onnx_load_run():
|
def test_lenet5_onnx_load_run():
|
||||||
|
|
Loading…
Reference in New Issue