forked from mindspore-Ecosystem/mindspore
!16563 Fix Codex
From: @huangxinjing Reviewed-by: @stsuteng,@yangzhenzhang Signed-off-by: @stsuteng
This commit is contained in:
commit
9f2c8642da
|
@ -45,6 +45,7 @@ else:
|
||||||
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
|
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
|
||||||
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
|
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
|
||||||
|
|
||||||
|
|
||||||
class Backend:
|
class Backend:
|
||||||
"""
|
"""
|
||||||
Class for available backends.
|
Class for available backends.
|
||||||
|
@ -79,6 +80,7 @@ class Backend:
|
||||||
|
|
||||||
DEFAULT_BACKEND = Backend("hccl")
|
DEFAULT_BACKEND = Backend("hccl")
|
||||||
|
|
||||||
|
|
||||||
class GlobalComm:
|
class GlobalComm:
|
||||||
"""World communication information."""
|
"""World communication information."""
|
||||||
BACKEND = DEFAULT_BACKEND
|
BACKEND = DEFAULT_BACKEND
|
||||||
|
@ -86,6 +88,7 @@ class GlobalComm:
|
||||||
INITED = False
|
INITED = False
|
||||||
CHECK_ENVS = True
|
CHECK_ENVS = True
|
||||||
|
|
||||||
|
|
||||||
def is_hccl_available():
|
def is_hccl_available():
|
||||||
"""
|
"""
|
||||||
Check hccl api is available.
|
Check hccl api is available.
|
||||||
|
|
|
@ -67,12 +67,14 @@ def check_rank_id(rank_id):
|
||||||
|
|
||||||
|
|
||||||
def load_lib():
|
def load_lib():
|
||||||
|
"""load hccl lib"""
|
||||||
try:
|
try:
|
||||||
base_dir = os.path.dirname(os.path.realpath(__file__))
|
base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
|
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
|
||||||
hccl_lib = ctypes.CDLL(lib_path)
|
hccl_lib = ctypes.CDLL(lib_path)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError('Get hccl lib error.')
|
raise RuntimeError('Get hccl lib error.')
|
||||||
|
|
||||||
global HCCL_LIB_CTYPES
|
global HCCL_LIB_CTYPES
|
||||||
HCCL_LIB_CTYPES = hccl_lib
|
HCCL_LIB_CTYPES = hccl_lib
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
||||||
|
|
||||||
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
||||||
|
|
||||||
|
|
||||||
def _get_group(group):
|
def _get_group(group):
|
||||||
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
||||||
if group == DEFAULT_WORLD_COMM_GROUP:
|
if group == DEFAULT_WORLD_COMM_GROUP:
|
||||||
|
@ -113,6 +114,7 @@ def init(backend_name=None):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
||||||
|
|
||||||
|
|
||||||
def release():
|
def release():
|
||||||
"""
|
"""
|
||||||
Release distributed resource. e.g. HCCL/NCCL.
|
Release distributed resource. e.g. HCCL/NCCL.
|
||||||
|
|
|
@ -328,20 +328,9 @@ class _AutoParallelContext:
|
||||||
if sorted(indices) != indices:
|
if sorted(indices) != indices:
|
||||||
raise ValueError('elements in indices must be sorted in ascending order')
|
raise ValueError('elements in indices must be sorted in ascending order')
|
||||||
|
|
||||||
if isinstance(group, (str)):
|
new_group = self._check_and_default_group(group)
|
||||||
group_len = len(group)
|
|
||||||
if group_len > _MAX_GROUP_NAME_LEN:
|
|
||||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
|
||||||
else:
|
|
||||||
raise TypeError('Group must be a python str')
|
|
||||||
|
|
||||||
if group == "":
|
self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
|
||||||
if context.get_context("device_target") == "Ascend":
|
|
||||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
|
||||||
else:
|
|
||||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
|
||||||
|
|
||||||
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
|
||||||
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
|
if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
|
||||||
_set_fusion_strategy_by_idx(indices)
|
_set_fusion_strategy_by_idx(indices)
|
||||||
|
|
||||||
|
@ -359,19 +348,8 @@ class _AutoParallelContext:
|
||||||
TypeError: If group is not a python str.
|
TypeError: If group is not a python str.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if isinstance(group, (str)):
|
new_group = self._check_and_default_group(group)
|
||||||
group_len = len(group)
|
return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
|
||||||
if group_len > _MAX_GROUP_NAME_LEN:
|
|
||||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
|
||||||
else:
|
|
||||||
raise TypeError('Group must be a python str')
|
|
||||||
|
|
||||||
if group == "":
|
|
||||||
if context.get_context("device_target") == "Ascend":
|
|
||||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
|
||||||
else:
|
|
||||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
|
||||||
return self._context_handle.get_all_reduce_fusion_split_indices(group)
|
|
||||||
|
|
||||||
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
|
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
|
||||||
"""
|
"""
|
||||||
|
@ -393,20 +371,8 @@ class _AutoParallelContext:
|
||||||
else:
|
else:
|
||||||
raise TypeError('sizes must be a python list')
|
raise TypeError('sizes must be a python list')
|
||||||
|
|
||||||
if isinstance(group, (str)):
|
new_group = self._check_and_default_group(group)
|
||||||
group_len = len(group)
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
|
||||||
if group_len > _MAX_GROUP_NAME_LEN:
|
|
||||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
|
||||||
else:
|
|
||||||
raise TypeError('Group must be a python str')
|
|
||||||
|
|
||||||
if group == "":
|
|
||||||
if context.get_context("device_target") == "Ascend":
|
|
||||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
|
||||||
else:
|
|
||||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
|
||||||
|
|
||||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
|
|
||||||
if context.get_context("device_target") == "Ascend":
|
if context.get_context("device_target") == "Ascend":
|
||||||
_set_fusion_strategy_by_size(sizes)
|
_set_fusion_strategy_by_size(sizes)
|
||||||
|
|
||||||
|
@ -424,19 +390,8 @@ class _AutoParallelContext:
|
||||||
TypeError: If group is not a python str.
|
TypeError: If group is not a python str.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
if isinstance(group, (str)):
|
new_group = self._check_and_default_group(group)
|
||||||
group_len = len(group)
|
return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
|
||||||
if group_len > _MAX_GROUP_NAME_LEN:
|
|
||||||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
|
||||||
else:
|
|
||||||
raise TypeError('Group must be a python str')
|
|
||||||
|
|
||||||
if group == "":
|
|
||||||
if context.get_context("device_target") == "Ascend":
|
|
||||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
|
||||||
else:
|
|
||||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
|
||||||
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
|
|
||||||
|
|
||||||
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
|
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
|
||||||
"""
|
"""
|
||||||
|
@ -550,6 +505,23 @@ class _AutoParallelContext:
|
||||||
self._context_handle.reset()
|
self._context_handle.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def _check_and_default_group(self, group):
|
||||||
|
"""Validate the given group, if group is empty, returns a default fusion group"""
|
||||||
|
if isinstance(group, (str)):
|
||||||
|
group_len = len(group)
|
||||||
|
if group_len > _MAX_GROUP_NAME_LEN:
|
||||||
|
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||||
|
else:
|
||||||
|
raise TypeError('Group must be a python str')
|
||||||
|
|
||||||
|
if group == "":
|
||||||
|
if context.get_context("device_target") == "Ascend":
|
||||||
|
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||||
|
else:
|
||||||
|
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
_auto_parallel_context = None
|
_auto_parallel_context = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -706,6 +678,7 @@ def _get_auto_parallel_context(attr_key):
|
||||||
get_func = _get_auto_parallel_context_func_map[attr_key]
|
get_func = _get_auto_parallel_context_func_map[attr_key]
|
||||||
return get_func()
|
return get_func()
|
||||||
|
|
||||||
|
|
||||||
def _reset_auto_parallel_context():
|
def _reset_auto_parallel_context():
|
||||||
"""
|
"""
|
||||||
Reset auto parallel context attributes to the default values:
|
Reset auto parallel context attributes to the default values:
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Data paralell allreduce fusion"""
|
"""Data parallel allreduce fusion"""
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
|
@ -41,27 +41,27 @@ def _c_array(ctype, values):
|
||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
def _set_fusion_strategy_by_idx(idx_list, group="hccl_world_group"):
|
||||||
"""
|
"""
|
||||||
A function set gradient segment strategy according to the index list.
|
A function set gradient segment strategy according to the index list.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
In the back propagation,
|
In the back propagation,
|
||||||
the fusion of the allreduce operators with a fusion attribute equals 1,
|
the fusion of the allreduce operators with a fusion attribute equals 1,
|
||||||
will be performed according to the idxList,
|
will be performed according to the idx_list,
|
||||||
to achieve the effect of parallel between calculation and communication.
|
to achieve the effect of parallel between calculation and communication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idxList (list): The index list of the gradient.
|
idx_list (list): The index list of the gradient.
|
||||||
group (str): The hccl communication group.
|
group (str): The hccl communication group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a python str.
|
TypeError: If group is not a python str.
|
||||||
TypeError: If IdxList is not a python list.
|
TypeError: If idx_list is not a python list.
|
||||||
TypeError: If type of idxList item is not int.
|
TypeError: If type of idx_list item is not int.
|
||||||
ValueError: If group name length is out of range.
|
ValueError: If group name length is out of range.
|
||||||
ValueError: If idxList length is 0.
|
ValueError: If idx_list length is 0.
|
||||||
ValueError: If idxList item is less than 0.
|
ValueError: If idx_list item is less than 0.
|
||||||
RuntimeError: If allreduce split failed.
|
RuntimeError: If allreduce split failed.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -70,6 +70,8 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||||
import hccl_test.manage.api as hccl
|
import hccl_test.manage.api as hccl
|
||||||
hccl.set_fusion_strategy_by_idx()
|
hccl.set_fusion_strategy_by_idx()
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
if isinstance(group, (str)):
|
if isinstance(group, (str)):
|
||||||
group_len = len(group)
|
group_len = len(group)
|
||||||
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
|
if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
|
||||||
|
@ -77,49 +79,49 @@ def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||||
else:
|
else:
|
||||||
raise TypeError('Group must be a python str')
|
raise TypeError('Group must be a python str')
|
||||||
|
|
||||||
if isinstance(idxList, (list)):
|
if isinstance(idx_list, (list)):
|
||||||
idx_len = len(idxList)
|
idx_len = len(idx_list)
|
||||||
if idx_len == 0:
|
if idx_len == 0:
|
||||||
raise ValueError('IdxList length is 0')
|
raise ValueError('idx_list length is 0')
|
||||||
else:
|
else:
|
||||||
raise TypeError('IdxList must be a python list')
|
raise TypeError('idx_list must be a python list')
|
||||||
|
|
||||||
for idx in idxList:
|
for idx in idx_list:
|
||||||
if isinstance(idx, (int)):
|
if isinstance(idx, (int)):
|
||||||
if idx < 0:
|
if idx < 0:
|
||||||
raise ValueError('Idx < 0')
|
raise ValueError('Idx < 0')
|
||||||
else:
|
else:
|
||||||
raise TypeError('Idx in idxList is invalid')
|
raise TypeError('Idx in idx_list is invalid')
|
||||||
|
|
||||||
c_array_idxList = _c_array(ctypes.c_uint, idxList)
|
c_array_idx_list = _c_array(ctypes.c_uint, idx_list)
|
||||||
c_idx_num = ctypes.c_uint(len(idxList))
|
c_idx_num = ctypes.c_uint(len(idx_list))
|
||||||
c_group = _c_str(group)
|
c_group = _c_str(group)
|
||||||
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
|
ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idx_list)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise RuntimeError('Allreduce split error')
|
raise RuntimeError('Allreduce split error')
|
||||||
|
|
||||||
|
|
||||||
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
|
||||||
"""
|
"""
|
||||||
A function set gradient segment strategy according to the data size percentage list.
|
A function set gradient segment strategy according to the data size percentage list.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
In the back propagation,
|
In the back propagation,
|
||||||
the fusion of the allreduce operators with a fusion attribute equals 1,
|
the fusion of the allreduce operators with a fusion attribute equals 1,
|
||||||
will be performed according to dataSizeList,
|
will be performed according to data_size_list,
|
||||||
to achieve the effect of parallel between calculation and communication.
|
to achieve the effect of parallel between calculation and communication.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataSizeList (list): The data size percentage list of the gradient.
|
data_size_list (list): The data size percentage list of the gradient.
|
||||||
group (str): The hccl communication group.
|
group (str): The hccl communication group.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If group is not a python str.
|
TypeError: If group is not a python str.
|
||||||
TypeError: If dataSizeList is not a python list.
|
TypeError: If data_size_list is not a python list.
|
||||||
TypeError: If type of dataSizeList item is not int or float.
|
TypeError: If type of data_size_list item is not int or float.
|
||||||
ValueError: If group name length is out of range.
|
ValueError: If group name length is out of range.
|
||||||
ValueError: If dataSizeList length is 0.
|
ValueError: If data_size_list length is 0.
|
||||||
ValueError: If dataSizeList item is less than 0.
|
ValueError: If data_size_list item is less than 0.
|
||||||
RuntimeError: If allreduce split failed.
|
RuntimeError: If allreduce split failed.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -128,24 +130,27 @@ def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
||||||
import hccl_test.manage.api as hccl
|
import hccl_test.manage.api as hccl
|
||||||
hccl.set_fusion_strategy_by_size()
|
hccl.set_fusion_strategy_by_size()
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
if isinstance(group, (str)):
|
if isinstance(group, (str)):
|
||||||
group_len = len(group)
|
group_len = len(group)
|
||||||
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
|
if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
|
||||||
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
|
raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
|
||||||
else:
|
else:
|
||||||
raise TypeError('Group must be a python str')
|
raise TypeError('Group must be a python str')
|
||||||
if isinstance(dataSizeList, (list)):
|
if isinstance(data_size_list, (list)):
|
||||||
len_data_size = len(dataSizeList)
|
len_data_size = len(data_size_list)
|
||||||
if len_data_size == 0:
|
if len_data_size == 0:
|
||||||
raise ValueError('DataSizeList length is 0')
|
raise ValueError('data_size_list length is 0')
|
||||||
else:
|
else:
|
||||||
raise TypeError('DataSizeList must be a python list')
|
raise TypeError('data_size_list must be a python list')
|
||||||
for dataSize in dataSizeList:
|
for data_size in data_size_list:
|
||||||
if not isinstance(dataSize, (int, float)):
|
if not isinstance(data_size, (int, float)):
|
||||||
raise TypeError('DataSize in dataSizeList is invalid')
|
raise TypeError('data_size in data_size_list is invalid')
|
||||||
|
|
||||||
c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
|
c_array_sizeList = _c_array(ctypes.c_float, data_size_list)
|
||||||
c_size_num = ctypes.c_uint(len(dataSizeList))
|
c_size_num = ctypes.c_uint(len(data_size_list))
|
||||||
c_group = _c_str(group)
|
c_group = _c_str(group)
|
||||||
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
|
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
|
|
|
@ -60,12 +60,14 @@ _get_ps_context_func_map = {
|
||||||
"enable_ps": ps_context().is_ps_mode
|
"enable_ps": ps_context().is_ps_mode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_ps_mode_rank():
|
def _get_ps_mode_rank():
|
||||||
ps_rank = ps_context().ps_rank_id()
|
ps_rank = ps_context().ps_rank_id()
|
||||||
if ps_rank == -1:
|
if ps_rank == -1:
|
||||||
raise RuntimeError("The parameter server mode training is not enabled yet.")
|
raise RuntimeError("The parameter server mode training is not enabled yet.")
|
||||||
return ps_rank
|
return ps_rank
|
||||||
|
|
||||||
|
|
||||||
def _set_ps_context(**kwargs):
|
def _set_ps_context(**kwargs):
|
||||||
"""
|
"""
|
||||||
Set parameter server training mode context.
|
Set parameter server training mode context.
|
||||||
|
@ -103,6 +105,7 @@ def _set_ps_context(**kwargs):
|
||||||
set_func = _set_ps_context_func_map[key]
|
set_func = _set_ps_context_func_map[key]
|
||||||
set_func(value)
|
set_func(value)
|
||||||
|
|
||||||
|
|
||||||
def _get_ps_context(attr_key):
|
def _get_ps_context(attr_key):
|
||||||
"""
|
"""
|
||||||
Get parameter server training mode context attribute value according to the key.
|
Get parameter server training mode context attribute value according to the key.
|
||||||
|
@ -122,6 +125,7 @@ def _get_ps_context(attr_key):
|
||||||
value = get_func()
|
value = get_func()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _reset_ps_context():
|
def _reset_ps_context():
|
||||||
"""
|
"""
|
||||||
Reset parameter server training mode context attributes to the default values:
|
Reset parameter server training mode context attributes to the default values:
|
||||||
|
@ -130,30 +134,39 @@ def _reset_ps_context():
|
||||||
"""
|
"""
|
||||||
ps_context().reset()
|
ps_context().reset()
|
||||||
|
|
||||||
|
|
||||||
def _is_role_worker():
|
def _is_role_worker():
|
||||||
return ps_context().is_worker()
|
return ps_context().is_worker()
|
||||||
|
|
||||||
|
|
||||||
def _is_role_pserver():
|
def _is_role_pserver():
|
||||||
return ps_context().is_server()
|
return ps_context().is_server()
|
||||||
|
|
||||||
|
|
||||||
def _is_role_sched():
|
def _is_role_sched():
|
||||||
return ps_context().is_scheduler()
|
return ps_context().is_scheduler()
|
||||||
|
|
||||||
|
|
||||||
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
|
def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
|
||||||
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
|
ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
|
||||||
|
|
||||||
|
|
||||||
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
|
def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
|
||||||
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
|
ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
|
||||||
|
|
||||||
|
|
||||||
def _insert_weight_init_info(name, global_seed, op_seed):
|
def _insert_weight_init_info(name, global_seed, op_seed):
|
||||||
ps_context().insert_weight_init_info(name, global_seed, op_seed)
|
ps_context().insert_weight_init_info(name, global_seed, op_seed)
|
||||||
|
|
||||||
|
|
||||||
def _insert_accumu_init_info(name, init_val):
|
def _insert_accumu_init_info(name, init_val):
|
||||||
ps_context().insert_accumu_init_info(name, init_val)
|
ps_context().insert_accumu_init_info(name, init_val)
|
||||||
|
|
||||||
|
|
||||||
def _clone_hash_table(dest_param_name, src_param_name):
|
def _clone_hash_table(dest_param_name, src_param_name):
|
||||||
ps_context().clone_hash_table(dest_param_name, src_param_name)
|
ps_context().clone_hash_table(dest_param_name, src_param_name)
|
||||||
|
|
||||||
|
|
||||||
def _set_cache_enable(cache_enable):
|
def _set_cache_enable(cache_enable):
|
||||||
# Environment variables are used to specify a maximum number of OpenBLAS threads:
|
# Environment variables are used to specify a maximum number of OpenBLAS threads:
|
||||||
# In ubuntu(GPU) environment, numpy will use too many threads for computing,
|
# In ubuntu(GPU) environment, numpy will use too many threads for computing,
|
||||||
|
@ -163,5 +176,6 @@ def _set_cache_enable(cache_enable):
|
||||||
os.environ['OMP_NUM_THREADS'] = '2'
|
os.environ['OMP_NUM_THREADS'] = '2'
|
||||||
ps_context().set_cache_enable(cache_enable)
|
ps_context().set_cache_enable(cache_enable)
|
||||||
|
|
||||||
|
|
||||||
def _set_rank_id(rank_id):
|
def _set_rank_id(rank_id):
|
||||||
ps_context().set_rank_id(rank_id)
|
ps_context().set_rank_id(rank_id)
|
||||||
|
|
|
@ -37,10 +37,12 @@ def _get_full_batch():
|
||||||
"""Get whether to use full_batch."""
|
"""Get whether to use full_batch."""
|
||||||
return auto_parallel_context().get_full_batch()
|
return auto_parallel_context().get_full_batch()
|
||||||
|
|
||||||
|
|
||||||
def _get_pipeline_stages():
|
def _get_pipeline_stages():
|
||||||
"""Get pipeline stages"""
|
"""Get pipeline stages"""
|
||||||
return auto_parallel_context().get_pipeline_stages()
|
return auto_parallel_context().get_pipeline_stages()
|
||||||
|
|
||||||
|
|
||||||
def _check_full_batch():
|
def _check_full_batch():
|
||||||
"""
|
"""
|
||||||
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
|
full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
|
||||||
|
@ -62,6 +64,7 @@ def _need_to_full():
|
||||||
and (not full_batch))
|
and (not full_batch))
|
||||||
return need
|
return need
|
||||||
|
|
||||||
|
|
||||||
def _to_full_shapes(shapes, device_num):
|
def _to_full_shapes(shapes, device_num):
|
||||||
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
|
"""Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
|
||||||
new_shapes = []
|
new_shapes = []
|
||||||
|
@ -75,9 +78,11 @@ def _to_full_shapes(shapes, device_num):
|
||||||
new_shapes.append(new_shape)
|
new_shapes.append(new_shape)
|
||||||
return new_shapes
|
return new_shapes
|
||||||
|
|
||||||
|
|
||||||
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
||||||
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
|
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
|
||||||
from host solution."""
|
from host solution.
|
||||||
|
"""
|
||||||
lst = []
|
lst = []
|
||||||
if not isinstance(elem, (tuple, list)):
|
if not isinstance(elem, (tuple, list)):
|
||||||
elem = [elem]
|
elem = [elem]
|
||||||
|
@ -109,6 +114,7 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
|
||||||
lst.append(Tensor(scaling_sens, mstype.float32))
|
lst.append(Tensor(scaling_sens, mstype.float32))
|
||||||
return tuple(lst)
|
return tuple(lst)
|
||||||
|
|
||||||
|
|
||||||
def _get_gradients_mean():
|
def _get_gradients_mean():
|
||||||
"""Get if using gradients_mean."""
|
"""Get if using gradients_mean."""
|
||||||
return auto_parallel_context().get_gradients_mean()
|
return auto_parallel_context().get_gradients_mean()
|
||||||
|
@ -188,6 +194,7 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
|
||||||
"do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
|
"do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
|
||||||
.format(parallel_mode, parameter_broadcast))
|
.format(parallel_mode, parameter_broadcast))
|
||||||
|
|
||||||
|
|
||||||
def _get_python_op(op_name, op_path, instance_name, arglist):
|
def _get_python_op(op_name, op_path, instance_name, arglist):
|
||||||
"""Get python operator."""
|
"""Get python operator."""
|
||||||
module = __import__(op_path, fromlist=["None"])
|
module = __import__(op_path, fromlist=["None"])
|
||||||
|
|
|
@ -19,6 +19,7 @@ import threading
|
||||||
from mindspore._c_expression import MpiConfig
|
from mindspore._c_expression import MpiConfig
|
||||||
from mindspore._checkparam import args_type_check
|
from mindspore._checkparam import args_type_check
|
||||||
|
|
||||||
|
|
||||||
class _MpiConfig:
|
class _MpiConfig:
|
||||||
"""
|
"""
|
||||||
_MpiConfig is the config tool for controlling MPI
|
_MpiConfig is the config tool for controlling MPI
|
||||||
|
@ -55,6 +56,8 @@ class _MpiConfig:
|
||||||
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
|
self._mpiconfig_handle.set_enable_mpi(enable_mpi)
|
||||||
|
|
||||||
_k_mpi_config = None
|
_k_mpi_config = None
|
||||||
|
|
||||||
|
|
||||||
def _mpi_config():
|
def _mpi_config():
|
||||||
"""
|
"""
|
||||||
Get the global mpi config, if mpi config is not created, create a new one.
|
Get the global mpi config, if mpi config is not created, create a new one.
|
||||||
|
@ -67,13 +70,14 @@ def _mpi_config():
|
||||||
_k_mpi_config = _MpiConfig()
|
_k_mpi_config = _MpiConfig()
|
||||||
return _k_mpi_config
|
return _k_mpi_config
|
||||||
|
|
||||||
|
|
||||||
@args_type_check(enable_mpi=bool)
|
@args_type_check(enable_mpi=bool)
|
||||||
def _set_mpi_config(**kwargs):
|
def _set_mpi_config(**kwargs):
|
||||||
"""
|
"""
|
||||||
Sets mpi config for running environment.
|
Sets mpi config for running environment.
|
||||||
|
|
||||||
mpi config should be configured before running your program. If there is no configuration,
|
mpi config should be configured before running your program. If there is no configuration,
|
||||||
mpi moudle will be disabled by default.
|
mpi module will be disabled by default.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Attribute name is required for setting attributes.
|
Attribute name is required for setting attributes.
|
||||||
|
|
Loading…
Reference in New Issue