!6483 remove parameter broadcast
Merge pull request !6483 from gziyan/rm——parameter_broadcast
This commit is contained in:
commit
cdff9412dc
|
@ -15,12 +15,14 @@
|
|||
"""Utils of auto parallel"""
|
||||
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from mindspore._c_expression import reset_op_id
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.common.seed import get_seed
|
||||
|
||||
|
||||
def _get_parallel_mode():
|
||||
|
@ -136,16 +138,11 @@ def _get_global_rank():
|
|||
def _get_parameter_broadcast():
|
||||
"""Get the parameter broadcast."""
|
||||
parallel_mode = auto_parallel_context().get_parallel_mode()
|
||||
if parallel_mode == "stand_alone":
|
||||
parameter_broadcast = False
|
||||
return parameter_broadcast
|
||||
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
||||
|
||||
if auto_parallel_context().get_parameter_broadcast_is_set() is True:
|
||||
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
||||
elif parallel_mode in ("data_parallel", "hybrid_parallel"):
|
||||
parameter_broadcast = True
|
||||
else:
|
||||
parameter_broadcast = False
|
||||
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
|
||||
logger.warning("You are suggested to use mindspore.common.set_seed() to share"
|
||||
" parameters among devices.")
|
||||
|
||||
return parameter_broadcast
|
||||
|
||||
|
|
|
@ -268,7 +268,7 @@ def train(cloud_args=None):
|
|||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
|
||||
|
||||
# checkpoint save
|
||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == '__main__':
|
|||
rank = args_opt.rank_id
|
||||
device_num = args_opt.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, parameter_broadcast=True)
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
|
|
|
@ -58,7 +58,7 @@ if __name__ == '__main__':
|
|||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
else:
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|||
rank = args_opt.rank_id
|
||||
device_num = args_opt.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, parameter_broadcast=True)
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
|
|
|
@ -49,7 +49,7 @@ def context_device_init(config):
|
|||
if config.run_distribute:
|
||||
context.set_auto_parallel_context(device_num=config.rank_size,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
parameter_broadcast=True, gradients_mean=True,
|
||||
gradients_mean=True,
|
||||
all_reduce_fusion_config=[140])
|
||||
init()
|
||||
else:
|
||||
|
|
|
@ -76,7 +76,6 @@ def train_on_ascend():
|
|||
if run_distribute:
|
||||
context.set_auto_parallel_context(device_num=rank_size,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
parameter_broadcast=True,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
|
||||
|
|
|
@ -74,7 +74,6 @@ if __name__ == '__main__':
|
|||
if run_distribute:
|
||||
context.set_auto_parallel_context(device_num=rank_size,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
parameter_broadcast=True,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num,
|
||||
|
|
|
@ -178,7 +178,7 @@ def test(cloud_args=None):
|
|||
if args.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
|
||||
args.logger.save_args(args)
|
||||
|
||||
|
|
|
@ -200,7 +200,7 @@ def train(cloud_args=None):
|
|||
if args.is_distributed:
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
# dataloader
|
||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||
args.per_batch_size, 1,
|
||||
|
|
|
@ -51,7 +51,6 @@ def train_net(data_dir,
|
|||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||
device_num=group_size,
|
||||
parameter_broadcast=True,
|
||||
gradients_mean=False)
|
||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||
|
||||
|
|
|
@ -139,7 +139,7 @@ if __name__ == '__main__':
|
|||
device_num = args.group_size
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
parameter_broadcast=True, gradients_mean=True)
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
|
|
@ -254,7 +254,6 @@ def _setup_parallel_env(platform):
|
|||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
parameter_broadcast=True,
|
||||
gradients_mean=True
|
||||
)
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ def run_transformer_train():
|
|||
device_num = args.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
parameter_broadcast=True, device_num=device_num)
|
||||
device_num=device_num)
|
||||
D.init()
|
||||
rank_id = args.device_id % device_num
|
||||
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
|
|
Loading…
Reference in New Issue