fix bug of auto control depend for bert pre training
add comment
This commit is contained in:
parent
4b702c4c66
commit
2fdf692c2e
|
@ -403,9 +403,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
# alloc status
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -417,6 +414,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
|
Loading…
Reference in New Issue