forked from mindspore-Ecosystem/mindspore
parent
17c9b7397d
commit
7f8c5b10da
|
@ -291,7 +291,7 @@ class DistributedGradReducer(Cell):
|
|||
ValueError: If degree is not a int or less than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``, ``GPU``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> # This example should be run with multiple processes.
|
||||
|
|
|
@ -140,9 +140,13 @@ class Primitive(Primitive_):
|
|||
Note:
|
||||
It is valid only in semi auto parallel.
|
||||
In other parallel modes, please set it to be 0.
|
||||
|
||||
Args:
|
||||
stage (int): The stage id for the current operation.
|
||||
Example:
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> add = P.Add()
|
||||
>>> print(add.set_stage(0))
|
||||
Prim[Add]<stage=0>
|
||||
"""
|
||||
self.add_prim_attr("stage", stage)
|
||||
return self
|
||||
|
@ -157,6 +161,11 @@ class Primitive(Primitive_):
|
|||
|
||||
Args:
|
||||
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
|
||||
Example:
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> add = P.Add()
|
||||
>>> print(add.shard(((1, 1), (1, 1))))
|
||||
Prim[Add]<strategy=((1, 1), (1, 1))>
|
||||
"""
|
||||
mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if strategy is not None:
|
||||
|
|
|
@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
|
|||
os.path.join(home_path, name) for name in files
|
||||
if not name.endswith(".db")
|
||||
]
|
||||
# Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem.
|
||||
data.sort()
|
||||
|
||||
# Load data files and preprocess
|
||||
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
||||
|
|
|
@ -279,20 +279,16 @@ def run_train_pipeline(args_opt):
|
|||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
if context.get_auto_parallel_context("full_batch"):
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
|
||||
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
|
||||
else:
|
||||
if batch_size % stage_device_num != 0:
|
||||
raise ValueError("Batch_size should be divisible by device_num")
|
||||
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
||||
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
|
||||
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
|
||||
full_batch=context.get_auto_parallel_context("full_batch"),
|
||||
column_name=args_opt.data_column_name)
|
||||
epoch_num = args_opt.epoch_size
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
callback_size = args_opt.sink_size
|
||||
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num, config.micro_size)]
|
||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, micro_size=config.micro_size)]
|
||||
loss_scale_value = math.pow(2, 32)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||||
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||||
|
|
Loading…
Reference in New Issue