forked from mindspore-Ecosystem/mindspore
modify repeatedly export quant net mindir
This commit is contained in:
parent
ccf5b3d808
commit
c1ada6a3e5
|
@ -18,6 +18,8 @@ import sys
|
|||
import stat
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
import copy
|
||||
from threading import Thread, Lock
|
||||
import numpy as np
|
||||
|
||||
|
@ -756,6 +758,9 @@ def _quant_export(network, *inputs, file_format, **kwargs):
|
|||
supported_formats = ['AIR', 'MINDIR']
|
||||
quant_mode_formats = ['AUTO', 'MANUAL']
|
||||
|
||||
quant_net = copy.deepcopy(network)
|
||||
quant_net._create_time = int(time.time() * 1e9)
|
||||
|
||||
mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
|
||||
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']
|
||||
|
||||
|
@ -772,17 +777,17 @@ def _quant_export(network, *inputs, file_format, **kwargs):
|
|||
if file_format not in supported_formats:
|
||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||
|
||||
network.set_train(False)
|
||||
quant_net.set_train(False)
|
||||
if file_format == "MINDIR":
|
||||
if quant_mode == 'MANUAL':
|
||||
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
if quant_mode == 'MANUAL':
|
||||
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
|
||||
exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs)
|
||||
else:
|
||||
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
|
||||
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
return deploy_net
|
||||
|
||||
|
|
Loading…
Reference in New Issue