forked from mindspore-Ecosystem/mindspore
pynative add identity primitive and add comment for set_grad
This commit is contained in:
parent
492e41a4af
commit
4fa89408a1
|
@ -853,8 +853,15 @@ class Cell:
|
|||
self.add_flags_recursive(**flags)
|
||||
return self
|
||||
|
||||
def set_grad(self, mode=True):
|
||||
self.requires_grad = mode
|
||||
def set_grad(self, requires_grad=True):
|
||||
"""
|
||||
Sets the cell flag for gradient.
|
||||
|
||||
Args:
|
||||
requires_grad (bool): Specifies if the net need to grad, if it is
|
||||
True, cell will construct backward network in pynative mode. Default: True.
|
||||
"""
|
||||
self.requires_grad = requires_grad
|
||||
return self
|
||||
|
||||
def set_train(self, mode=True):
|
||||
|
|
|
@ -82,6 +82,7 @@ pack = P.Pack()
|
|||
partial = P.Partial()
|
||||
# depend: mount a node to another node
|
||||
depend = P.Depend()
|
||||
identity = P.identity()
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
@ -135,7 +136,6 @@ broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
|||
dot = Primitive('dot')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
zeros_like = P.ZerosLike()
|
||||
identity = Primitive('identity')
|
||||
distribute = Primitive('distribute')
|
||||
embed = Primitive('embed')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
|
|
|
@ -83,7 +83,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
|
|||
from . import _quant_ops
|
||||
from ._quant_ops import *
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
||||
CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull)
|
||||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight,
|
||||
|
@ -266,6 +266,7 @@ __all__ = [
|
|||
'MakeRefKey',
|
||||
'Partial',
|
||||
'Depend',
|
||||
'identity',
|
||||
'AvgPool',
|
||||
# Back Primitive
|
||||
'Equal',
|
||||
|
|
|
@ -560,3 +560,21 @@ class Pull(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, key_dtype, weight_dtype):
|
||||
return mstype.float32
|
||||
|
||||
class identity(Primitive):
|
||||
"""
|
||||
Make a identify primitive, used for pynative mode.
|
||||
|
||||
Inputs:
|
||||
- **x** (Any) - identity input value.
|
||||
|
||||
Outputs:
|
||||
The same as input.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue