forked from mindspore-Ecosystem/mindspore
!6161 fix a bug with add set_grad in wide_and_deep network
Merge pull request !6161 from lvchangquan/master
This commit is contained in:
commit
92bd24304a
|
@ -138,6 +138,7 @@ class TrainOneStepCell(nn.Cell):
|
|||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCell, self).__init__(auto_prefix=True)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
|
@ -167,7 +168,6 @@ class TrainGAT(nn.Cell):
|
|||
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff):
|
||||
super(TrainGAT, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff)
|
||||
optimizer = nn.Adam(loss_net.trainable_params(),
|
||||
learning_rate=learning_rate)
|
||||
|
|
|
@ -328,7 +328,6 @@ class TrainStepWrap(nn.Cell):
|
|||
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.set_train()
|
||||
self.trainable_params = network.trainable_params()
|
||||
weights_w = []
|
||||
|
@ -361,6 +360,8 @@ class TrainStepWrap(nn.Cell):
|
|||
self.sens = sens
|
||||
self.loss_net_w = IthOutputCell(network, output_index=0)
|
||||
self.loss_net_d = IthOutputCell(network, output_index=1)
|
||||
self.loss_net_w.set_grad()
|
||||
self.loss_net_d.set_grad()
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer_w = None
|
||||
|
|
|
@ -509,7 +509,6 @@ class TrainStepWrap(nn.Cell):
|
|||
def __init__(self, network, config, sens=1000.0):
|
||||
super(TrainStepWrap, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.set_train()
|
||||
self.trainable_params = network.trainable_params()
|
||||
weights_w = []
|
||||
|
@ -544,6 +543,8 @@ class TrainStepWrap(nn.Cell):
|
|||
self.sens = sens
|
||||
self.loss_net_w = IthOutputCell(network, output_index=0)
|
||||
self.loss_net_d = IthOutputCell(network, output_index=1)
|
||||
self.loss_net_w.set_grad()
|
||||
self.loss_net_w.set_grad()
|
||||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer_w = None
|
||||
|
|
Loading…
Reference in New Issue