!7317 update transformer scripts

Merge pull request !7317 from panfengfeng/update_transformer_scripts
This commit is contained in:
mindspore-ci-bot 2020-10-15 09:24:14 +08:00 committed by Gitee
commit 839b4eb486
2 changed files with 9 additions and 4 deletions

View File

@ -17,10 +17,10 @@
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC
from .config import transformer_net_cfg
from .config import transformer_net_cfg, transformer_net_cfg_gpu
de.config.set_seed(1)
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None,
bucket_boundaries=None):
bucket_boundaries=None, device_target="Ascend"):
"""create dataset"""
def batch_per_bucket(bucket_len, dataset_path):
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
@ -38,7 +38,11 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
# apply batch operations
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
if device_target == "Ascend":
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
else:
ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True)
ds = ds.repeat(epoch_count)
return ds

View File

@ -146,7 +146,8 @@ def run_transformer_train():
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle,
dataset_path=args.data_path,
bucket_boundaries=args.bucket_boundaries)
bucket_boundaries=args.bucket_boundaries,
device_target=args.device_target)
if args.device_target == "Ascend":
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
else: