!1174 Make optimizer parameter same as gradient

Merge pull request !1174 from ghzl/fix-beg-group-parameters
This commit is contained in:
mindspore-ci-bot 2020-05-15 14:20:21 +08:00 committed by Gitee
commit 1234439661
2 changed files with 3 additions and 3 deletions

View File

@ -141,7 +141,7 @@ class DistributedGradReducer(Cell):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.weights = optimizer.parameters
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
>>> self.sens = sens

View File

@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor, ParameterTuple
from ...common import Tensor
from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
@ -201,7 +201,7 @@ class TrainOneStepWithLossScaleCell(Cell):
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()