diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 25f312d37db..49299245bdd 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -521,6 +521,8 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): """ logger.info("exporting model file:%s format:%s.", file_name, file_format) check_input_data(*inputs, data_class=Tensor) + if not isinstance(file_name, str): + raise ValueError("Args file_name {} must be string, please check it".format(file_name)) net = _quant_export(net, *inputs, file_format=file_format, **kwargs) _export(net, file_name, file_format, *inputs) @@ -549,11 +551,13 @@ def _export(net, file_name, file_format, *inputs): if file_format == 'AIR': phase_name = 'export.air' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) + file_name += ".air" _executor.export(file_name, graph_id) elif file_format == 'ONNX': # file_format is 'ONNX' phase_name = 'export.onnx' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(graph_id) + file_name += ".onnx" with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) @@ -561,6 +565,7 @@ def _export(net, file_name, file_format, *inputs): phase_name = 'export.mindir' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir') + file_name += ".mindir" with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) diff --git a/tests/st/export/test_export.py b/tests/st/export/test_export.py index 4734c8c0644..b0de52d2daa 100644 --- a/tests/st/export/test_export.py +++ b/tests/st/export/test_export.py @@ -329,7 +329,8 @@ def resnet50(num_classes): def test_export_resnet_air(): net = resnet50(10) inputs = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01) - file_name = "resnet.air" + file_name = "resnet" export(net, inputs, file_name=file_name, file_format='AIR') + file_name += ".air" assert os.path.exists(file_name) os.remove(file_name) diff --git a/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py b/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py index 6b4e1662fbf..f39a70a7ef2 100644 --- a/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py +++ b/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py @@ -45,12 +45,11 @@ def test_maskrcnn_export(): gt_mask = Tensor(np.zeros([bs, 128], np.bool)) input_data = [img, img_metas, gt_bboxes, gt_labels, gt_num, gt_mask] + export(net, *input_data, file_name="maskrcnn", file_format="AIR") file_name = "maskrcnn.air" - - export(net, *input_data, file_name=file_name, file_format="AIR") - assert os.path.exists(file_name) os.remove(file_name) + if __name__ == '__main__': test_maskrcnn_export() diff --git a/tests/st/serving/generate_model.py b/tests/st/serving/generate_model.py index 142ff91587f..162c1a4ea40 100644 --- a/tests/st/serving/generate_model.py +++ b/tests/st/serving/generate_model.py @@ -54,7 +54,7 @@ def export_bert_model(): input_mask = np.zeros((2, 32), dtype=np.int32) net = BertModel(bert_net_cfg, False) export(net, Tensor(input_ids), Tensor(segment_ids), Tensor(input_mask), - file_name='bert.mindir', file_format='MINDIR') + file_name='bert', file_format='MINDIR') if __name__ == '__main__': export_bert_model() diff --git a/tests/ut/python/onnx/test_onnx.py b/tests/ut/python/onnx/test_onnx.py index 73dea6b6ffa..40e7abd3f6f 100644 --- a/tests/ut/python/onnx/test_onnx.py +++ b/tests/ut/python/onnx/test_onnx.py @@ -38,10 +38,6 @@ def is_enable_onnxruntime(): run_on_onnxruntime = pytest.mark.skipif(not is_enable_onnxruntime(), reason="Only support running on onnxruntime") -def setup_module(): - pass - - def teardown_module(): cur_dir = os.path.dirname(os.path.realpath(__file__)) for filename in os.listdir(cur_dir): @@ -52,7 +48,7 @@ def teardown_module(): class BatchNormTester(nn.Cell): - "used to test exporting network in training mode in onnx format" + """used to test exporting network in training mode in onnx format""" def __init__(self, num_features): super(BatchNormTester, self).__init__() @@ -63,21 +59,22 @@ class BatchNormTester(nn.Cell): def test_batchnorm_train_onnx_export(): - "test onnx export interface does not modify trainable flag of a network" + """test onnx export interface does not modify trainable flag of a network""" input_ = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01) net = BatchNormTester(3) net.set_train() if not net.training: raise ValueError('netowrk is not in training mode') - onnx_file = 'batch_norm.onnx' + onnx_file = 'batch_norm' export(net, input_, file_name=onnx_file, file_format='ONNX') if not net.training: raise ValueError('netowrk is not in training mode') - # check existence of exported onnx file and delete it - assert os.path.exists(onnx_file) - os.chmod(onnx_file, stat.S_IWRITE) - os.remove(onnx_file) + + file_name = "batch_norm.onnx" + assert os.path.exists(file_name) + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) class LeNet5(nn.Cell): @@ -127,8 +124,7 @@ class DefinedNet(nn.Cell): class DepthwiseConv2dAndReLU6(nn.Cell): - "Net for testing DepthwiseConv2d and ReLU6" - + """Net for testing DepthwiseConv2d and ReLU6""" def __init__(self, input_channel, kernel_size): super(DepthwiseConv2dAndReLU6, self).__init__() weight_shape = [1, input_channel, kernel_size, kernel_size] @@ -142,9 +138,9 @@ class DepthwiseConv2dAndReLU6(nn.Cell): x = self.relu6(x) return x + class DeepFMOpNet(nn.Cell): """Net definition with Gatherv2 and Tile and Square.""" - def __init__(self): super(DeepFMOpNet, self).__init__() self.gather = P.GatherV2() @@ -157,12 +153,11 @@ class DeepFMOpNet(nn.Cell): x = self.gather(x, y, 0) return x -# generate mindspore Tensor by shape and numpy datatype + def gen_tensor(shape, dtype=np.float32): return Tensor(np.ones(shape).astype(dtype)) -# ut configs in triple: (ut_name, network, network-input) net_cfgs = [ ('lenet', LeNet5(), gen_tensor([1, 1, 32, 32])), ('maxpoolwithargmax', DefinedNet(), gen_tensor([1, 3, 224, 224])), @@ -179,23 +174,21 @@ def get_id(cfg): # use `pytest test_onnx.py::test_onnx_export[name]` or `pytest test_onnx.py::test_onnx_export -k name` to run single ut @pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs)) def test_onnx_export(name, net, inp): - onnx_file = name + ".onnx" if isinstance(inp, (tuple, list)): - export(net, *inp, file_name=onnx_file, file_format='ONNX') + export(net, *inp, file_name=name, file_format='ONNX') else: - export(net, inp, file_name=onnx_file, file_format='ONNX') + export(net, inp, file_name=name, file_format='ONNX') - # check existence of exported onnx file and delete it - assert os.path.exists(onnx_file) - os.chmod(onnx_file, stat.S_IWRITE) - os.remove(onnx_file) + file_file = name + ".onnx" + assert os.path.exists(file_file) + os.chmod(file_file, stat.S_IWRITE) + os.remove(file_file) @run_on_onnxruntime @pytest.mark.parametrize('name, net, inp', net_cfgs, ids=get_id(net_cfgs)) def test_onnx_export_load_run(name, net, inp): - onnx_file = name + ".onnx" - export(net, inp, file_name=onnx_file, file_format='ONNX') + export(net, inp, file_name=name, file_format='ONNX') import onnx import onnxruntime as ort @@ -222,7 +215,7 @@ def test_onnx_export_load_run(name, net, inp): outputs = ort_session.run(None, input_map) print(outputs[0]) - # check existence of exported onnx file and delete it - assert os.path.exists(onnx_file) - os.chmod(onnx_file, stat.S_IWRITE) - os.remove(onnx_file) + file_name = name + ".onnx" + assert os.path.exists(file_name) + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) diff --git a/tests/ut/python/utils/test_export.py b/tests/ut/python/utils/test_export.py index 199e0838cc0..dfbffe3c72f 100644 --- a/tests/ut/python/utils/test_export.py +++ b/tests/ut/python/utils/test_export.py @@ -91,7 +91,8 @@ def test_export_lenet_grad_mindir(): predict = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.zeros([32, 10]).astype(np.float32)) net = TrainOneStepCell(WithLossCell(network)) - file_name = "lenet_grad.mindir" + file_name = "lenet_grad" export(net, predict, label, file_name=file_name, file_format='MINDIR') - assert os.path.exists(file_name) - os.remove(file_name) + verify_name = file_name + ".mindir" + assert os.path.exists(verify_name) + os.remove(verify_name)