!37127 Support high grad of StridedSlice

Merge pull request !37127 from luoyang/stridedslices_high_grad
This commit is contained in:
i-robot 2022-07-06 01:12:28 +00:00 committed by Gitee
commit 2a953992f5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
16 changed files with 114 additions and 24 deletions

View File

@ -722,6 +722,21 @@ def get_bprop_strided_slice(self):
return bprop
@bprop_getters.register(G.StridedSliceGrad)
def get_bprop_strided_slice_grad(self):
"""Generate bprop for StridedSliceGrad"""
strided_slice = P.StridedSlice(begin_mask=self.begin_mask,
end_mask=self.end_mask,
ellipsis_mask=self.ellipsis_mask,
new_axis_mask=self.new_axis_mask,
shrink_axis_mask=self.shrink_axis_mask)
def bprop(dy, shapex, begin, end, strides, out, dout):
return strided_slice(dout, begin, end, strides), zeros_like(shapex), zeros_like(begin), zeros_like(end), \
zeros_like(strides)
return bprop
@bprop_getters.register(P.Eye)
def get_bprop_eye(self):
"""Generate bprop for Eye"""

View File

@ -9,6 +9,6 @@ y
bprop.12:x*
bprop.12:out*
bprop.12:dout2
bprop.12:[CNode]14:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.12:[CNode]14:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.15:x*
bprop.15:out*
bprop.15:dout2
bprop.15:[CNode]17:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
bprop.15:[CNode]17:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.18:x*
bprop.18:out*
bprop.18:dout2
bprop.18:[CNode]20:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.18:[CNode]20:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -9,6 +9,6 @@ z
bprop.24:x*
bprop.24:out*
bprop.24:dout2
bprop.24:[CNode]26:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
bprop.24:[CNode]26:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -5,5 +5,5 @@ m
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:1:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
bprop.1:[CNode]2:1:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
S-Prim-MakeTuple:2S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ s
bprop.6:x*
bprop.6:out*
bprop.6:dout2
bprop.6:[CNode]8:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
bprop.6:[CNode]8:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ s
bprop.3:x*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]5:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.3:[CNode]5:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -9,6 +9,6 @@ z
bprop.27:x*
bprop.27:out*
bprop.27:dout2
bprop.27:[CNode]29:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1PbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh
bprop.27:[CNode]29:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -27,9 +27,9 @@ bprop.30:x*
bprop.30:y*
bprop.30:out*
bprop.30:dout2
bprop.30:[CNode]36:8:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pbv
bprop.30:[CNode]36:8:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:9S-Prim-MakeTuplebv
S-Prim-Select:5 S-Prim-Select
output_names€Š Zoutput€3
input_names€ŠZ condition€ŠZx€ŠZy€bH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:9S-Prim-MakeTupleh
input_names€ŠZ condition€ŠZx€ŠZy€h

View File

@ -9,6 +9,6 @@ z
bprop.21:x*
bprop.21:out*
bprop.21:dout2
bprop.21:[CNode]23:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.21:[CNode]23:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -7,6 +7,6 @@ v
bprop.9:x*
bprop.9:out*
bprop.9:dout2
bprop.9:[CNode]11:3:@c0204b943f5430ecc5fb77e4a6d67388ec50534279bc72793d0a997633157be1Pb&
bprop.9:[CNode]11:3:@2d3dda8f3c50b46f8e44221415a63b0943b72687a88d69dd5e8036f528f5e18aPb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -0,0 +1,75 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops import composite as C
from mindspore.ops import operations as P
context.set_context(device_target="Ascend")
class NetGrad(nn.Cell):
def __init__(self):
super(NetGrad, self).__init__()
self.grad = G.StridedSliceGrad()
def construct(self, x, begin, end, strides, dout):
return self.grad(x, begin, end, strides, dout)
class NetGradGrad(nn.Cell):
def __init__(self, forward_net):
super(NetGradGrad, self).__init__()
self.forward_net = forward_net
self.grad_ops = C.GradOperation(get_all=True, sens_param=True)
def construct(self, dy, shapex, begin, end, strides, dout):
backward_net = self.grad_ops(self.forward_net)
return backward_net(dy, shapex, begin, end, strides, dout)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_stridedslice_high_grad_float32():
"""
Feature: StridedSlice Grad Grad operation
Description: test the grad of StridedSliceGrad kernel, with float input.
Expectation: the output is same with numpy
"""
x = np.array([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]).astype(np.float32)
dy = Tensor(np.ones((2, 1, 1)).astype(np.float32))
x_shape = Tensor(np.array(list(x.shape)).astype(np.int64))
begin = (1, 0, 2)
end = (3, 1, 3)
strides = (1, 1, 1)
dout = np.ones_like(x).astype(np.float32)
grad_net = NetGrad()
grad_grad_net = NetGradGrad(grad_net)
dgrad_ms = grad_grad_net(dy, x_shape, begin, end, strides, Tensor(dout))
stridedslice = P.StridedSlice()
forward_res = stridedslice(Tensor(x), begin, end, strides)
expected = np.ones_like(forward_res.asnumpy())
assert np.allclose(dgrad_ms[0].asnumpy(), expected, 1e-4, 1e-4)