Add function and ops api for stop_gradient
This commit is contained in:
parent
9faaa60b01
commit
120ac9e7e8
|
@ -496,6 +496,7 @@ Parameter操作函数
|
||||||
|
|
||||||
mindspore.ops.derivative
|
mindspore.ops.derivative
|
||||||
mindspore.ops.jet
|
mindspore.ops.jet
|
||||||
|
mindspore.ops.stop_gradient
|
||||||
|
|
||||||
调试函数
|
调试函数
|
||||||
----------------
|
----------------
|
||||||
|
|
|
@ -599,6 +599,7 @@ Parameter操作算子
|
||||||
mindspore.ops.Map
|
mindspore.ops.Map
|
||||||
mindspore.ops.MultitypeFuncGraph
|
mindspore.ops.MultitypeFuncGraph
|
||||||
mindspore.ops.Partial
|
mindspore.ops.Partial
|
||||||
|
mindspore.ops.StopGradient
|
||||||
|
|
||||||
算子信息注册
|
算子信息注册
|
||||||
-------------
|
-------------
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
mindspore.ops.StopGradient
|
||||||
|
===========================
|
||||||
|
|
||||||
|
.. py:class:: mindspore.ops.StopGradient
|
||||||
|
|
||||||
|
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。
|
||||||
|
|
||||||
|
更多详情请查看: :class:`mindspore.ops.stop_gradient` 。
|
|
@ -0,0 +1,12 @@
|
||||||
|
mindspore.ops.stop_gradient
|
||||||
|
===========================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.stop_gradient(value)
|
||||||
|
|
||||||
|
用于消除某个值对梯度的影响,例如截断来自于函数输出的梯度传播。更多细节请参考 `Stop Gradient <https://www.mindspore.cn/tutorials/zh-CN/master/beginner/autograd.html#stop-gradient>`_ 。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **value** (Any) - 需要被消除梯度影响的值。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
一个与 `value` 相同的值。
|
|
@ -496,6 +496,7 @@ Differential Functions
|
||||||
|
|
||||||
mindspore.ops.derivative
|
mindspore.ops.derivative
|
||||||
mindspore.ops.jet
|
mindspore.ops.jet
|
||||||
|
mindspore.ops.stop_gradient
|
||||||
|
|
||||||
Debugging Functions
|
Debugging Functions
|
||||||
-------------------
|
-------------------
|
||||||
|
|
|
@ -597,6 +597,7 @@ Frame Operators
|
||||||
mindspore.ops.Map
|
mindspore.ops.Map
|
||||||
mindspore.ops.MultitypeFuncGraph
|
mindspore.ops.MultitypeFuncGraph
|
||||||
mindspore.ops.Partial
|
mindspore.ops.Partial
|
||||||
|
mindspore.ops.StopGradient
|
||||||
|
|
||||||
Operator Information Registration
|
Operator Information Registration
|
||||||
---------------------------------
|
---------------------------------
|
||||||
|
|
|
@ -606,7 +606,7 @@ constexpr char RESOLVE[] = "resolve";
|
||||||
constexpr char EMBED[] = "embed";
|
constexpr char EMBED[] = "embed";
|
||||||
constexpr char CREATINSTANCE[] = "create_instance";
|
constexpr char CREATINSTANCE[] = "create_instance";
|
||||||
constexpr char REF_TO_EMBED[] = "RefToEmbed";
|
constexpr char REF_TO_EMBED[] = "RefToEmbed";
|
||||||
constexpr char STOP_GRADIENT[] = "stop_gradient";
|
constexpr char STOP_GRADIENT[] = "StopGradient";
|
||||||
constexpr char UPDATESTATE[] = "UpdateState";
|
constexpr char UPDATESTATE[] = "UpdateState";
|
||||||
constexpr char LOAD[] = "Load";
|
constexpr char LOAD[] = "Load";
|
||||||
constexpr char OPPOSITE_RANK[] = "opposite_rank";
|
constexpr char OPPOSITE_RANK[] = "opposite_rank";
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
namespace {
|
namespace {
|
||||||
const std::set<std::string> kVmOperators = {"InsertGradientOf", "stop_gradient", "HookBackward", "CellBackwardHook"};
|
const std::set<std::string> kVmOperators = {"InsertGradientOf", "StopGradient", "HookBackward", "CellBackwardHook"};
|
||||||
constexpr char kBegin[] = "Begin";
|
constexpr char kBegin[] = "Begin";
|
||||||
constexpr char kEnd[] = "End";
|
constexpr char kEnd[] = "End";
|
||||||
enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
enum class RunOpArgsEnum : size_t { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
||||||
|
|
|
@ -1555,7 +1555,7 @@ GVAR_DEF(PrimitivePtr, kPrimShapeMul, std::make_shared<Primitive>("shape_mul"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
|
GVAR_DEF(PrimitivePtr, kPrimTupleEqual, std::make_shared<Primitive>("tuple_equal"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
|
GVAR_DEF(PrimitivePtr, kPrimListEqual, std::make_shared<Primitive>("list_equal"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range"));
|
GVAR_DEF(PrimitivePtr, kPrimMakeRange, std::make_shared<Primitive>("make_range"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("stop_gradient"));
|
GVAR_DEF(PrimitivePtr, kPrimStopGradient, std::make_shared<Primitive>("StopGradient"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
|
GVAR_DEF(PrimitivePtr, kPrimDictLen, std::make_shared<Primitive>("dict_len"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop"));
|
GVAR_DEF(PrimitivePtr, kPrimFakeBprop, std::make_shared<Primitive>("fake_bprop"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs"));
|
GVAR_DEF(PrimitivePtr, kPrimBroadcastGradientArgs, std::make_shared<Primitive>("BroadcastGradientArgs"));
|
||||||
|
|
|
@ -31,7 +31,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
||||||
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||||
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
"StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||||
#else
|
#else
|
||||||
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "J", "list_getitem",
|
||||||
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
"array_getitem", "tuple_setitem", "Depend", "list_setitem", "array_setitem", "dict_getitem",
|
||||||
|
@ -40,7 +40,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
||||||
"identity", "partial", "env_setitem", "env_getitem", "env_add",
|
"identity", "partial", "env_setitem", "env_getitem", "env_add",
|
||||||
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
"dot", "im2col", "col2im", "im2col_v1", "state_setitem", "Debug", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||||
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
"InvertPermutation", "DropoutGenMask", "StatelessDropOutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||||
"stop_gradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
"StopGradient", "UpdateState", "Load", "Switch", "Print", "call_instance"};
|
||||||
#endif
|
#endif
|
||||||
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
static const std::set<PrimitivePtr> ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather,
|
||||||
prim::kPrimMicroStepAllGather};
|
prim::kPrimMicroStepAllGather};
|
||||||
|
|
|
@ -176,8 +176,8 @@ def tuple_to_array(x):
|
||||||
return Tensor(np.array(x))
|
return Tensor(np.array(x))
|
||||||
|
|
||||||
|
|
||||||
def stop_gradient(x):
|
def StopGradient(x):
|
||||||
"""Implement `stop_gradient`."""
|
"""Implement `StopGradient`."""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -159,9 +159,9 @@ def bprop_embed(x, out, dout):
|
||||||
return (C.zeros_like(x),)
|
return (C.zeros_like(x),)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("stop_gradient")
|
@bprops.register("StopGradient")
|
||||||
def bprop_stop_gradient(x, out, dout):
|
def bprop_stop_gradient(x, out, dout):
|
||||||
"""Backpropagator for primitive `stop_gradient`."""
|
"""Backpropagator for primitive `StopGradient`."""
|
||||||
return (C.zeros_like(x),)
|
return (C.zeros_like(x),)
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,11 @@
|
||||||
|
|
||||||
|
0.1.1 MindSpore*2.0.0:Ì
|
||||||
|
¼
|
||||||
|
bprop_stop_gradient.1:x bprop_stop_gradient.1:[CNode]2:1 bprop_stop_gradient.1:[CNode]2:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:-Default/S-Prim-hyper_map[zeros_like_leaf]-op0
|
||||||
|
<EFBFBD>
|
||||||
|
bprop_stop_gradient.1:[CNode]2:1 bprop_stop_gradient.1:[CNode]3:2 bprop_stop_gradient.1:[CNode]3:2"REF::S-Prim-MakeTuple:3:Default/S-Prim-MakeTuple-op1bprop_stop_gradient.1*
|
||||||
|
bprop_stop_gradient.1:x*
|
||||||
|
bprop_stop_gradient.1:out*
|
||||||
|
bprop_stop_gradient.1:dout2"
|
||||||
|
bprop_stop_gradient.1:[CNode]3:2:@141f506cd32f226ad95f30747c497c54ce9a4c03452787e74867285f33b93439J/grad_implementations.pyPb&
|
||||||
|
S-Prim-MakeTuple:3S-Prim-MakeTupleh
|
|
@ -1,11 +0,0 @@
|
||||||
|
|
||||||
0.1.1 MindSpore*2.0.0:…
|
|
||||||
Î
|
|
||||||
bprop_stop_gradient.1321:x&bprop_stop_gradient.1321:[CNode]1322:1&bprop_stop_gradient.1321:[CNode]1322:1".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:0Default/S-Prim-hyper_map[zeros_like_leaf]-op1024
|
|
||||||
²
|
|
||||||
&bprop_stop_gradient.1321:[CNode]1322:1&bprop_stop_gradient.1321:[CNode]1323:2&bprop_stop_gradient.1321:[CNode]1323:2"REF::S-Prim-MakeTuple:3:Default/S-Prim-MakeTuple-op1025bprop_stop_gradient.1321*
|
|
||||||
bprop_stop_gradient.1321:x*
|
|
||||||
bprop_stop_gradient.1321:out*
|
|
||||||
bprop_stop_gradient.1321:dout2(
|
|
||||||
&bprop_stop_gradient.1321:[CNode]1323:2:@35a8eb6f9d6633aef0da1b8dc181d6920f11a7695b002789ad4d8c166d922ac0J/grad_implementations.pyPb&
|
|
||||||
S-Prim-MakeTuple:3S-Prim-MakeTupleh
|
|
|
@ -476,7 +476,7 @@ class _Grad(GradOperation_):
|
||||||
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
if not isinstance(outputs, tuple) or len(outputs) < 2:
|
||||||
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
raise ValueError("When has_aux is True, origin fn requires more than one outputs.")
|
||||||
res = (outputs[0],)
|
res = (outputs[0],)
|
||||||
stop_gradient = Primitive("stop_gradient")
|
stop_gradient = Primitive("StopGradient")
|
||||||
for item in outputs[1:]:
|
for item in outputs[1:]:
|
||||||
res += (stop_gradient(item),)
|
res += (stop_gradient(item),)
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -446,7 +446,8 @@ from .grad import (
|
||||||
jvp,
|
jvp,
|
||||||
vjp,
|
vjp,
|
||||||
custom_vjp,
|
custom_vjp,
|
||||||
linearize
|
linearize,
|
||||||
|
stop_gradient
|
||||||
)
|
)
|
||||||
from .debug_func import (
|
from .debug_func import (
|
||||||
print_,
|
print_,
|
||||||
|
|
|
@ -26,7 +26,8 @@ from .grad_func import (
|
||||||
jvp,
|
jvp,
|
||||||
vjp,
|
vjp,
|
||||||
custom_vjp,
|
custom_vjp,
|
||||||
linearize
|
linearize,
|
||||||
|
stop_gradient
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
|
@ -1299,6 +1299,43 @@ def custom_vjp(fn=None):
|
||||||
return deco
|
return deco
|
||||||
|
|
||||||
|
|
||||||
|
def stop_gradient(value):
|
||||||
|
"""
|
||||||
|
StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
|
||||||
|
the gradient propagation from an output of a function.
|
||||||
|
For more details, please refer to `Stop Gradient
|
||||||
|
<https://www.mindspore.cn/tutorials/en/master/beginner/autograd.html#stop-gradient>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (Any): The value whose effect on the gradient to be eliminated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same as `value`.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.ops as ops
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> from mindspore import dtype as mstype
|
||||||
|
>>> def net(x, y):
|
||||||
|
... out1 = ops.MatMul()(x, y)
|
||||||
|
... out2 = ops.MatMul()(x, y)
|
||||||
|
... out2 = ops.stop_gradient(out2)
|
||||||
|
... return out1, out2
|
||||||
|
...
|
||||||
|
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
>>> grad_fn = ops.grad(net)
|
||||||
|
>>> output = grad_fn(x, y)
|
||||||
|
>>> print(output)
|
||||||
|
[[1.4100001 1.6 6.5999994]
|
||||||
|
[1.4100001 1.6 6.5999994]]
|
||||||
|
"""
|
||||||
|
return P.StopGradient()(value)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'grad',
|
'grad',
|
||||||
'value_and_grad',
|
'value_and_grad',
|
||||||
|
@ -1309,6 +1346,7 @@ __all__ = [
|
||||||
'jvp',
|
'jvp',
|
||||||
'vjp',
|
'vjp',
|
||||||
'custom_vjp',
|
'custom_vjp',
|
||||||
'linearize'
|
'linearize',
|
||||||
|
'stop_gradient'
|
||||||
]
|
]
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -117,8 +117,6 @@ switch_layer = Primitive('switch_layer')
|
||||||
reduced_shape = Primitive("reduced_shape")
|
reduced_shape = Primitive("reduced_shape")
|
||||||
# shape_mul:input must be shape multiply elements in tuple(shape)
|
# shape_mul:input must be shape multiply elements in tuple(shape)
|
||||||
shape_mul = Primitive("shape_mul")
|
shape_mul = Primitive("shape_mul")
|
||||||
# a primitive to compare between tuple.
|
|
||||||
stop_gradient = Primitive("stop_gradient")
|
|
||||||
|
|
||||||
tensor_operator_registry.register('add', P.Add)
|
tensor_operator_registry.register('add', P.Add)
|
||||||
tensor_operator_registry.register('addr', addr)
|
tensor_operator_registry.register('addr', addr)
|
||||||
|
|
|
@ -98,7 +98,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
||||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
|
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink,
|
||||||
ApplyAdamWithAmsgrad, AdaptiveAvgPool3D, AdaptiveMaxPool2D, AdaptiveMaxPool3D)
|
ApplyAdamWithAmsgrad, AdaptiveAvgPool3D, AdaptiveMaxPool2D, AdaptiveMaxPool3D)
|
||||||
from .other_ops import (Assign, IOU, BartlettWindow, BlackmanWindow, BoundingBoxDecode, BoundingBoxEncode,
|
from .other_ops import (Assign, IOU, BartlettWindow, BlackmanWindow, BoundingBoxDecode, BoundingBoxEncode,
|
||||||
ConfusionMatrix, UpdateState, Load,
|
ConfusionMatrix, UpdateState, Load, StopGradient,
|
||||||
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale)
|
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale)
|
||||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
|
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
|
||||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||||
|
@ -339,6 +339,7 @@ __all__ = [
|
||||||
'Partial',
|
'Partial',
|
||||||
'Depend',
|
'Depend',
|
||||||
'UpdateState',
|
'UpdateState',
|
||||||
|
'StopGradient',
|
||||||
'identity',
|
'identity',
|
||||||
'AvgPool',
|
'AvgPool',
|
||||||
# Back Primitive
|
# Back Primitive
|
||||||
|
|
|
@ -603,6 +603,40 @@ class UpdateState(Primitive):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class StopGradient(Primitive):
|
||||||
|
"""
|
||||||
|
StopGradient is used for eliminating the effect of a value on the gradient,
|
||||||
|
such as truncating the gradient propagation from an output of a function.
|
||||||
|
|
||||||
|
Refer to :func:`mindspore.ops.stop_gradient` for more detail.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.ops as ops
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> from mindspore import dtype as mstype
|
||||||
|
>>> def net(x, y):
|
||||||
|
... out1 = ops.MatMul()(x, y)
|
||||||
|
... out2 = ops.MatMul()(x, y)
|
||||||
|
... out2 = ops.StopGradient()(out2)
|
||||||
|
... return out1, out2
|
||||||
|
...
|
||||||
|
>>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
|
||||||
|
>>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
|
||||||
|
>>> grad_fn = ops.grad(net)
|
||||||
|
>>> output = grad_fn(x, y)
|
||||||
|
>>> print(output)
|
||||||
|
[[1.4100001 1.6 6.5999994]
|
||||||
|
[1.4100001 1.6 6.5999994]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrix(PrimitiveWithInfer):
|
class ConfusionMatrix(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
Calculates the confusion matrix from labels and predictions.
|
Calculates the confusion matrix from labels and predictions.
|
||||||
|
|
Loading…
Reference in New Issue