forked from mindspore-Ecosystem/mindspore
!1174 Make optimizer parameter same as gradient
Merge pull request !1174 from ghzl/fix-beg-group-parameters
This commit is contained in:
commit
1234439661
|
@ -141,7 +141,7 @@ class DistributedGradReducer(Cell):
|
|||
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
>>> self.network = network
|
||||
>>> self.network.add_flags(defer_inline=True)
|
||||
>>> self.weights = ParameterTuple(network.trainable_params())
|
||||
>>> self.weights = optimizer.parameters
|
||||
>>> self.optimizer = optimizer
|
||||
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
>>> self.sens = sens
|
||||
|
|
|
@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
from ..cell import Cell
|
||||
from ...common import Tensor, ParameterTuple
|
||||
from ...common import Tensor
|
||||
from ...common.parameter import Parameter
|
||||
from ...ops import functional as F
|
||||
from ...ops import composite as C
|
||||
|
@ -201,7 +201,7 @@ class TrainOneStepWithLossScaleCell(Cell):
|
|||
super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
|
Loading…
Reference in New Issue