!43397 Fix bug in msfunction exportation

Merge pull request !43397 from shaojunsong/fix/export
This commit is contained in:
i-robot 2022-10-10 01:50:42 +00:00 committed by Gitee
commit c39958e4f8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 8 additions and 5 deletions

View File

@ -1,5 +1,5 @@
mindspore.export
=================
================
.. py:function:: mindspore.export(net, *inputs, file_name, file_format, **kwargs)
@ -8,9 +8,11 @@ mindspore.export
.. note::
- 当导出文件格式为AIR、ONNX时单个Tensor的大小不能超过2GB。
- 当 `file_name` 没有后缀时,系统会根据 `file_format` 自动添加后缀。
- 现已支持将Mindspore function (ms_function) 导出成MINDIR格式文件。
- 当导出ms_function时函数内不能包含有类属性参与的计算。
参数:
- **net** (Cell) - MindSpore网络结构。
- **net** (Union[Cell, ms_function]) - MindSpore网络结构。
- **inputs** (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 `Dataset`将会把数据预处理行为同步保存起来。需要手动调整batch的大小当前仅支持获取 `Dataset``image` 列。
- **file_name** (str) - 导出模型的文件名称。
- **file_format** (str) - MindSpore目前支持导出"AIR""ONNX"和"MINDIR"格式的模型。

View File

@ -27,7 +27,6 @@ import threading
from threading import Thread, Lock
from collections import defaultdict, OrderedDict
from io import BytesIO
from inspect import isfunction
import math
import sys
@ -851,9 +850,11 @@ def export(net, *inputs, file_name, file_format, **kwargs):
Note:
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
3. Mindspore functions (ms_function) export as mindir format is enabled.
4. When export ms_function, the function should not involve class properties in calculations.
Args:
net (Cell): MindSpore network.
net (Union[Cell, ms_function]): MindSpore network.
inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
@ -1157,7 +1158,7 @@ def _cell_info(net, *inputs):
def _save_mindir(net, file_name, *inputs, **kwargs):
"""Save MindIR format file."""
model = mindir_model()
if isfunction(net):
if not isinstance(net, nn.Cell):
mindir_stream, net_dict = _msfunc_info(net, *inputs)
else:
mindir_stream, net_dict = _cell_info(net, *inputs)