forked from mindspore-Ecosystem/mindspore
add onnx ops for deepfm
This commit is contained in:
parent
a5161a969f
commit
9c5f6b9198
|
@ -249,6 +249,13 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
||||||
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
.Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
|
||||||
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
.Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
|
||||||
|
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo())
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo())
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo())
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo())
|
||||||
|
OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
|
||||||
|
|
||||||
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
|
||||||
|
|
||||||
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||||
|
@ -269,6 +276,12 @@ void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
fn(OP_CONVERT_FUNCTION_NAME(Squeeze)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
|
fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
|
||||||
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
|
fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
|
||||||
|
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(make_tuple)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(Concat)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
|
||||||
|
fn(OP_CONVERT_FUNCTION_NAME(Sub)());
|
||||||
}
|
}
|
||||||
|
|
||||||
class OpConvertRegistry {
|
class OpConvertRegistry {
|
||||||
|
@ -325,8 +338,8 @@ class OnnxExporter {
|
||||||
|
|
||||||
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
|
@ -335,6 +348,12 @@ class OnnxExporter {
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
|
onnx::GraphProto *graph_proto);
|
||||||
|
void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto);
|
||||||
|
|
||||||
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr,
|
||||||
onnx::GraphProto *graph_proto);
|
onnx::GraphProto *graph_proto);
|
||||||
|
@ -628,16 +647,19 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const
|
||||||
node_proto->add_input(name_shape);
|
node_proto->add_input(name_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr,
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
onnx::GraphProto *const graph_proto) {
|
|
||||||
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
auto input_axis = node->input(2);
|
auto input_axis = node->input(2);
|
||||||
|
|
||||||
auto node_idx = AllocateNodeIndex();
|
auto node_idx = AllocateNodeIndex();
|
||||||
(*node_map_ptr)[node] = node_idx;
|
(*node_map_ptr)[node] = node_idx;
|
||||||
onnx::NodeProto *node_proto = graph_proto->add_node();
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
node_proto->set_op_type(prim::kPrimReduceMean->name());
|
auto name = prim::kPrimReduceMean->name();
|
||||||
|
if (node->IsApply(prim::kPrimReduceSum)) {
|
||||||
|
name = prim::kPrimReduceSum->name();
|
||||||
|
}
|
||||||
|
node_proto->set_op_type(name);
|
||||||
node_proto->add_output(std::to_string(node_idx));
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
node_proto->add_input(input_data);
|
node_proto->add_input(input_data);
|
||||||
|
|
||||||
|
@ -646,13 +668,18 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, con
|
||||||
attr_proto->set_name("axes");
|
attr_proto->set_name("axes");
|
||||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
|
||||||
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
|
auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
|
||||||
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
|
auto int_ptr = dyn_cast<Int32Imm>(axis_value);
|
||||||
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
if (int_ptr == nullptr) {
|
||||||
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
|
auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
|
||||||
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i]));
|
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||||
|
for (size_t i = 0; i < tuple_ptr->size(); ++i) {
|
||||||
|
attr_proto->add_ints(GetValue<int>((*tuple_ptr)[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
attr_proto->add_ints(int_ptr->value());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for ReduceMean.";
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -826,6 +853,83 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/
|
||||||
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
|
auto multiples = node->input(2);
|
||||||
|
std::string name_multiples;
|
||||||
|
if (multiples->isa<ValueNode>()) {
|
||||||
|
auto const_node_idx = AllocateNodeIndex();
|
||||||
|
(*node_map_ptr)[multiples] = const_node_idx;
|
||||||
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
name_multiples = std::to_string(const_node_idx);
|
||||||
|
node_proto->add_output(name_multiples);
|
||||||
|
|
||||||
|
node_proto->set_op_type("Constant");
|
||||||
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
|
attr_proto->set_name("repeat");
|
||||||
|
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||||
|
ConvertTupleToTensor(dyn_cast<ValueNode>(multiples)->value(), attr_proto->mutable_t());
|
||||||
|
} else {
|
||||||
|
name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto);
|
||||||
|
MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto node_idx = AllocateNodeIndex();
|
||||||
|
(*node_map_ptr)[node] = node_idx;
|
||||||
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
node_proto->set_op_type("Tile");
|
||||||
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
|
node_proto->add_input(name_x);
|
||||||
|
node_proto->add_input(name_multiples);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
|
std::string name_exponent;
|
||||||
|
auto const_node_idx = AllocateNodeIndex();
|
||||||
|
onnx::NodeProto *node_proto_exp = graph_proto->add_node();
|
||||||
|
name_exponent = std::to_string(const_node_idx);
|
||||||
|
node_proto_exp->add_output(name_exponent);
|
||||||
|
|
||||||
|
node_proto_exp->set_op_type("Constant");
|
||||||
|
onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute();
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||||
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||||
|
tensor_proto->set_name("exponent");
|
||||||
|
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1));
|
||||||
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
|
||||||
|
tensor_proto->add_int64_data(2);
|
||||||
|
|
||||||
|
auto node_idx = AllocateNodeIndex();
|
||||||
|
(*node_map_ptr)[node] = node_idx;
|
||||||
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
node_proto->set_op_type("Pow");
|
||||||
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
|
node_proto->add_input(name_x);
|
||||||
|
node_proto->add_input(name_exponent);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
|
auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto);
|
||||||
|
auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto);
|
||||||
|
auto axis = node->input(3)->cast<ValueNodePtr>()->value();
|
||||||
|
|
||||||
|
auto node_idx = AllocateNodeIndex();
|
||||||
|
(*node_map_ptr)[node] = node_idx;
|
||||||
|
onnx::NodeProto *node_proto = graph_proto->add_node();
|
||||||
|
node_proto->set_op_type("Gather");
|
||||||
|
node_proto->add_output(std::to_string(node_idx));
|
||||||
|
node_proto->add_input(name_x);
|
||||||
|
node_proto->add_input(name_indices);
|
||||||
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||||
|
attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast<Int32Imm>(axis)->value()));
|
||||||
|
}
|
||||||
|
|
||||||
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||||
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
|
||||||
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
|
||||||
|
@ -833,8 +937,8 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
||||||
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
|
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->IsApply(prim::kPrimReduceMean)) {
|
if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) {
|
||||||
return ExportPrimReduceMean(func_graph, node, node_map_ptr, graph_proto);
|
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
|
// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
|
||||||
|
@ -857,6 +961,21 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n
|
||||||
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
|
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MindSpore Tile(x) --> ONNX Tile(x, repeat)
|
||||||
|
if (node->IsApply(prim::kPrimTile)) {
|
||||||
|
return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MindSpore Square(x) --> ONNX Pow(x, 2)
|
||||||
|
if (node->IsApply(prim::kPrimSquare)) {
|
||||||
|
return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices)
|
||||||
|
if (node->IsApply(prim::kPrimGatherV2)) {
|
||||||
|
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto);
|
||||||
|
}
|
||||||
|
|
||||||
auto inputs = node->inputs();
|
auto inputs = node->inputs();
|
||||||
if (inputs.size() < 1) {
|
if (inputs.size() < 1) {
|
||||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||||
|
@ -1054,7 +1173,30 @@ void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *cons
|
||||||
node_proto->set_op_type("Constant");
|
node_proto->set_op_type("Constant");
|
||||||
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||||
attr_proto->set_name("value");
|
attr_proto->set_name("value");
|
||||||
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
|
if (value->isa<Int32Imm>()) {
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||||
|
auto casted_value = dyn_cast<Int32Imm>(value);
|
||||||
|
if (casted_value == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
|
||||||
|
}
|
||||||
|
auto attr_value = casted_value->value();
|
||||||
|
attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value));
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
|
||||||
|
} else if (value->isa<tensor::Tensor>()) {
|
||||||
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
|
||||||
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
|
||||||
|
auto data = dyn_cast<tensor::Tensor>(value);
|
||||||
|
tensor_proto->set_raw_data(data->data().request(true).ptr, static_cast<size_t>(data->data().nbytes()));
|
||||||
|
auto dtype = data->data_type();
|
||||||
|
auto shape = data->shape_c();
|
||||||
|
|
||||||
|
tensor_proto->set_data_type(GetOnnxDataType(dtype));
|
||||||
|
for (const auto &dim : shape) {
|
||||||
|
tensor_proto->add_dims(dim);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
|
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
|
||||||
|
|
|
@ -142,6 +142,20 @@ class DepthwiseConv2dAndReLU6(nn.Cell):
|
||||||
x = self.relu6(x)
|
x = self.relu6(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class DeepFMOpNet(nn.Cell):
|
||||||
|
"""Net definition with Gatherv2 and Tile and Square."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(DeepFMOpNet, self).__init__()
|
||||||
|
self.gather = P.GatherV2()
|
||||||
|
self.square = P.Square()
|
||||||
|
self.tile = P.Tile()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
x = self.tile(x, (1000, 1))
|
||||||
|
x = self.square(x)
|
||||||
|
x = self.gather(x, y, 0)
|
||||||
|
return x
|
||||||
|
|
||||||
# generate mindspore Tensor by shape and numpy datatype
|
# generate mindspore Tensor by shape and numpy datatype
|
||||||
def gen_tensor(shape, dtype=np.float32):
|
def gen_tensor(shape, dtype=np.float32):
|
||||||
|
@ -153,6 +167,7 @@ net_cfgs = [
|
||||||
('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])),
|
('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])),
|
||||||
('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])),
|
('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])),
|
||||||
('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])),
|
('depthwiseconv_relu6', DepthwiseConv2dAndReLU6(3, kernel_size=3), gen_tensor([1, 3, 32, 32])),
|
||||||
|
('deepfm_ops', DeepFMOpNet(), (gen_tensor([1, 1]), gen_tensor([1000, 1], dtype=np.int32)))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,7 +179,10 @@ def get_id(cfg):
|
||||||
@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs))
|
@pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs))
|
||||||
def test_onnx_export(name, net, inp):
|
def test_onnx_export(name, net, inp):
|
||||||
onnx_file = name + ".onnx"
|
onnx_file = name + ".onnx"
|
||||||
export(net, inp, file_name=onnx_file, file_format='ONNX')
|
if isinstance(inp, (tuple, list)):
|
||||||
|
export(net, *inp, file_name=onnx_file, file_format='ONNX')
|
||||||
|
else:
|
||||||
|
export(net, inp, file_name=onnx_file, file_format='ONNX')
|
||||||
|
|
||||||
# check existence of exported onnx file and delete it
|
# check existence of exported onnx file and delete it
|
||||||
assert os.path.exists(onnx_file)
|
assert os.path.exists(onnx_file)
|
||||||
|
|
Loading…
Reference in New Issue