forked from mindspore-Ecosystem/mindspore
Fix dataset sink judge in cache enable case.
This commit is contained in:
parent
95e5475f58
commit
92ba08e89f
|
@ -401,6 +401,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
|
||||
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
|
||||
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
|
||||
.def("cache_enable", &PSContext::cache_enable, "Get ps mode cache enable or not.")
|
||||
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
|
||||
.def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
|
||||
.def("server_mode", &PSContext::server_mode, "Get server mode.")
|
||||
|
|
|
@ -279,6 +279,10 @@ def _set_cache_enable(cache_enable):
|
|||
ps_context().set_cache_enable(cache_enable)
|
||||
|
||||
|
||||
def _cache_enable():
|
||||
return ps_context().cache_enable()
|
||||
|
||||
|
||||
def _set_rank_id(rank_id):
|
||||
ps_context().set_rank_id(rank_id)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Call
|
|||
from .. import context
|
||||
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode
|
||||
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, _cache_enable
|
||||
from ..nn.metrics import Loss
|
||||
from .. import nn
|
||||
from ..boost import AutoBoost
|
||||
|
@ -883,7 +883,7 @@ class Model:
|
|||
"is not equal to value in Model.train, got {} and {} separately."
|
||||
.format(train_dataset._warmup_epoch, epoch))
|
||||
|
||||
if dataset_sink_mode and _is_ps_mode():
|
||||
if dataset_sink_mode and _is_ps_mode() and not _cache_enable():
|
||||
raise ValueError("Parameter server mode does not support 'data_sink_mode=True'.")
|
||||
|
||||
Validator.check_is_int(sink_size)
|
||||
|
|
Loading…
Reference in New Issue