From: @huangxinjing
Reviewed-by: @stsuteng,@yangzhenzhang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-05-24 17:34:11 +08:00 committed by Gitee
commit 9f2c8642da
8 changed files with 98 additions and 88 deletions

View File

@ -45,6 +45,7 @@ else:
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend:
"""
Class for available backends.
@ -79,6 +80,7 @@ class Backend:
DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
@ -86,6 +88,7 @@ class GlobalComm:
INITED = False
CHECK_ENVS = True
def is_hccl_available():
"""
Check hccl api is available.

View File

@ -67,12 +67,14 @@ def check_rank_id(rank_id):
def load_lib():
"""load hccl lib"""
try:
base_dir = os.path.dirname(os.path.realpath(__file__))
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
hccl_lib = ctypes.CDLL(lib_path)
except Exception:
raise RuntimeError('Get hccl lib error.')
global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib

View File

@ -29,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
if group == DEFAULT_WORLD_COMM_GROUP:
@ -113,6 +114,7 @@ def init(backend_name=None):
else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
def release():
"""
Release distributed resource. e.g. HCCL/NCCL.

View File

@ -328,20 +328,9 @@ class _AutoParallelContext:
if sorted(indices) != indices:
raise ValueError('elements in indices must be sorted in ascending order')
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
new_group = self._check_and_default_group(group)
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
_set_fusion_strategy_by_idx(indices)
@ -359,19 +348,8 @@ class _AutoParallelContext:
TypeError: If group is not a python str.
"""
self.check_context_handle()
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return self._context_handle.get_all_reduce_fusion_split_indices(group)
new_group = self._check_and_default_group(group)
return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
"""
@ -393,20 +371,8 @@ class _AutoParallelContext:
else:
raise TypeError('sizes must be a python list')
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
new_group = self._check_and_default_group(group)
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_size(sizes)
@ -424,19 +390,8 @@ class _AutoParallelContext:
TypeError: If group is not a python str.
"""
self.check_context_handle()
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
new_group = self._check_and_default_group(group)
return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
"""
@ -550,6 +505,23 @@ class _AutoParallelContext:
self._context_handle.reset()
def _check_and_default_group(self, group):
"""Validate the given group, if group is empty, returns a default fusion group"""
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return group
_auto_parallel_context = None
@ -706,6 +678,7 @@ def _get_auto_parallel_context(attr_key):
get_func = _get_auto_parallel_context_func_map[attr_key]
return get_func()
def _reset_auto_parallel_context():
"""
Reset auto parallel context attributes to the default values:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data paralell allreduce fusion"""
"""Data parallel allreduce fusion"""
import ctypes
@ -41,27 +41,27 @@ def _c_array(ctype, values):
return (ctype * len(values))(*values)
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
"""
A function set gradient segment strategy according to the index list.
Note:
In the back propagation,
the fusion of the allreduce operators with a fusion attribute equals 1,
will be performed according to the idxList,
will be performed according to the idx_list,
to achieve the effect of parallel between calculation and communication.
Args:
idxList (list): The index list of the gradient.
idx_list (list): The index list of the gradient.
group (str): The hccl communication group.
Raises:
TypeError: If group is not a python str.
TypeError: If IdxList is not a python list.
TypeError: If type of idxList item is not int.
TypeError: If idx_list is not a python list.
TypeError: If type of idx_list item is not int.
ValueError: If group name length is out of range.
ValueError: If idxList length is 0.
ValueError: If idxList item is less than 0.
ValueError: If idx_list length is 0.
ValueError: If idx_list item is less than 0.
RuntimeError: If allreduce split failed.
"""
try:
@ -70,6 +70,8 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
import hccl_test.manage.api as hccl
hccl.set_fusion_strategy_by_idx()
return
finally:
pass
if isinstance(group, (str)):
group_len = len(group)
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
@ -77,49 +79,49 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
else:
raise TypeError('Group must be a python str')
if isinstance(idxList, (list)):
idx_len = len(idxList)
if isinstance(idx_list, (list)):
idx_len = len(idx_list)
if idx_len == 0:
raise ValueError('IdxList length is 0')
raise ValueError('idx_list length is 0')
else:
raise TypeError('IdxList must be a python list')
raise TypeError('idx_list must be a python list')
for idx in idxList:
for idx in idx_list:
if isinstance(idx, (int)):
if idx < 0:
raise ValueError('Idx < 0')
else:
raise TypeError('Idx in idxList is invalid')
raise TypeError('Idx in idx_list is invalid')
c_array_idxList = _c_array(ctypes.c_uint, idxList)
c_idx_num = ctypes.c_uint(len(idxList))
c_array_idx_list = _c_array(ctypes.c_uint, idx_list)
c_idx_num = ctypes.c_uint(len(idx_list))
c_group = _c_str(group)
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idx_list)
if ret != 0:
raise RuntimeError('Allreduce split error')
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
"""
A function set gradient segment strategy according to the data size percentage list.
Note:
In the back propagation,
the fusion of the allreduce operators with a fusion attribute equals 1,
will be performed according to dataSizeList,
will be performed according to data_size_list,
to achieve the effect of parallel between calculation and communication.
Args:
dataSizeList (list): The data size percentage list of the gradient.
data_size_list (list): The data size percentage list of the gradient.
group (str): The hccl communication group.
Raises:
TypeError: If group is not a python str.
TypeError: If dataSizeList is not a python list.
TypeError: If type of dataSizeList item is not int or float.
TypeError: If data_size_list is not a python list.
TypeError: If type of data_size_list item is not int or float.
ValueError: If group name length is out of range.
ValueError: If dataSizeList length is 0.
ValueError: If dataSizeList item is less than 0.
ValueError: If data_size_list length is 0.
ValueError: If data_size_list item is less than 0.
RuntimeError: If allreduce split failed.
"""
try:
@ -128,24 +130,27 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
import hccl_test.manage.api as hccl
hccl.set_fusion_strategy_by_size()
return
finally:
pass
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if isinstance(dataSizeList, (list)):
len_data_size = len(dataSizeList)
if isinstance(data_size_list, (list)):
len_data_size = len(data_size_list)
if len_data_size == 0:
raise ValueError('DataSizeList length is 0')
raise ValueError('data_size_list length is 0')
else:
raise TypeError('DataSizeList must be a python list')
for dataSize in dataSizeList:
if not isinstance(dataSize, (int, float)):
raise TypeError('DataSize in dataSizeList is invalid')
raise TypeError('data_size_list must be a python list')
for data_size in data_size_list:
if not isinstance(data_size, (int, float)):
raise TypeError('data_size in data_size_list is invalid')
c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
c_size_num = ctypes.c_uint(len(dataSizeList))
c_array_sizeList = _c_array(ctypes.c_float, data_size_list)
c_size_num = ctypes.c_uint(len(data_size_list))
c_group = _c_str(group)
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
if ret != 0:

View File

@ -60,12 +60,14 @@ _get_ps_context_func_map = {
"enable_ps": ps_context().is_ps_mode
}
def _get_ps_mode_rank():
ps_rank = ps_context().ps_rank_id()
if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not enabled yet.")
return ps_rank
def _set_ps_context(**kwargs):
"""
Set parameter server training mode context.
@ -103,6 +105,7 @@ def _set_ps_context(**kwargs):
set_func = _set_ps_context_func_map[key]
set_func(value)
def _get_ps_context(attr_key):
"""
Get parameter server training mode context attribute value according to the key.
@ -122,6 +125,7 @@ def _get_ps_context(attr_key):
value = get_func()
return value
def _reset_ps_context():
"""
Reset parameter server training mode context attributes to the default values:
@ -130,30 +134,39 @@ def _reset_ps_context():
"""
ps_context().reset()
def _is_role_worker():
return ps_context().is_worker()
def _is_role_pserver():
return ps_context().is_server()
def _is_role_sched():
return ps_context().is_scheduler()
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
def _insert_weight_init_info(name, global_seed, op_seed):
ps_context().insert_weight_init_info(name, global_seed, op_seed)
def _insert_accumu_init_info(name, init_val):
ps_context().insert_accumu_init_info(name, init_val)
def _clone_hash_table(dest_param_name, src_param_name):
ps_context().clone_hash_table(dest_param_name, src_param_name)
def _set_cache_enable(cache_enable):
# Environment variables are used to specify a maximum number of OpenBLAS threads:
# In ubuntu(GPU) environment, numpy will use too many threads for computing,
@ -163,5 +176,6 @@ def _set_cache_enable(cache_enable):
os.environ['OMP_NUM_THREADS'] = '2'
ps_context().set_cache_enable(cache_enable)
def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id)

View File

@ -37,10 +37,12 @@ def _get_full_batch():
"""Get whether to use full_batch."""
return auto_parallel_context().get_full_batch()
def _get_pipeline_stages():
"""Get pipeline stages"""
return auto_parallel_context().get_pipeline_stages()
def _check_full_batch():
"""
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
@ -62,6 +64,7 @@ def _need_to_full():
and (not full_batch))
return need
def _to_full_shapes(shapes, device_num):
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
new_shapes = []
@ -75,9 +78,11 @@ def _to_full_shapes(shapes, device_num):
new_shapes.append(new_shape)
return new_shapes
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution."""
from host solution.
"""
lst = []
if not isinstance(elem, (tuple, list)):
elem = [elem]
@ -109,6 +114,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
lst.append(Tensor(scaling_sens, mstype.float32))
return tuple(lst)
def _get_gradients_mean():
"""Get if using gradients_mean."""
return auto_parallel_context().get_gradients_mean()
@ -188,6 +194,7 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
"do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
.format(parallel_mode, parameter_broadcast))
def _get_python_op(op_name, op_path, instance_name, arglist):
"""Get python operator."""
module = __import__(op_path, fromlist=["None"])

View File

@ -19,6 +19,7 @@ import threading
from mindspore._c_expression import MpiConfig
from mindspore._checkparam import args_type_check
class _MpiConfig:
"""
_MpiConfig is the config tool for controlling MPI
@ -55,6 +56,8 @@ class _MpiConfig:
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
_k_mpi_config = None
def _mpi_config():
"""
Get the global mpi config, if mpi config is not created, create a new one.
@ -67,13 +70,14 @@ def _mpi_config():
_k_mpi_config = _MpiConfig()
return _k_mpi_config
@args_type_check(enable_mpi=bool)
def _set_mpi_config(**kwargs):
"""
Sets mpi config for running environment.
mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default.
mpi module will be disabled by default.
Note:
Attribute name is required for setting attributes.