!1597 add _set_dataset_mode_config to vm
Merge pull request !1597 from jinyaohui/sink_graph
This commit is contained in:
commit
c39902153c
|
@ -86,7 +86,10 @@ class ConfigManager {
|
|||
|
||||
DatasetMode dataset_mode() const { return dataset_mode_; }
|
||||
void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
|
||||
int64_t iter_num() const { return iter_num_; }
|
||||
int64_t iter_num() const {
|
||||
if (dataset_mode_ == DS_NORMAL_MODE) return 1;
|
||||
return iter_num_;
|
||||
}
|
||||
void set_iter_num(const int64_t num) { iter_num_ = num; }
|
||||
|
||||
std::string dataset_phase() const { return dataset_phase_; }
|
||||
|
|
|
@ -341,6 +341,15 @@ class _Executor:
|
|||
param.init_data(layout, set_sliced=True)
|
||||
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
||||
|
||||
def _set_dataset_mode(self, args_list):
|
||||
"""set dataset mode."""
|
||||
# decide whether to sink based on whether the inputs is virtual or args_list is ()
|
||||
if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
|
||||
(args_list is not None and args_list == ()):
|
||||
_set_dataset_mode_config('sink')
|
||||
else:
|
||||
_set_dataset_mode_config('normal')
|
||||
|
||||
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
|
||||
"""
|
||||
Compiles graph.
|
||||
|
@ -371,6 +380,8 @@ class _Executor:
|
|||
|
||||
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
|
||||
|
||||
self._set_dataset_mode(args_list)
|
||||
|
||||
if phase in self.compile_cache.keys():
|
||||
logger.debug("%r graph has existed.", phase)
|
||||
return phase, False
|
||||
|
@ -399,12 +410,6 @@ class _Executor:
|
|||
|
||||
# the following GE init process is not needed when use vm or ms backend
|
||||
if enable_ge:
|
||||
# decide whether to sink based on whether the inputs is virtual or not
|
||||
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag:
|
||||
_set_dataset_mode_config('sink')
|
||||
else:
|
||||
_set_dataset_mode_config('normal')
|
||||
|
||||
self._build_data_graph(obj, params, phase)
|
||||
|
||||
if "export" not in phase:
|
||||
|
|
Loading…
Reference in New Issue