diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 39f1443b8d6..a0c8474afbb 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -401,6 +401,7 @@ PYBIND11_MODULE(_c_expression, m) { .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.") + .def("cache_enable", &PSContext::cache_enable, "Get ps mode cache enable or not.") .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.") .def("set_server_mode", &PSContext::set_server_mode, "Set server mode.") .def("server_mode", &PSContext::server_mode, "Get server mode.") diff --git a/mindspore/python/mindspore/parallel/_ps_context.py b/mindspore/python/mindspore/parallel/_ps_context.py index f170c0fe455..1df855b5400 100644 --- a/mindspore/python/mindspore/parallel/_ps_context.py +++ b/mindspore/python/mindspore/parallel/_ps_context.py @@ -279,6 +279,10 @@ def _set_cache_enable(cache_enable): ps_context().set_cache_enable(cache_enable) +def _cache_enable(): + return ps_context().cache_enable() + + def _set_rank_id(rank_id): ps_context().set_rank_id(rank_id) diff --git a/mindspore/python/mindspore/train/model.py b/mindspore/python/mindspore/train/model.py index 00c68baa589..dee7d8c72c2 100644 --- a/mindspore/python/mindspore/train/model.py +++ b/mindspore/python/mindspore/train/model.py @@ -31,7 +31,7 @@ from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Call from .. import context from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check -from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode +from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, _cache_enable from ..nn.metrics import Loss from .. import nn from ..boost import AutoBoost @@ -883,7 +883,7 @@ class Model: "is not equal to value in Model.train, got {} and {} separately." .format(train_dataset._warmup_epoch, epoch)) - if dataset_sink_mode and _is_ps_mode(): + if dataset_sink_mode and _is_ps_mode() and not _cache_enable(): raise ValueError("Parameter server mode does not support 'data_sink_mode=True'.") Validator.check_is_int(sink_size)