!18026 add test_lenet_quant_ascend st

Merge pull request !18026 from Erpim/master
This commit is contained in:
i-robot 2021-06-09 14:20:43 +08:00 committed by Gitee
commit 255959406a
1 changed files with 12 additions and 8 deletions

View File

@ -34,12 +34,10 @@ from config import quant_cfg
from lenet_fusion import LeNet5 as LeNet5Fusion
import numpy as np
device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt"
def train_lenet_quant(optim_option="QAT"):
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
ckpt_path = lenet_ckpt_path
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
@ -90,7 +88,6 @@ def train_lenet_quant(optim_option="QAT"):
def eval_quant(optim_option="QAT"):
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
ckpt_path = './ckpt_lenet_quant'+optim_option+'-10_937.ckpt'
@ -135,8 +132,7 @@ def eval_quant(optim_option="QAT"):
assert acc['Accuracy'] > 0.98
def export_lenet(optim_option="QAT"):
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
def export_lenet(optim_option="QAT", file_format="MINDIR"):
cfg = quant_cfg
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
@ -161,13 +157,14 @@ def export_lenet(optim_option="QAT"):
# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')
export(network, inputs, file_name="lenet_quant", file_format=file_format, quant_mode='AUTO')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_lenet_quant():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
train_lenet_quant()
eval_quant()
export_lenet()
@ -176,5 +173,12 @@ def test_lenet_quant():
export_lenet(optim_option="LEARNED_SCALE")
if __name__ == "__main__":
train_lenet_quant()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet_quant_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
train_lenet_quant(optim_option="LEARNED_SCALE")
eval_quant(optim_option="LEARNED_SCALE")
export_lenet(optim_option="LEARNED_SCALE", file_format="AIR")