!19172 Add DataQueue Control For PanGu Model

Merge pull request !19172 from huangxinjing/add_data_queue
This commit is contained in:
i-robot 2021-07-01 07:29:11 +00:00 committed by Gitee
commit e89b075a0b
2 changed files with 9 additions and 6 deletions

View File

@ -85,6 +85,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
dataset_restore: the dataset for training or evaluating
"""
ds.config.set_seed(1)
# Control the size of data queue in the consideration of the memory
ds.config.set_prefetch_size(1)
# Get path for source data files
home_path = os.path.join(os.getcwd(), data_path)

View File

@ -336,9 +336,10 @@ def add_training_params(opt):
opt.add_argument("--sink_size",
type=int,
default=2,
help="The sink size of the training")
help="The sink size of the training. default is 2")
opt.add_argument("--full_batch",
default=1,
type=int,
help="Import the full size of a batch for each card, default is 1")
opt.add_argument("--optimizer_shard",
type=int,
@ -347,11 +348,11 @@ def add_training_params(opt):
opt.add_argument("--per_batch_size",
type=int,
default=6,
help="The batch size for each data parallel way. default 32")
help="The batch size for each data parallel way. default 6")
opt.add_argument("--start_lr",
type=float,
default=5e-5,
help="The start learning rate. default 1e-5")
help="The start learning rate. default 5e-5")
opt.add_argument("--end_lr",
type=float,
default=1e-6,
@ -377,12 +378,12 @@ def get_args(inference=False):
parser.add_argument("--device_num",
type=int,
default=128,
help="Use device nums, default is 1.")
help="Use device nums, default is 128.")
parser.add_argument("--distribute",
type=str,
default="true",
choices=["true", "false"],
help="Run distribute, default is false.")
help="Run distribute, default is true.")
parser.add_argument("--load_ckpt_name",
type=str,
default='PANGUALPHA3.ckpt',
@ -408,7 +409,7 @@ def get_args(inference=False):
type=str,
default="2.6B",
choices=["200B", "13B", "2.6B", "self_define"],
help="The train/eval mode")
help="The scale of the model parameters")
parser.add_argument("--strategy_load_ckpt_path",
type=str,
default="",