diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index cd705edcba5..659413f49c7 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -294,7 +294,7 @@ class Cell(Cell_): def get_func_graph_proto(self): """Return graph binary proto.""" - exec_id = self.phase + "." + str(self.create_time) + '.' + str(id(self)) + exec_id = ".".join([self.phase, str(self.create_time), str(id(self))]) return _cell_graph_executor._get_func_graph_proto(self, exec_id, "anf_ir", True) def __getattr__(self, name): @@ -730,17 +730,6 @@ class Cell(Cell_): else: object.__setattr__(self, name, value) - def _check_param_list_tuple(self, value): - """ - Check the type of input in list or tuple is Parameter. - :param value: list or tuple. - :return: The types of all inputs are parameter. - """ - for item in value: - if not isinstance(item, Parameter): - return False - return True - def __setattr__(self, name, value): cells = self.__dict__.get('_cells') params = self.__dict__.get('_params') @@ -748,7 +737,7 @@ class Cell(Cell_): self._set_attr_for_parameter(name, value) elif isinstance(value, ParameterTuple): self._set_attr_for_parameter_tuple(name, value) - elif isinstance(value, (list, tuple)) and value and self._check_param_list_tuple(value): + elif isinstance(value, (list, tuple)) and value and _check_param_list_tuple(value): self._set_attr_for_parameter_in_list_or_tuple(name, value) elif isinstance(value, Cell): self._set_attr_for_cell(name, value) @@ -2227,3 +2216,15 @@ class GraphCell(Cell): self.phase = "graph_load_from_mindir" self._add_attr("graph_load_from_mindir", self.graph) return self.compile_and_run(*inputs) + + +def _check_param_list_tuple(value): + """ + Check the type of input in list or tuple is Parameter. + :param value: list or tuple. + :return: The types of all inputs are parameter. + """ + for item in value: + if not isinstance(item, Parameter): + return False + return True diff --git a/mindspore/python/mindspore/nn/wrap/loss_scale.py b/mindspore/python/mindspore/nn/wrap/loss_scale.py index 63e3644b2a3..8cbd3be78b0 100644 --- a/mindspore/python/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/python/mindspore/nn/wrap/loss_scale.py @@ -521,18 +521,16 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell): self.opt_shard = _get_enable_parallel_optimizer() def construct(self, *inputs): - weights = self.weights loss = self.network(*inputs) scaling_sens = self.scale_sense init = self.alloc_status() - status_clear = self.clear_before_grad(init) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) - grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) + grads = self.grad(self.network, self.weights)(*inputs, scaling_sens_filled) init = F.depend(init, grads) get_status = self.get_status(init) init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) - loss = F.depend(loss, status_clear) + loss = F.depend(loss, self.clear_before_grad(init)) if self.opt_shard: grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)