rectification init
This commit is contained in:
parent
49aa4b7686
commit
d3e55b543e
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""Communication management API"""
|
||||
import os
|
||||
from mindspore import context
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
||||
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
||||
|
@ -45,7 +46,7 @@ class GlobalComm:
|
|||
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
||||
|
||||
|
||||
def init(backend_name="hccl"):
|
||||
def init(backend_name=None):
|
||||
"""
|
||||
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used.
|
||||
|
||||
|
@ -57,11 +58,20 @@ def init(backend_name="hccl"):
|
|||
backend_name (str): Backend.
|
||||
|
||||
Raises:
|
||||
TypeError: If backend name is not a string.
|
||||
TypeError: If backen_name is not a string.
|
||||
RuntimeError: If device target is invalid.
|
||||
RuntimeError: If backend is invalid or distributed init fails.
|
||||
"""
|
||||
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
||||
return
|
||||
if backend_name is None:
|
||||
device_target = context.get_context("device_target")
|
||||
if device_target == "Ascend":
|
||||
backend_name = "hccl"
|
||||
elif device_target == "GPU":
|
||||
backend_name = "nccl"
|
||||
else:
|
||||
raise RuntimeError("Device target {} is not supported.".format(device_target))
|
||||
if not isinstance(backend_name, str):
|
||||
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class AllReduce(PrimitiveWithInfer):
|
|||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops.operations as P
|
||||
>>>
|
||||
>>> init('nccl')
|
||||
>>> init()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
|
@ -136,7 +136,7 @@ class AllGather(PrimitiveWithInfer):
|
|||
>>> from mindspore.communication import init
|
||||
>>> from mindspore import Tensor
|
||||
>>>
|
||||
>>> init('nccl')
|
||||
>>> init()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
|
@ -246,7 +246,7 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops.operations as P
|
||||
>>>
|
||||
>>> init('nccl')
|
||||
>>> init()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
|
@ -360,7 +360,7 @@ class Broadcast(PrimitiveWithInfer):
|
|||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops.operations as P
|
||||
>>>
|
||||
>>> init('nccl')
|
||||
>>> init()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
|
|
|
@ -81,7 +81,7 @@ if __name__ == '__main__':
|
|||
mirror_mean=True)
|
||||
init()
|
||||
elif device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
|
||||
if device_num > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
|
|
|
@ -57,10 +57,7 @@ if __name__ == '__main__':
|
|||
cfg = config_ascend if args_opt.platform == 'Ascend' else config_gpu
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if args_opt.platform == "Ascend":
|
||||
init()
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
cfg.rank = get_rank()
|
||||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
|
|
|
@ -64,7 +64,7 @@ elif args_opt.device_target == "GPU":
|
|||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="GPU",
|
||||
save_graphs=False)
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
|
|
|
@ -57,7 +57,7 @@ if args_opt.device_target == "Ascend":
|
|||
device_target="Ascend",
|
||||
device_id=device_id, save_graphs=False)
|
||||
elif args_opt.device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
|
|
|
@ -54,7 +54,7 @@ if args_opt.device_target == "GPU":
|
|||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="GPU",
|
||||
save_graphs=False)
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
|
|
|
@ -38,7 +38,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
|
@ -93,7 +93,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ if __name__ == '__main__':
|
|||
init()
|
||||
# GPU target
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
if args_opt.net == "resnet50":
|
||||
|
|
|
@ -46,7 +46,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
|
@ -114,7 +114,7 @@ def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, targe
|
|||
device_num = int(os.getenv("RANK_SIZE"))
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ if __name__ == '__main__':
|
|||
init()
|
||||
# GPU target
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
|
||||
|
|
|
@ -112,10 +112,7 @@ def test(cloud_args=None):
|
|||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.platform == "Ascend":
|
||||
init()
|
||||
elif args.platform == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
|
|
|
@ -172,10 +172,7 @@ def train(cloud_args=None):
|
|||
|
||||
# init distributed
|
||||
if args.is_distributed:
|
||||
if args.platform == "Ascend":
|
||||
init()
|
||||
else:
|
||||
init("nccl")
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
|
|
|
@ -135,7 +135,7 @@ if __name__ == '__main__':
|
|||
init()
|
||||
context.set_context(device_id=args.device_id)
|
||||
elif args.device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
|
||||
args.rank = get_rank()
|
||||
args.group_size = get_group_size()
|
||||
|
|
|
@ -60,7 +60,7 @@ if __name__ == '__main__':
|
|||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
init('nccl')
|
||||
init()
|
||||
lr_scale = 0.5
|
||||
device_num = get_group_size()
|
||||
rank = get_rank()
|
||||
|
|
|
@ -70,11 +70,11 @@ def run_pretrain():
|
|||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init('hccl')
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
D.init('nccl')
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
|
|
|
@ -73,11 +73,11 @@ def run_pretrain():
|
|||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init('hccl')
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
D.init('nccl')
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
|
|
|
@ -227,10 +227,7 @@ def _build_training_pipeline(config: TransformerConfig,
|
|||
|
||||
def _setup_parallel_env(platform):
|
||||
context.reset_auto_parallel_context()
|
||||
if platform == "GPU":
|
||||
MultiAscend.init("nccl")
|
||||
else:
|
||||
MultiAscend.init()
|
||||
MultiAscend.init()
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
device_num=MultiAscend.get_group_size(),
|
||||
|
|
|
@ -67,11 +67,11 @@ def run_general_distill():
|
|||
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init('hccl')
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
D.init('nccl')
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
|
||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|||
init()
|
||||
rank_id = int(os.environ.get('RANK_ID'))
|
||||
elif args_opt.device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
|
|
|
@ -128,10 +128,7 @@ if __name__ == "__main__":
|
|||
context.set_context(variable_memory_max_size="24GB")
|
||||
context.set_context(enable_sparse=True)
|
||||
set_multi_subgraphs()
|
||||
if wide_deep_config.device_target == "Ascend":
|
||||
init("hccl")
|
||||
elif wide_deep_config.device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
if wide_deep_config.host_device_mix == 1:
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
|
||||
else:
|
||||
|
|
|
@ -122,10 +122,7 @@ if __name__ == "__main__":
|
|||
wide_deep_config.argparse_init()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
|
||||
if wide_deep_config.device_target == "Ascend":
|
||||
init("hccl")
|
||||
elif wide_deep_config.device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=get_group_size())
|
||||
|
||||
|
|
|
@ -119,10 +119,7 @@ if __name__ == "__main__":
|
|||
wide_deep_config.argparse_init()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
|
||||
if wide_deep_config.device_target == "Ascend":
|
||||
init("hccl")
|
||||
elif wide_deep_config.device_target == "GPU":
|
||||
init("nccl")
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=get_group_size())
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
init('nccl')
|
||||
init()
|
||||
rank = get_rank()
|
||||
size = get_group_size()
|
||||
x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1)
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
init('nccl')
|
||||
init()
|
||||
rank = get_rank()
|
||||
size = get_group_size()
|
||||
x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1)
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
init('nccl')
|
||||
init()
|
||||
rank = get_rank()
|
||||
size = get_group_size()
|
||||
x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1)
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.nn.optim import Momentum
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
init('nccl')
|
||||
init()
|
||||
|
||||
epoch = 5
|
||||
total = 5000
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
init('nccl')
|
||||
init()
|
||||
rank = get_rank()
|
||||
size = get_group_size()
|
||||
x = np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1)
|
||||
|
|
|
@ -30,7 +30,7 @@ args, _ = parser.parse_known_args()
|
|||
device_target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
if device_target == "GPU":
|
||||
init('nccl')
|
||||
init()
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_dataset_iter_normal():
|
|||
|
||||
@pytest.mark.skipif('not context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ge():
|
||||
init()
|
||||
init("hccl")
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
count = 0
|
||||
|
@ -87,7 +87,7 @@ def test_dataset_iter_ge():
|
|||
|
||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ms_loop_sink():
|
||||
init()
|
||||
init("hccl")
|
||||
context.set_context(enable_loop_sink=True)
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
|
@ -101,7 +101,7 @@ def test_dataset_iter_ms_loop_sink():
|
|||
|
||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ms():
|
||||
init()
|
||||
init("hccl")
|
||||
context.set_context(enable_loop_sink=False)
|
||||
dataset = get_dataset(32)
|
||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
|
|
Loading…
Reference in New Issue