forked from mindspore-Ecosystem/mindspore
!7317 update transformer scripts
Merge pull request !7317 from panfengfeng/update_transformer_scripts
This commit is contained in:
commit
839b4eb486
|
@ -17,10 +17,10 @@
|
|||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
from .config import transformer_net_cfg
|
||||
from .config import transformer_net_cfg, transformer_net_cfg_gpu
|
||||
de.config.set_seed(1)
|
||||
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", dataset_path=None,
|
||||
bucket_boundaries=None):
|
||||
bucket_boundaries=None, device_target="Ascend"):
|
||||
"""create dataset"""
|
||||
def batch_per_bucket(bucket_len, dataset_path):
|
||||
dataset_path = dataset_path + "_" + str(bucket_len) + "_00"
|
||||
|
@ -38,7 +38,11 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
|
|||
ds = ds.map(operations=type_cast_op, input_columns="target_eos_mask")
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
||||
if device_target == "Ascend":
|
||||
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
|
||||
else:
|
||||
ds = ds.batch(transformer_net_cfg_gpu.batch_size, drop_remainder=True)
|
||||
|
||||
ds = ds.repeat(epoch_count)
|
||||
return ds
|
||||
|
||||
|
|
|
@ -146,7 +146,8 @@ def run_transformer_train():
|
|||
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
|
||||
rank_id=rank_id, do_shuffle=args.do_shuffle,
|
||||
dataset_path=args.data_path,
|
||||
bucket_boundaries=args.bucket_boundaries)
|
||||
bucket_boundaries=args.bucket_boundaries,
|
||||
device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue