forked from mindspore-Ecosystem/mindspore
remove ControlDepend from ternarybert
This commit is contained in:
parent
6cd1a0b484
commit
4ef7251e7f
|
@ -423,9 +423,9 @@ class Depend(Primitive):
|
|||
In order to ensure that operator A is executed before operator B, it is recommended to
|
||||
insert the Depend operator between operators A and B. The usage method is as follows::
|
||||
|
||||
out_a = A(in_a)
|
||||
in_b = Depend(in_b, out_a)
|
||||
out_b = B(in_b)
|
||||
a = A(x) ---> a = A(x)
|
||||
b = B(y) ---> y = Depend(y, a)
|
||||
---> b = B(y)
|
||||
|
||||
Inputs:
|
||||
- **value** (Tensor) - the real value to return for depend operator.
|
||||
|
|
|
@ -377,21 +377,16 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
saved = ()
|
||||
for i in range(self.length):
|
||||
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
||||
F.assign(self.saved_params[i], weights[i])
|
||||
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
||||
quant_embedding = F.depend(quant_embedding, saved)
|
||||
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
||||
input_ids = F.depend(input_ids, assign_embedding)
|
||||
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
||||
|
||||
for i in range(self.quant_weight_list_length):
|
||||
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
||||
quant_weight = F.depend(quant_weight, saved)
|
||||
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
||||
input_ids = F.depend(input_ids, assign_weight)
|
||||
F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
||||
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
|
@ -411,10 +406,10 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
||||
restore = ()
|
||||
|
||||
for i in range(self.length):
|
||||
weights[i] = F.depend(weights[i], grads)
|
||||
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
||||
param = F.depend(self.saved_params[i], grads)
|
||||
F.assign(weights[i], param)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
|
@ -431,7 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
succ = F.depend(succ, restore)
|
||||
return succ
|
||||
|
||||
|
||||
|
@ -490,21 +484,16 @@ class BertTrainCell(nn.Cell):
|
|||
label_ids):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
saved = ()
|
||||
for i in range(self.length):
|
||||
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
||||
F.assign(self.saved_params[i], weights[i])
|
||||
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
||||
quant_embedding = F.depend(quant_embedding, saved)
|
||||
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
||||
input_ids = F.depend(input_ids, assign_embedding)
|
||||
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
||||
|
||||
for i in range(self.quant_weight_list_length):
|
||||
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
||||
quant_weight = F.depend(quant_weight, saved)
|
||||
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
||||
input_ids = F.depend(input_ids, assign_weight)
|
||||
F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
||||
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
|
@ -515,11 +504,10 @@ class BertTrainCell(nn.Cell):
|
|||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
||||
restore = ()
|
||||
|
||||
for i in range(self.length):
|
||||
weights[i] = F.depend(weights[i], grads)
|
||||
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
||||
param = F.depend(self.saved_params[i], grads)
|
||||
F.assign(weights[i], param)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
succ = F.depend(succ, restore)
|
||||
return succ
|
||||
|
|
Loading…
Reference in New Issue