forked from mindspore-Ecosystem/mindspore
access_control
This commit is contained in:
parent
86729985df
commit
763a7bd3aa
|
@ -19,6 +19,8 @@ train and infer lenet quantization network
|
||||||
import os
|
import os
|
||||||
import pytest
|
import pytest
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
@ -30,6 +32,7 @@ from dataset import create_dataset
|
||||||
from config import nonquant_cfg, quant_cfg
|
from config import nonquant_cfg, quant_cfg
|
||||||
from lenet import LeNet5
|
from lenet import LeNet5
|
||||||
from lenet_fusion import LeNet5 as LeNet5Fusion
|
from lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
device_target = 'GPU'
|
device_target = 'GPU'
|
||||||
data_path = "/home/workspace/mindspore_dataset/mnist"
|
data_path = "/home/workspace/mindspore_dataset/mnist"
|
||||||
|
@ -122,6 +125,19 @@ def eval_quant():
|
||||||
print("============== {} ==============".format(acc))
|
print("============== {} ==============".format(acc))
|
||||||
assert acc['Accuracy'] > 0.98
|
assert acc['Accuracy'] > 0.98
|
||||||
|
|
||||||
|
def export_lenet():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||||
|
cfg = quant_cfg
|
||||||
|
# define fusion network
|
||||||
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
# convert fusion network to quantization aware network
|
||||||
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
||||||
|
per_channel=[True, False], symmetric=[True, False])
|
||||||
|
|
||||||
|
# export network
|
||||||
|
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
|
||||||
|
quant.export(network, inputs, file_name="lenet_quant", file_format='MINDIR')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@ -130,6 +146,7 @@ def test_lenet_quant():
|
||||||
train_lenet()
|
train_lenet()
|
||||||
train_lenet_quant()
|
train_lenet_quant()
|
||||||
eval_quant()
|
eval_quant()
|
||||||
|
export_lenet()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue