forked from mindspore-Ecosystem/mindspore
enable dynamic shape wenet testcase to level0
This commit is contained in:
parent
ce53db6836
commit
f851af4f80
|
@ -1752,11 +1752,11 @@ def get_train_loss(train_dataset, run_mode):
|
|||
steps_size = train_dataset.get_dataset_size()
|
||||
logging.warning("Training dataset has %d steps in each epoch.", steps_size)
|
||||
|
||||
# define network
|
||||
net_with_loss = init_asr_model(mb, VOCAB_SIZE)
|
||||
weights = ParameterTuple(net_with_loss.trainable_params())
|
||||
logging.info("Total parameter of ASR model: %s.",
|
||||
get_parameter_numel(net_with_loss))
|
||||
# define wenet network
|
||||
wenet_with_loss = init_asr_model(mb, VOCAB_SIZE)
|
||||
weights = ParameterTuple(wenet_with_loss.trainable_params())
|
||||
logging.info("Total parameter of WeNet-ASR model: %s.",
|
||||
get_parameter_numel(wenet_with_loss))
|
||||
|
||||
lr_schedule = ASRWarmupLR(
|
||||
learninig_rate=OPTIM_LR,
|
||||
|
@ -1768,7 +1768,7 @@ def get_train_loss(train_dataset, run_mode):
|
|||
loss_scale_value=1024, scale_factor=2, scale_window=1000)
|
||||
|
||||
train_net = TrainAccumulationAllReduceEachWithLossScaleCell(
|
||||
net_with_loss, optimizer, update_cell, accumulation_steps=ACCUM_GRAD
|
||||
wenet_with_loss, optimizer, update_cell, accumulation_steps=ACCUM_GRAD
|
||||
)
|
||||
|
||||
callback = TimeMonitor(steps_size)
|
||||
|
@ -1797,7 +1797,7 @@ def get_train_loss(train_dataset, run_mode):
|
|||
return callback.loss
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue