forked from mindspore-Ecosystem/mindspore
remove parameter broadcast
This commit is contained in:
parent
8346da267b
commit
cc131193ec
|
@ -15,12 +15,14 @@
|
||||||
"""Utils of auto parallel"""
|
"""Utils of auto parallel"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mindspore import log as logger
|
||||||
from mindspore._c_expression import reset_op_id
|
from mindspore._c_expression import reset_op_id
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.dtype import dtype_to_nptype
|
from mindspore.common.dtype import dtype_to_nptype
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.communication.management import get_group_size, get_rank
|
from mindspore.communication.management import get_group_size, get_rank
|
||||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
|
from mindspore.common.seed import get_seed
|
||||||
|
|
||||||
|
|
||||||
def _get_parallel_mode():
|
def _get_parallel_mode():
|
||||||
|
@ -136,16 +138,11 @@ def _get_global_rank():
|
||||||
def _get_parameter_broadcast():
|
def _get_parameter_broadcast():
|
||||||
"""Get the parameter broadcast."""
|
"""Get the parameter broadcast."""
|
||||||
parallel_mode = auto_parallel_context().get_parallel_mode()
|
parallel_mode = auto_parallel_context().get_parallel_mode()
|
||||||
if parallel_mode == "stand_alone":
|
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
||||||
parameter_broadcast = False
|
|
||||||
return parameter_broadcast
|
|
||||||
|
|
||||||
if auto_parallel_context().get_parameter_broadcast_is_set() is True:
|
if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
|
||||||
parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
|
logger.warning("You are suggested to use mindspore.common.set_seed() to share"
|
||||||
elif parallel_mode in ("data_parallel", "hybrid_parallel"):
|
" parameters among devices.")
|
||||||
parameter_broadcast = True
|
|
||||||
else:
|
|
||||||
parameter_broadcast = False
|
|
||||||
|
|
||||||
return parameter_broadcast
|
return parameter_broadcast
|
||||||
|
|
||||||
|
|
|
@ -268,7 +268,7 @@ def train(cloud_args=None):
|
||||||
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
|
||||||
|
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
|
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
|
||||||
|
|
||||||
# checkpoint save
|
# checkpoint save
|
||||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == '__main__':
|
||||||
rank = args_opt.rank_id
|
rank = args_opt.rank_id
|
||||||
device_num = args_opt.device_num
|
device_num = args_opt.device_num
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True, parameter_broadcast=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
|
|
|
@ -58,7 +58,7 @@ if __name__ == '__main__':
|
||||||
cfg.group_size = get_group_size()
|
cfg.group_size = get_group_size()
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
else:
|
else:
|
||||||
cfg.rank = 0
|
cfg.rank = 0
|
||||||
cfg.group_size = 1
|
cfg.group_size = 1
|
||||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
||||||
rank = args_opt.rank_id
|
rank = args_opt.rank_id
|
||||||
device_num = args_opt.device_num
|
device_num = args_opt.device_num
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True, parameter_broadcast=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
|
|
|
@ -49,7 +49,7 @@ def context_device_init(config):
|
||||||
if config.run_distribute:
|
if config.run_distribute:
|
||||||
context.set_auto_parallel_context(device_num=config.rank_size,
|
context.set_auto_parallel_context(device_num=config.rank_size,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
parameter_broadcast=True, gradients_mean=True,
|
gradients_mean=True,
|
||||||
all_reduce_fusion_config=[140])
|
all_reduce_fusion_config=[140])
|
||||||
init()
|
init()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -76,7 +76,6 @@ def train_on_ascend():
|
||||||
if run_distribute:
|
if run_distribute:
|
||||||
context.set_auto_parallel_context(device_num=rank_size,
|
context.set_auto_parallel_context(device_num=rank_size,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
parameter_broadcast=True,
|
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,6 @@ if __name__ == '__main__':
|
||||||
if run_distribute:
|
if run_distribute:
|
||||||
context.set_auto_parallel_context(device_num=rank_size,
|
context.set_auto_parallel_context(device_num=rank_size,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
parameter_broadcast=True,
|
|
||||||
gradients_mean=True)
|
gradients_mean=True)
|
||||||
init()
|
init()
|
||||||
context.set_auto_parallel_context(device_num=args_opt.device_num,
|
context.set_auto_parallel_context(device_num=args_opt.device_num,
|
||||||
|
|
|
@ -178,7 +178,7 @@ def test(cloud_args=None):
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
|
|
||||||
args.logger.save_args(args)
|
args.logger.save_args(args)
|
||||||
|
|
||||||
|
|
|
@ -200,7 +200,7 @@ def train(cloud_args=None):
|
||||||
if args.is_distributed:
|
if args.is_distributed:
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
# dataloader
|
# dataloader
|
||||||
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
de_dataset = classification_dataset(args.data_dir, args.image_size,
|
||||||
args.per_batch_size, 1,
|
args.per_batch_size, 1,
|
||||||
|
|
|
@ -51,7 +51,6 @@ def train_net(data_dir,
|
||||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||||
device_num=group_size,
|
device_num=group_size,
|
||||||
parameter_broadcast=True,
|
|
||||||
gradients_mean=False)
|
gradients_mean=False)
|
||||||
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
net = UNet(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ if __name__ == '__main__':
|
||||||
device_num = args.group_size
|
device_num = args.group_size
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
parameter_broadcast=True, gradients_mean=True)
|
gradients_mean=True)
|
||||||
else:
|
else:
|
||||||
context.set_context(device_id=args.device_id)
|
context.set_context(device_id=args.device_id)
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
|
|
|
@ -254,7 +254,6 @@ def _setup_parallel_env(platform):
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
device_num=MultiAscend.get_group_size(),
|
device_num=MultiAscend.get_group_size(),
|
||||||
parameter_broadcast=True,
|
|
||||||
gradients_mean=True
|
gradients_mean=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ def run_transformer_train():
|
||||||
device_num = args.device_num
|
device_num = args.device_num
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||||
parameter_broadcast=True, device_num=device_num)
|
device_num=device_num)
|
||||||
D.init()
|
D.init()
|
||||||
rank_id = args.device_id % device_num
|
rank_id = args.device_id % device_num
|
||||||
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
|
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
|
||||||
|
|
Loading…
Reference in New Issue