forked from mindspore-Ecosystem/mindspore
add onnx ops for deepfm
This commit is contained in:
parent
a5161a969f
commit
9c5f6b9198
|
@ -249,6 +249,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
|||
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo())
|
||||
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
|
||||
|
||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||
|
||||
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||
|
@ -269,6 +276,12 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
|||
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
|
||||
|
||||
fn(OP_CONVERT_FUNCTION_NAME(make_tuple)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Concat)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
|
||||
fn(OP_CONVERT_FUNCTION_NAME(Sub)());
|
||||
}
|
||||
|
||||
class OpConvertRegistry {
|
||||
|
@ -325,8 +338,8 @@ class OnnxExporter {
|
|||
|
||||
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
|
@ -335,6 +348,12 @@ class OnnxExporter {
|
|||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||
|
||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *graph_proto);
|
||||
|
@ -628,16 +647,19 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const
|
|||
node_proto->add_input(name_shape);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||
onnx::GraphProto *const graph_proto) {
|
||||
void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto input_axis = node->input(2);
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type(prim::kPrimReduceMean->name());
|
||||
auto name = prim::kPrimReduceMean->name();
|
||||
if (node->IsApply(prim::kPrimReduceSum)) {
|
||||
name = prim::kPrimReduceSum->name();
|
||||
}
|
||||
node_proto->set_op_type(name);
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(input_data);
|
||||
|
||||
|
@ -646,13 +668,18 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, con
|
|||
attr_proto->set_name("axes");
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
||||
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
|
||||
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
|
||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
|
||||
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i]));
|
||||
auto int_ptr = dyn_cast<Int32Imm>(axis_value);
|
||||
if (int_ptr == nullptr) {
|
||||
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
|
||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
|
||||
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i]));
|
||||
}
|
||||
} else {
|
||||
attr_proto->add_ints(int_ptr->value());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for ReduceMean.";
|
||||
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -826,6 +853,83 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/
|
|||
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto multiples = node->input(2);
|
||||
std::string name_multiples;
|
||||
if (multiples->isa<ValueNode>()) {
|
||||
auto const_node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[multiples] = const_node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
name_multiples = std::to_string(const_node_idx);
|
||||
node_proto->add_output(name_multiples);
|
||||
|
||||
node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("repeat");
|
||||
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
ConvertTupleToTensor(dyn_cast<ValueNode>(multiples)->value(), attr_proto->mutable_t());
|
||||
} else {
|
||||
name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto);
|
||||
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile.";
|
||||
}
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Tile");
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(name_x);
|
||||
node_proto->add_input(name_multiples);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
std::string name_exponent;
|
||||
auto const_node_idx = AllocateNodeIndex();
|
||||
onnx::NodeProto *node_proto_exp = graph_proto->add_node();
|
||||
name_exponent = std::to_string(const_node_idx);
|
||||
node_proto_exp->add_output(name_exponent);
|
||||
|
||||
node_proto_exp->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute();
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
tensor_proto->set_name("exponent");
|
||||
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1));
|
||||
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||
tensor_proto->add_int64_data(2);
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Pow");
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(name_x);
|
||||
node_proto->add_input(name_exponent);
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||
auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
|
||||
auto axis = node->input(3)->cast<ValueNodePtr>()->value();
|
||||
|
||||
auto node_idx = AllocateNodeIndex();
|
||||
(*node_map_ptr)[node] = node_idx;
|
||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||
node_proto->set_op_type("Gather");
|
||||
node_proto->add_output(std::to_string(node_idx));
|
||||
node_proto->add_input(name_x);
|
||||
node_proto->add_input(name_indices);
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast<Int32Imm>(axis)->value()));
|
||||
}
|
||||
|
||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
||||
|
@ -833,8 +937,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
if (node->IsApply(prim::kPrimReduceMean)) {
|
||||
return ExportPrimReduceMean(func_graph, node, node_map_ptr, graph_proto);
|
||||
if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) {
|
||||
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
|
||||
|
@ -857,6 +961,21 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
|||
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore Tile(x) --> ONNX Tile(x, repeat)
|
||||
if (node->IsApply(prim::kPrimTile)) {
|
||||
return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore Square(x) --> ONNX Pow(x, 2)
|
||||
if (node->IsApply(prim::kPrimSquare)) {
|
||||
return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices)
|
||||
if (node->IsApply(prim::kPrimGatherV2)) {
|
||||
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto);
|
||||
}
|
||||
|
||||
auto inputs = node->inputs();
|
||||
if (inputs.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
|
@ -1054,7 +1173,30 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons
|
|||
node_proto->set_op_type("Constant");
|
||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name("value");
|
||||
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
|
||||
if (value->isa<Int32Imm>()) {
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
auto casted_value = dyn_cast<Int32Imm>(value);
|
||||
if (casted_value == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
|
||||
}
|
||||
auto attr_value = casted_value->value();
|
||||
attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value));
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||
auto data = dyn_cast<tensor::Tensor>(value);
|
||||
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
|
||||
auto dtype = data->data_type();
|
||||
auto shape = data->shape_c();
|
||||
|
||||
tensor_proto->set_data_type(GetOnnxDataType(dtype));
|
||||
for (const auto &dim : shape) {
|
||||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
|
||||
|
|
|
@ -142,6 +142,20 @@ class DepthwiseConv2dAndReLU6(nn.Cell):
|
|||
x = self.relu6(x)
|
||||
return x
|
||||
|
||||
class DeepFMOpNet(nn.Cell):
|
||||
"""Net definition with Gatherv2 and Tile and Square."""
|
||||
|
||||
def __init__(self):
|
||||
super(DeepFMOpNet, self).__init__()
|
||||
self.gather = P.GatherV2()
|
||||
self.square = P.Square()
|
||||
self.tile = P.Tile()
|
||||
|
||||
def construct(self, x, y):
|
||||
x = self.tile(x, (1000, 1))
|
||||
x = self.square(x)
|
||||
x = self.gather(x, y, 0)
|
||||
return x
|
||||
|
||||
# generate mindspore Tensor by shape and numpy datatype
|
||||
def gen_tensor(shape, dtype=np.float32):
|
||||
|
@ -153,6 +167,7 @@ net_cfgs = [
|
|||
('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])),
|
||||
('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])),
|
||||
('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])),
|
||||
('deepfm_ops', DeepFMOpNet(), (gen_tensor([1, 1]), gen_tensor([1000, 1], dtype=np.int32)))
|
||||
]
|
||||
|
||||
|
||||
|
@ -164,7 +179,10 @@ def get_id(cfg):
|
|||
@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs))
|
||||
def test_onnx_export(name, net, inp):
|
||||
onnx_file = name + ".onnx"
|
||||
export(net, inp, file_name=onnx_file, file_format='ONNX')
|
||||
if isinstance(inp, (tuple, list)):
|
||||
export(net, *inp, file_name=onnx_file, file_format='ONNX')
|
||||
else:
|
||||
export(net, inp, file_name=onnx_file, file_format='ONNX')
|
||||
|
||||
# check existence of exported onnx file and delete it
|
||||
assert os.path.exists(onnx_file)
|
||||
|
|
Loading…
Reference in New Issue