fix modelzoo network tinybert accuracy waving

This commit is contained in:
anzhengqi 2021-07-26 11:00:27 +08:00
parent cdc48dfdb4
commit f45d499cbf
1 changed files with 4 additions and 1 deletions

View File

@ -19,7 +19,9 @@ import os
import time
import re
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.callback import TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
@ -43,7 +45,8 @@ if not os.path.exists(td_phase1_save_ckpt_dir):
if not os.path.exists(td_phase2_save_ckpt_dir):
os.makedirs(td_phase2_save_ckpt_dir)
enable_loss_scale = True
set_seed(123)
ds.config.set_seed(12345)
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD