forked from mindspore-Ecosystem/mindspore
!9124 fix grad of the Identity
From: @yanzhenxiang2020 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
756594d000
|
@ -456,6 +456,16 @@ def get_bprop_sparse_gather_v2(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Identity)
|
||||
def get_bprop_identity(self):
|
||||
"""Generate bprop for Identity"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
return (dout,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.Range)
|
||||
def get_bprop_range(self):
|
||||
"""Generate bprop for Range"""
|
||||
|
|
|
@ -4105,14 +4105,14 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
Tensors, A Tuple of N N-D Tensor objects.
|
||||
|
||||
Examples:
|
||||
>>> x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
>>> y = np.array([5, 6, 7]).astype(np.int32)
|
||||
>>> z = np.array([8, 9, 0, 1, 2]).astype(np.int32)
|
||||
>>> x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
|
||||
>>> y = Tensor(np.array([5, 6, 7]).astype(np.int32))
|
||||
>>> z = Tensor(np.array([8, 9, 0, 1, 2]).astype(np.int32))
|
||||
>>> inputs = (x, y, z)
|
||||
>>> meshgrid = ops.Meshgrid(indexing="xy")
|
||||
>>> output = meshgrid(inputs)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
(Tensor(shape=[3, 4, 6], dtype=Int32, value=
|
||||
[[[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
|
@ -4125,7 +4125,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
[4, 4, 4, 4, 4]]]),
|
||||
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
Tensor(shape=[3, 4, 6], dtype=Int32, value=
|
||||
[[[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5],
|
||||
|
@ -4138,7 +4138,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7]]]),
|
||||
Tensor(shape=[3, 4, 6], dtype=UInt32, value=
|
||||
Tensor(shape=[3, 4, 6], dtype=Int32, value=
|
||||
[[[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
|
@ -4612,6 +4612,8 @@ class Identity(PrimitiveWithInfer):
|
|||
"""Initialize identity"""
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
validator.check_tensor_dtype_valid('x', x['dtype'], mstype.number_type + (mstype.bool_,), self.name)
|
||||
out = {'shape': x['shape'],
|
||||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
|
|
Loading…
Reference in New Issue