forked from mindspore-Ecosystem/mindspore
ONNX adapter for the MaxPoolWithArgmax
This commit is contained in:
parent
7a95393780
commit
cd899fba0a
|
@ -29,11 +29,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
enum OpMergeMode {
|
||||
OP_MERGE_UNDEFINED = 0, // undefined behavior
|
||||
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
|
||||
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
|
||||
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
|
||||
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
|
||||
OP_MERGE_UNDEFINED = 0, // undefined behavior
|
||||
OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
|
||||
OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
|
||||
OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
|
||||
OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization`
|
||||
OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
|
||||
};
|
||||
|
||||
struct OpMergedInfo {
|
||||
|
@ -233,6 +234,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
|||
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(
|
||||
MaxPoolWithArgmax, MaxPool,
|
||||
OpNameInfo()
|
||||
.Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
|
||||
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(
|
||||
AvgPool, AveragePool,
|
||||
OpNameInfo()
|
||||
|
@ -254,6 +262,7 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
|||
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
|
||||
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
||||
|
@ -328,6 +337,8 @@ class OnnxExporter {
|
|||
onnx::GraphProto *graph_proto);
|
||||
void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
|
||||
void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
|
@ -516,6 +527,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto
|
|||
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
|
||||
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
||||
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
||||
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
|
||||
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) &&
|
||||
GetInt32Value(cnode->input(2)) == 0) {
|
||||
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX;
|
||||
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
|
||||
op_merged_infos[cnode->input(1)].referred_count -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -563,6 +580,9 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodeP
|
|||
case OP_MERGE_BATCH_NORM:
|
||||
ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
|
||||
break;
|
||||
case OP_MERGE_MAXPOOL_WITH_ARGMAX:
|
||||
ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto);
|
||||
break;
|
||||
default:
|
||||
ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
|
||||
break;
|
||||
|
@ -811,6 +831,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CN
|
|||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(1));
|
||||
|
||||
PrimitivePtr prim_maxpool_with_argmax =
|
||||
dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(0)))->value());
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) {
|
||||
inputs.push_back(maxpool_with_argmax_node->input(i));
|
||||
}
|
||||
(*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
if (node->inputs().size() != 2) {
|
||||
|
|
|
@ -362,6 +362,31 @@ def test_lenet5_onnx_export():
|
|||
net = LeNet5()
|
||||
export(net, input, file_name='lenet5.onnx', file_format='ONNX')
|
||||
|
||||
class DefinedNet(nn.Cell):
|
||||
"""simple Net definition with maxpoolwithargmax."""
|
||||
def __init__(self, num_classes=10):
|
||||
super(DefinedNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2)
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc = nn.Dense(int(56*56*64), num_classes)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x, argmax = self.maxpool(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def test_net_onnx_maxpoolwithargmax_export():
|
||||
input = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
|
||||
net = DefinedNet()
|
||||
export(net, input, file_name='definedNet.onnx', file_format='ONNX')
|
||||
|
||||
|
||||
@run_on_onnxruntime
|
||||
def test_lenet5_onnx_load_run():
|
||||
|
|
Loading…
Reference in New Issue