!7247 export access control

Merge pull request !7247 from baiyangfan/access_control
This commit is contained in:
mindspore-ci-bot 2020-10-13 21:15:26 +08:00 committed by Gitee
commit 186275a517
1 changed files with 17 additions and 0 deletions

View File

@ -19,6 +19,8 @@ train and infer lenet quantization network
import os
import pytest
from mindspore import context
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@ -30,6 +32,7 @@ from dataset import create_dataset
from config import nonquant_cfg, quant_cfg
from lenet import LeNet5
from lenet_fusion import LeNet5 as LeNet5Fusion
import numpy as np
device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
@ -122,6 +125,19 @@ def eval_quant():
print("============== {} ==============".format(acc))
assert acc['Accuracy'] > 0.98
def export_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False], symmetric=[True, False])
# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -130,6 +146,7 @@ def test_lenet_quant():
train_lenet()
train_lenet_quant()
eval_quant()
export_lenet()
if __name__ == "__main__":