forked from mindspore-Ecosystem/mindspore
add onnx op converter: ReLU6, DepthwiseConv2dNative
This commit is contained in:
parent
0edf22e68c
commit
a109bc3d96
|
@ -330,6 +330,10 @@ class OnnxExporter {
|
|||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
|
||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
|
@ -711,6 +715,115 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN
|
|||
node_proto->add_input(input_slope);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Clip");
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(input_x);
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
||||
attr_proto->set_name("min");
|
||||
attr_proto->set_f(0.f);
|
||||
attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
||||
attr_proto->set_name("max");
|
||||
attr_proto->set_f(6.f);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
|
||||
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
|
||||
auto w_shape = dyn_cast<abstract::Shape>(node->input(2)->Shape());
|
||||
MS_EXCEPTION_IF_NULL(x_shape);
|
||||
MS_EXCEPTION_IF_NULL(w_shape);
|
||||
if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d.";
|
||||
}
|
||||
if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) {
|
||||
MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape";
|
||||
}
|
||||
// create w_shape constant node
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
std::string name_w_shape = std::to_string(node_idx);
|
||||
node_proto->add_output(name_w_shape);
|
||||
node_proto->set_op_type("Constant");
|
||||
// create Value Tensor
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("value");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size()));
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||
// reshape
|
||||
tensor_proto->add_int64_data(w_shape->shape()[1]);
|
||||
tensor_proto->add_int64_data(w_shape->shape()[0]);
|
||||
tensor_proto->add_int64_data(w_shape->shape()[2]);
|
||||
tensor_proto->add_int64_data(w_shape->shape()[3]);
|
||||
|
||||
// add reshape node
|
||||
node_idx = AllocateNodeIndex();
|
||||
node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type(prim::kPrimReshape->name());
|
||||
node_proto->add_input(input_w);
|
||||
node_proto->add_input(name_w_shape);
|
||||
input_w = std::to_string(node_idx);
|
||||
node_proto->add_output(input_w);
|
||||
|
||||
// add conv node
|
||||
node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Conv");
|
||||
node_proto->add_input(input_x);
|
||||
node_proto->add_input(input_w);
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
// set attributes
|
||||
AnfNodePtr op = node->input(0);
|
||||
auto op_value = dyn_cast<ValueNode>(op);
|
||||
auto prim = dyn_cast<Primitive>(op_value->value());
|
||||
// set dilations
|
||||
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
|
||||
onnx_attr_proto->set_name("dilations");
|
||||
SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
|
||||
prim);
|
||||
// set group
|
||||
onnx_attr_proto = node_proto->add_attribute();
|
||||
onnx_attr_proto->set_name("group");
|
||||
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
onnx_attr_proto->set_i(x_shape->shape()[1]);
|
||||
// set kernel_shape
|
||||
onnx_attr_proto = node_proto->add_attribute();
|
||||
onnx_attr_proto->set_name("kernel_shape");
|
||||
SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
|
||||
prim);
|
||||
|
||||
// set pad
|
||||
onnx_attr_proto = node_proto->add_attribute();
|
||||
auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||
onnx_attr_proto->set_name("auto_pad");
|
||||
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
if (attr_value == "valid") {
|
||||
onnx_attr_proto->set_s("VALID");
|
||||
} else if (attr_value == "same") {
|
||||
onnx_attr_proto->set_s("SAME_UPPER");
|
||||
} else {
|
||||
onnx_attr_proto->set_name("pads");
|
||||
SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
||||
}
|
||||
// set strides
|
||||
onnx_attr_proto = node_proto->add_attribute();
|
||||
onnx_attr_proto->set_name("strides");
|
||||
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
||||
|
@ -732,6 +845,16 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x)
|
||||
if (node->IsApply(std::make_shared<Primitive>("ReLU6"))) {
|
||||
return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w))
|
||||
if (node->IsApply(std::make_shared<Primitive>("DepthwiseConv2dNative"))) {
|
||||
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
auto inputs = node->inputs();
|
||||
if (inputs.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
|
|
|
@ -89,7 +89,8 @@ class DepthwiseConv(nn.Cell):
|
|||
self.channel_multiplier = channel_multiplier
|
||||
self.out_channels = in_planes * channel_multiplier
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=stride, pad_mode=pad_mode, pad=pad)
|
||||
self.bias_add = P.BiasAdd()
|
||||
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
|
||||
|
|
|
@ -323,6 +323,22 @@ class BatchNormTester(nn.Cell):
|
|||
return self.bn(x)
|
||||
|
||||
|
||||
class DepthwiseConv2dAndReLU6(nn.Cell):
|
||||
"Net for testing DepthwiseConv2d and ReLU6"
|
||||
|
||||
def __init__(self, input_channel, kernel_size):
|
||||
super(DepthwiseConv2dAndReLU6, self).__init__()
|
||||
weight_shape = [1, input_channel, kernel_size, kernel_size]
|
||||
from mindspore.common.initializer import initializer
|
||||
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
|
||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size))
|
||||
self.relu6 = nn.ReLU6()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.depthwise_conv(x, self.weight)
|
||||
x = self.relu6(x)
|
||||
return x
|
||||
|
||||
def test_batchnorm_train_onnx_export():
|
||||
input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
|
||||
net = BatchNormTester(3)
|
||||
|
@ -421,6 +437,38 @@ def test_lenet5_onnx_load_run():
|
|||
print(outputs[0])
|
||||
|
||||
|
||||
@run_on_onnxruntime
|
||||
def test_depthwiseconv_relu6_onnx_load_run():
|
||||
onnx_file = 'depthwiseconv_relu6.onnx'
|
||||
input_channel = 3
|
||||
input = Tensor(np.ones([1, input_channel, 32, 32]).astype(np.float32) * 0.01)
|
||||
net = DepthwiseConv2dAndReLU6(input_channel, kernel_size=3)
|
||||
export(net, input, file_name=onnx_file, file_format='ONNX')
|
||||
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
print('--------------------- onnx load ---------------------')
|
||||
# Load the ONNX model
|
||||
model = onnx.load(onnx_file)
|
||||
# Check that the IR is well formed
|
||||
onnx.checker.check_model(model)
|
||||
# Print a human readable representation of the graph
|
||||
g = onnx.helper.printable_graph(model.graph)
|
||||
print(g)
|
||||
|
||||
print('------------------ onnxruntime run ------------------')
|
||||
ort_session = ort.InferenceSession(onnx_file)
|
||||
input_map = {'x' : input.asnumpy()}
|
||||
# provide only input x to run model
|
||||
outputs = ort_session.run(None, input_map)
|
||||
print(outputs[0])
|
||||
# overwrite default weight to run model
|
||||
for item in net.trainable_params():
|
||||
input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
|
||||
outputs = ort_session.run(None, input_map)
|
||||
print(outputs[0])
|
||||
|
||||
def teardown_module():
|
||||
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt']
|
||||
for item in files:
|
||||
|
|
Loading…
Reference in New Issue