!36926 Support high grad of BiasAdd

Merge pull request !36926 from luoyang/biasadd_high_grad
This commit is contained in:
i-robot 2022-07-04 13:46:42 +00:00 committed by Gitee
commit 961a8ebca0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 244 additions and 11 deletions

View File

@ -39,6 +39,31 @@ def get_bprop_bias_add(self):
return bprop
@bprop_getters.register(G.BiasAddGrad)
def get_bprop_bias_add_grad(self):
"""Grad definition for `BiasAddGrad` operation."""
def bprop(x, out, dout):
get_shape = P.Shape()
concat = P.Concat(axis=0)
reshape = P.Reshape()
tile = P.Tile()
shape = get_shape(x)
bias_shape = get_shape(dout)
if self.data_format == "NCHW":
expand_shape = concat((P.ones_like(shape[:1], bias_shape, P.ones_like(shape[2:]))))
tile_mults = concat((shape[:1], [1], shape[2:]))
else:
expand_shape = concat((P.ones_like(shape[:-1], bias_shape)))
tile_mults = concat((shape[:1], [1]))
expand_grad = reshape(dout, expand_shape)
return tile(expand_grad, tile_mults)
return bprop
@bprop_getters.register(P.Conv2D)
def get_bprop_conv2d(self):
"""Grad definition for `Conv2D` operation."""

View File

@ -21,7 +21,7 @@ bprop.13:y*
bprop.13:keep_prob*
bprop.13:out*
bprop.13:dout2
bprop.13:[CNode]17:6:@c3b784055020f26853bf8810898ed0dc0317483adac76918f3b8ef6ff639fae5PbH
bprop.13:[CNode]17:6:@3bb59bfb3a99678d99fd070f9d30ad071146cb47ad7351bac84954974a37c85cPbH
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]b.
S-Prim-DropoutDoMask:2S-Prim-DropoutDoMaskb&
S-Prim-MakeTuple:7S-Prim-MakeTupleh

View File

@ -11,6 +11,6 @@
bprop.3:keep_prob*
bprop.3:out*
bprop.3:dout2
bprop.3:[CNode]6:4:@c3b784055020f26853bf8810898ed0dc0317483adac76918f3b8ef6ff639fae5Pb&
bprop.3:[CNode]6:4:@3bb59bfb3a99678d99fd070f9d30ad071146cb47ad7351bac84954974a37c85cPb&
S-Prim-MakeTuple:5S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h

View File

@ -19,6 +19,6 @@
bprop.7:off_value*
bprop.7:out*
bprop.7:dout2
bprop.7:[CNode]12:6:@c3b784055020f26853bf8810898ed0dc0317483adac76918f3b8ef6ff639fae5Pb&
S-Prim-MakeTuple:7S-Prim-MakeTuplebH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]h
bprop.7:[CNode]12:6:@3bb59bfb3a99678d99fd070f9d30ad071146cb47ad7351bac84954974a37c85cPbH
#S-Prim-hyper_map[zeros_like_leaf]:2!S-Prim-hyper_map[zeros_like_leaf]b&
S-Prim-MakeTuple:7S-Prim-MakeTupleh

View File

@ -8,9 +8,9 @@ m
bprop.1:x*
bprop.1:out*
bprop.1:dout2
bprop.1:[CNode]2:3:@c3b784055020f26853bf8810898ed0dc0317483adac76918f3b8ef6ff639fae5Pb&
S-Prim-MakeTuple:4S-Prim-MakeTuplebr
bprop.1:[CNode]2:3:@3bb59bfb3a99678d99fd070f9d30ad071146cb47ad7351bac84954974a37c85cPbr
S-Prim-ReluGrad:2S-Prim-ReluGrad
output_names€Š Zoutput€+
input_names€ŠZ
y_backprop€ŠZx€h
y_backprop€ŠZx€b&
S-Prim-MakeTuple:4S-Prim-MakeTupleh

View File

@ -15,10 +15,10 @@ bprop.18:ybprop.18:[CNode]19:3bprop.18:[CNode]19:3"(REF::S-Prim-hyper_map[ze
bprop.18:y*
bprop.18:out*
bprop.18:dout2
bprop.18:[CNode]20:5:@c3b784055020f26853bf8810898ed0dc0317483adac76918f3b8ef6ff639fae5Pb&
S-Prim-MakeTuple:6S-Prim-MakeTuplebH
bprop.18:[CNode]20:5:@3bb59bfb3a99678d99fd070f9d30ad071146cb47ad7351bac84954974a37c85cPbH
#S-Prim-hyper_map[zeros_like_leaf]:4!S-Prim-hyper_map[zeros_like_leaf]br
S-Prim-ReluGrad:2S-Prim-ReluGrad
output_names€Š Zoutput€+
input_names€ŠZ
y_backprop€ŠZx€h
y_backprop€ŠZx€b&
S-Prim-MakeTuple:6S-Prim-MakeTupleh

View File

