access_control

This commit is contained in:
bai-yangfan 2020-10-13 17:14:55 +08:00
parent 86729985df
commit 763a7bd3aa
1 changed files with 17 additions and 0 deletions

View File

@ -19,6 +19,8 @@ train and infer lenet quantization network
import os
import pytest
from mindspore import context
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@ -30,6 +32,7 @@ from dataset import create_dataset
from config import nonquant_cfg, quant_cfg
from lenet import LeNet5
from lenet_fusion import LeNet5 as LeNet5Fusion
import numpy as np
device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
@ -122,6 +125,19 @@ def eval_quant():
print("============== {} ==============".format(acc))
assert acc['Accuracy'] > 0.98
def export_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
per_channel=[True, False], symmetric=[True, False])
# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -130,6 +146,7 @@ def test_lenet_quant():
train_lenet()
train_lenet_quant()
eval_quant()
export_lenet()
if __name__ == "__main__":