From: @huangxinjing
Reviewed-by: @stsuteng,@yangzhenzhang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-05-24 17:34:11 +08:00 committed by Gitee
commit 9f2c8642da
8 changed files with 98 additions and 88 deletions

View File

@ -45,6 +45,7 @@ else:
HCCL_WORLD_COMM_GROUP = "hccl_world_group" HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend: class Backend:
""" """
Class for available backends. Class for available backends.
@ -79,6 +80,7 @@ class Backend:
DEFAULT_BACKEND = Backend("hccl") DEFAULT_BACKEND = Backend("hccl")
class GlobalComm: class GlobalComm:
"""World communication information.""" """World communication information."""
BACKEND = DEFAULT_BACKEND BACKEND = DEFAULT_BACKEND
@ -86,6 +88,7 @@ class GlobalComm:
INITED = False INITED = False
CHECK_ENVS = True CHECK_ENVS = True
def is_hccl_available(): def is_hccl_available():
""" """
Check hccl api is available. Check hccl api is available.

View File

@ -67,12 +67,14 @@ def check_rank_id(rank_id):
def load_lib(): def load_lib():
"""load hccl lib"""
try: try:
base_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.dirname(os.path.realpath(__file__))
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB) lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
hccl_lib = ctypes.CDLL(lib_path) hccl_lib = ctypes.CDLL(lib_path)
except Exception: except Exception:
raise RuntimeError('Get hccl lib error.') raise RuntimeError('Get hccl lib error.')
global HCCL_LIB_CTYPES global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib HCCL_LIB_CTYPES = hccl_lib

View File

@ -29,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
def _get_group(group): def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
if group == DEFAULT_WORLD_COMM_GROUP: if group == DEFAULT_WORLD_COMM_GROUP:
@ -113,6 +114,7 @@ def init(backend_name=None):
else: else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name)) raise RuntimeError("Backend name {} is not supported.".format(backend_name))
def release(): def release():
""" """
Release distributed resource. e.g. HCCL/NCCL. Release distributed resource. e.g. HCCL/NCCL.

View File

