forked from mindspore-Ecosystem/mindspore
add onnx op converter: ReLU6, DepthwiseConv2dNative
This commit is contained in:
parent
0edf22e68c
commit
a109bc3d96
|
@ -330,6 +330,10 @@ class OnnxExporter {
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
|
onnx::GraphProto *graph_proto);
|
||||||
|
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
@ -711,6 +715,115 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN
|
||||||
node_proto->add_input(input_slope);
|
node_proto->add_input(input_slope);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
|
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
|
auto node_idx = AllocateNodeIndex();
|
||||||
|
(*node_map_ptr)[node] = node_idx;
|
||||||
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
node_proto->set_op_type("Clip");
|
||||||
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
|
node_proto->add_input(input_x);
|
||||||
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
||||||
|
attr_proto->set_name("min");
|
||||||
|
attr_proto->set_f(0.f);
|
||||||
|
attr_proto = node_proto->add_attribute();
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
||||||
|
attr_proto->set_name("max");
|
||||||
|
attr_proto->set_f(6.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
|
onnx::GraphProto *const graph_proto) {
|
||||||
|
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
|
auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
|
||||||
|
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape());
|
||||||
|
auto w_shape = dyn_cast<abstract::Shape>(node->input(2)->Shape());
|
||||||
|
MS_EXCEPTION_IF_NULL(x_shape);
|
||||||
|
MS_EXCEPTION_IF_NULL(w_shape);
|
||||||
|
if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) {
|
||||||
|
MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d.";
|
||||||
|
}
|
||||||
|
if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape";
|
||||||
|
}
|
||||||
|
// create w_shape constant node
|
||||||
|
auto node_idx = AllocateNodeIndex();
|
||||||
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
std::string name_w_shape = std::to_string(node_idx);
|
||||||
|
node_proto->add_output(name_w_shape);
|
||||||
|
node_proto->set_op_type("Constant");
|
||||||
|
// create Value Tensor
|
||||||
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
|
attr_proto->set_name("value");
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||||
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||||
|
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size()));
|
||||||
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||||
|
// reshape
|
||||||
|
tensor_proto->add_int64_data(w_shape->shape()[1]);
|
||||||
|
tensor_proto->add_int64_data(w_shape->shape()[0]);
|
||||||
|
tensor_proto->add_int64_data(w_shape->shape()[2]);
|
||||||
|
tensor_proto->add_int64_data(w_shape->shape()[3]);
|
||||||
|
|
||||||
|
// add reshape node
|
||||||
|
node_idx = AllocateNodeIndex();
|
||||||
|
node_proto = graph_proto->add_node();
|
||||||
|
node_proto->set_op_type(prim::kPrimReshape->name());
|
||||||
|
node_proto->add_input(input_w);
|
||||||
|
node_proto->add_input(name_w_shape);
|
||||||
|
input_w = std::to_string(node_idx);
|
||||||
|
node_proto->add_output(input_w);
|
||||||
|
|
||||||
|
// add conv node
|
||||||
|
node_idx = AllocateNodeIndex();
|
||||||
|
(*node_map_ptr)[node] = node_idx;
|
||||||
|
node_proto = graph_proto->add_node();
|
||||||
|
node_proto->set_op_type("Conv");
|
||||||
|
node_proto->add_input(input_x);
|
||||||
|
node_proto->add_input(input_w);
|
||||||
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
|
// set attributes
|
||||||
|
AnfNodePtr op = node->input(0);
|
||||||
|
auto op_value = dyn_cast<ValueNode>(op);
|
||||||
|
auto prim = dyn_cast<Primitive>(op_value->value());
|
||||||
|
// set dilations
|
||||||
|
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
|
||||||
|
onnx_attr_proto->set_name("dilations");
|
||||||
|
SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
|
||||||
|
prim);
|
||||||
|
// set group
|
||||||
|
onnx_attr_proto = node_proto->add_attribute();
|
||||||
|
onnx_attr_proto->set_name("group");
|
||||||
|
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||||
|
onnx_attr_proto->set_i(x_shape->shape()[1]);
|
||||||
|
// set kernel_shape
|
||||||
|
onnx_attr_proto = node_proto->add_attribute();
|
||||||
|
onnx_attr_proto->set_name("kernel_shape");
|
||||||
|
SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
|
||||||
|
prim);
|
||||||
|
|
||||||
|
// set pad
|
||||||
|
onnx_attr_proto = node_proto->add_attribute();
|
||||||
|
auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||||
|
onnx_attr_proto->set_name("auto_pad");
|
||||||
|
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||||
|
if (attr_value == "valid") {
|
||||||
|
onnx_attr_proto->set_s("VALID");
|
||||||
|
} else if (attr_value == "same") {
|
||||||
|
onnx_attr_proto->set_s("SAME_UPPER");
|
||||||
|
} else {
|
||||||
|
onnx_attr_proto->set_name("pads");
|
||||||
|
SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
||||||
|
}
|
||||||
|
// set strides
|
||||||
|
onnx_attr_proto = node_proto->add_attribute();
|
||||||
|
onnx_attr_proto->set_name("strides");
|
||||||
|
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
||||||
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
||||||
|
@ -732,6 +845,16 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
||||||
return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto);
|
return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x)
|
||||||
|
if (node->IsApply(std::make_shared<Primitive>("ReLU6"))) {
|
||||||
|
return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w))
|
||||||
|
if (node->IsApply(std::make_shared<Primitive>("DepthwiseConv2dNative"))) {
|
||||||
|
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
auto inputs = node->inputs();
|
auto inputs = node->inputs();
|
||||||
if (inputs.size() < 1) {
|
if (inputs.size() < 1) {
|
||||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||||
|
|
|
@ -89,7 +89,8 @@ class DepthwiseConv(nn.Cell):
|
||||||
self.channel_multiplier = channel_multiplier
|
self.channel_multiplier = channel_multiplier
|
||||||
self.out_channels = in_planes * channel_multiplier
|
self.out_channels = in_planes * channel_multiplier
|
||||||
self.kernel_size = (kernel_size, kernel_size)
|
self.kernel_size = (kernel_size, kernel_size)
|
||||||
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size,
|
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
stride=stride, pad_mode=pad_mode, pad=pad)
|
stride=stride, pad_mode=pad_mode, pad=pad)
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
|
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
|
||||||
|
|
|
@ -323,6 +323,22 @@ class BatchNormTester(nn.Cell):
|
||||||
return self.bn(x)
|
return self.bn(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseConv2dAndReLU6(nn.Cell):
|
||||||
|
"Net for testing DepthwiseConv2d and ReLU6"
|
||||||
|
|
||||||
|
def __init__(self, input_channel, kernel_size):
|
||||||
|
super(DepthwiseConv2dAndReLU6, self).__init__()
|
||||||
|
weight_shape = [1, input_channel, kernel_size, kernel_size]
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
|
||||||
|
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size))
|
||||||
|
self.relu6 = nn.ReLU6()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.depthwise_conv(x, self.weight)
|
||||||
|
x = self.relu6(x)
|
||||||
|
return x
|
||||||
|
|
||||||
def test_batchnorm_train_onnx_export():
|
def test_batchnorm_train_onnx_export():
|
||||||
input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
|
input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
|
||||||
net = BatchNormTester(3)
|
net = BatchNormTester(3)
|
||||||
|
@ -421,6 +437,38 @@ def test_lenet5_onnx_load_run():
|
||||||
print(outputs[0])
|
print(outputs[0])
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_onnxruntime
|
||||||
|
def test_depthwiseconv_relu6_onnx_load_run():
|
||||||
|
onnx_file = 'depthwiseconv_relu6.onnx'
|
||||||
|
input_channel = 3
|
||||||
|
input = Tensor(np.ones([1, input_channel, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
net = DepthwiseConv2dAndReLU6(input_channel, kernel_size=3)
|
||||||
|
export(net, input, file_name=onnx_file, file_format='ONNX')
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
print('--------------------- onnx load ---------------------')
|
||||||
|
# Load the ONNX model
|
||||||
|
model = onnx.load(onnx_file)
|
||||||
|
# Check that the IR is well formed
|
||||||
|
onnx.checker.check_model(model)
|
||||||
|
# Print a human readable representation of the graph
|
||||||
|
g = onnx.helper.printable_graph(model.graph)
|
||||||
|
print(g)
|
||||||
|
|
||||||
|
print('------------------ onnxruntime run ------------------')
|
||||||
|
ort_session = ort.InferenceSession(onnx_file)
|
||||||
|
input_map = {'x' : input.asnumpy()}
|
||||||
|
# provide only input x to run model
|
||||||
|
outputs = ort_session.run(None, input_map)
|
||||||
|
print(outputs[0])
|
||||||
|
# overwrite default weight to run model
|
||||||
|
for item in net.trainable_params():
|
||||||
|
input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
|
||||||
|
outputs = ort_session.run(None, input_map)
|
||||||
|
print(outputs[0])
|
||||||
|
|
||||||
def teardown_module():
|
def teardown_module():
|
||||||
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt']
|
files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt']
|
||||||
for item in files:
|
for item in files:
|
||||||
|
|
Loading…
Reference in New Issue