forked from mindspore-Ecosystem/mindspore
fix ge dynamic shape input st
This commit is contained in:
parent
bfe7af1f8a
commit
1875ee3280
|
@ -147,13 +147,13 @@ def create_dataset(batch_size=32):
|
|||
class LossCallBack(LossMonitor):
|
||||
def __init__(self):
|
||||
super(LossCallBack, self).__init__()
|
||||
self.last_5_losses = []
|
||||
self.last_10_losses = []
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs
|
||||
loss = np.mean(loss.asnumpy())
|
||||
self.last_5_losses = self.last_5_losses[-4:] + [loss]
|
||||
self.last_10_losses = self.last_10_losses[-9:] + [loss]
|
||||
|
||||
|
||||
def train(batch_size, lr, momentum, epochs, dataset_sink_mode):
|
||||
|
@ -173,11 +173,11 @@ def train(batch_size, lr, momentum, epochs, dataset_sink_mode):
|
|||
model.train(epochs, dummy_dataset, callbacks=[loss_callback],
|
||||
sink_size=dummy_dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
avg_loss = np.mean(loss_callback.last_5_losses)
|
||||
avg_loss = np.min(loss_callback.last_10_losses)
|
||||
return avg_loss
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue