forked from mindspore-Ecosystem/mindspore
!9611 add export file name check
From: @changzherui Reviewed-by: @kingxian Signed-off-by: @kingxian
This commit is contained in:
commit
9af8abaab9
|
@ -427,6 +427,16 @@ class Validator:
|
|||
target, prim_name, reg, flag))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
||||
if reg is None:
|
||||
reg = r"^[0-9a-zA-Z\_\.\/\\]*$"
|
||||
if re.match(reg, target, flag) is None:
|
||||
prim_name = f'in `{prim_name}`' if prim_name else ""
|
||||
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
|
||||
target, prim_name, reg, flag))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_pad_value_by_mode(pad_mode, padding, prim_name):
|
||||
"""Validates value of padding according to pad_mode"""
|
||||
|
|
|
@ -530,6 +530,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|||
if not isinstance(file_name, str):
|
||||
raise ValueError("Args file_name {} must be string, please check it".format(file_name))
|
||||
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
||||
_export(net, file_name, file_format, *inputs)
|
||||
|
||||
|
@ -552,14 +553,14 @@ def _export(net, file_name, file_format, *inputs):
|
|||
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
|
||||
if is_dump_onnx_in_training:
|
||||
net.set_train(mode=False)
|
||||
# export model
|
||||
|
||||
net.init_parameters_data()
|
||||
if file_format == 'AIR':
|
||||
phase_name = 'export.air'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
||||
file_name += ".air"
|
||||
_executor.export(file_name, graph_id)
|
||||
elif file_format == 'ONNX': # file_format is 'ONNX'
|
||||
elif file_format == 'ONNX':
|
||||
phase_name = 'export.onnx'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
|
||||
|
@ -567,7 +568,7 @@ def _export(net, file_name, file_format, *inputs):
|
|||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
f.write(onnx_stream)
|
||||
elif file_format == 'MINDIR': # file_format is 'MINDIR'
|
||||
elif file_format == 'MINDIR':
|
||||
phase_name = 'export.mindir'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
||||
|
@ -575,7 +576,7 @@ def _export(net, file_name, file_format, *inputs):
|
|||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
f.write(onnx_stream)
|
||||
# restore network training mode
|
||||
|
||||
if is_dump_onnx_in_training:
|
||||
net.set_train(mode=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue