Judge cache enable in dataset helper

This commit is contained in:
ZPaC 2021-12-13 11:46:22 +08:00
parent 70b2c67775
commit c7b48f1b9e
5 changed files with 23 additions and 10 deletions

View File

@ -71,6 +71,7 @@ void PSContext::SetPSEnable(bool enabled) {
if (node_id_.length() > kLength) {
MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
}
server_mode_ = kServerModePS;
} else {
MS_LOG(INFO) << "PS mode is disabled.";
is_worker_ = false;
@ -192,6 +193,13 @@ void PSContext::set_cache_enable(bool cache_enable) const {
#endif
}
bool PSContext::cache_enable() const {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
return PsDataPrefetch::GetInstance().cache_enable();
#endif
return false;
}
void PSContext::set_rank_id(uint32_t rank_id) const {
#if ((defined ENABLE_CPU) && (!defined _WIN32))
ps_cache_instance.set_rank_id(rank_id);

View File

@ -80,6 +80,7 @@ class PSContext {
void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
void set_cache_enable(bool cache_enable) const;
bool cache_enable() const;
void set_rank_id(uint32_t rank_id) const;
// In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set

View File

@ -28,7 +28,7 @@ from .._checkparam import Validator
from .._c_expression import Tensor as Tensor_
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table, _is_fl_mode
from ..parallel._ps_context import _reinsert_hash_table_size
from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
from .seed import _get_global_and_op_seed
@ -229,7 +229,7 @@ class Parameter(Tensor_):
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor) and data.has_init:
if context.get_fl_context('server_mode') not in ('FEDERATED_LEARNING', 'HYBRID_TRAINING'):
if not _is_fl_mode():
if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched() or _is_role_pserver():
# do not init data while in auto parallel.
return (Tensor, None, data.dtype, data.shape, data.init)

View File

@ -252,6 +252,14 @@ def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id)
def _is_ps_mode():
return _get_ps_context("server_mode") == "PARAMETER_SERVER"
def _is_fl_mode():
return _get_ps_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
def _check_value(key, value):
"""
Validate the value for parameter server context keys.

View File

@ -20,14 +20,10 @@ from mindspore.common.dtype import pytype_to_dtype
from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode
from ..ops import operations as P
def _is_fl_mode():
return context.get_fl_context("server_mode") in ("FEDERATED_LEARNING", "HYBRID_TRAINING")
def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
@ -58,7 +54,7 @@ def _dynamic_sink_exception_scenario(dataset_iter):
"""The exception scenario for dynamic data is not applicable."""
_, dataset_shapes = dataset_iter.types_shapes()
if _has_dynamic_shape(dataset_shapes) or (_is_role_worker() and not _is_fl_mode()) or \
if _has_dynamic_shape(dataset_shapes) or (_is_role_worker() and _is_ps_mode()) or \
context.get_context("mode") != context.GRAPH_MODE:
return True
return False
@ -264,7 +260,7 @@ class DatasetHelper:
if context.get_context("mode") == context.GRAPH_MODE:
if _is_role_sched() or _is_role_pserver():
iterclass = _DatasetIterPSServer
elif _is_role_worker() and not _is_fl_mode():
elif _is_role_worker() and _is_ps_mode():
iterclass = _DatasetIterPSWork
elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"):
@ -415,7 +411,7 @@ class _DatasetIter:
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
elif _is_role_worker() and not _is_fl_mode():
elif _is_role_worker() and _is_ps_mode():
# PS mode does not support loop sink.
sink_size = 1
else: