[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug

[ms][pynative][lenet]fix training bug
This commit is contained in:
xiongkun 2022-03-25 11:18:39 +08:00
parent 53aa304232
commit c4f7aad9db
3 changed files with 51 additions and 8 deletions

View File

@ -144,7 +144,7 @@ def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
shape_c = [1] * len(x_shape)
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis_ == 1:
if shape_c[channel_axis_] != x_shape[channel_axis_]:
shape_c = min_val.get("shape")
return x_shape, shape_c, x_dtype

View File

@ -182,3 +182,18 @@ def test_lenet_quant_ascend():
train_lenet_quant(optim_option="LEARNED_SCALE")
eval_quant(optim_option="LEARNED_SCALE")
export_lenet(optim_option="LEARNED_SCALE", file_format="AIR")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet_quant_ascend_pynative():
"""
test_lenet_quant_ascend_pynative
Features: test_lenet_quant_ascend_pynative
Description: test_lenet_quant_ascend_pynative pynative mode
Expectation: None
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
train_lenet_quant(optim_option="QAT")

View File

@ -55,13 +55,8 @@ config_ascend_quant = ed({
dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_mobilenetv2_quant():
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def train():
"""train"""
config = config_ascend_quant
print("training configure: {}".format(config))
@ -121,5 +116,38 @@ def test_mobilenetv2_quant():
assert avg_step_loss < expect_avg_step_loss
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_mobilenetv2_quant():
"""
test_mobilenetv2_quant
Features: test_mobilenetv2_quant
Description: test_mobilenetv2_quant graph mode
Expectation: None
"""
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
train()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_mobilenetv2_quant_pynative():
"""
test_mobilenetv2_quant_pynative
Features: test_mobilenetv2_quant_pynative
Description: test_mobilenetv2_quant_pynative pynative mode
Expectation: None
"""
set_seed(1)
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
train()
if __name__ == '__main__':
test_mobilenetv2_quant()
test_mobilenetv2_quant_pynative()