forked from mindspore-Ecosystem/mindspore
Judge cache enable in dataset helper
This commit is contained in:
parent
70b2c67775
commit
c7b48f1b9e
|
@ -71,6 +71,7 @@ void PSContext::SetPSEnable(bool enabled) {
|
|||
if (node_id_.length() > kLength) {
|
||||
MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
|
||||
}
|
||||
server_mode_ = kServerModePS;
|
||||
} else {
|
||||
MS_LOG(INFO) << "PS mode is disabled.";
|
||||
is_worker_ = false;
|
||||
|
@ -192,6 +193,13 @@ void PSContext::set_cache_enable(bool cache_enable) const {
|
|||
#endif
|
||||
}
|
||||
|
||||
bool PSContext::cache_enable() const {
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
return PsDataPrefetch::GetInstance().cache_enable();
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
void PSContext::set_rank_id(uint32_t rank_id) const {
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
ps_cache_instance.set_rank_id(rank_id);
|
||||
|
|
|
@ -80,6 +80,7 @@ class PSContext {
|
|||
void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const;
|
||||
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
|
||||
void set_cache_enable(bool cache_enable) const;
|
||||
bool cache_enable() const;
|
||||
void set_rank_id(uint32_t rank_id) const;
|
||||
|
||||
// In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
|
||||
|
|
|
@ -28,7 +28,7 @@ from .._checkparam import Validator
|
|||
from .._c_expression import Tensor as Tensor_
|
||||
from ..parallel._tensor import _get_slice_index
|
||||
from ..parallel._auto_parallel_context import auto_parallel_context
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table, _is_fl_mode
|
||||
from ..parallel._ps_context import _reinsert_hash_table_size
|
||||
from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
|
||||
from .seed import _get_global_and_op_seed
|
||||
|
@ -229,7 +229,7 @@ class Parameter(Tensor_):
|
|||
if isinstance(data, bool):
|
||||
raise ValueError('Parameter data can not be `bool`')
|
||||
if isinstance(data, Tensor) and data.has_init:
|
||||
if context.get_fl_context('server_mode') not in ('FEDERATED_LEARNING', 'HYBRID_TRAINING'):
|
||||
if not _is_fl_mode():
|
||||
if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched() or _is_role_pserver():
|
||||
# do not init data while in auto parallel.
|
||||
return (Tensor, None, data.dtype, data.shape, data.init)
|
||||
|
|
|
@ -252,6 +252,14 @@ def _set_rank_id(rank_id):
|
|||
ps_context().set_rank_id(rank_id)
|
||||
|
||||
|
||||
def _is_ps_mode():
|
||||
return _get_ps_context("server_mode") == "PARAMETER_SERVER"
|
||||
|
||||
|
||||
def _is_fl_mode():
|
||||
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
|
||||
|
||||
|
||||
def _check_value(key, value):
|
||||
"""
|
||||
Validate the value for parameter server context keys.
|
||||
|
|
|
@ -20,14 +20,10 @@ from mindspore.common.dtype import pytype_to_dtype
|
|||
from .. import context, nn
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode
|
||||
from ..ops import operations as P
|
||||
|
||||
|
||||
def _is_fl_mode():
|
||||
return context.get_fl_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
|
||||
|
||||
|
||||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
|
@ -58,7 +54,7 @@ def _dynamic_sink_exception_scenario(dataset_iter):
|
|||
"""The exception scenario for dynamic data is not applicable."""
|
||||
_, dataset_shapes = dataset_iter.types_shapes()
|
||||
|
||||
if _has_dynamic_shape(dataset_shapes) or (_is_role_worker() and not _is_fl_mode()) or \
|
||||
if _has_dynamic_shape(dataset_shapes) or (_is_role_worker() and _is_ps_mode()) or \
|
||||
context.get_context("mode") != context.GRAPH_MODE:
|
||||
return True
|
||||
return False
|
||||
|
@ -264,7 +260,7 @@ class DatasetHelper:
|
|||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
if _is_role_sched() or _is_role_pserver():
|
||||
iterclass = _DatasetIterPSServer
|
||||
elif _is_role_worker() and not _is_fl_mode():
|
||||
elif _is_role_worker() and _is_ps_mode():
|
||||
iterclass = _DatasetIterPSWork
|
||||
elif (context.get_context("device_target") == "Ascend") or \
|
||||
(context.get_context("device_target") == "GPU"):
|
||||
|
@ -415,7 +411,7 @@ class _DatasetIter:
|
|||
sink_size = 1
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
elif _is_role_worker() and not _is_fl_mode():
|
||||
elif _is_role_worker() and _is_ps_mode():
|
||||
# PS mode does not support loop sink.
|
||||
sink_size = 1
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue