Add function and ops api for stop_gradient
This commit is contained in:
parent
9faaa60b01
commit
120ac9e7e8
|
@ -496,6 +496,7 @@ Parameter操作函数
|
|||
|
||||
mindspore.ops.derivative
|
||||
mindspore.ops.jet
|
||||
mindspore.ops.stop_gradient
|
||||
|
||||
调试函数
|
||||
----------------
|
||||
|
|
|
@ -599,6 +599,7 @@ Parameter操作算子
|
|||
mindspore.ops.Map
|
||||
mindspore.ops.MultitypeFuncGraph
|
||||
mindspore.ops.Partial
|
||||
mindspore.ops.StopGradient
|
||||
|
||||
算子信息注册
|
||||
-------------
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
mindspore.ops.StopGradient
|
||||
===========================
|
||||
|
||||
.. py:class:: mindspore.ops.StopGradient
|
||||
|
||||
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。
|
||||
|
||||
更多详情请查看: :class:`mindspore.ops.stop_gradient` 。
|
|
@ -0,0 +1,12 @@
|
|||
mindspore.ops.stop_gradient
|
||||
===========================
|
||||
|
||||
.. py:function:: mindspore.ops.stop_gradient(value)
|
||||
|
||||
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 `Stop Gradient <https://www.mindspore.cn/tutorials/zh-CN/master/beginner/autograd.html#stop-gradient>`_ 。
|
||||
|
||||
参数:
|
||||
- **value** (Any) - 需要被消除梯度影响的值。
|
||||
|
||||
返回:
|
||||
一个与 `value` 相同的值。
|
|
@ -496,6 +496,7 @@ Differential Functions
|
|||
|
||||
mindspore.ops.derivative
|
||||
mindspore.ops.jet
|
||||
mindspore.ops.stop_gradient
|
||||
|
||||
Debugging Functions
|
||||
-------------------
|
||||
|
|
|
@ -597,6 +597,7 @@ Frame Operators
|
|||
mindspore.ops.Map
|
||||
mindspore.ops.MultitypeFuncGraph
|
||||
mindspore.ops.Partial
|
||||
mindspore.ops.StopGradient
|
||||
|
||||
Operator Information Registration
|
||||
---------------------------------
|
||||
|
|
|
@ -606,7 +606,7 @@ constexpr char RESOLVE[] = "resolve";
|
|||
constexpr char EMBED[] = "embed";
|
||||
constexpr char CREATINSTANCE[] = "create_instance";
|
||||
constexpr char REF_TO_EMBED[] = "RefToEmbed";
|
||||
constexpr char STOP_GRADIENT[] = "stop_gradient";
|
||||
constexpr char STOP_GRADIENT[] = "StopGradient";
|
||||
constexpr char UPDATESTATE[] = "UpdateState";
|
||||
constexpr char LOAD[] = "Load";
|
||||
constexpr char OPPOSITE_RANK[] = "opposite_rank";
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace pynative {
|
||||
namespace {
|
||||
const std::set<std::string> kVmOperators = {"InsertGradientOf", "stop_gradient", "HookBackward", "CellBackwardHook"};
|
||||
const std::set<std::string> kVmOperators = {"InsertGradientOf", "StopGradient", "HookBackward", "CellBackwardHook"};
|
||||
constexpr char kBegin[] = "Begin";
|
||||
constexpr char kEnd[] = "End";
|
||||
enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
||||
|
|
|
@ -1555,7 +1555,7 @@ GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("stop_gradient"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("StopGradient"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs"));
|
||||
|
|
|
@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
"StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
#else
|
||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||
|
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"identity", "partial", "env_setitem", "env_getitem", "env_add",
|
||||
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
"StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||
#endif
|
||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
||||
prim::kPrimMicroStepAllGather};
|
||||
|
|
|
@ -176,8 +176,8 @@ def tuple_to_array(x):
|
|||
return Tensor(np.array(x))
|
||||
|
||||
|
||||
def stop_gradient(x):
|
||||
"""Implement `stop_gradient`."""
|
||||
def StopGradient(x):
|
||||
"""Implement `StopGradient`."""
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -159,9 +159,9 @@ def bprop_embed(x, out, dout):
|
|||
return (C.zeros_like(x),)
|
||||
|
||||
|
||||
@bprops.register("stop_gradient")
|
||||
@bprops.register("StopGradient")
|
||||
def bprop_stop_gradient(x, out, dout):
|
||||
"""Backpropagator for primitive `stop_gradient`."""
|
||||
"""Backpropagator for primitive `StopGradient`."""
|
||||
return (C.zeros_like(x),)
|
||||
|
||||
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,11 @@
|
|||
|
||||
0.1.1 MindSpore*2.0.0:Ì
|
||||
¼
|
||||
bprop_stop_gradient.1:x bprop_stop_gradient.1:[CNode]2:1 bprop_stop_gradient.1:[CNode]2:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:-Default/S-Prim-hyper_map[zeros_like_leaf]-op0
|
||||
<EFBFBD>
|
||||
bprop_stop_gradient.1:[CNode]2:1 bprop_stop_gradient.1:[CNode]3:2 bprop_stop_gradient.1:[CNode]3:2"REF::S-Prim-MakeTuple:3:Default/S-Prim-MakeTuple-op1bprop_stop_gradient.1*
|
||||
bprop_stop_gradient.1:x*
|
||||
bprop_stop_gradient.1:out*
|
||||
bprop_stop_gradient.1:dout2"
|
||||
bprop_stop_gradient.1:[CNode]3:2:@141f506cd32f226ad95f30747c497c54ce9a4c03452787e74867285f33b93439J/grad_implementations.pyPb&
|
||||
S-Prim-MakeTuple:3S-Prim-MakeTupleh
|
|
@ -1,11 +0,0 @@
|
|||
|
||||
0.1.1 MindSpore*2.0.0:…
|
||||
Î
|
||||
bprop_stop_gradient.1321:x&bprop_stop_gradient.1321:[CNode]1322:1&bprop_stop_gradient.1321:[CNode]1322:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:0Default/S-Prim-hyper_map[zeros_like_leaf]-op1024
|
||||
²
|
||||
&bprop_stop_gradient.1321:[CNode]1322:1&bprop_stop_gradient.1321:[CNode]1323:2&bprop_stop_gradient.1321:[CNode]1323:2"REF::S-Prim-MakeTuple:3:Default/S-Prim-MakeTuple-op1025bprop_stop_gradient.1321*
|
||||
bprop_stop_gradient.1321:x*
|
||||
bprop_stop_gradient.1321:out*
|
||||
bprop_stop_gradient.1321:dout2(
|
||||
&bprop_stop_gradient.1321:[CNode]1323:2:@35a8eb6f9d6633aef0da1b8dc181d6920f11a7695b002789ad4d8c166d922ac0J/grad_implementations.pyPb&
|
||||
S-Prim-MakeTuple:3S-Prim-MakeTupleh
|
|
@ -476,7 +476,7 @@ class _Grad(GradOperation_):
|
|||
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
||||
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
||||
res = (outputs[0],)
|
||||
stop_gradient = Primitive("stop_gradient")
|
||||
stop_gradient = Primitive("StopGradient")
|
||||
for item in outputs[1:]:
|
||||
res += (stop_gradient(item),)
|
||||
return res
|
||||
|
|
|
@ -446,7 +446,8 @@ from .grad import (
|
|||
jvp,
|
||||
vjp,
|
||||
custom_vjp,
|
||||
linearize
|
||||
linearize,
|
||||
stop_gradient
|
||||
)
|
||||
from .debug_func import (
|
||||
print_,
|
||||
|
|
|
@ -26,7 +26,8 @@ from .grad_func import (
|
|||
jvp,
|
||||
vjp,
|
||||
custom_vjp,
|
||||
linearize
|
||||
linearize,
|
||||
stop_gradient
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
|
|
@ -1299,6 +1299,43 @@ def custom_vjp(fn=None):
|
|||
return deco
|
||||
|
||||
|
||||
def stop_gradient(value):
|
||||
"""
|
||||
StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
|
||||
the gradient propagation from an output of a function.
|
||||
For more details, please refer to `Stop Gradient
|
||||
<https://www.mindspore.cn/tutorials/en/master/beginner/autograd.html#stop-gradient>`_.
|
||||
|
||||
Args:
|
||||
value (Any): The value whose effect on the gradient to be eliminated.
|
||||
|
||||
Returns:
|
||||
The same as `value`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> def net(x, y):
|
||||
... out1 = ops.MatMul()(x, y)
|
||||
... out2 = ops.MatMul()(x, y)
|
||||
... out2 = ops.stop_gradient(out2)
|
||||
... return out1, out2
|
||||
...
|
||||
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
>>> grad_fn = ops.grad(net)
|
||||
>>> output = grad_fn(x, y)
|
||||
>>> print(output)
|
||||
[[1.4100001 1.6 6.5999994]
|
||||
[1.4100001 1.6 6.5999994]]
|
||||
"""
|
||||
return P.StopGradient()(value)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'grad',
|
||||
'value_and_grad',
|
||||
|
@ -1309,6 +1346,7 @@ __all__ = [
|
|||
'jvp',
|
||||
'vjp',
|
||||
'custom_vjp',
|
||||
'linearize'
|
||||
'linearize',
|
||||
'stop_gradient'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -117,8 +117,6 @@ switch_layer = Primitive('switch_layer')
|
|||
reduced_shape = Primitive("reduced_shape")
|
||||
# shape_mul:input must be shape multiply elements in tuple(shape)
|
||||
shape_mul = Primitive("shape_mul")
|
||||
# a primitive to compare between tuple.
|
||||
stop_gradient = Primitive("stop_gradient")
|
||||
|
||||
tensor_operator_registry.register('add', P.Add)
|
||||
tensor_operator_registry.register('addr', addr)
|
||||
|
|
|
@ -98,7 +98,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
|
||||
ApplyAdamWithAmsgrad, AdaptiveAvgPool3D, AdaptiveMaxPool2D, AdaptiveMaxPool3D)
|
||||
from .other_ops import (Assign, IOU, BartlettWindow, BlackmanWindow, BoundingBoxDecode, BoundingBoxEncode,
|
||||
ConfusionMatrix, UpdateState, Load,
|
||||
ConfusionMatrix, UpdateState, Load, StopGradient,
|
||||
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale)
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
|
@ -339,6 +339,7 @@ __all__ = [
|
|||
'Partial',
|
||||
'Depend',
|
||||
'UpdateState',
|
||||
'StopGradient',
|
||||
'identity',
|
||||
'AvgPool',
|
||||
# Back Primitive
|
||||
|
|
|
@ -603,6 +603,40 @@ class UpdateState(Primitive):
|
|||
return state
|
||||
|
||||
|
||||
class StopGradient(Primitive):
|
||||
"""
|
||||
StopGradient is used for eliminating the effect of a value on the gradient,
|
||||
such as truncating the gradient propagation from an output of a function.
|
||||
|
||||
Refer to :func:`mindspore.ops.stop_gradient` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> def net(x, y):
|
||||
... out1 = ops.MatMul()(x, y)
|
||||
... out2 = ops.MatMul()(x, y)
|
||||
... out2 = ops.StopGradient()(out2)
|
||||
... return out1, out2
|
||||
...
|
||||
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||
>>> grad_fn = ops.grad(net)
|
||||
>>> output = grad_fn(x, y)
|
||||
>>> print(output)
|
||||
[[1.4100001 1.6 6.5999994]
|
||||
[1.4100001 1.6 6.5999994]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class ConfusionMatrix(PrimitiveWithInfer):
|
||||
r"""
|
||||
Calculates the confusion matrix from labels and predictions.
|
||||
|
|
Loading…
Reference in New Issue