fix bugs in bert example script

This commit is contained in:
yoonlee666 2020-04-14 10:18:31 +08:00
parent 03c2555977
commit 39b446524d
1 changed files with 4 additions and 3 deletions

View File

@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
from mindspore import context from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
@ -49,9 +50,9 @@ def create_train_dataset(batch_size):
"""create train dataset""" """create train dataset"""
# apply repeat operations # apply repeat operations
repeat_count = bert_train_cfg.epoch_size repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)