Add function and ops api for stop_gradient

This commit is contained in:
yujianfeng 2022-11-11 14:37:39 +08:00
parent 9faaa60b01
commit 120ac9e7e8
22 changed files with 124 additions and 27 deletions

View File

@ -496,6 +496,7 @@ Parameter操作函数
mindspore.ops.derivative mindspore.ops.derivative
mindspore.ops.jet mindspore.ops.jet
mindspore.ops.stop_gradient
调试函数 调试函数
---------------- ----------------

View File

@ -599,6 +599,7 @@ Parameter操作算子
mindspore.ops.Map mindspore.ops.Map
mindspore.ops.MultitypeFuncGraph mindspore.ops.MultitypeFuncGraph
mindspore.ops.Partial mindspore.ops.Partial
mindspore.ops.StopGradient
算子信息注册 算子信息注册
------------- -------------

View File

@ -0,0 +1,8 @@
mindspore.ops.StopGradient
===========================
.. py:class:: mindspore.ops.StopGradient
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。
更多详情请查看: :class:`mindspore.ops.stop_gradient`

View File

@ -0,0 +1,12 @@
mindspore.ops.stop_gradient
===========================
.. py:function:: mindspore.ops.stop_gradient(value)
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 `Stop Gradient <https://www.mindspore.cn/tutorials/zh-CN/master/beginner/autograd.html#stop-gradient>`_
参数:
- **value** (Any) - 需要被消除梯度影响的值。
返回:
一个与 `value` 相同的值。

View File

@ -496,6 +496,7 @@ Differential Functions
mindspore.ops.derivative mindspore.ops.derivative
mindspore.ops.jet mindspore.ops.jet
mindspore.ops.stop_gradient
Debugging Functions Debugging Functions
------------------- -------------------

View File

@ -597,6 +597,7 @@ Frame Operators
mindspore.ops.Map mindspore.ops.Map
mindspore.ops.MultitypeFuncGraph mindspore.ops.MultitypeFuncGraph
mindspore.ops.Partial mindspore.ops.Partial
mindspore.ops.StopGradient
Operator Information Registration Operator Information Registration
--------------------------------- ---------------------------------

View File

@ -606,7 +606,7 @@ constexpr char RESOLVE[] = "resolve";
constexpr char EMBED[] = "embed"; constexpr char EMBED[] = "embed";
constexpr char CREATINSTANCE[] = "create_instance"; constexpr char CREATINSTANCE[] = "create_instance";
constexpr char REF_TO_EMBED[] = "RefToEmbed"; constexpr char REF_TO_EMBED[] = "RefToEmbed";
constexpr char STOP_GRADIENT[] = "stop_gradient"; constexpr char STOP_GRADIENT[] = "StopGradient";
constexpr char UPDATESTATE[] = "UpdateState"; constexpr char UPDATESTATE[] = "UpdateState";
constexpr char LOAD[] = "Load"; constexpr char LOAD[] = "Load";
constexpr char OPPOSITE_RANK[] = "opposite_rank"; constexpr char OPPOSITE_RANK[] = "opposite_rank";

View File

@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace pynative { namespace pynative {
namespace { namespace {
const std::set<std::string> kVmOperators = {"InsertGradientOf", "stop_gradient", "HookBackward", "CellBackwardHook"}; const std::set<std::string> kVmOperators = {"InsertGradientOf", "StopGradient", "HookBackward", "CellBackwardHook"};
constexpr char kBegin[] = "Begin"; constexpr char kBegin[] = "Begin";
constexpr char kEnd[] = "End"; constexpr char kEnd[] = "End";
enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };

View File

@ -1555,7 +1555,7 @@ GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal")); GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal")); GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range")); GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range"));
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("stop_gradient")); GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("StopGradient"));
GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len")); GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop")); GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop"));
GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs")); GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs"));

View File

