forked from mindspore-Ecosystem/mindspore
fix bugs in bert example script
This commit is contained in:
parent
03c2555977
commit
39b446524d
|
@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de
|
|||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
|
||||
|
@ -49,9 +50,9 @@ def create_train_dataset(batch_size):
|
|||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
repeat_count = bert_train_cfg.epoch_size
|
||||
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
|
||||
ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
|
|
Loading…
Reference in New Issue