ONNX adapter for the MaxPoolWithArgmax

This commit is contained in:
meixiaowei 2020-04-28 14:25:59 +08:00
parent 7a95393780
commit cd899fba0a
2 changed files with 64 additions and 5 deletions

View File

@ -29,11 +29,12 @@
namespace mindspore { namespace mindspore {
enum OpMergeMode { enum OpMergeMode {
OP_MERGE_UNDEFINED = 0, // undefined behavior OP_MERGE_UNDEFINED = 0, // undefined behavior
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
}; };
struct OpMergedInfo { struct OpMergedInfo {
@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
OPERATOR_ONNX_CONVERT_DEFINE(
MaxPoolWithArgmax, MaxPool,
OpNameInfo()
.Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
OPERATOR_ONNX_CONVERT_DEFINE( OPERATOR_ONNX_CONVERT_DEFINE(
AvgPool, AveragePool, AvgPool, AveragePool,
OpNameInfo() OpNameInfo()
@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
@ -328,6 +337,8 @@ class OnnxExporter {
onnx::GraphProto *graph_proto); onnx::GraphProto *graph_proto);
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto); onnx::GraphProto *graph_proto);
@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1; op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) &&
GetInt32Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} }
} }
} }
@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
case OP_MERGE_BATCH_NORM: case OP_MERGE_BATCH_NORM:
ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto); ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
break; break;
case OP_MERGE_MAXPOOL_WITH_ARGMAX:
ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto);
break;
default: default:
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
break; break;
@ -811,6 +831,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
} }
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(1));
PrimitivePtr prim_maxpool_with_argmax =
dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(0)))->value());
std::vector<AnfNodePtr> inputs;
for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) {
inputs.push_back(maxpool_with_argmax_node->input(i));
}
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto);
}
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
if (node->inputs().size() != 2) { if (node->inputs().size() != 2) {

View File

@ -362,6 +362,31 @@ def test_lenet5_onnx_export():
net = LeNet5() net = LeNet5()
export(net, input, file_name='lenet5.onnx', file_format='ONNX') export(net, input, file_name='lenet5.onnx', file_format='ONNX')
class DefinedNet(nn.Cell):
"""simple Net definition with maxpoolwithargmax."""
def __init__(self, num_classes=10):
super(DefinedNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2)
self.flatten = nn.Flatten()
self.fc = nn.Dense(int(56*56*64), num_classes)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x, argmax = self.maxpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
def test_net_onnx_maxpoolwithargmax_export():
input = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
net = DefinedNet()
export(net, input, file_name='definedNet.onnx', file_format='ONNX')
@run_on_onnxruntime @run_on_onnxruntime
def test_lenet5_onnx_load_run(): def test_lenet5_onnx_load_run():