@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed", "InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"}; "StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#else #else
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem", static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem", "array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"identity", "partial", "env_setitem", "env_getitem", "env_add", "identity", "partial", "env_setitem", "env_getitem", "env_add",
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "resolve", "BroadcastGradientArgs", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed", "InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"}; "StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
#endif #endif
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather, static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
prim::kPrimMicroStepAllGather}; prim::kPrimMicroStepAllGather};

View File

@ -176,8 +176,8 @@ def tuple_to_array(x):
return Tensor(np.array(x)) return Tensor(np.array(x))
def stop_gradient(x): def StopGradient(x):
"""Implement `stop_gradient`.""" """Implement `StopGradient`."""
return x return x

View File

@ -159,9 +159,9 @@ def bprop_embed(x, out, dout):
return (C.zeros_like(x),) return (C.zeros_like(x),)
@bprops.register("stop_gradient") @bprops.register("StopGradient")
def bprop_stop_gradient(x, out, dout): def bprop_stop_gradient(x, out, dout):
"""Backpropagator for primitive `stop_gradient`.""" """Backpropagator for primitive `StopGradient`."""
return (C.zeros_like(x),) return (C.zeros_like(x),)

View File

@ -0,0 +1,11 @@
0.1.1 MindSpore*2.0.0:Ì
¼
bprop_stop_gradient.1:x bprop_stop_gradient.1:[CNode]2:1 bprop_stop_gradient.1:[CNode]2:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:-Default/S-Prim-hyper_map[zeros_like_leaf]-op0
<EFBFBD>
bprop_stop_gradient.1:[CNode]2:1 bprop_stop_gradient.1:[CNode]3:2 bprop_stop_gradient.1:[CNode]3:2"REF::S-Prim-MakeTuple:3:Default/S-Prim-MakeTuple-op1bprop_stop_gradient.1*
bprop_stop_gradient.1:x*
bprop_stop_gradient.1:out*
bprop_stop_gradient.1:dout2"
bprop_stop_gradient.1:[CNode]3:2:@141f506cd32f226ad95f30747c497c54ce9a4c03452787e74867285f33b93439J/grad_implementations.pyPb&
S-Prim-MakeTuple:3S-Prim-MakeTupleh

View File

