enable dynamic shape wenet testcase to level0

This commit is contained in:
zhengzuohe 2023-02-03 11:03:38 +08:00
parent ce53db6836
commit f851af4f80
1 changed files with 7 additions and 7 deletions

View File

@ -1752,11 +1752,11 @@ def get_train_loss(train_dataset, run_mode):
steps_size = train_dataset.get_dataset_size()
logging.warning("Training dataset has %d steps in each epoch.", steps_size)
# define network
net_with_loss = init_asr_model(mb, VOCAB_SIZE)
weights = ParameterTuple(net_with_loss.trainable_params())
logging.info("Total parameter of ASR model: %s.",
get_parameter_numel(net_with_loss))
# define wenet network
wenet_with_loss = init_asr_model(mb, VOCAB_SIZE)
weights = ParameterTuple(wenet_with_loss.trainable_params())
logging.info("Total parameter of WeNet-ASR model: %s.",
get_parameter_numel(wenet_with_loss))
lr_schedule = ASRWarmupLR(
learninig_rate=OPTIM_LR,
@ -1768,7 +1768,7 @@ def get_train_loss(train_dataset, run_mode):
loss_scale_value=1024, scale_factor=2, scale_window=1000)
train_net = TrainAccumulationAllReduceEachWithLossScaleCell(
net_with_loss, optimizer, update_cell, accumulation_steps=ACCUM_GRAD
wenet_with_loss, optimizer, update_cell, accumulation_steps=ACCUM_GRAD
)
callback = TimeMonitor(steps_size)
@ -1797,7 +1797,7 @@ def get_train_loss(train_dataset, run_mode):
return callback.loss
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard