Merge pull request !31806 from liutongtong9/clean_code
This commit is contained in:
i-robot 2022-03-24 08:18:23 +00:00 committed by Gitee
commit bbe3ab24b3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 16 additions and 17 deletions

View File

@ -294,7 +294,7 @@ class Cell(Cell_):
def get_func_graph_proto(self):
"""Return graph binary proto."""
exec_id = self.phase + "." + str(self.create_time) + '.' + str(id(self))
exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True)
def __getattr__(self, name):
@ -730,17 +730,6 @@ class Cell(Cell_):
else:
object.__setattr__(self, name, value)
def _check_param_list_tuple(self, value):
"""
Check the type of input in list or tuple is Parameter.
:param value: list or tuple.
:return: The types of all inputs are parameter.
"""
for item in value:
if not isinstance(item, Parameter):
return False
return True
def __setattr__(self, name, value):
cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params')
@ -748,7 +737,7 @@ class Cell(Cell_):
self._set_attr_for_parameter(name, value)
elif isinstance(value, ParameterTuple):
self._set_attr_for_parameter_tuple(name, value)
elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value):
elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value):
self._set_attr_for_parameter_in_list_or_tuple(name, value)
elif isinstance(value, Cell):
self._set_attr_for_cell(name, value)
@ -2227,3 +2216,15 @@ class GraphCell(Cell):
self.phase = "graph_load_from_mindir"
self._add_attr("graph_load_from_mindir", self.graph)
return self.compile_and_run(*inputs)
def _check_param_list_tuple(value):
"""
Check the type of input in list or tuple is Parameter.
:param value: list or tuple.
:return: The types of all inputs are parameter.
"""
for item in value:
if not isinstance(item, Parameter):
return False
return True

View File

@ -521,18 +521,16 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
self.opt_shard = _get_enable_parallel_optimizer()
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense
init = self.alloc_status()
status_clear = self.clear_before_grad(init)
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.grad(self.network, self.weights)(*inputs, scaling_sens_filled)
init = F.depend(init, grads)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
loss = F.depend(loss, status_clear)
loss = F.depend(loss, self.clear_before_grad(init))
if self.opt_shard:
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)