@ -1,11 +0,0 @@
0.1.1 MindSpore*2.0.0:…
Î
bprop_stop_gradient.1321:x&bprop_stop_gradient.1321:[CNode]1322:1&bprop_stop_gradient.1321:[CNode]1322:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:0Default/S-Prim-hyper_map[zeros_like_leaf]-op1024
²
&bprop_stop_gradient.1321:[CNode]1322:1&bprop_stop_gradient.1321:[CNode]1323:2&bprop_stop_gradient.1321:[CNode]1323:2"REF::S-Prim-MakeTuple:3:Default/S-Prim-MakeTuple-op1025bprop_stop_gradient.1321*
bprop_stop_gradient.1321:x*
bprop_stop_gradient.1321:out*
bprop_stop_gradient.1321:dout2(
&bprop_stop_gradient.1321:[CNode]1323:2:@35a8eb6f9d6633aef0da1b8dc181d6920f11a7695b002789ad4d8c166d922ac0J/grad_implementations.pyPb&
S-Prim-MakeTuple:3S-Prim-MakeTupleh

View File

@ -476,7 +476,7 @@ class _Grad(GradOperation_):
if not isinstance(outputs, tuple) or len(outputs) < 2: if not isinstance(outputs, tuple) or len(outputs) < 2:
raise ValueError("When has_aux is True, origin fn requires more than one outputs.") raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
res = (outputs[0],) res = (outputs[0],)
stop_gradient = Primitive("stop_gradient") stop_gradient = Primitive("StopGradient")
for item in outputs[1:]: for item in outputs[1:]:
res += (stop_gradient(item),) res += (stop_gradient(item),)
return res return res

View File

@ -446,7 +446,8 @@ from .grad import (
jvp, jvp,
vjp, vjp,
custom_vjp, custom_vjp,
linearize linearize,
stop_gradient
) )
from .debug_func import ( from .debug_func import (
print_, print_,

View File

@ -26,7 +26,8 @@ from .grad_func import (
jvp, jvp,
vjp, vjp,
custom_vjp, custom_vjp,
linearize linearize,
stop_gradient
) )
__all__ = [] __all__ = []

View File

@ -1299,6 +1299,43 @@ def custom_vjp(fn=None):
return deco return deco
def stop_gradient(value):
"""
StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
the gradient propagation from an output of a function.
For more details, please refer to `Stop Gradient
<https://www.mindspore.cn/tutorials/en/master/beginner/autograd.html#stop-gradient>`_.
Args:
value (Any): The value whose effect on the gradient to be eliminated.
Returns:
The same as `value`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> def net(x, y):
... out1 = ops.MatMul()(x, y)
... out2 = ops.MatMul()(x, y)
... out2 = ops.stop_gradient(out2)
... return out1, out2
...
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> grad_fn = ops.grad(net)
>>> output = grad_fn(x, y)
>>> print(output)
[[1.4100001 1.6 6.5999994]
[1.4100001 1.6 6.5999994]]
"""
return P.StopGradient()(value)
__all__ = [ __all__ = [
'grad', 'grad',
'value_and_grad', 'value_and_grad',
@ -1309,6 +1346,7 @@ __all__ = [
'jvp', 'jvp',
'vjp', 'vjp',
'custom_vjp', 'custom_vjp',
'linearize' 'linearize',
'stop_gradient'
] ]
__all__.sort() __all__.sort()

View File

@ -117,8 +117,6 @@ switch_layer = Primitive('switch_layer')
reduced_shape = Primitive("reduced_shape") reduced_shape = Primitive("reduced_shape")
# shape_mul:input must be shape multiply elements in tuple(shape) # shape_mul:input must be shape multiply elements in tuple(shape)
shape_mul = Primitive("shape_mul") shape_mul = Primitive("shape_mul")
# a primitive to compare between tuple.
stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('add', P.Add) tensor_operator_registry.register('add', P.Add)
tensor_operator_registry.register('addr', addr) tensor_operator_registry.register('addr', addr)

View File

@ -98,7 +98,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
ApplyAdamWithAmsgrad, AdaptiveAvgPool3D, AdaptiveMaxPool2D, AdaptiveMaxPool3D) ApplyAdamWithAmsgrad, AdaptiveAvgPool3D, AdaptiveMaxPool2D, AdaptiveMaxPool3D)
from .other_ops import (Assign, IOU, BartlettWindow, BlackmanWindow, BoundingBoxDecode, BoundingBoxEncode, from .other_ops import (Assign, IOU, BartlettWindow, BlackmanWindow, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, UpdateState, Load, ConfusionMatrix, UpdateState, Load, StopGradient,
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale) CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal, from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler, RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -339,6 +339,7 @@ __all__ = [
'Partial', 'Partial',
'Depend', 'Depend',
'UpdateState', 'UpdateState',
'StopGradient',
'identity', 'identity',
'AvgPool', 'AvgPool',
# Back Primitive # Back Primitive

View File

@ -603,6 +603,40 @@ class UpdateState(Primitive):
return state return state
class StopGradient(Primitive):
"""
StopGradient is used for eliminating the effect of a value on the gradient,
such as truncating the gradient propagation from an output of a function.
Refer to :func:`mindspore.ops.stop_gradient` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore.ops as ops
>>> from mindspore import Tensor
>>> from mindspore import dtype as mstype
>>> def net(x, y):
... out1 = ops.MatMul()(x, y)
... out2 = ops.MatMul()(x, y)
... out2 = ops.StopGradient()(out2)
... return out1, out2
...
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
>>> grad_fn = ops.grad(net)
>>> output = grad_fn(x, y)
>>> print(output)
[[1.4100001 1.6 6.5999994]
[1.4100001 1.6 6.5999994]]
"""
@prim_attr_register
def __init__(self):
pass
class ConfusionMatrix(PrimitiveWithInfer): class ConfusionMatrix(PrimitiveWithInfer):
r""" r"""
Calculates the confusion matrix from labels and predictions. Calculates the confusion matrix from labels and predictions.