From a109bc3d96d421c84259242b3c18c17bd59acb90 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Thu, 30 Apr 2020 12:59:39 +0800 Subject: [PATCH] add onnx op converter: ReLU6, DepthwiseConv2dNative --- mindspore/ccsrc/onnx/onnx_exporter.cc | 123 ++++++++++++++++++++++++ mindspore/model_zoo/mobilenet.py | 3 +- tests/ut/python/utils/test_serialize.py | 48 +++++++++ 3 files changed, 173 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index 1c5a7b93c34..d53d1f63eda 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -330,6 +330,10 @@ class OnnxExporter { onnx::GraphProto *graph_proto); void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *graph_proto); @@ -711,6 +715,115 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN node_proto->add_input(input_slope); } +void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Clip"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_x); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto->set_name("min"); + attr_proto->set_f(0.f); + attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto->set_name("max"); + attr_proto->set_f(6.f); +} + +void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + auto x_shape = dyn_cast(node->input(1)->Shape()); + auto w_shape = dyn_cast(node->input(2)->Shape()); + MS_EXCEPTION_IF_NULL(x_shape); + MS_EXCEPTION_IF_NULL(w_shape); + if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) { + MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d."; + } + if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) { + MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape"; + } + // create w_shape constant node + auto node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto = graph_proto->add_node(); + std::string name_w_shape = std::to_string(node_idx); + node_proto->add_output(name_w_shape); + node_proto->set_op_type("Constant"); + // create Value Tensor + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size())); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + // reshape + tensor_proto->add_int64_data(w_shape->shape()[1]); + tensor_proto->add_int64_data(w_shape->shape()[0]); + tensor_proto->add_int64_data(w_shape->shape()[2]); + tensor_proto->add_int64_data(w_shape->shape()[3]); + + // add reshape node + node_idx = AllocateNodeIndex(); + node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimReshape->name()); + node_proto->add_input(input_w); + node_proto->add_input(name_w_shape); + input_w = std::to_string(node_idx); + node_proto->add_output(input_w); + + // add conv node + node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + node_proto = graph_proto->add_node(); + node_proto->set_op_type("Conv"); + node_proto->add_input(input_x); + node_proto->add_input(input_w); + node_proto->add_output(std::to_string(node_idx)); + // set attributes + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + // set dilations + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("dilations"); + SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, + prim); + // set group + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("group"); + onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + onnx_attr_proto->set_i(x_shape->shape()[1]); + // set kernel_shape + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("kernel_shape"); + SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, + prim); + + // set pad + onnx_attr_proto = node_proto->add_attribute(); + auto attr_value = GetValue(prim->GetAttr("pad_mode")); + onnx_attr_proto->set_name("auto_pad"); + onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + if (attr_value == "valid") { + onnx_attr_proto->set_s("VALID"); + } else if (attr_value == "same") { + onnx_attr_proto->set_s("SAME_UPPER"); + } else { + onnx_attr_proto->set_name("pads"); + SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); + } + // set strides + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("strides"); + SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); +} + void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert @@ -732,6 +845,16 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); } + // MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) + if (node->IsApply(std::make_shared("ReLU6"))) { + return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) + if (node->IsApply(std::make_shared("DepthwiseConv2dNative"))) { + return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); + } + auto inputs = node->inputs(); if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; diff --git a/mindspore/model_zoo/mobilenet.py b/mindspore/model_zoo/mobilenet.py index 1d4f1b10b58..6539c3e2690 100644 --- a/mindspore/model_zoo/mobilenet.py +++ b/mindspore/model_zoo/mobilenet.py @@ -89,7 +89,8 @@ class DepthwiseConv(nn.Cell): self.channel_multiplier = channel_multiplier self.out_channels = in_planes * channel_multiplier self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, + kernel_size=self.kernel_size, stride=stride, pad_mode=pad_mode, pad=pad) self.bias_add = P.BiasAdd() weight_shape = [channel_multiplier, in_planes, *self.kernel_size] diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 59a4b938336..2cb27cadfd1 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -323,6 +323,22 @@ class BatchNormTester(nn.Cell): return self.bn(x) +class DepthwiseConv2dAndReLU6(nn.Cell): + "Net for testing DepthwiseConv2d and ReLU6" + + def __init__(self, input_channel, kernel_size): + super(DepthwiseConv2dAndReLU6, self).__init__() + weight_shape = [1, input_channel, kernel_size, kernel_size] + from mindspore.common.initializer import initializer + self.weight = Parameter(initializer('ones', weight_shape), name='weight') + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size)) + self.relu6 = nn.ReLU6() + + def construct(self, x): + x = self.depthwise_conv(x, self.weight) + x = self.relu6(x) + return x + def test_batchnorm_train_onnx_export(): input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01) net = BatchNormTester(3) @@ -421,6 +437,38 @@ def test_lenet5_onnx_load_run(): print(outputs[0]) +@run_on_onnxruntime +def test_depthwiseconv_relu6_onnx_load_run(): + onnx_file = 'depthwiseconv_relu6.onnx' + input_channel = 3 + input = Tensor(np.ones([1, input_channel, 32, 32]).astype(np.float32) * 0.01) + net = DepthwiseConv2dAndReLU6(input_channel, kernel_size=3) + export(net, input, file_name=onnx_file, file_format='ONNX') + + import onnx + import onnxruntime as ort + + print('--------------------- onnx load ---------------------') + # Load the ONNX model + model = onnx.load(onnx_file) + # Check that the IR is well formed + onnx.checker.check_model(model) + # Print a human readable representation of the graph + g = onnx.helper.printable_graph(model.graph) + print(g) + + print('------------------ onnxruntime run ------------------') + ort_session = ort.InferenceSession(onnx_file) + input_map = {'x' : input.asnumpy()} + # provide only input x to run model + outputs = ort_session.run(None, input_map) + print(outputs[0]) + # overwrite default weight to run model + for item in net.trainable_params(): + input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32) + outputs = ort_session.run(None, input_map) + print(outputs[0]) + def teardown_module(): files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt'] for item in files: