forked from mindspore-Ecosystem/mindspore
!18026 add test_lenet_quant_ascend st
Merge pull request !18026 from Erpim/master
This commit is contained in:
commit
255959406a
|
@ -34,12 +34,10 @@ from config import quant_cfg
|
|||
from lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
import numpy as np
|
||||
|
||||
device_target = 'GPU'
|
||||
data_path = "/home/workspace/mindspore_dataset/mnist"
|
||||
lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt"
|
||||
|
||||
def train_lenet_quant(optim_option="QAT"):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
cfg = quant_cfg
|
||||
ckpt_path = lenet_ckpt_path
|
||||
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
|
||||
|
@ -90,7 +88,6 @@ def train_lenet_quant(optim_option="QAT"):
|
|||
|
||||
|
||||
def eval_quant(optim_option="QAT"):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
cfg = quant_cfg
|
||||
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
|
||||
ckpt_path = './ckpt_lenet_quant'+optim_option+'-10_937.ckpt'
|
||||
|
@ -135,8 +132,7 @@ def eval_quant(optim_option="QAT"):
|
|||
assert acc['Accuracy'] > 0.98
|
||||
|
||||
|
||||
def export_lenet(optim_option="QAT"):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
def export_lenet(optim_option="QAT", file_format="MINDIR"):
|
||||
cfg = quant_cfg
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
|
@ -161,13 +157,14 @@ def export_lenet(optim_option="QAT"):
|
|||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
|
||||
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO')
|
||||
export(network, inputs, file_name="lenet_quant", file_format=file_format, quant_mode='AUTO')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet_quant():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
train_lenet_quant()
|
||||
eval_quant()
|
||||
export_lenet()
|
||||
|
@ -176,5 +173,12 @@ def test_lenet_quant():
|
|||
export_lenet(optim_option="LEARNED_SCALE")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_lenet_quant()
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet_quant_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
train_lenet_quant(optim_option="LEARNED_SCALE")
|
||||
eval_quant(optim_option="LEARNED_SCALE")
|
||||
export_lenet(optim_option="LEARNED_SCALE", file_format="AIR")
|
||||
|
|
Loading…
Reference in New Issue