add onnx op converter: ReLU6, DepthwiseConv2dNative

This commit is contained in:
chenzupeng 2020-04-30 12:59:39 +08:00
parent 0edf22e68c
commit a109bc3d96
3 changed files with 173 additions and 1 deletions

View File

@ -330,6 +330,10 @@ class OnnxExporter {
onnx::GraphProto *graph_proto);
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);
@ -711,6 +715,115 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN
node_proto->add_input(input_slope);
}
void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
onnx::NodeProto *node_proto = graph_proto->add_node();
node_proto->set_op_type("Clip");
node_proto->add_output(std::to_string(node_idx));
node_proto->add_input(input_x);
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
attr_proto->set_name("min");
attr_proto->set_f(0.f);
attr_proto = node_proto->add_attribute();
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
attr_proto->set_name("max");
attr_proto->set_f(6.f);
}
void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *const graph_proto) {
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
auto w_shape = dyn_cast<abstract::Shape>(node->input(2)->Shape());
MS_EXCEPTION_IF_NULL(x_shape);
MS_EXCEPTION_IF_NULL(w_shape);
if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) {
MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d.";
}
if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) {
MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape";
}
// create w_shape constant node
auto node_idx = AllocateNodeIndex();
onnx::NodeProto *node_proto = graph_proto->add_node();
std::string name_w_shape = std::to_string(node_idx);
node_proto->add_output(name_w_shape);
node_proto->set_op_type("Constant");
// create Value Tensor
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name("value");
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size()));
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
// reshape
tensor_proto->add_int64_data(w_shape->shape()[1]);
tensor_proto->add_int64_data(w_shape->shape()[0]);
tensor_proto->add_int64_data(w_shape->shape()[2]);
tensor_proto->add_int64_data(w_shape->shape()[3]);
// add reshape node
node_idx = AllocateNodeIndex();
node_proto = graph_proto->add_node();
node_proto->set_op_type(prim::kPrimReshape->name());
node_proto->add_input(input_w);
node_proto->add_input(name_w_shape);
input_w = std::to_string(node_idx);
node_proto->add_output(input_w);
// add conv node
node_idx = AllocateNodeIndex();
(*node_map_ptr)[node] = node_idx;
node_proto = graph_proto->add_node();
node_proto->set_op_type("Conv");
node_proto->add_input(input_x);
node_proto->add_input(input_w);
node_proto->add_output(std::to_string(node_idx));
// set attributes
AnfNodePtr op = node->input(0);
auto op_value = dyn_cast<ValueNode>(op);
auto prim = dyn_cast<Primitive>(op_value->value());
// set dilations
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name("dilations");
SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
prim);
// set group
onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name("group");
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
onnx_attr_proto->set_i(x_shape->shape()[1]);
// set kernel_shape
onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name("kernel_shape");
SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
prim);
// set pad
onnx_attr_proto = node_proto->add_attribute();
auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode"));
onnx_attr_proto->set_name("auto_pad");
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
if (attr_value == "valid") {
onnx_attr_proto->set_s("VALID");
} else if (attr_value == "same") {
onnx_attr_proto->set_s("SAME_UPPER");
} else {
onnx_attr_proto->set_name("pads");
SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
}
// set strides
onnx_attr_proto = node_proto->add_attribute();
onnx_attr_proto->set_name("strides");
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
}
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
@ -732,6 +845,16 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x)
if (node->IsApply(std::make_shared<Primitive>("ReLU6"))) {
return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto);
}
// MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w))
if (node->IsApply(std::make_shared<Primitive>("DepthwiseConv2dNative"))) {
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
}
auto inputs = node->inputs();
if (inputs.size() < 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";

View File

@ -89,7 +89,8 @@ class DepthwiseConv(nn.Cell):
self.channel_multiplier = channel_multiplier
self.out_channels = in_planes * channel_multiplier
self.kernel_size = (kernel_size, kernel_size)
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=self.kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]

View File

@ -323,6 +323,22 @@ class BatchNormTester(nn.Cell):
return self.bn(x)
class DepthwiseConv2dAndReLU6(nn.Cell):
"Net for testing DepthwiseConv2d and ReLU6"
def __init__(self, input_channel, kernel_size):
super(DepthwiseConv2dAndReLU6, self).__init__()
weight_shape = [1, input_channel, kernel_size, kernel_size]
from mindspore.common.initializer import initializer
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size))
self.relu6 = nn.ReLU6()
def construct(self, x):
x = self.depthwise_conv(x, self.weight)
x = self.relu6(x)
return x
def test_batchnorm_train_onnx_export():
input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
net = BatchNormTester(3)
@ -421,6 +437,38 @@ def test_lenet5_onnx_load_run():
print(outputs[0])
@run_on_onnxruntime
def test_depthwiseconv_relu6_onnx_load_run():
onnx_file = 'depthwiseconv_relu6.onnx'
input_channel = 3
input = Tensor(np.ones([1, input_channel, 32, 32]).astype(np.float32) * 0.01)
net = DepthwiseConv2dAndReLU6(input_channel, kernel_size=3)
export(net, input, file_name=onnx_file, file_format='ONNX')
import onnx
import onnxruntime as ort
print('--------------------- onnx load ---------------------')
# Load the ONNX model
model = onnx.load(onnx_file)
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
g = onnx.helper.printable_graph(model.graph)
print(g)
print('------------------ onnxruntime run ------------------')
ort_session = ort.InferenceSession(onnx_file)
input_map = {'x' : input.asnumpy()}
# provide only input x to run model
outputs = ort_session.run(None, input_map)
print(outputs[0])
# overwrite default weight to run model
for item in net.trainable_params():
input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
outputs = ort_session.run(None, input_map)
print(outputs[0])
def teardown_module():
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt']
for item in files: