forked from mindspore-Ecosystem/mindspore
!12216 modify transformer st
From: @yuchaojie Reviewed-by: @xsmq,@linqingke Signed-off-by: @linqingke
This commit is contained in:
commit
c2b5490375
|
@ -129,7 +129,7 @@ class TimeMonitor(Callback):
|
|||
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
||||
|
||||
|
||||
# @pytest.mark.level0
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -171,10 +171,10 @@ def test_transformer():
|
|||
|
||||
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
|
||||
loss_value = np.array(callback.loss_list)
|
||||
assert np.allclose(loss_value[0], 11.241624, 0, 0.000005)
|
||||
assert np.allclose(loss_value[0], 11.241606, 0, 0.000005)
|
||||
|
||||
expect_loss_value = [11.241624, 11.243232, 11.217465, 11.204196, 11.2138195,
|
||||
11.215386, 11.19053, 11.150403, 11.191858, 11.160057]
|
||||
expect_loss_value = [11.241606, 11.243232, 11.217459, 11.204157, 11.213804,
|
||||
11.215373, 11.190564, 11.150393, 11.191823, 11.160045]
|
||||
|
||||
print("loss value: {}".format(loss_value))
|
||||
assert np.allclose(loss_value[0:10], expect_loss_value, 0, 0.0005)
|
||||
|
@ -191,12 +191,12 @@ def test_transformer():
|
|||
assert np.allclose(loss_scale[0:10], expect_loss_scale, 0, 0)
|
||||
|
||||
epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
|
||||
expect_epoch_mseconds = 3180
|
||||
expect_epoch_mseconds = 2400
|
||||
print("epoch mseconds: {}".format(epoch_mseconds))
|
||||
assert epoch_mseconds <= expect_epoch_mseconds + 20
|
||||
|
||||
per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
|
||||
expect_per_step_mseconds = 318
|
||||
expect_per_step_mseconds = 240
|
||||
print("per step mseconds: {}".format(per_step_mseconds))
|
||||
assert per_step_mseconds <= expect_per_step_mseconds + 2
|
||||
|
||||
|
|
Loading…
Reference in New Issue