@ -0,0 +1,208 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore.ops.composite import GradOperation
from mindspore.common import ParameterTuple
TF_INSTALL_FLG = 1
try:
import tensorflow as tf
except ImportError:
TF_INSTALL_FLG = 0
context.set_context(device_target="Ascend")
class _Grad(Cell):
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
super().__init__()
self.network = network
self.grad = grad
self.sens_param = self.grad.sens_param
self.wrt_params = wrt_params
self.real_inputs_count = real_inputs_count
if self.wrt_params:
self.params = ParameterTuple(self.network.trainable_params())
def construct(self, *inputs):
if self.wrt_params:
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network, self.params)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network)(*real_inputs, sense_param_inputs)
class GradOfAllInputs(_Grad):
"""
get grads of all inputs
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
class HighGrad(Cell):
"""
get any order of grad
"""
def __init__(self, network, grad_list, sens_param=False, real_inputs_count=None):
super().__init__()
self.grads = [network]
for i in range(len(grad_list)-1):
_grad = grad_list[i](self.grads[i], sens_param=False)
self.grads.append(_grad)
self.final_grad = grad_list[-1](self.grads[-1],
sens_param=sens_param, real_inputs_count=real_inputs_count)
def construct(self, *inputs):
return self.final_grad(*inputs)
class BiasAdd(nn.Cell):
def __init__(self, ms_format):
super().__init__()
self.op = P.BiasAdd(ms_format)
def construct(self, x, b):
return self.op(x, b)
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_me) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)):
assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
count_unequal_element(data_expected, data_me, rtol, atol)
else:
assert np.array(data_expected).shape == np.array(data_me).shape
class TestEntry:
def __init__(self, input_x_np, dtype, ms_format, tf_format):
self.input_x_np = input_x_np
self.dtype = dtype
self.ms_format = ms_format
self.tf_format = tf_format
if self.dtype == np.float16:
self.loss = 1e-3
elif self.dtype == np.float32:
self.loss = 1e-4
elif self.dtype == np.float64:
self.loss = 1e-5
elif self.dtype == np.complex64:
self.loss = 2e-6
elif self.dtype == np.complex128:
self.loss = 2e-10
else:
self.loss = 0
def highgrad_mindspore_impl(self):
x = Tensor(self.input_x_np[0].copy().astype(self.dtype))
b = Tensor(self.input_x_np[1].copy().astype(self.dtype))
net = BiasAdd(ms_format=self.ms_format)
grad_net = HighGrad(net, grad_list=[GradOfAllInputs, GradOfAllInputs])
y = grad_net(x, b)
return y
def highgrad_tensorflow_impl(self):
x = tf.Variable(self.input_x_np[0].copy().astype(self.dtype))
b = tf.Variable(self.input_x_np[1].copy().astype(self.dtype))
with tf.GradientTape(persistent=True) as tape:
y = tf.nn.bias_add(x, b, data_format=self.tf_format)
dydx, dydb = tape.gradient(y, [x, b])
ddx, ddb = tape.gradient([dydx, dydb], [x, b], unconnected_gradients=tf.UnconnectedGradients.ZERO)
return ddx, ddb
def highgrad_cmp(self):
out_ms = self.highgrad_mindspore_impl()
if TF_INSTALL_FLG == 1:
out_tf = self.highgrad_tensorflow_impl()
else:
out_tf = []
out_tf.append(np.zeros_like(self.input_x_np[0]))
out_tf.append(np.zeros_like(self.input_x_np[1]))
allclose_nparray(out_tf[0], out_ms[0].asnumpy(), self.loss, self.loss)
allclose_nparray(out_tf[1], out_ms[1].asnumpy(), self.loss, self.loss)
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_biasadd_high_grad_dim2_float16():
"""
Feature: Biasadd Grad Grad operation
Description: test the high grad of Rsqrt. Input tensor has 2 dims, float16 type.
Expectation: the output is same with tensorflow
"""
test = TestEntry(input_x_np=[np.arange(1, 7).reshape((2, 3)), np.ones(shape=(3,))],
dtype=np.float16, ms_format="NCHW", tf_format="NCHW")
test.highgrad_cmp()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_biasadd_high_grad_dim4_float32():
"""
Feature: Biasadd Grad Grad operation
Description: test the high grad of Rsqrt. Input tensor has 4 dims, float32 type.
Expectation: the output is same with tensorflow
"""
test = TestEntry(input_x_np=[np.random.randn(3, 2, 3, 3), np.ones(shape=(2,))],
dtype=np.float32, ms_format="NCHW", tf_format="NCHW")
test.highgrad_cmp()
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_biasadd_high_grad_dim5_float64():
"""
Feature: Biasadd Grad Grad operation
Description: test the high grad of Rsqrt. Input tensor has 5 dims, float64 type.
Expectation: the output is same with tensorflow
"""
test = TestEntry(input_x_np=[np.random.randn(1, 5, 2, 2, 2), np.ones(shape=(5,))],
dtype=np.float64, ms_format="NCDHW", tf_format="NCDHW")
test.highgrad_cmp()