diff --git a/model_zoo/official/nlp/pangu_alpha/src/dataset.py b/model_zoo/official/nlp/pangu_alpha/src/dataset.py index 8e7707510d0..6f6fc147b31 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/dataset.py +++ b/model_zoo/official/nlp/pangu_alpha/src/dataset.py @@ -85,6 +85,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_ dataset_restore: the dataset for training or evaluating """ ds.config.set_seed(1) + # Control the size of data queue in the consideration of the memory + ds.config.set_prefetch_size(1) # Get path for source data files home_path = os.path.join(os.getcwd(), data_path) diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index 96b44acc83c..7ef9133f536 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -336,9 +336,10 @@ def add_training_params(opt): opt.add_argument("--sink_size", type=int, default=2, - help="The sink size of the training") + help="The sink size of the training. default is 2") opt.add_argument("--full_batch", default=1, + type=int, help="Import the full size of a batch for each card, default is 1") opt.add_argument("--optimizer_shard", type=int, @@ -347,11 +348,11 @@ def add_training_params(opt): opt.add_argument("--per_batch_size", type=int, default=6, - help="The batch size for each data parallel way. default 32") + help="The batch size for each data parallel way. default 6") opt.add_argument("--start_lr", type=float, default=5e-5, - help="The start learning rate. default 1e-5") + help="The start learning rate. default 5e-5") opt.add_argument("--end_lr", type=float, default=1e-6, @@ -377,12 +378,12 @@ def get_args(inference=False): parser.add_argument("--device_num", type=int, default=128, - help="Use device nums, default is 1.") + help="Use device nums, default is 128.") parser.add_argument("--distribute", type=str, default="true", choices=["true", "false"], - help="Run distribute, default is false.") + help="Run distribute, default is true.") parser.add_argument("--load_ckpt_name", type=str, default='PANGUALPHA3.ckpt', @@ -408,7 +409,7 @@ def get_args(inference=False): type=str, default="2.6B", choices=["200B", "13B", "2.6B", "self_define"], - help="The train/eval mode") + help="The scale of the model parameters") parser.add_argument("--strategy_load_ckpt_path", type=str, default="",