forked from mindspore-Ecosystem/mindspore
!7247 export access control
Merge pull request !7247 from baiyangfan/access_control
This commit is contained in:
commit
186275a517
|
@ -19,6 +19,8 @@ train and infer lenet quantization network
|
|||
import os
|
||||
import pytest
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
@ -30,6 +32,7 @@ from dataset import create_dataset
|
|||
from config import nonquant_cfg, quant_cfg
|
||||
from lenet import LeNet5
|
||||
from lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
import numpy as np
|
||||
|
||||
device_target = 'GPU'
|
||||
data_path = "/home/workspace/mindspore_dataset/mnist"
|
||||
|
@ -122,6 +125,19 @@ def eval_quant():
|
|||
print("============== {} ==============".format(acc))
|
||||
assert acc['Accuracy'] > 0.98
|
||||
|
||||
def export_lenet():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
cfg = quant_cfg
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
||||
per_channel=[True, False], symmetric=[True, False])
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
|
||||
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
@ -130,6 +146,7 @@ def test_lenet_quant():
|
|||
train_lenet()
|
||||
train_lenet_quant()
|
||||
eval_quant()
|
||||
export_lenet()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue