Add pipeline dataset imports

Fix document
This commit is contained in:
huangxinjing 2021-07-09 15:05:10 +08:00
parent 17c9b7397d
commit 7f8c5b10da
4 changed files with 19 additions and 12 deletions

View File

@ -291,7 +291,7 @@ class DistributedGradReducer(Cell):
ValueError: If degree is not a int or less than 0.
Supported Platforms:
``Ascend``, ``GPU``
``Ascend`` ``GPU``
Examples:
>>> # This example should be run with multiple processes.

View File

@ -140,9 +140,13 @@ class Primitive(Primitive_):
Note:
It is valid only in semi auto parallel.
In other parallel modes, please set it to be 0.
Args:
stage (int): The stage id for the current operation.
Example:
>>> from mindspore.ops import operations as P
>>> add = P.Add()
>>> print(add.set_stage(0))
Prim[Add]<stage=0>
"""
self.add_prim_attr("stage", stage)
return self
@ -157,6 +161,11 @@ class Primitive(Primitive_):
Args:
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
Example:
>>> from mindspore.ops import operations as P
>>> add = P.Add()
>>> print(add.shard(((1, 1), (1, 1))))
Prim[Add]<strategy=((1, 1), (1, 1))>
"""
mode = context.get_auto_parallel_context("parallel_mode")
if strategy is not None:

View File

@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
os.path.join(home_path, name) for name in files
if not name.endswith(".db")
]
# Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem.
data.sort()
# Load data files and preprocess
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)

View File

@ -279,20 +279,16 @@ def run_train_pipeline(args_opt):
optimizer = nn.Lamb(group_params, learning_rate=lr)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
if context.get_auto_parallel_context("full_batch"):
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
else:
if batch_size % stage_device_num != 0:
raise ValueError("Batch_size should be divisible by device_num")
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
column_name=args_opt.data_column_name)
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
full_batch=context.get_auto_parallel_context("full_batch"),
column_name=args_opt.data_column_name)
epoch_num = args_opt.epoch_size
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num, config.micro_size)]
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, micro_size=config.micro_size)]
loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(