!19807 Add obj id to the phase as key of compile cache

Merge pull request !19807 from YuJianfeng/master
This commit is contained in:
i-robot 2021-07-10 02:14:58 +00:00 committed by Gitee
commit cd26959d40
2 changed files with 9 additions and 8 deletions

View File

@ -508,9 +508,9 @@ class _Executor:
key = generate_key(phase, dic)
obj.phase_prefix = str(key[1])
if 'export' in phase:
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time)
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) + '.' + str(id(obj))
else:
phase = obj.phase_prefix + phase + '.' + str(obj.create_time)
phase = obj.phase_prefix + phase + '.' + str(obj.create_time) + '.' + str(id(obj))
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
@ -582,15 +582,15 @@ class _Executor:
return self._executor.updata_param_node_default_input(phase, new_param)
def _get_shard_strategy(self, obj):
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj))
return self._executor.get_strategy(real_phase)
def _get_num_parallel_ops(self, obj):
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj))
return self._executor.get_num_parallel_ops(real_phase)
def _get_allreduce_fusion(self, obj):
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time)
real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj))
return self._executor.get_allreduce_fusion(real_phase)
def has_compiled(self, phase='predict'):
@ -632,9 +632,9 @@ class _Executor:
Tensor/Tuple, return execute result.
"""
if phase == 'save':
return self._executor((), phase + '.' + str(obj.create_time))
return self._executor((), phase + '.' + str(obj.create_time) + '.' + str(id(obj)))
phase_real = obj.phase_prefix + phase + '.' + str(obj.create_time)
phase_real = obj.phase_prefix + phase + '.' + str(obj.create_time) + '.' + str(id(obj))
if self.has_compiled(phase_real):
return self._exec_pip(obj, *args, phase=phase_real)
raise KeyError('{} graph is not exist.'.format(phase_real))

View File

@ -258,7 +258,8 @@ class Cell(Cell_):
def get_func_graph_proto(self):
"""Return graph binary proto."""
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True)
return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time) + '.' + str(id(self)),
"anf_ir", True)
def __getattr__(self, name):
if '_params' in self.__dict__: