parallel envs variable check

This commit is contained in:
yao_yf 2021-04-30 17:57:52 +08:00
parent 4c1cd579a7
commit e967f1939b
11 changed files with 59 additions and 10 deletions

View File

@ -84,6 +84,7 @@ class GlobalComm:
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True
def is_hccl_available():
"""

View File

@ -24,7 +24,7 @@ HCCL_LIB_CTYPES = ""
def check_group(group):
"""
A function that check if a collection communication group is leagal.
A function that check if a collection communication group is legal.
Returns:
None
@ -39,7 +39,7 @@ def check_group(group):
def check_rank_num(rank_num):
"""
A function that check if a collection communication rank number is leagal.If not raise error.
A function that check if a collection communication rank number is legal.If not raise error.
Returns:
None
@ -53,7 +53,7 @@ def check_rank_num(rank_num):
def check_rank_id(rank_id):
"""
A function that check if a collection communication rank id is leagal.If not raise error.
A function that check if a collection communication rank id is legal.If not raise error.
Returns:
None
@ -112,7 +112,7 @@ def create_group(group, rank_num, rank_ids):
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
if ret != 0:
raise RuntimeError('Create group error.')
raise RuntimeError('Create group error, the error code is ', ret)
else:
raise TypeError('Rank ids must be a python list.')

View File

@ -36,6 +36,28 @@ def _get_group(group):
return group
def _check_parallel_envs():
"""
Check whether parallel environment variables have been exported or not.
Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
"""
if not GlobalComm.CHECK_ENVS:
return
import os
rank_id_str = os.getenv("RANK_ID")
if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported")
try:
int(rank_id_str)
except ValueError:
print("RANK_ID should be number")
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None):
"""
@ -68,6 +90,7 @@ def init(backend_name=None):
if backend_name == "hccl":
if device_target != "Ascend":
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
_check_parallel_envs()
init_hccl()
GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP

View File

@ -30,13 +30,16 @@ from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
from mindspore.ops.operations.array_ops import Gather
import mindspore
# pylint: disable=W0212
# W0212: protected-access
tag = 0
context.set_context(device_target="Ascend")
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
class AllReduceNet(nn.Cell):

View File

@ -31,11 +31,13 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as resset_op_id
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
def weight_variable():
return TruncatedNormal(0.02)

View File

@ -18,11 +18,15 @@ from mindspore.communication.management import init
from mindspore.parallel import set_algo_parameters
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
from .test_auto_parallel_resnet import resnet50
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8

View File

@ -19,7 +19,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.communication.management import init
from mindspore.ops import operations as P
from mindspore.communication._comm_helper import GlobalComm
class DataParallelNet(nn.Cell):
def __init__(self):
@ -49,7 +49,9 @@ def test_param_broadcast():
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
network = DataParallelNet()
network.set_train()
@ -62,7 +64,9 @@ def test_param_not_broadcast():
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
network = ModelParallelNet()
network.set_train()

View File

@ -29,6 +29,7 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
context.set_context(mode=context.GRAPH_MODE)
device_number = 32
@ -124,7 +125,9 @@ class TrainOneStepCell(Cell):
def net_trains(criterion, rank):
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
lr = 0.1
momentum = 0.9
max_epoch = 20

View File

@ -24,7 +24,7 @@ from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
class Net(nn.Cell):
def __init__(self, input_channel, out_channel):
@ -47,7 +47,9 @@ def test_dense_gen_graph():
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
network = Net(512, 128)
loss_fn = nn.SoftmaxCrossEntropyWithLogits()

View File

@ -21,6 +21,7 @@ from mindspore import Tensor
from mindspore import amp
from mindspore import nn
from mindspore.communication.management import init
from mindspore.communication._comm_helper import GlobalComm
from mindspore.context import ParallelMode
from mindspore.train import Model
from ....dataset_mock import MindData
@ -156,8 +157,8 @@ def test_compile_model_train_O2_parallel():
net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
model.train(2, dataset, dataset_sink_mode=False)

View File

@ -19,9 +19,9 @@ import numpy as np
import mindspore.context as context
from mindspore.communication.management import init
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.communication._comm_helper import GlobalComm
from ....dataset_mock import MindData
def get_dataset(batch_size=1):
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
@ -75,7 +75,9 @@ def test_dataset_iter_normal():
@pytest.mark.skipif('not context.get_context("enable_ge")')
def test_dataset_iter_ge():
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
count = 0
@ -87,7 +89,9 @@ def test_dataset_iter_ge():
@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms_loop_sink():
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
context.set_context(enable_loop_sink=True)
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink():
@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms():
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
context.set_context(enable_loop_sink=False)
dataset = get_dataset(32)
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)