forked from mindspore-Ecosystem/mindspore
parent
17c9b7397d
commit
7f8c5b10da
|
@ -291,7 +291,7 @@ class DistributedGradReducer(Cell):
|
||||||
ValueError: If degree is not a int or less than 0.
|
ValueError: If degree is not a int or less than 0.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend``, ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # This example should be run with multiple processes.
|
>>> # This example should be run with multiple processes.
|
||||||
|
|
|
@ -140,9 +140,13 @@ class Primitive(Primitive_):
|
||||||
Note:
|
Note:
|
||||||
It is valid only in semi auto parallel.
|
It is valid only in semi auto parallel.
|
||||||
In other parallel modes, please set it to be 0.
|
In other parallel modes, please set it to be 0.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stage (int): The stage id for the current operation.
|
stage (int): The stage id for the current operation.
|
||||||
|
Example:
|
||||||
|
>>> from mindspore.ops import operations as P
|
||||||
|
>>> add = P.Add()
|
||||||
|
>>> print(add.set_stage(0))
|
||||||
|
Prim[Add]<stage=0>
|
||||||
"""
|
"""
|
||||||
self.add_prim_attr("stage", stage)
|
self.add_prim_attr("stage", stage)
|
||||||
return self
|
return self
|
||||||
|
@ -157,6 +161,11 @@ class Primitive(Primitive_):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
|
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
|
||||||
|
Example:
|
||||||
|
>>> from mindspore.ops import operations as P
|
||||||
|
>>> add = P.Add()
|
||||||
|
>>> print(add.shard(((1, 1), (1, 1))))
|
||||||
|
Prim[Add]<strategy=((1, 1), (1, 1))>
|
||||||
"""
|
"""
|
||||||
mode = context.get_auto_parallel_context("parallel_mode")
|
mode = context.get_auto_parallel_context("parallel_mode")
|
||||||
if strategy is not None:
|
if strategy is not None:
|
||||||
|
|
|
@ -95,6 +95,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, full_
|
||||||
os.path.join(home_path, name) for name in files
|
os.path.join(home_path, name) for name in files
|
||||||
if not name.endswith(".db")
|
if not name.endswith(".db")
|
||||||
]
|
]
|
||||||
|
# Ensure the order of mindrecords is same in all machines, otherwise it will meet loss converge problem.
|
||||||
|
data.sort()
|
||||||
|
|
||||||
# Load data files and preprocess
|
# Load data files and preprocess
|
||||||
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False)
|
||||||
|
|
|
@ -279,20 +279,16 @@ def run_train_pipeline(args_opt):
|
||||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||||
else:
|
else:
|
||||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||||
if context.get_auto_parallel_context("full_batch"):
|
|
||||||
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
|
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
||||||
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
|
rank=rank_id % stage_device_num, eod_reset=True, data_start_index=0,
|
||||||
else:
|
full_batch=context.get_auto_parallel_context("full_batch"),
|
||||||
if batch_size % stage_device_num != 0:
|
column_name=args_opt.data_column_name)
|
||||||
raise ValueError("Batch_size should be divisible by device_num")
|
|
||||||
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
|
||||||
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
|
|
||||||
column_name=args_opt.data_column_name)
|
|
||||||
epoch_num = args_opt.epoch_size
|
epoch_num = args_opt.epoch_size
|
||||||
step_per_epoch = ds.get_dataset_size()
|
step_per_epoch = ds.get_dataset_size()
|
||||||
callback_size = args_opt.sink_size
|
callback_size = args_opt.sink_size
|
||||||
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num, config.micro_size)]
|
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, micro_size=config.micro_size)]
|
||||||
loss_scale_value = math.pow(2, 32)
|
loss_scale_value = math.pow(2, 32)
|
||||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||||||
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||||||
|
|
Loading…
Reference in New Issue