!32087 Fix dataset sink judge in cache enable case.

Merge pull request !32087 from ZPaC/add-dist-execution-mode
This commit is contained in:
i-robot 2022-03-28 13:28:15 +00:00 committed by Gitee
commit bc070833c7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 7 additions and 2 deletions

View File

@ -401,6 +401,7 @@ PYBIND11_MODULE(_c_expression, m) {
.def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
.def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
.def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
.def("cache_enable", &PSContext::cache_enable, "Get ps mode cache enable or not.")
.def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
.def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
.def("server_mode", &PSContext::server_mode, "Get server mode.")

View File

@ -279,6 +279,10 @@ def _set_cache_enable(cache_enable):
ps_context().set_cache_enable(cache_enable)
def _cache_enable():
return ps_context().cache_enable()
def _set_rank_id(rank_id):
ps_context().set_rank_id(rank_id)

View File

@ -31,7 +31,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Call
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, _cache_enable
from ..nn.metrics import Loss
from .. import nn
from ..boost import AutoBoost
@ -883,7 +883,7 @@ class Model:
"is not equal to value in Model.train, got {} and {} separately."
.format(train_dataset._warmup_epoch, epoch))
if dataset_sink_mode and _is_ps_mode():
if dataset_sink_mode and _is_ps_mode() and not _cache_enable():
raise ValueError("Parameter server mode does not support 'data_sink_mode=True'.")
Validator.check_is_int(sink_size)