!19172 Add DataQueue Control For PanGu Model
Merge pull request !19172 from huangxinjing/add_data_queue
This commit is contained in:
commit
e89b075a0b
|
@ -85,6 +85,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
|
|||
dataset_restore: the dataset for training or evaluating
|
||||
"""
|
||||
ds.config.set_seed(1)
|
||||
# Control the size of data queue in the consideration of the memory
|
||||
ds.config.set_prefetch_size(1)
|
||||
|
||||
# Get path for source data files
|
||||
home_path = os.path.join(os.getcwd(), data_path)
|
||||
|
|
|
@ -336,9 +336,10 @@ def add_training_params(opt):
|
|||
opt.add_argument("--sink_size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The sink size of the training")
|
||||
help="The sink size of the training. default is 2")
|
||||
opt.add_argument("--full_batch",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Import the full size of a batch for each card, default is 1")
|
||||
opt.add_argument("--optimizer_shard",
|
||||
type=int,
|
||||
|
@ -347,11 +348,11 @@ def add_training_params(opt):
|
|||
opt.add_argument("--per_batch_size",
|
||||
type=int,
|
||||
default=6,
|
||||
help="The batch size for each data parallel way. default 32")
|
||||
help="The batch size for each data parallel way. default 6")
|
||||
opt.add_argument("--start_lr",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="The start learning rate. default 1e-5")
|
||||
help="The start learning rate. default 5e-5")
|
||||
opt.add_argument("--end_lr",
|
||||
type=float,
|
||||
default=1e-6,
|
||||
|
@ -377,12 +378,12 @@ def get_args(inference=False):
|
|||
parser.add_argument("--device_num",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Use device nums, default is 1.")
|
||||
help="Use device nums, default is 128.")
|
||||
parser.add_argument("--distribute",
|
||||
type=str,
|
||||
default="true",
|
||||
choices=["true", "false"],
|
||||
help="Run distribute, default is false.")
|
||||
help="Run distribute, default is true.")
|
||||
parser.add_argument("--load_ckpt_name",
|
||||
type=str,
|
||||
default='PANGUALPHA3.ckpt',
|
||||
|
@ -408,7 +409,7 @@ def get_args(inference=False):
|
|||
type=str,
|
||||
default="2.6B",
|
||||
choices=["200B", "13B", "2.6B", "self_define"],
|
||||
help="The train/eval mode")
|
||||
help="The scale of the model parameters")
|
||||
parser.add_argument("--strategy_load_ckpt_path",
|
||||
type=str,
|
||||
default="",
|
||||
|
|
Loading…
Reference in New Issue