forked from mindspore-Ecosystem/mindspore
[ms][pynative][lenet]fix training bug
[ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug [ms][pynative][lenet]fix training bug
This commit is contained in:
parent
53aa304232
commit
c4f7aad9db
|
@ -144,7 +144,7 @@ def fake_quant_perchannel_grad_param(x, min_val, max_val, channel_axis,
|
|||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
if shape_c[channel_axis_] != x_shape[channel_axis_]:
|
||||
shape_c = min_val.get("shape")
|
||||
return x_shape, shape_c, x_dtype
|
||||
|
||||
|
|
|
@ -182,3 +182,18 @@ def test_lenet_quant_ascend():
|
|||
train_lenet_quant(optim_option="LEARNED_SCALE")
|
||||
eval_quant(optim_option="LEARNED_SCALE")
|
||||
export_lenet(optim_option="LEARNED_SCALE", file_format="AIR")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet_quant_ascend_pynative():
|
||||
"""
|
||||
test_lenet_quant_ascend_pynative
|
||||
Features: test_lenet_quant_ascend_pynative
|
||||
Description: test_lenet_quant_ascend_pynative pynative mode
|
||||
Expectation: None
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
train_lenet_quant(optim_option="QAT")
|
||||
|
|
|
@ -55,13 +55,8 @@ config_ascend_quant = ed({
|
|||
dataset_path = "/home/workspace/mindspore_dataset/cifar-10-batches-bin/"
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_mobilenetv2_quant():
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
def train():
|
||||
"""train"""
|
||||
config = config_ascend_quant
|
||||
print("training configure: {}".format(config))
|
||||
|
||||
|
@ -121,5 +116,38 @@ def test_mobilenetv2_quant():
|
|||
assert avg_step_loss < expect_avg_step_loss
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_mobilenetv2_quant():
|
||||
"""
|
||||
test_mobilenetv2_quant
|
||||
Features: test_mobilenetv2_quant
|
||||
Description: test_mobilenetv2_quant graph mode
|
||||
Expectation: None
|
||||
"""
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
train()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_single
|
||||
def test_mobilenetv2_quant_pynative():
|
||||
"""
|
||||
test_mobilenetv2_quant_pynative
|
||||
Features: test_mobilenetv2_quant_pynative
|
||||
Description: test_mobilenetv2_quant_pynative pynative mode
|
||||
Expectation: None
|
||||
"""
|
||||
set_seed(1)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mobilenetv2_quant()
|
||||
test_mobilenetv2_quant_pynative()
|
||||
|
|
Loading…
Reference in New Issue