remove parameter broadcast

This commit is contained in:
Ziyan 2020-09-18 11:07:58 +08:00
parent 8346da267b
commit cc131193ec
14 changed files with 15 additions and 22 deletions

View File

@ -15,12 +15,14 @@
"""Utils of auto parallel"""
import numpy as np
from mindspore import log as logger
from mindspore._c_expression import reset_op_id
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype
from mindspore.common import dtype as mstype
from mindspore.communication.management import get_group_size, get_rank
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.common.seed import get_seed
def _get_parallel_mode():
@ -136,16 +138,11 @@ def _get_global_rank():
def _get_parameter_broadcast():
"""Get the parameter broadcast."""
parallel_mode = auto_parallel_context().get_parallel_mode()
if parallel_mode == "stand_alone":
parameter_broadcast = False
return parameter_broadcast
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
if auto_parallel_context().get_parameter_broadcast_is_set() is True:
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
elif parallel_mode in ("data_parallel", "hybrid_parallel"):
parameter_broadcast = True
else:
parameter_broadcast = False
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
logger.warning("You are suggested to use mindspore.common.set_seed() to share"
" parameters among devices.")
return parameter_broadcast

View File

@ -268,7 +268,7 @@ def train(cloud_args=None):
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, gradients_mean=True)
gradients_mean=True)
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
# checkpoint save

View File

@ -54,7 +54,7 @@ if __name__ == '__main__':
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True)
gradients_mean=True)
init()
else:
rank = 0

View File

@ -58,7 +58,7 @@ if __name__ == '__main__':
cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, gradients_mean=True)
gradients_mean=True)
else:
cfg.rank = 0
cfg.group_size = 1

View File

@ -59,7 +59,7 @@ if __name__ == '__main__':
rank = args_opt.rank_id
device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True)
gradients_mean=True)
init()
else:
rank = 0

View File

@ -49,7 +49,7 @@ def context_device_init(config):
if config.run_distribute:
context.set_auto_parallel_context(device_num=config.rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True,
gradients_mean=True,
all_reduce_fusion_config=[140])
init()
else:

View File

@ -76,7 +76,6 @@ def train_on_ascend():
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True,
gradients_mean=True)
init()

View File

@ -74,7 +74,6 @@ if __name__ == '__main__':
if run_distribute:
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True,
gradients_mean=True)
init()
context.set_auto_parallel_context(device_num=args_opt.device_num,

View File

@ -178,7 +178,7 @@ def test(cloud_args=None):
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, gradients_mean=True)
gradients_mean=True)
args.logger.save_args(args)

View File

@ -200,7 +200,7 @@ def train(cloud_args=None):
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, gradients_mean=True)
gradients_mean=True)
# dataloader
de_dataset = classification_dataset(args.data_dir, args.image_size,
args.per_batch_size, 1,

View File

@ -51,7 +51,6 @@ def train_net(data_dir,
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size,
parameter_broadcast=True,
gradients_mean=False)
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])

View File

@ -140,7 +140,7 @@ if __name__ == '__main__':
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True)
gradients_mean=True)
else:
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

View File

@ -254,7 +254,6 @@ def _setup_parallel_env(platform):
context.set_auto_parallel_context(
parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiAscend.get_group_size(),
parameter_broadcast=True,
gradients_mean=True
)

View File

@ -123,7 +123,7 @@ def run_transformer_train():
device_num = args.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
parameter_broadcast=True, device_num=device_num)
device_num=device_num)
D.init()
rank_id = args.device_id % device_num
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')