!37127 Support high grad of StridedSlice
Merge pull request !37127 from luoyang/stridedslices_high_grad
This commit is contained in:
commit
2a953992f5
|
@ -722,6 +722,21 @@ def get_bprop_strided_slice(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(G.StridedSliceGrad)
|
||||
def get_bprop_strided_slice_grad(self):
|
||||
"""Generate bprop for StridedSliceGrad"""
|
||||
strided_slice = P.StridedSlice(begin_mask=self.begin_mask,
|
||||
end_mask=self.end_mask,
|
||||
ellipsis_mask=self.ellipsis_mask,
|
||||
new_axis_mask=self.new_axis_mask,
|
||||
shrink_axis_mask=self.shrink_axis_mask)
|
||||
def bprop(dy, shapex, begin, end, strides, out, dout):
|
||||
return strided_slice(dout, begin, end, strides), zeros_like(shapex), zeros_like(begin), zeros_like(end), \
|
||||
zeros_like(strides)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.Eye)
|
||||
def get_bprop_eye(self):
|
||||
"""Generate bprop for Eye"""
|
||||
|
|
|
@ -9,6 +9,6 @@ y
|
|||
bprop.12:x*
|
||||
bprop.12:out*
|
||||
bprop.12:dout2
|
||||
bprop.12:[CNode]14:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.12:[CNode]14:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.15:x*
|
||||
bprop.15:out*
|
||||
bprop.15:dout2
|
||||
bprop.15:[CNode]17:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
|
||||
bprop.15:[CNode]17:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.18:x*
|
||||
bprop.18:out*
|
||||
bprop.18:dout2
|
||||
bprop.18:[CNode]20:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.18:[CNode]20:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.24:x*
|
||||
bprop.24:out*
|
||||
bprop.24:dout2
|
||||
bprop.24:[CNode]26:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
|
||||
bprop.24:[CNode]26:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -5,5 +5,5 @@ m
|
|||
bprop.1:x*
|
||||
bprop.1:out*
|
||||
bprop.1:dout2
|
||||
bprop.1:[CNode]2:1:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
|
||||
bprop.1:[CNode]2:1:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
|
||||
S-Prim-MakeTuple:2S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ s
|
|||
bprop.6:x*
|
||||
bprop.6:out*
|
||||
bprop.6:dout2
|
||||
bprop.6:[CNode]8:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
|
||||
bprop.6:[CNode]8:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ s
|
|||
bprop.3:x*
|
||||
bprop.3:out*
|
||||
bprop.3:dout2
|
||||
bprop.3:[CNode]5:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.3:[CNode]5:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.27:x*
|
||||
bprop.27:out*
|
||||
bprop.27:dout2
|
||||
bprop.27:[CNode]29:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
||||
bprop.27:[CNode]29:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -27,9 +27,9 @@ bprop.30:x*
|
|||
bprop.30:y*
|
||||
bprop.30:out*
|
||||
bprop.30:dout2
|
||||
bprop.30:[CNode]36:8:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pbv
|
||||
bprop.30:[CNode]36:8:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:9S-Prim-MakeTuplebv
|
||||
S-Prim-Select:5
S-Prim-Select
|
||||
output_names€ŠZoutput€3
|
||||
input_names€ŠZ condition€ŠZx€ŠZy€bH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:9S-Prim-MakeTupleh
|
||||
input_names€ŠZ condition€ŠZx€ŠZy€h
|
|
@ -9,6 +9,6 @@ z
|
|||
bprop.21:x*
|
||||
bprop.21:out*
|
||||
bprop.21:dout2
|
||||
bprop.21:[CNode]23:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
||||
bprop.21:[CNode]23:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTupleh
|
|
@ -7,6 +7,6 @@ v
|
|||
bprop.9:x*
|
||||
bprop.9:out*
|
||||
bprop.9:dout2
|
||||
bprop.9:[CNode]11:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
|
||||
bprop.9:[CNode]11:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
|
||||
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
|
||||
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
|
||||
class NetGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetGrad, self).__init__()
|
||||
self.grad = G.StridedSliceGrad()
|
||||
|
||||
def construct(self, x, begin, end, strides, dout):
|
||||
return self.grad(x, begin, end, strides, dout)
|
||||
|
||||
|
||||
class NetGradGrad(nn.Cell):
|
||||
def __init__(self, forward_net):
|
||||
super(NetGradGrad, self).__init__()
|
||||
self.forward_net = forward_net
|
||||
self.grad_ops = C.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, dy, shapex, begin, end, strides, dout):
|
||||
backward_net = self.grad_ops(self.forward_net)
|
||||
return backward_net(dy, shapex, begin, end, strides, dout)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_stridedslice_high_grad_float32():
|
||||
"""
|
||||
Feature: StridedSlice Grad Grad operation
|
||||
Description: test the grad of StridedSliceGrad kernel, with float input.
|
||||
Expectation: the output is same with numpy
|
||||
"""
|
||||
x = np.array([[[1, 1, 1], [2, 2, 2]],
|
||||
[[3, 3, 3], [4, 4, 4]],
|
||||
[[5, 5, 5], [6, 6, 6]]]).astype(np.float32)
|
||||
|
||||
dy = Tensor(np.ones((2, 1, 1)).astype(np.float32))
|
||||
x_shape = Tensor(np.array(list(x.shape)).astype(np.int64))
|
||||
begin = (1, 0, 2)
|
||||
end = (3, 1, 3)
|
||||
strides = (1, 1, 1)
|
||||
dout = np.ones_like(x).astype(np.float32)
|
||||
|
||||
grad_net = NetGrad()
|
||||
grad_grad_net = NetGradGrad(grad_net)
|
||||
dgrad_ms = grad_grad_net(dy, x_shape, begin, end, strides, Tensor(dout))
|
||||
|
||||
stridedslice = P.StridedSlice()
|
||||
forward_res = stridedslice(Tensor(x), begin, end, strides)
|
||||
expected = np.ones_like(forward_res.asnumpy())
|
||||
assert np.allclose(dgrad_ms[0].asnumpy(), expected, 1e-4, 1e-4)
|
Loading…
Reference in New Issue