pynative add identity primitive and add comment for set_grad

This commit is contained in:
kpy 2020-08-21 16:36:52 +08:00
parent 492e41a4af
commit 4fa89408a1
4 changed files with 30 additions and 4 deletions

View File

@ -853,8 +853,15 @@ class Cell:
self.add_flags_recursive(**flags)
return self
def set_grad(self, mode=True):
self.requires_grad = mode
def set_grad(self, requires_grad=True):
"""
Sets the cell flag for gradient.
Args:
requires_grad (bool): Specifies if the net need to grad, if it is
True, cell will construct backward network in pynative mode. Default: True.
"""
self.requires_grad = requires_grad
return self
def set_train(self, mode=True):

View File

@ -82,6 +82,7 @@ pack = P.Pack()
partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
identity = P.identity()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
@ -135,7 +136,6 @@ broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
zeros_like = P.ZerosLike()
identity = Primitive('identity')
distribute = Primitive('distribute')
embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed()

View File

@ -83,7 +83,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
@ -266,6 +266,7 @@ __all__ = [
'MakeRefKey',
'Partial',
'Depend',
'identity',
'AvgPool',
# Back Primitive
'Equal',

View File

@ -560,3 +560,21 @@ class Pull(PrimitiveWithInfer):
def infer_dtype(self, key_dtype, weight_dtype):
return mstype.float32
class identity(Primitive):
"""
Make a identify primitive, used for pynative mode.
Inputs:
- **x** (Any) - identity input value.
Outputs:
The same as input.
"""
@prim_attr_register
def __init__(self):
pass
def __call__(self, x):
return x