forked from mindspore-Ecosystem/mindspore
!31806 clean code
Merge pull request !31806 from liutongtong9/clean_code
This commit is contained in:
commit
bbe3ab24b3
|
@ -294,7 +294,7 @@ class Cell(Cell_):
|
|||
|
||||
def get_func_graph_proto(self):
|
||||
"""Return graph binary proto."""
|
||||
exec_id = self.phase + "." + str(self.create_time) + '.' + str(id(self))
|
||||
exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
|
||||
return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
@ -730,17 +730,6 @@ class Cell(Cell_):
|
|||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _check_param_list_tuple(self, value):
|
||||
"""
|
||||
Check the type of input in list or tuple is Parameter.
|
||||
:param value: list or tuple.
|
||||
:return: The types of all inputs are parameter.
|
||||
"""
|
||||
for item in value:
|
||||
if not isinstance(item, Parameter):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
cells = self.__dict__.get('_cells')
|
||||
params = self.__dict__.get('_params')
|
||||
|
@ -748,7 +737,7 @@ class Cell(Cell_):
|
|||
self._set_attr_for_parameter(name, value)
|
||||
elif isinstance(value, ParameterTuple):
|
||||
self._set_attr_for_parameter_tuple(name, value)
|
||||
elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value):
|
||||
elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value):
|
||||
self._set_attr_for_parameter_in_list_or_tuple(name, value)
|
||||
elif isinstance(value, Cell):
|
||||
self._set_attr_for_cell(name, value)
|
||||
|
@ -2227,3 +2216,15 @@ class GraphCell(Cell):
|
|||
self.phase = "graph_load_from_mindir"
|
||||
self._add_attr("graph_load_from_mindir", self.graph)
|
||||
return self.compile_and_run(*inputs)
|
||||
|
||||
|
||||
def _check_param_list_tuple(value):
|
||||
"""
|
||||
Check the type of input in list or tuple is Parameter.
|
||||
:param value: list or tuple.
|
||||
:return: The types of all inputs are parameter.
|
||||
"""
|
||||
for item in value:
|
||||
if not isinstance(item, Parameter):
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -521,18 +521,16 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
|
|||
self.opt_shard = _get_enable_parallel_optimizer()
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(*inputs)
|
||||
scaling_sens = self.scale_sense
|
||||
init = self.alloc_status()
|
||||
status_clear = self.clear_before_grad(init)
|
||||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
|
||||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
|
||||
grads = self.grad(self.network, self.weights)(*inputs, scaling_sens_filled)
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
loss = F.depend(loss, status_clear)
|
||||
loss = F.depend(loss, self.clear_before_grad(init))
|
||||
if self.opt_shard:
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
|
||||
|
|
Loading…
Reference in New Issue