From eb48c8d64764a6d0537f4fd1a1620bd95eecc2f9 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Mon, 19 Oct 2020 15:50:31 +0800 Subject: [PATCH] compile profiling --- mindspore/common/api.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 279a93dc8a8..7f368faea93 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -383,8 +383,6 @@ class _Executor: Str, the full phase of the cell. Bool, if the graph has been compiled before, return False, else return True. """ - obj.check_names() - _check_full_batch() args_names, args_list = _generate_pip_args(obj, *args) dic = dict(zip(args_names, args_list)) key = generate_key(phase, dic) @@ -393,22 +391,23 @@ class _Executor: phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) else: phase = self.phase_prefix + phase + '.' + str(obj.create_time) - enable_debug_runtime = context.get_context("enable_debug_runtime") - enable_ge = context.get_context("enable_ge") - - use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) - - self._set_dataset_mode(args_list) if phase in self.compile_cache.keys(): logger.debug("%r graph has existed.", phase) return phase, False + obj.check_names() + _check_full_batch() + self._set_dataset_mode(args_list) + is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run(): args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank()) _, args_list = _generate_pip_args(obj, *args_full) + enable_debug_runtime = context.get_context("enable_debug_runtime") + enable_ge = context.get_context("enable_ge") + use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) result = self._executor.compile(obj, args_list, phase, use_vm) self.compile_cache[phase] = phase if not result: