Add pipeline dataset imports

Fix document
This commit is contained in:
huangxinjing 2021-07-09 15:05:10 +08:00
parent 17c9b7397d
commit 7f8c5b10da
4 changed files with 19 additions and 12 deletions

View File

@ -291,7 +291,7 @@ class DistributedGradReducer(Cell):
ValueError: If degree is not a int or less than 0. ValueError: If degree is not a int or less than 0.
Supported Platforms: Supported Platforms:
``Ascend``, ``GPU`` ``Ascend`` ``GPU``
Examples: Examples:
>>> # This example should be run with multiple processes. >>> # This example should be run with multiple processes.

View File

@ -140,9 +140,13 @@ class Primitive(Primitive_):
Note: Note:
It is valid only in semi auto parallel. It is valid only in semi auto parallel.
In other parallel modes, please set it to be 0. In other parallel modes, please set it to be 0.
Args: Args:
stage (int): The stage id for the current operation. stage (int): The stage id for the current operation.
Example:
>>> from mindspore.ops import operations as P
>>> add = P.Add()
>>> print(add.set_stage(0))
Prim[Add]<stage=0>
""" """
self.add_prim_attr("stage", stage) self.add_prim_attr("stage", stage)
return self return self
@ -157,6 +161,11 @@ class Primitive(Primitive_):
Args: Args:
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
Example:
>>> from mindspore.ops import operations as P
>>> add = P.Add()
>>> print(add.shard(((1, 1), (1, 1))))
Prim[Add]<strategy=((1, 1), (1, 1))>
""" """
mode = context.get_auto_parallel_context("parallel_mode") mode = context.get_auto_parallel_context("parallel_mode")
if strategy is not None: if strategy is not None:

View File

@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
os.path.join(home_path, name) for name in files os.path.join(home_path, name) for name in files
if not name.endswith(".db") if not name.endswith(".db")
] ]
# Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem.
data.sort()
# Load data files and preprocess # Load data files and preprocess
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False) dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)

View File

@ -279,20 +279,16 @@ def run_train_pipeline(args_opt):
optimizer = nn.Lamb(group_params, learning_rate=lr) optimizer = nn.Lamb(group_params, learning_rate=lr)
else: else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
if context.get_auto_parallel_context("full_batch"):
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True, ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name) rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
else: full_batch=context.get_auto_parallel_context("full_batch"),
if batch_size % stage_device_num != 0: column_name=args_opt.data_column_name)
raise ValueError("Batch_size should be divisible by device_num")
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
column_name=args_opt.data_column_name)
epoch_num = args_opt.epoch_size epoch_num = args_opt.epoch_size
step_per_epoch = ds.get_dataset_size() step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num, config.micro_size)] callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, micro_size=config.micro_size)]
loss_scale_value = math.pow(2, 32) loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(