forked from OSSInnovation/mindspore
!7444 mode_export_modelzoo
Merge pull request !7444 from baiyangfan/mode_export_modelzoo
This commit is contained in:
commit
8d4a127bd4
|
@ -29,10 +29,9 @@ from mindspore.common.initializer import initializer
|
|||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import check_input_data
|
||||
from mindspore._checkparam import check_input_data, Validator
|
||||
from mindspore.train.quant import quant
|
||||
import mindspore.context as context
|
||||
from .._checkparam import Validator
|
||||
|
||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
|
||||
"build_searched_strategy", "merge_sliced_parameter"]
|
||||
|
|
|
@ -23,7 +23,7 @@ import mindspore
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
|
@ -52,4 +52,4 @@ if __name__ == "__main__":
|
|||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="lenet_quant", file_format='AIR')
|
||||
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore.train.quant import quant
|
||||
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
|
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
|||
# export network
|
||||
print("============== Starting export ==============")
|
||||
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR')
|
||||
export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO')
|
||||
print("============== End export ==============")
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
|
|||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
|
@ -136,7 +136,7 @@ def export_lenet():
|
|||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
|
||||
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
|
||||
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue