forked from mindspore-Ecosystem/mindspore
!16563 Fix Codex
From: @huangxinjing Reviewed-by: @stsuteng,@yangzhenzhang Signed-off-by: @stsuteng
This commit is contained in:
commit
9f2c8642da
|
@ -45,6 +45,7 @@ else:
|
|||
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
|
||||
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
|
||||
|
||||
|
||||
class Backend:
|
||||
"""
|
||||
Class for available backends.
|
||||
|
@ -79,6 +80,7 @@ class Backend:
|
|||
|
||||
DEFAULT_BACKEND = Backend("hccl")
|
||||
|
||||
|
||||
class GlobalComm:
|
||||
"""World communication information."""
|
||||
BACKEND = DEFAULT_BACKEND
|
||||
|
@ -86,6 +88,7 @@ class GlobalComm:
|
|||
INITED = False
|
||||
CHECK_ENVS = True
|
||||
|
||||
|
||||
def is_hccl_available():
|
||||
"""
|
||||
Check hccl api is available.
|
||||
|
|
|
@ -67,12 +67,14 @@ def check_rank_id(rank_id):
|
|||
|
||||
|
||||
def load_lib():
|
||||
"""load hccl lib"""
|
||||
try:
|
||||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
|
||||
hccl_lib = ctypes.CDLL(lib_path)
|
||||
except Exception:
|
||||
raise RuntimeError('Get hccl lib error.')
|
||||
|
||||
global HCCL_LIB_CTYPES
|
||||
HCCL_LIB_CTYPES = hccl_lib
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
|||
|
||||
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||
|
||||
|
||||
def _get_group(group):
|
||||
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
||||
if group == DEFAULT_WORLD_COMM_GROUP:
|
||||
|
@ -113,6 +114,7 @@ def init(backend_name=None):
|
|||
else:
|
||||
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
||||
|
||||
|
||||
def release():
|
||||
"""
|
||||
Release distributed resource. e.g. HCCL/NCCL.
|
||||
|
|
|
@ -328,20 +328,9 @@ class _AutoParallelContext:
|
|||
if sorted(indices) != indices:
|
||||
raise ValueError('elements in indices must be sorted in ascending order')
|
||||
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
new_group = self._check_and_default_group(group)
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
|
||||
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
||||
self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
|
||||
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
|
||||
_set_fusion_strategy_by_idx(indices)
|
||||
|
||||
|
@ -359,19 +348,8 @@ class _AutoParallelContext:
|
|||
TypeError: If group is not a python str.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
return self._context_handle.get_all_reduce_fusion_split_indices(group)
|
||||
new_group = self._check_and_default_group(group)
|
||||
return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
|
||||
|
||||
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
|
||||
"""
|
||||
|
@ -393,20 +371,8 @@ class _AutoParallelContext:
|
|||
else:
|
||||
raise TypeError('sizes must be a python list')
|
||||
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
|
||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
|
||||
new_group = self._check_and_default_group(group)
|
||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_size(sizes)
|
||||
|
||||
|
@ -424,19 +390,8 @@ class _AutoParallelContext:
|
|||
TypeError: If group is not a python str.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
|
||||
new_group = self._check_and_default_group(group)
|
||||
return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
|
||||
|
||||
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
|
||||
"""
|
||||
|
@ -550,6 +505,23 @@ class _AutoParallelContext:
|
|||
self._context_handle.reset()
|
||||
|
||||
|
||||
def _check_and_default_group(self, group):
|
||||
"""Validate the given group, if group is empty, returns a default fusion group"""
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN:
|
||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
return group
|
||||
|
||||
|
||||
_auto_parallel_context = None
|
||||
|
||||
|
||||
|
@ -706,6 +678,7 @@ def _get_auto_parallel_context(attr_key):
|
|||
get_func = _get_auto_parallel_context_func_map[attr_key]
|
||||
return get_func()
|
||||
|
||||
|
||||
def _reset_auto_parallel_context():
|
||||
"""
|
||||
Reset auto parallel context attributes to the default values:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Data paralell allreduce fusion"""
|
||||
"""Data parallel allreduce fusion"""
|
||||
|
||||
import ctypes
|
||||
|
||||
|
@ -41,27 +41,27 @@ def _c_array(ctype, values):
|
|||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||
def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
|
||||
"""
|
||||
A function set gradient segment strategy according to the index list.
|
||||
|
||||
Note:
|
||||
In the back propagation,
|
||||
the fusion of the allreduce operators with a fusion attribute equals 1,
|
||||
will be performed according to the idxList,
|
||||
will be performed according to the idx_list,
|
||||
to achieve the effect of parallel between calculation and communication.
|
||||
|
||||
Args:
|
||||
idxList (list): The index list of the gradient.
|
||||
idx_list (list): The index list of the gradient.
|
||||
group (str): The hccl communication group.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a python str.
|
||||
TypeError: If IdxList is not a python list.
|
||||
TypeError: If type of idxList item is not int.
|
||||
TypeError: If idx_list is not a python list.
|
||||
TypeError: If type of idx_list item is not int.
|
||||
ValueError: If group name length is out of range.
|
||||
ValueError: If idxList length is 0.
|
||||
ValueError: If idxList item is less than 0.
|
||||
ValueError: If idx_list length is 0.
|
||||
ValueError: If idx_list item is less than 0.
|
||||
RuntimeError: If allreduce split failed.
|
||||
"""
|
||||
try:
|
||||
|
@ -70,6 +70,8 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
|||
import hccl_test.manage.api as hccl
|
||||
hccl.set_fusion_strategy_by_idx()
|
||||
return
|
||||
finally:
|
||||
pass
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
|
||||
|
@ -77,49 +79,49 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
|||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if isinstance(idxList, (list)):
|
||||
idx_len = len(idxList)
|
||||
if isinstance(idx_list, (list)):
|
||||
idx_len = len(idx_list)
|
||||
if idx_len == 0:
|
||||
raise ValueError('IdxList length is 0')
|
||||
raise ValueError('idx_list length is 0')
|
||||
else:
|
||||
raise TypeError('IdxList must be a python list')
|
||||
raise TypeError('idx_list must be a python list')
|
||||
|
||||
for idx in idxList:
|
||||
for idx in idx_list:
|
||||
if isinstance(idx, (int)):
|
||||
if idx < 0:
|
||||
raise ValueError('Idx < 0')
|
||||
else:
|
||||
raise TypeError('Idx in idxList is invalid')
|
||||
raise TypeError('Idx in idx_list is invalid')
|
||||
|
||||
c_array_idxList = _c_array(ctypes.c_uint, idxList)
|
||||
c_idx_num = ctypes.c_uint(len(idxList))
|
||||
c_array_idx_list = _c_array(ctypes.c_uint, idx_list)
|
||||
c_idx_num = ctypes.c_uint(len(idx_list))
|
||||
c_group = _c_str(group)
|
||||
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
|
||||
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idx_list)
|
||||
if ret != 0:
|
||||
raise RuntimeError('Allreduce split error')
|
||||
|
||||
|
||||
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
||||
def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
|
||||
"""
|
||||
A function set gradient segment strategy according to the data size percentage list.
|
||||
|
||||
Note:
|
||||
In the back propagation,
|
||||
the fusion of the allreduce operators with a fusion attribute equals 1,
|
||||
will be performed according to dataSizeList,
|
||||
will be performed according to data_size_list,
|
||||
to achieve the effect of parallel between calculation and communication.
|
||||
|
||||
Args:
|
||||
dataSizeList (list): The data size percentage list of the gradient.
|
||||
data_size_list (list): The data size percentage list of the gradient.
|
||||
group (str): The hccl communication group.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a python str.
|
||||
TypeError: If dataSizeList is not a python list.
|
||||
TypeError: If type of dataSizeList item is not int or float.
|
||||
TypeError: If data_size_list is not a python list.
|
||||
TypeError: If type of data_size_list item is not int or float.
|
||||
ValueError: If group name length is out of range.
|
||||
ValueError: If dataSizeList length is 0.
|
||||
ValueError: If dataSizeList item is less than 0.
|
||||
ValueError: If data_size_list length is 0.
|
||||
ValueError: If data_size_list item is less than 0.
|
||||
RuntimeError: If allreduce split failed.
|
||||
"""
|
||||
try:
|
||||
|
@ -128,24 +130,27 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
|||
import hccl_test.manage.api as hccl
|
||||
hccl.set_fusion_strategy_by_size()
|
||||
return
|
||||
finally:
|
||||
pass
|
||||
|
||||
if isinstance(group, (str)):
|
||||
group_len = len(group)
|
||||
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
|
||||
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
if isinstance(dataSizeList, (list)):
|
||||
len_data_size = len(dataSizeList)
|
||||
if isinstance(data_size_list, (list)):
|
||||
len_data_size = len(data_size_list)
|
||||
if len_data_size == 0:
|
||||
raise ValueError('DataSizeList length is 0')
|
||||
raise ValueError('data_size_list length is 0')
|
||||
else:
|
||||
raise TypeError('DataSizeList must be a python list')
|
||||
for dataSize in dataSizeList:
|
||||
if not isinstance(dataSize, (int, float)):
|
||||
raise TypeError('DataSize in dataSizeList is invalid')
|
||||
raise TypeError('data_size_list must be a python list')
|
||||
for data_size in data_size_list:
|
||||
if not isinstance(data_size, (int, float)):
|
||||
raise TypeError('data_size in data_size_list is invalid')
|
||||
|
||||
c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
|
||||
c_size_num = ctypes.c_uint(len(dataSizeList))
|
||||
c_array_sizeList = _c_array(ctypes.c_float, data_size_list)
|
||||
c_size_num = ctypes.c_uint(len(data_size_list))
|
||||
c_group = _c_str(group)
|
||||
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
|
||||
if ret != 0:
|
||||
|
|
|
@ -60,12 +60,14 @@ _get_ps_context_func_map = {
|
|||
"enable_ps": ps_context().is_ps_mode
|
||||
}
|
||||
|
||||
|
||||
def _get_ps_mode_rank():
|
||||
ps_rank = ps_context().ps_rank_id()
|
||||
if ps_rank == -1:
|
||||
raise RuntimeError("The parameter server mode training is not enabled yet.")
|
||||
return ps_rank
|
||||
|
||||
|
||||
def _set_ps_context(**kwargs):
|
||||
"""
|
||||
Set parameter server training mode context.
|
||||
|
@ -103,6 +105,7 @@ def _set_ps_context(**kwargs):
|
|||
set_func = _set_ps_context_func_map[key]
|
||||
set_func(value)
|
||||
|
||||
|
||||
def _get_ps_context(attr_key):
|
||||
"""
|
||||
Get parameter server training mode context attribute value according to the key.
|
||||
|
@ -122,6 +125,7 @@ def _get_ps_context(attr_key):
|
|||
value = get_func()
|
||||
return value
|
||||
|
||||
|
||||
def _reset_ps_context():
|
||||
"""
|
||||
Reset parameter server training mode context attributes to the default values:
|
||||
|
@ -130,30 +134,39 @@ def _reset_ps_context():
|
|||
"""
|
||||
ps_context().reset()
|
||||
|
||||
|
||||
def _is_role_worker():
|
||||
return ps_context().is_worker()
|
||||
|
||||
|
||||
def _is_role_pserver():
|
||||
return ps_context().is_server()
|
||||
|
||||
|
||||
def _is_role_sched():
|
||||
return ps_context().is_scheduler()
|
||||
|
||||
|
||||
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
|
||||
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
|
||||
|
||||
|
||||
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
|
||||
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
|
||||
|
||||
|
||||
def _insert_weight_init_info(name, global_seed, op_seed):
|
||||
ps_context().insert_weight_init_info(name, global_seed, op_seed)
|
||||
|
||||
|
||||
def _insert_accumu_init_info(name, init_val):
|
||||
ps_context().insert_accumu_init_info(name, init_val)
|
||||
|
||||
|
||||
def _clone_hash_table(dest_param_name, src_param_name):
|
||||
ps_context().clone_hash_table(dest_param_name, src_param_name)
|
||||
|
||||
|
||||
def _set_cache_enable(cache_enable):
|
||||
# Environment variables are used to specify a maximum number of OpenBLAS threads:
|
||||
# In ubuntu(GPU) environment, numpy will use too many threads for computing,
|
||||
|
@ -163,5 +176,6 @@ def _set_cache_enable(cache_enable):
|
|||
os.environ['OMP_NUM_THREADS'] = '2'
|
||||
ps_context().set_cache_enable(cache_enable)
|
||||
|
||||
|
||||
def _set_rank_id(rank_id):
|
||||
ps_context().set_rank_id(rank_id)
|
||||
|
|
|
@ -37,10 +37,12 @@ def _get_full_batch():
|
|||
"""Get whether to use full_batch."""
|
||||
return auto_parallel_context().get_full_batch()
|
||||
|
||||
|
||||
def _get_pipeline_stages():
|
||||
"""Get pipeline stages"""
|
||||
return auto_parallel_context().get_pipeline_stages()
|
||||
|
||||
|
||||
def _check_full_batch():
|
||||
"""
|
||||
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
|
||||
|
@ -62,6 +64,7 @@ def _need_to_full():
|
|||
and (not full_batch))
|
||||
return need
|
||||
|
||||
|
||||
def _to_full_shapes(shapes, device_num):
|
||||
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
|
||||
new_shapes = []
|
||||
|
@ -75,9 +78,11 @@ def _to_full_shapes(shapes, device_num):
|
|||
new_shapes.append(new_shape)
|
||||
return new_shapes
|
||||
|
||||
|
||||
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
||||
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
|
||||
from host solution."""
|
||||
from host solution.
|
||||
"""
|
||||
lst = []
|
||||
if not isinstance(elem, (tuple, list)):
|
||||
elem = [elem]
|
||||
|
@ -109,6 +114,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
|||
lst.append(Tensor(scaling_sens, mstype.float32))
|
||||
return tuple(lst)
|
||||
|
||||
|
||||
def _get_gradients_mean():
|
||||
"""Get if using gradients_mean."""
|
||||
return auto_parallel_context().get_gradients_mean()
|
||||
|
@ -188,6 +194,7 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
|
|||
"do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
|
||||
.format(parallel_mode, parameter_broadcast))
|
||||
|
||||
|
||||
def _get_python_op(op_name, op_path, instance_name, arglist):
|
||||
"""Get python operator."""
|
||||
module = __import__(op_path, fromlist=["None"])
|
||||
|
|
|
@ -19,6 +19,7 @@ import threading
|
|||
from mindspore._c_expression import MpiConfig
|
||||
from mindspore._checkparam import args_type_check
|
||||
|
||||
|
||||
class _MpiConfig:
|
||||
"""
|
||||
_MpiConfig is the config tool for controlling MPI
|
||||
|
@ -55,6 +56,8 @@ class _MpiConfig:
|
|||
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
|
||||
|
||||
_k_mpi_config = None
|
||||
|
||||
|
||||
def _mpi_config():
|
||||
"""
|
||||
Get the global mpi config, if mpi config is not created, create a new one.
|
||||
|
@ -67,13 +70,14 @@ def _mpi_config():
|
|||
_k_mpi_config = _MpiConfig()
|
||||
return _k_mpi_config
|
||||
|
||||
|
||||
@args_type_check(enable_mpi=bool)
|
||||
def _set_mpi_config(**kwargs):
|
||||
"""
|
||||
Sets mpi config for running environment.
|
||||
|
||||
mpi config should be configured before running your program. If there is no configuration,
|
||||
mpi moudle will be disabled by default.
|
||||
mpi module will be disabled by default.
|
||||
|
||||
Note:
|
||||
Attribute name is required for setting attributes.
|
||||
|
|
Loading…
Reference in New Issue