@ -328,20 +328,9 @@ class _AutoParallelContext:
if sorted(indices) != indices: if sorted(indices) != indices:
raise ValueError('elements in indices must be sorted in ascending order') raise ValueError('elements in indices must be sorted in ascending order')
if isinstance(group, (str)): new_group = self._check_and_default_group(group)
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "": self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"): if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
_set_fusion_strategy_by_idx(indices) _set_fusion_strategy_by_idx(indices)
@ -359,19 +348,8 @@ class _AutoParallelContext:
TypeError: If group is not a python str. TypeError: If group is not a python str.
""" """
self.check_context_handle() self.check_context_handle()
if isinstance(group, (str)): new_group = self._check_and_default_group(group)
group_len = len(group) return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return self._context_handle.get_all_reduce_fusion_split_indices(group)
def set_all_reduce_fusion_split_sizes(self, sizes, group=""): def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
""" """
@ -393,20 +371,8 @@ class _AutoParallelContext:
else: else:
raise TypeError('sizes must be a python list') raise TypeError('sizes must be a python list')
if isinstance(group, (str)): new_group = self._check_and_default_group(group)
group_len = len(group) self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_size(sizes) _set_fusion_strategy_by_size(sizes)
@ -424,19 +390,8 @@ class _AutoParallelContext:
TypeError: If group is not a python str. TypeError: If group is not a python str.
""" """
self.check_context_handle() self.check_context_handle()
if isinstance(group, (str)): new_group = self._check_and_default_group(group)
group_len = len(group) return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
""" """
@ -550,6 +505,23 @@ class _AutoParallelContext:
self._context_handle.reset() self._context_handle.reset()
def _check_and_default_group(self, group):
"""Validate the given group, if group is empty, returns a default fusion group"""
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
if group == "":
if context.get_context("device_target") == "Ascend":
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
else:
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
return group
_auto_parallel_context = None _auto_parallel_context = None
@ -706,6 +678,7 @@ def _get_auto_parallel_context(attr_key):
get_func = _get_auto_parallel_context_func_map[attr_key] get_func = _get_auto_parallel_context_func_map[attr_key]
return get_func() return get_func()
def _reset_auto_parallel_context(): def _reset_auto_parallel_context():
""" """
Reset auto parallel context attributes to the default values: Reset auto parallel context attributes to the default values:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Data paralell allreduce fusion""" """Data parallel allreduce fusion"""
import ctypes import ctypes
@ -41,27 +41,27 @@ def _c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"): def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
""" """
A function set gradient segment strategy according to the index list. A function set gradient segment strategy according to the index list.
Note: Note:
In the back propagation, In the back propagation,
the fusion of the allreduce operators with a fusion attribute equals 1, the fusion of the allreduce operators with a fusion attribute equals 1,
will be performed according to the idxList, will be performed according to the idx_list,
to achieve the effect of parallel between calculation and communication. to achieve the effect of parallel between calculation and communication.
Args: Args:
idxList (list): The index list of the gradient. idx_list (list): The index list of the gradient.
group (str): The hccl communication group. group (str): The hccl communication group.
Raises: Raises:
TypeError: If group is not a python str. TypeError: If group is not a python str.
TypeError: If IdxList is not a python list. TypeError: If idx_list is not a python list.
TypeError: If type of idxList item is not int. TypeError: If type of idx_list item is not int.
ValueError: If group name length is out of range. ValueError: If group name length is out of range.
ValueError: If idxList length is 0. ValueError: If idx_list length is 0.
ValueError: If idxList item is less than 0. ValueError: If idx_list item is less than 0.
RuntimeError: If allreduce split failed. RuntimeError: If allreduce split failed.
""" """
try: try:
@ -70,6 +70,8 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
import hccl_test.manage.api as hccl import hccl_test.manage.api as hccl
hccl.set_fusion_strategy_by_idx() hccl.set_fusion_strategy_by_idx()
return return
finally:
pass
if isinstance(group, (str)): if isinstance(group, (str)):
group_len = len(group) group_len = len(group)
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0): if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
@ -77,49 +79,49 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
else: else:
raise TypeError('Group must be a python str') raise TypeError('Group must be a python str')
if isinstance(idxList, (list)): if isinstance(idx_list, (list)):
idx_len = len(idxList) idx_len = len(idx_list)
if idx_len == 0: if idx_len == 0:
raise ValueError('IdxList length is 0') raise ValueError('idx_list length is 0')
else: else:
raise TypeError('IdxList must be a python list') raise TypeError('idx_list must be a python list')
for idx in idxList: for idx in idx_list:
if isinstance(idx, (int)): if isinstance(idx, (int)):
if idx < 0: if idx < 0:
raise ValueError('Idx < 0') raise ValueError('Idx < 0')
else: else:
raise TypeError('Idx in idxList is invalid') raise TypeError('Idx in idx_list is invalid')
c_array_idxList = _c_array(ctypes.c_uint, idxList) c_array_idx_list = _c_array(ctypes.c_uint, idx_list)
c_idx_num = ctypes.c_uint(len(idxList)) c_idx_num = ctypes.c_uint(len(idx_list))
c_group = _c_str(group) c_group = _c_str(group)
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList) ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idx_list)
if ret != 0: if ret != 0:
raise RuntimeError('Allreduce split error') raise RuntimeError('Allreduce split error')
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"): def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
""" """
A function set gradient segment strategy according to the data size percentage list. A function set gradient segment strategy according to the data size percentage list.
Note: Note:
In the back propagation, In the back propagation,
the fusion of the allreduce operators with a fusion attribute equals 1, the fusion of the allreduce operators with a fusion attribute equals 1,
will be performed according to dataSizeList, will be performed according to data_size_list,
to achieve the effect of parallel between calculation and communication. to achieve the effect of parallel between calculation and communication.
Args: Args:
dataSizeList (list): The data size percentage list of the gradient. data_size_list (list): The data size percentage list of the gradient.
group (str): The hccl communication group. group (str): The hccl communication group.
Raises: Raises:
TypeError: If group is not a python str. TypeError: If group is not a python str.
TypeError: If dataSizeList is not a python list. TypeError: If data_size_list is not a python list.
TypeError: If type of dataSizeList item is not int or float. TypeError: If type of data_size_list item is not int or float.
ValueError: If group name length is out of range. ValueError: If group name length is out of range.
ValueError: If dataSizeList length is 0. ValueError: If data_size_list length is 0.
ValueError: If dataSizeList item is less than 0. ValueError: If data_size_list item is less than 0.
RuntimeError: If allreduce split failed. RuntimeError: If allreduce split failed.
""" """
try: try:
@ -128,24 +130,27 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
import hccl_test.manage.api as hccl import hccl_test.manage.api as hccl
hccl.set_fusion_strategy_by_size() hccl.set_fusion_strategy_by_size()
return return
finally:
pass
if isinstance(group, (str)): if isinstance(group, (str)):
group_len = len(group) group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0: if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}') raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
else: else:
raise TypeError('Group must be a python str') raise TypeError('Group must be a python str')
if isinstance(dataSizeList, (list)): if isinstance(data_size_list, (list)):
len_data_size = len(dataSizeList) len_data_size = len(data_size_list)
if len_data_size == 0: if len_data_size == 0:
raise ValueError('DataSizeList length is 0') raise ValueError('data_size_list length is 0')
else: else:
raise TypeError('DataSizeList must be a python list') raise TypeError('data_size_list must be a python list')
for dataSize in dataSizeList: for data_size in data_size_list:
if not isinstance(dataSize, (int, float)): if not isinstance(data_size, (int, float)):
raise TypeError('DataSize in dataSizeList is invalid') raise TypeError('data_size in data_size_list is invalid')
c_array_sizeList = _c_array(ctypes.c_float, dataSizeList) c_array_sizeList = _c_array(ctypes.c_float, data_size_list)
c_size_num = ctypes.c_uint(len(dataSizeList)) c_size_num = ctypes.c_uint(len(data_size_list))
c_group = _c_str(group) c_group = _c_str(group)
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList) ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
if ret != 0: if ret != 0:

View File

@ -60,12 +60,14 @@ _get_ps_context_func_map = {
"enable_ps": ps_context().is_ps_mode "enable_ps": ps_context().is_ps_mode
} }
def _get_ps_mode_rank(): def _get_ps_mode_rank():
ps_rank = ps_context().ps_rank_id() ps_rank = ps_context().ps_rank_id()
if ps_rank == -1: if ps_rank == -1:
raise RuntimeError("The parameter server mode training is not enabled yet.") raise RuntimeError("The parameter server mode training is not enabled yet.")
return ps_rank return ps_rank
def _set_ps_context(**kwargs): def _set_ps_context(**kwargs):
""" """
Set parameter server training mode context. Set parameter server training mode context.
@ -103,6 +105,7 @@ def _set_ps_context(**kwargs):
set_func = _set_ps_context_func_map[key] set_func = _set_ps_context_func_map[key]
set_func(value) set_func(value)
def _get_ps_context(attr_key): def _get_ps_context(attr_key):
""" """
Get parameter server training mode context attribute value according to the key. Get parameter server training mode context attribute value according to the key.
@ -122,6 +125,7 @@ def _get_ps_context(attr_key):
value = get_func() value = get_func()
return value return value
def _reset_ps_context(): def _reset_ps_context():
""" """
Reset parameter server training mode context attributes to the default values: Reset parameter server training mode context attributes to the default values:
@ -130,30 +134,39 @@ def _reset_ps_context():
""" """
ps_context().reset() ps_context().reset()
def _is_role_worker(): def _is_role_worker():
return ps_context().is_worker() return ps_context().is_worker()
def _is_role_pserver(): def _is_role_pserver():
return ps_context().is_server() return ps_context().is_server()
def _is_role_sched(): def _is_role_sched():
return ps_context().is_scheduler() return ps_context().is_scheduler()
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size): def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size) ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size): def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size) ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
def _insert_weight_init_info(name, global_seed, op_seed): def _insert_weight_init_info(name, global_seed, op_seed):
ps_context().insert_weight_init_info(name, global_seed, op_seed) ps_context().insert_weight_init_info(name, global_seed, op_seed)
def _insert_accumu_init_info(name, init_val): def _insert_accumu_init_info(name, init_val):
ps_context().insert_accumu_init_info(name, init_val) ps_context().insert_accumu_init_info(name, init_val)
def _clone_hash_table(dest_param_name, src_param_name): def _clone_hash_table(dest_param_name, src_param_name):
ps_context().clone_hash_table(dest_param_name, src_param_name) ps_context().clone_hash_table(dest_param_name, src_param_name)
def _set_cache_enable(cache_enable): def _set_cache_enable(cache_enable):
# Environment variables are used to specify a maximum number of OpenBLAS threads: # Environment variables are used to specify a maximum number of OpenBLAS threads:
# In ubuntu(GPU) environment, numpy will use too many threads for computing, # In ubuntu(GPU) environment, numpy will use too many threads for computing,
@ -163,5 +176,6 @@ def _set_cache_enable(cache_enable):
os.environ['OMP_NUM_THREADS'] = '2' os.environ['OMP_NUM_THREADS'] = '2'
ps_context().set_cache_enable(cache_enable) ps_context().set_cache_enable(cache_enable)
def _set_rank_id(rank_id): def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id) ps_context().set_rank_id(rank_id)

View File

@ -37,10 +37,12 @@ def _get_full_batch():
"""Get whether to use full_batch.""" """Get whether to use full_batch."""
return auto_parallel_context().get_full_batch() return auto_parallel_context().get_full_batch()
def _get_pipeline_stages(): def _get_pipeline_stages():
"""Get pipeline stages""" """Get pipeline stages"""
return auto_parallel_context().get_pipeline_stages() return auto_parallel_context().get_pipeline_stages()
def _check_full_batch(): def _check_full_batch():
""" """
full_batch could only be used under semi_auto_parallel or auto_parallel, check it. full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
@ -62,6 +64,7 @@ def _need_to_full():
and (not full_batch)) and (not full_batch))
return need return need
def _to_full_shapes(shapes, device_num): def _to_full_shapes(shapes, device_num):
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution.""" """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
new_shapes = [] new_shapes = []
@ -75,9 +78,11 @@ def _to_full_shapes(shapes, device_num):
new_shapes.append(new_shape) new_shapes.append(new_shape)
return new_shapes return new_shapes
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution.""" from host solution.
"""
lst = [] lst = []
if not isinstance(elem, (tuple, list)): if not isinstance(elem, (tuple, list)):
elem = [elem] elem = [elem]
@ -109,6 +114,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
lst.append(Tensor(scaling_sens, mstype.float32)) lst.append(Tensor(scaling_sens, mstype.float32))
return tuple(lst) return tuple(lst)
def _get_gradients_mean(): def _get_gradients_mean():
"""Get if using gradients_mean.""" """Get if using gradients_mean."""
return auto_parallel_context().get_gradients_mean() return auto_parallel_context().get_gradients_mean()
@ -188,6 +194,7 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
"do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}" "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
.format(parallel_mode, parameter_broadcast)) .format(parallel_mode, parameter_broadcast))
def _get_python_op(op_name, op_path, instance_name, arglist): def _get_python_op(op_name, op_path, instance_name, arglist):
"""Get python operator.""" """Get python operator."""
module = __import__(op_path, fromlist=["None"]) module = __import__(op_path, fromlist=["None"])

View File

@ -19,6 +19,7 @@ import threading
from mindspore._c_expression import MpiConfig from mindspore._c_expression import MpiConfig
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
class _MpiConfig: class _MpiConfig:
""" """
_MpiConfig is the config tool for controlling MPI _MpiConfig is the config tool for controlling MPI
@ -55,6 +56,8 @@ class _MpiConfig:
self._mpiconfig_handle.set_enable_mpi(enable_mpi) self._mpiconfig_handle.set_enable_mpi(enable_mpi)
_k_mpi_config = None _k_mpi_config = None
def _mpi_config(): def _mpi_config():
""" """
Get the global mpi config, if mpi config is not created, create a new one. Get the global mpi config, if mpi config is not created, create a new one.
@ -67,13 +70,14 @@ def _mpi_config():
_k_mpi_config = _MpiConfig() _k_mpi_config = _MpiConfig()
return _k_mpi_config return _k_mpi_config
@args_type_check(enable_mpi=bool) @args_type_check(enable_mpi=bool)
def _set_mpi_config(**kwargs): def _set_mpi_config(**kwargs):
""" """
Sets mpi config for running environment. Sets mpi config for running environment.
mpi config should be configured before running your program. If there is no configuration, mpi config should be configured before running your program. If there is no configuration,
mpi moudle will be disabled by default. mpi module will be disabled by default.
Note: Note:
Attribute name is required for setting attributes. Attribute name is required for setting attributes.