forked from mindspore-Ecosystem/mindspore
!13769 remove control_depend from py file
From: @huangbingjian Reviewed-by: @hwhewei,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
669a32355c
|
@ -177,7 +177,6 @@ class PrimLib:
|
|||
'ReduceMax': Prim(REDUCE),
|
||||
'ReduceMin': Prim(REDUCE),
|
||||
'MakeTuple': Prim(CONTROL),
|
||||
'ControlDepend': Prim(CONTROL),
|
||||
'Assign': Prim(ELEMWISE),
|
||||
'Tanh': Prim(ELEMWISE),
|
||||
'ExpandDims': Prim(RESHAPE),
|
||||
|
|
|
@ -261,12 +261,6 @@ def bprop_bool_and(x, y, out, dout):
|
|||
return C.zeros_like(x), C.zeros_like(y)
|
||||
|
||||
|
||||
@bprops.register("ControlDepend")
|
||||
def bprop_control_depend(x, y, out, dout):
|
||||
"""Backpropagator for primitive `Control_depend`."""
|
||||
return C.zeros_like(x), C.zeros_like(y)
|
||||
|
||||
|
||||
@bprops.register("Switch")
|
||||
def bprop_switch(cond, tb, fb, out, dout):
|
||||
"""Backpropagator for primitive `switch`."""
|
||||
|
|
|
@ -42,11 +42,6 @@ shape = P.Shape()
|
|||
rank = P.Rank()
|
||||
reshape = P.Reshape()
|
||||
|
||||
# control_depend: represent dependency between two operators
|
||||
def control_depend(src, dst):
|
||||
control_depend_op = P.ControlDepend()
|
||||
return control_depend_op(src, dst)
|
||||
|
||||
merge = P.Merge()
|
||||
geswitch = P.GeSwitch()
|
||||
addn = P.AddN()
|
||||
|
|
|
@ -40,7 +40,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
|
|||
_HostAllGather, _HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print, Assert)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
from .control_ops import GeSwitch, Merge
|
||||
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey,
|
||||
FusedWeightScaleApplyMomentum, AdamWeightDecay)
|
||||
|
||||
|
@ -278,7 +278,6 @@ __all__ = [
|
|||
'ScalarToArray',
|
||||
'ScalarToTensor',
|
||||
'TupleToArray',
|
||||
'ControlDepend',
|
||||
'GeSwitch',
|
||||
'Merge',
|
||||
'SameTypeShape',
|
||||
|
|
|
@ -14,76 +14,9 @@
|
|||
# ============================================================================
|
||||
|
||||
"""control_ops"""
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import Rel
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ...common._decorator import deprecated
|
||||
|
||||
|
||||
class ControlDepend(Primitive):
|
||||
"""
|
||||
Adds control dependency relation between source and destination operations.
|
||||
|
||||
In many cases, we need to control the execution order of operations. ControlDepend is designed for this.
|
||||
ControlDepend will instruct the execution engine to run the operations in a specific order. ControlDepend
|
||||
tells the engine that the destination operations must depend on the source operation which means the source
|
||||
operations must be executed before the destination.
|
||||
|
||||
Note:
|
||||
This operation does not work in `PYNATIVE_MODE`.
|
||||
`ControlDepend` is deprecated from version 1.1 and will be removed in a future version, use `Depend` instead.
|
||||
Args:
|
||||
depend_mode (int): Use 0 for a normal dependency relation and 1 for a user-defined dependency relation.
|
||||
Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **src** (Any) - The source input. It can be a tuple of operations output or a single operation output. We do
|
||||
not concern about the input data, but concern about the operation that generates the input data.
|
||||
If `depend_mode` is 1 and the source input is Parameter, we will try to find the operations that
|
||||
used the parameter as input.
|
||||
- **dst** (Any) - The destination input. It can be a tuple of operations output or a single operation output.
|
||||
We do not concern about the input data, but concern about the operation that generates the input data.
|
||||
If `depend_mode` is 1 and the source input is Parameter, we will try to find the operations that
|
||||
used the parameter as input.
|
||||
|
||||
Outputs:
|
||||
This operation has no actual data output, it will be used to setup the order of relative operations.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.control_depend = P.ControlDepend()
|
||||
... self.softmax = ops.Softmax()
|
||||
...
|
||||
... def construct(self, x, y):
|
||||
... mul = x * y
|
||||
... softmax = self.softmax(x)
|
||||
... ret = self.control_depend(mul, softmax)
|
||||
... return ret
|
||||
...
|
||||
>>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
|
||||
>>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
|
||||
>>> net = Net()
|
||||
>>> output = net(x, y)
|
||||
>>> print(output)
|
||||
[[1. 1. 1. 1. 1.]
|
||||
[1. 1. 1. 1. 1.]
|
||||
[1. 1. 1. 1. 1.]
|
||||
[1. 1. 1. 1. 1.]]
|
||||
"""
|
||||
@deprecated("1.1", "Depend")
|
||||
@prim_attr_register
|
||||
def __init__(self, depend_mode=0):
|
||||
"""init"""
|
||||
validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name)
|
||||
|
||||
def __call__(self, src, dst):
|
||||
return src
|
||||
|
||||
|
||||
class GeSwitch(PrimitiveWithInfer):
|
||||
|
|
|
@ -420,16 +420,12 @@ class Depend(Primitive):
|
|||
Depend is used for processing dependency operations.
|
||||
|
||||
In some side-effect scenarios, we need to ensure the execution order of operators.
|
||||
In order to ensure that operator A is executed before operator B, it is recommended
|
||||
to insert the Depend operator between operators A and B.
|
||||
In order to ensure that operator A is executed before operator B, it is recommended to
|
||||
insert the Depend operator between operators A and B. The usage method is as follows::
|
||||
|
||||
Previously, the ControlDepend operator was used to control the execution order.
|
||||
Since the ControlDepend operator is deprecated from version 1.1, it is recommended
|
||||
to use the Depend operator instead. The replacement method is as follows::
|
||||
|
||||
a = A(x) ---> a = A(x)
|
||||
b = B(y) ---> y = Depend(y, a)
|
||||
ControlDepend(a, b) ---> b = B(y)
|
||||
out_a = A(in_a)
|
||||
in_b = Depend(in_b, out_a)
|
||||
out_b = B(in_b)
|
||||
|
||||
Inputs:
|
||||
- **value** (Tensor) - the real value to return for depend operator.
|
||||
|
|
|
@ -129,8 +129,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|||
op_sqrt = P.Sqrt()
|
||||
scatter_add = P.ScatterAdd(use_locking)
|
||||
|
||||
assign_m = F.assign(m, op_mul(beta1, m))
|
||||
assign_v = F.assign(v, op_mul(beta2, v))
|
||||
success = F.depend(success, F.assign(m, op_mul(beta1, m)))
|
||||
success = F.depend(success, F.assign(v, op_mul(beta2, v)))
|
||||
|
||||
grad_indices = gradient.indices
|
||||
grad_value = gradient.values
|
||||
|
@ -145,27 +145,18 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov,
|
|||
|
||||
if use_nesterov:
|
||||
m_temp = next_m * _scaler_ten
|
||||
assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
|
||||
F.assign(m, op_mul(beta1, next_m))
|
||||
div_value = scatter_add(m,
|
||||
op_mul(grad_indices, _scaler_one),
|
||||
op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
|
||||
param_update = div_value / (op_sqrt(next_v) + eps)
|
||||
|
||||
m_recover = F.assign(m, m_temp / _scaler_ten)
|
||||
|
||||
F.control_depend(m_temp, assign_m_nesterov)
|
||||
F.control_depend(assign_m_nesterov, div_value)
|
||||
F.control_depend(param_update, m_recover)
|
||||
F.assign(m, m_temp / _scaler_ten)
|
||||
else:
|
||||
param_update = next_m / (op_sqrt(next_v) + eps)
|
||||
|
||||
lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
|
||||
|
||||
next_param = param - lr_t * param_update
|
||||
|
||||
F.control_depend(assign_m, next_m)
|
||||
F.control_depend(assign_v, next_v)
|
||||
|
||||
success = F.depend(success, F.assign(param, next_param))
|
||||
success = F.depend(success, F.assign(m, next_m))
|
||||
success = F.depend(success, F.assign(v, next_v))
|
||||
|
|
|
@ -172,7 +172,6 @@ class GRUTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, \
|
||||
LessEqual, ControlDepend
|
||||
LessEqual
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore import Tensor
|
||||
|
@ -25,7 +25,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
|
@ -69,7 +69,6 @@ class TrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.base = Tensor(1, mstype.float32)
|
||||
self.reducer_flag = False
|
||||
self.less_equal = LessEqual()
|
||||
self.depend_parameter_use = ControlDepend(depend_mode=1)
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
self.grad_reducer = None
|
||||
|
|
|
@ -341,7 +341,6 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
@ -381,24 +380,24 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
saved = ()
|
||||
for i in range(self.length):
|
||||
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
||||
assign_embedding = ()
|
||||
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
||||
assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),)
|
||||
F.control_depend(saved, assign_embedding[i])
|
||||
assign_weight = ()
|
||||
quant_embedding = F.depend(quant_embedding, saved)
|
||||
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
||||
input_ids = F.depend(input_ids, assign_embedding)
|
||||
|
||||
for i in range(self.quant_weight_list_length):
|
||||
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
||||
assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),)
|
||||
F.control_depend(saved, assign_weight[i])
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
F.control_depend(assign_embedding[i], input_ids)
|
||||
for i in range(self.quant_weight_list_length):
|
||||
F.control_depend(assign_weight[i], input_ids)
|
||||
quant_weight = F.depend(quant_weight, saved)
|
||||
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
||||
input_ids = F.depend(input_ids, assign_weight)
|
||||
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# alloc status and clear should be right before grad operation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
|
@ -408,15 +407,15 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
label_ids,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
F.control_depend(input_ids, grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
||||
restore = ()
|
||||
for i in range(self.length):
|
||||
weights[i] = F.depend(weights[i], grads)
|
||||
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
||||
F.control_depend(grads, restore[i])
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
|
@ -432,8 +431,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
|
|||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
for i in range(self.length):
|
||||
F.control_depend(restore[i], succ)
|
||||
succ = F.depend(succ, restore)
|
||||
return succ
|
||||
|
||||
|
||||
|
@ -495,35 +493,33 @@ class BertTrainCell(nn.Cell):
|
|||
saved = ()
|
||||
for i in range(self.length):
|
||||
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
|
||||
assign_embedding = ()
|
||||
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
|
||||
assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),)
|
||||
F.control_depend(saved, assign_embedding[i])
|
||||
assign_weight = ()
|
||||
quant_embedding = F.depend(quant_embedding, saved)
|
||||
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
|
||||
input_ids = F.depend(input_ids, assign_embedding)
|
||||
|
||||
for i in range(self.quant_weight_list_length):
|
||||
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
|
||||
assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),)
|
||||
F.control_depend(saved, assign_weight[i])
|
||||
for i in range(self.quant_embedding_list_length):
|
||||
F.control_depend(assign_embedding[i], input_ids)
|
||||
for i in range(self.quant_weight_list_length):
|
||||
F.control_depend(assign_weight[i], input_ids)
|
||||
quant_weight = F.depend(quant_weight, saved)
|
||||
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
|
||||
input_ids = F.depend(input_ids, assign_weight)
|
||||
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
label_ids,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
F.control_depend(input_ids, grads)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
|
||||
restore = ()
|
||||
for i in range(self.length):
|
||||
weights[i] = F.depend(weights[i], grads)
|
||||
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
|
||||
F.control_depend(grads, restore[i])
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
for i in range(self.length):
|
||||
F.control_depend(restore[i], succ)
|
||||
succ = F.depend(succ, restore)
|
||||
return succ
|
||||
|
|
|
@ -399,7 +399,6 @@ class MixControlNet(Cell):
|
|||
kernel_size=1, stride=1, has_bias=False,
|
||||
weight_init='ones', pad_mode='same')
|
||||
self.bn = BatchNorm2d(num_features=in_channel)
|
||||
self.controldepend = P.ControlDepend()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.assign = P.Assign()
|
||||
self.relu = ReLU()
|
||||
|
@ -428,9 +427,8 @@ class MixControlNet(Cell):
|
|||
if x < 20:
|
||||
out = self.biasadd(out, self.bias)
|
||||
if x % 2 == 0:
|
||||
self.assignadd(self.bias, self.value)
|
||||
out = self.biasadd(out, self.bias)
|
||||
assign = self.assignadd(self.bias, self.value)
|
||||
self.controldepend(assign, out)
|
||||
out = self.bn(out)
|
||||
else:
|
||||
out = self.conv(out)
|
||||
|
|
|
@ -33,14 +33,6 @@ def reduce_graph(shape, reduce_axis):
|
|||
gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis})
|
||||
return gb.get()[0]
|
||||
|
||||
def control_graph(shape):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('control') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
gb.emit('ControlDepend', a2)
|
||||
return gb.get()[0]
|
||||
|
||||
def block_fusion(graphs):
|
||||
gain = model.parallel_estimate(graphs)
|
||||
print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain))
|
||||
|
@ -51,4 +43,3 @@ if __name__ == "__main__":
|
|||
assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])])
|
||||
assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])])
|
||||
assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])])
|
||||
assert block_fusion([control_graph([20, 128]), injective_graph([40, 1024])])
|
||||
|
|
Loading…
Reference in New Issue