forked from mindspore-Ecosystem/mindspore
parallel envs variable check
This commit is contained in:
parent
4c1cd579a7
commit
e967f1939b
|
@ -84,6 +84,7 @@ class GlobalComm:
|
||||||
BACKEND = DEFAULT_BACKEND
|
BACKEND = DEFAULT_BACKEND
|
||||||
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||||
INITED = False
|
INITED = False
|
||||||
|
CHECK_ENVS = True
|
||||||
|
|
||||||
def is_hccl_available():
|
def is_hccl_available():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -24,7 +24,7 @@ HCCL_LIB_CTYPES = ""
|
||||||
|
|
||||||
def check_group(group):
|
def check_group(group):
|
||||||
"""
|
"""
|
||||||
A function that check if a collection communication group is leagal.
|
A function that check if a collection communication group is legal.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
|
@ -39,7 +39,7 @@ def check_group(group):
|
||||||
|
|
||||||
def check_rank_num(rank_num):
|
def check_rank_num(rank_num):
|
||||||
"""
|
"""
|
||||||
A function that check if a collection communication rank number is leagal.If not raise error.
|
A function that check if a collection communication rank number is legal.If not raise error.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
|
@ -53,7 +53,7 @@ def check_rank_num(rank_num):
|
||||||
|
|
||||||
def check_rank_id(rank_id):
|
def check_rank_id(rank_id):
|
||||||
"""
|
"""
|
||||||
A function that check if a collection communication rank id is leagal.If not raise error.
|
A function that check if a collection communication rank id is legal.If not raise error.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
|
@ -112,7 +112,7 @@ def create_group(group, rank_num, rank_ids):
|
||||||
c_group = c_str(group)
|
c_group = c_str(group)
|
||||||
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
|
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise RuntimeError('Create group error.')
|
raise RuntimeError('Create group error, the error code is ', ret)
|
||||||
else:
|
else:
|
||||||
raise TypeError('Rank ids must be a python list.')
|
raise TypeError('Rank ids must be a python list.')
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,28 @@ def _get_group(group):
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
def _check_parallel_envs():
|
||||||
|
"""
|
||||||
|
Check whether parallel environment variables have been exported or not.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
|
||||||
|
"""
|
||||||
|
if not GlobalComm.CHECK_ENVS:
|
||||||
|
return
|
||||||
|
import os
|
||||||
|
rank_id_str = os.getenv("RANK_ID")
|
||||||
|
if not rank_id_str:
|
||||||
|
raise RuntimeError("Environment variables RANK_ID has not been exported")
|
||||||
|
try:
|
||||||
|
int(rank_id_str)
|
||||||
|
except ValueError:
|
||||||
|
print("RANK_ID should be number")
|
||||||
|
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
|
||||||
|
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
|
||||||
|
if not rank_table_file_str and not rank_table_file_str_old:
|
||||||
|
raise RuntimeError("Get hccl rank_table_file failed, "
|
||||||
|
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
|
||||||
|
|
||||||
def init(backend_name=None):
|
def init(backend_name=None):
|
||||||
"""
|
"""
|
||||||
|
@ -68,6 +90,7 @@ def init(backend_name=None):
|
||||||
if backend_name == "hccl":
|
if backend_name == "hccl":
|
||||||
if device_target != "Ascend":
|
if device_target != "Ascend":
|
||||||
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
|
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
|
||||||
|
_check_parallel_envs()
|
||||||
init_hccl()
|
init_hccl()
|
||||||
GlobalComm.BACKEND = Backend("hccl")
|
GlobalComm.BACKEND = Backend("hccl")
|
||||||
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||||
|
|
|
@ -30,13 +30,16 @@ from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
|
||||||
from mindspore.ops.operations.array_ops import Gather
|
from mindspore.ops.operations.array_ops import Gather
|
||||||
import mindspore
|
import mindspore
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=W0212
|
# pylint: disable=W0212
|
||||||
# W0212: protected-access
|
# W0212: protected-access
|
||||||
|
|
||||||
tag = 0
|
tag = 0
|
||||||
|
|
||||||
context.set_context(device_target="Ascend")
|
context.set_context(device_target="Ascend")
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init("hccl")
|
init("hccl")
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
|
|
||||||
|
|
||||||
class AllReduceNet(nn.Cell):
|
class AllReduceNet(nn.Cell):
|
||||||
|
|
|
@ -31,11 +31,13 @@ from mindspore.parallel import set_algo_parameters
|
||||||
from mindspore.parallel._utils import _reset_op_id as resset_op_id
|
from mindspore.parallel._utils import _reset_op_id as resset_op_id
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
context.set_context(device_id=0)
|
context.set_context(device_id=0)
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
|
|
||||||
def weight_variable():
|
def weight_variable():
|
||||||
return TruncatedNormal(0.02)
|
return TruncatedNormal(0.02)
|
||||||
|
|
|
@ -18,11 +18,15 @@ from mindspore.communication.management import init
|
||||||
from mindspore.parallel import set_algo_parameters
|
from mindspore.parallel import set_algo_parameters
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
from .test_auto_parallel_resnet import resnet50
|
from .test_auto_parallel_resnet import resnet50
|
||||||
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
context.set_context(device_id=0)
|
context.set_context(device_id=0)
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
|
|
||||||
def test_train_32k_8p(batch_size=32, num_classes=32768):
|
def test_train_32k_8p(batch_size=32, num_classes=32768):
|
||||||
dev_num = 8
|
dev_num = 8
|
||||||
|
|
|
@ -19,7 +19,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
|
|
||||||
class DataParallelNet(nn.Cell):
|
class DataParallelNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -49,7 +49,9 @@ def test_param_broadcast():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True)
|
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True)
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
network = DataParallelNet()
|
network = DataParallelNet()
|
||||||
network.set_train()
|
network.set_train()
|
||||||
|
|
||||||
|
@ -62,7 +64,9 @@ def test_param_not_broadcast():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False)
|
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False)
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
network = ModelParallelNet()
|
network = ModelParallelNet()
|
||||||
network.set_train()
|
network.set_train()
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.ops import functional as F
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
device_number = 32
|
device_number = 32
|
||||||
|
@ -124,7 +125,9 @@ class TrainOneStepCell(Cell):
|
||||||
|
|
||||||
|
|
||||||
def net_trains(criterion, rank):
|
def net_trains(criterion, rank):
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
lr = 0.1
|
lr = 0.1
|
||||||
momentum = 0.9
|
momentum = 0.9
|
||||||
max_epoch = 20
|
max_epoch = 20
|
||||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.nn import Momentum
|
||||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, input_channel, out_channel):
|
def __init__(self, input_channel, out_channel):
|
||||||
|
@ -47,7 +47,9 @@ def test_dense_gen_graph():
|
||||||
context.set_context(mode=context.GRAPH_MODE)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8)
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8)
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
network = Net(512, 128)
|
network = Net(512, 128)
|
||||||
|
|
||||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
|
|
@ -21,6 +21,7 @@ from mindspore import Tensor
|
||||||
from mindspore import amp
|
from mindspore import amp
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from ....dataset_mock import MindData
|
from ....dataset_mock import MindData
|
||||||
|
@ -156,8 +157,8 @@ def test_compile_model_train_O2_parallel():
|
||||||
net = NetNoLoss(16, 16)
|
net = NetNoLoss(16, 16)
|
||||||
loss = nn.MSELoss()
|
loss = nn.MSELoss()
|
||||||
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
|
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init()
|
init()
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
|
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
|
||||||
model.train(2, dataset, dataset_sink_mode=False)
|
model.train(2, dataset, dataset_sink_mode=False)
|
||||||
|
|
|
@ -19,9 +19,9 @@ import numpy as np
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
from mindspore.train.dataset_helper import DatasetHelper
|
from mindspore.train.dataset_helper import DatasetHelper
|
||||||
|
from mindspore.communication._comm_helper import GlobalComm
|
||||||
from ....dataset_mock import MindData
|
from ....dataset_mock import MindData
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(batch_size=1):
|
def get_dataset(batch_size=1):
|
||||||
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
|
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
|
||||||
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
|
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
|
||||||
|
@ -75,7 +75,9 @@ def test_dataset_iter_normal():
|
||||||
|
|
||||||
@pytest.mark.skipif('not context.get_context("enable_ge")')
|
@pytest.mark.skipif('not context.get_context("enable_ge")')
|
||||||
def test_dataset_iter_ge():
|
def test_dataset_iter_ge():
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init("hccl")
|
init("hccl")
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
dataset = get_dataset(32)
|
dataset = get_dataset(32)
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -87,7 +89,9 @@ def test_dataset_iter_ge():
|
||||||
|
|
||||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||||
def test_dataset_iter_ms_loop_sink():
|
def test_dataset_iter_ms_loop_sink():
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init("hccl")
|
init("hccl")
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
context.set_context(enable_loop_sink=True)
|
context.set_context(enable_loop_sink=True)
|
||||||
dataset = get_dataset(32)
|
dataset = get_dataset(32)
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||||
|
@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink():
|
||||||
|
|
||||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||||
def test_dataset_iter_ms():
|
def test_dataset_iter_ms():
|
||||||
|
GlobalComm.CHECK_ENVS = False
|
||||||
init("hccl")
|
init("hccl")
|
||||||
|
GlobalComm.CHECK_ENVS = True
|
||||||
context.set_context(enable_loop_sink=False)
|
context.set_context(enable_loop_sink=False)
|
||||||
dataset = get_dataset(32)
|
dataset = get_dataset(32)
|
||||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||||
|
|
Loading…
Reference in New Issue