forked from mindspore-Ecosystem/mindspore
parallel envs variable check
This commit is contained in:
parent
4c1cd579a7
commit
e967f1939b
|
@ -84,6 +84,7 @@ class GlobalComm:
|
|||
BACKEND = DEFAULT_BACKEND
|
||||
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
INITED = False
|
||||
CHECK_ENVS = True
|
||||
|
||||
def is_hccl_available():
|
||||
"""
|
||||
|
|
|
@ -24,7 +24,7 @@ HCCL_LIB_CTYPES = ""
|
|||
|
||||
def check_group(group):
|
||||
"""
|
||||
A function that check if a collection communication group is leagal.
|
||||
A function that check if a collection communication group is legal.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
@ -39,7 +39,7 @@ def check_group(group):
|
|||
|
||||
def check_rank_num(rank_num):
|
||||
"""
|
||||
A function that check if a collection communication rank number is leagal.If not raise error.
|
||||
A function that check if a collection communication rank number is legal.If not raise error.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
@ -53,7 +53,7 @@ def check_rank_num(rank_num):
|
|||
|
||||
def check_rank_id(rank_id):
|
||||
"""
|
||||
A function that check if a collection communication rank id is leagal.If not raise error.
|
||||
A function that check if a collection communication rank id is legal.If not raise error.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
@ -112,7 +112,7 @@ def create_group(group, rank_num, rank_ids):
|
|||
c_group = c_str(group)
|
||||
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
|
||||
if ret != 0:
|
||||
raise RuntimeError('Create group error.')
|
||||
raise RuntimeError('Create group error, the error code is ', ret)
|
||||
else:
|
||||
raise TypeError('Rank ids must be a python list.')
|
||||
|
||||
|
|
|
@ -36,6 +36,28 @@ def _get_group(group):
|
|||
return group
|
||||
|
||||
|
||||
def _check_parallel_envs():
|
||||
"""
|
||||
Check whether parallel environment variables have been exported or not.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
|
||||
"""
|
||||
if not GlobalComm.CHECK_ENVS:
|
||||
return
|
||||
import os
|
||||
rank_id_str = os.getenv("RANK_ID")
|
||||
if not rank_id_str:
|
||||
raise RuntimeError("Environment variables RANK_ID has not been exported")
|
||||
try:
|
||||
int(rank_id_str)
|
||||
except ValueError:
|
||||
print("RANK_ID should be number")
|
||||
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
|
||||
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
|
||||
if not rank_table_file_str and not rank_table_file_str_old:
|
||||
raise RuntimeError("Get hccl rank_table_file failed, "
|
||||
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
|
||||
|
||||
def init(backend_name=None):
|
||||
"""
|
||||
|
@ -68,6 +90,7 @@ def init(backend_name=None):
|
|||
if backend_name == "hccl":
|
||||
if device_target != "Ascend":
|
||||
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
|
||||
_check_parallel_envs()
|
||||
init_hccl()
|
||||
GlobalComm.BACKEND = Backend("hccl")
|
||||
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
|
|
|
@ -30,13 +30,16 @@ from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
|
|||
from mindspore.ops.operations.array_ops import Gather
|
||||
import mindspore
|
||||
|
||||
|
||||
# pylint: disable=W0212
|
||||
# W0212: protected-access
|
||||
|
||||
tag = 0
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init("hccl")
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
|
||||
|
||||
class AllReduceNet(nn.Cell):
|
||||
|
|
|
@ -31,11 +31,13 @@ from mindspore.parallel import set_algo_parameters
|
|||
from mindspore.parallel._utils import _reset_op_id as resset_op_id
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
|
||||
def weight_variable():
|
||||
return TruncatedNormal(0.02)
|
||||
|
|
|
@ -18,11 +18,15 @@ from mindspore.communication.management import init
|
|||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
from .test_auto_parallel_resnet import resnet50
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(device_id=0)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
|
||||
def test_train_32k_8p(batch_size=32, num_classes=32768):
|
||||
dev_num = 8
|
||||
|
|
|
@ -19,7 +19,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor, Parameter
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
class DataParallelNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -49,7 +49,9 @@ def test_param_broadcast():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
network = DataParallelNet()
|
||||
network.set_train()
|
||||
|
||||
|
@ -62,7 +64,9 @@ def test_param_not_broadcast():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
network = ModelParallelNet()
|
||||
network.set_train()
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
device_number = 32
|
||||
|
@ -124,7 +125,9 @@ class TrainOneStepCell(Cell):
|
|||
|
||||
|
||||
def net_trains(criterion, rank):
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
lr = 0.1
|
||||
momentum = 0.9
|
||||
max_epoch = 20
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.nn import Momentum
|
|||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.context import ParallelMode
|
||||
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, input_channel, out_channel):
|
||||
|
@ -47,7 +47,9 @@ def test_dense_gen_graph():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8)
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
network = Net(512, 128)
|
||||
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore import Tensor
|
|||
from mindspore import amp
|
||||
from mindspore import nn
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train import Model
|
||||
from ....dataset_mock import MindData
|
||||
|
@ -156,8 +157,8 @@ def test_compile_model_train_O2_parallel():
|
|||
net = NetNoLoss(16, 16)
|
||||
loss = nn.MSELoss()
|
||||
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
|
||||
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init()
|
||||
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
|
|
@ -19,9 +19,9 @@ import numpy as np
|
|||
import mindspore.context as context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.dataset_helper import DatasetHelper
|
||||
from mindspore.communication._comm_helper import GlobalComm
|
||||
from ....dataset_mock import MindData
|
||||
|
||||
|
||||
def get_dataset(batch_size=1):
|
||||
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
|
||||
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
|
||||
|
@ -75,7 +75,9 @@ def test_dataset_iter_normal():
|
|||
|
||||
@pytest.mark.skipif('not context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ge():
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init("hccl")
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
count = 0
|
||||
|
@ -87,7 +89,9 @@ def test_dataset_iter_ge():
|
|||
|
||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ms_loop_sink():
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init("hccl")
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
context.set_context(enable_loop_sink=True)
|
||||
dataset = get_dataset(32)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
|
@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink():
|
|||
|
||||
@pytest.mark.skipif('context.get_context("enable_ge")')
|
||||
def test_dataset_iter_ms():
|
||||
GlobalComm.CHECK_ENVS = False
|
||||
init("hccl")
|
||||
GlobalComm.CHECK_ENVS = True
|
||||
context.set_context(enable_loop_sink=False)
|
||||
dataset = get_dataset(32)
|
||||
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
|
||||
|
|
Loading…
Reference in New Issue