!1597 add _set_dataset_mode_config to vm

Merge pull request !1597 from jinyaohui/sink_graph
This commit is contained in:
mindspore-ci-bot 2020-05-29 16:32:01 +08:00 committed by Gitee
commit c39902153c
2 changed files with 15 additions and 7 deletions

View File

@ -86,7 +86,10 @@ class ConfigManager {
DatasetMode dataset_mode() const { return dataset_mode_; }
void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
int64_t iter_num() const { return iter_num_; }
int64_t iter_num() const {
if (dataset_mode_ == DS_NORMAL_MODE) return 1;
return iter_num_;
}
void set_iter_num(const int64_t num) { iter_num_ = num; }
std::string dataset_phase() const { return dataset_phase_; }

View File

@ -341,6 +341,15 @@ class _Executor:
param.init_data(layout, set_sliced=True)
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
# decide whether to sink based on whether the inputs is virtual or args_list is ()
if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
(args_list is not None and args_list == ()):
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('normal')
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
"""
Compiles graph.
@ -371,6 +380,8 @@ class _Executor:
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
self._set_dataset_mode(args_list)
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False
@ -399,12 +410,6 @@ class _Executor:
# the following GE init process is not needed when use vm or ms backend
if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag:
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('normal')
self._build_data_graph(obj, params, phase)
if "export" not in phase: