modify zoo transformer for train
This commit is contained in:
parent
6d7e352524
commit
e300aa8342
|
@ -18,7 +18,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -231,7 +231,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True,
|
||||
sens_param=True)
|
||||
|
|
Loading…
Reference in New Issue