forked from mindspore-Ecosystem/mindspore
add new GPU kernel ReciprocalGrad
This commit is contained in:
parent
4cd49ed77e
commit
88b5458f78
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,6 +26,7 @@ __global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const s
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -37,6 +38,7 @@ __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AsinGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -46,6 +48,7 @@ __global__ void AsinGradKernel(const T *input, const T *dout, T *output, const s
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void AsinGradKernel(const half *input, const half *dout, half *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -55,6 +58,7 @@ __global__ void AsinGradKernel(const half *input, const half *dout, half *output
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ACosGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -65,6 +69,7 @@ __global__ void ACosGradKernel(const T *input, const T *dout, T *output, const s
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void ACosGradKernel(const half *input, const half *dout, half *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -75,6 +80,7 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -84,6 +90,7 @@ __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const s
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -93,6 +100,7 @@ __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
|
@ -102,11 +110,24 @@ __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ReciprocalGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
float inputf = static_cast<float>(input[i]);
|
||||
float doutf = static_cast<float>(dout[i]);
|
||||
float res = -1 * doutf * inputf * inputf;
|
||||
output[i] = static_cast<T>(res);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RsqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
|
@ -143,20 +164,28 @@ void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cud
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
ReciprocalGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
cudaStream_t cuda_stream);
|
||||
template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
cudaStream_t cuda_stream);
|
||||
template void AtanGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinhGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
cudaStream_t cuda_stream);
|
||||
template void AcoshGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ReciprocalGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
|
@ -164,10 +193,12 @@ template void RsqrtGrad<half>(const half *input, const half *dout, half *output,
|
|||
template void AsinGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
cudaStream_t cuda_stream);
|
||||
template void AtanGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void AsinhGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
cudaStream_t cuda_stream);
|
||||
template void AcoshGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void ReciprocalGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,6 +32,7 @@ template <typename T>
|
|||
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -74,5 +74,13 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
AcoshGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReciprocalGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryGradOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
ReciprocalGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,12 +35,14 @@ enum UnaryGradOptype {
|
|||
UNARY_OP_ATAN_GRAD = 4,
|
||||
UNARY_OP_ASINH_GRAD = 5,
|
||||
UNARY_OP_ACOSH_GRAD = 6,
|
||||
UNARY_OP_RECIPROCAL_GRAD = 7,
|
||||
UNARY_OP_GRAD_INVALID_TYPE = 255
|
||||
};
|
||||
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {
|
||||
{"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD},
|
||||
{"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD},
|
||||
{"AcoshGrad", UNARY_OP_ACOSH_GRAD}};
|
||||
{"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD},
|
||||
{"AsinGrad", UNARY_OP_ASIN_GRAD}, {"ACosGrad", UNARY_OP_ACOS_GRAD},
|
||||
{"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD},
|
||||
{"AcoshGrad", UNARY_OP_ACOSH_GRAD}, {"ReciprocalGrad", UNARY_OP_RECIPROCAL_GRAD}};
|
||||
|
||||
template <typename T>
|
||||
class UnaryGradOpGpuKernel : public GpuKernel {
|
||||
|
@ -101,6 +103,11 @@ class UnaryGradOpGpuKernel : public GpuKernel {
|
|||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_RECIPROCAL_GRAD: {
|
||||
ReciprocalGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary grad operation " << unary_grad_op_type_ << " is not supported.";
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -448,22 +448,11 @@ def get_bprop_rsqrt(self):
|
|||
@bprop_getters.register(P.Reciprocal)
|
||||
def get_bprop_reciprocal(self):
|
||||
"""Grad definition for `Reciprocal` operation."""
|
||||
if self.target == "GPU":
|
||||
neg = P.Neg()
|
||||
mul = P.Mul()
|
||||
square = P.Square()
|
||||
reciprocal = P.Reciprocal()
|
||||
reciprocal_grad = G.ReciprocalGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
g = neg(reciprocal(square(x)))
|
||||
dx = mul(dout, g)
|
||||
return (dx,)
|
||||
else:
|
||||
reciprocal_grad = G.ReciprocalGrad()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = reciprocal_grad(out, dout)
|
||||
return (dx,)
|
||||
def bprop(x, out, dout):
|
||||
dx = reciprocal_grad(out, dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class NetReciprocalGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetReciprocalGrad, self).__init__()
|
||||
self.grad = G.ReciprocalGrad()
|
||||
|
||||
def construct(self, y, dy):
|
||||
return self.grad(y, dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reciprocal_grad_float32():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
y = Tensor(np.array([[[[-1, 1, 12],
|
||||
[5, 34, 6],
|
||||
[10, 2, -1]]]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[[[29, 1, 55],
|
||||
[2.2, 63, 2],
|
||||
[3, 3, 12]]]]).astype(np.float32))
|
||||
expect = np.array([[[[-29, -1, -7920],
|
||||
[-55, -72828, -72],
|
||||
[-300, -12, -12]]]]).astype(np.float32)
|
||||
net = NetReciprocalGrad()
|
||||
output = net(y, dy)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
y = Tensor(np.array([[[[-1, 1, 12],
|
||||
[5, 34, 6],
|
||||
[10, 2, -1]]]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[[[29, 1, 55],
|
||||
[2.2, 63, 2],
|
||||
[3, 3, 12]]]]).astype(np.float32))
|
||||
expect = np.array([[[[-29, -1, -7920],
|
||||
[-55, -72828, -72],
|
||||
[-300, -12, -12]]]]).astype(np.float32)
|
||||
net = NetReciprocalGrad()
|
||||
output = net(y, dy)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reciprocal_grad_float16():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
y = Tensor(np.array([[0.01, 0.2, 0.22],
|
||||
[10.002, 2, -1]]).astype(np.float16))
|
||||
dy = Tensor(np.array([[34, 1, 55],
|
||||
[3, 3, 63]]).astype(np.float16))
|
||||
expect = np.array([[-0.0034, -0.03998, -2.662],
|
||||
[-300, -12, -63]]).astype(np.float16)
|
||||
net = NetReciprocalGrad()
|
||||
output = net(y, dy)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
y = Tensor(np.array([[0.01, 0.2, 0.22],
|
||||
[10.002, 2, -1]]).astype(np.float16))
|
||||
dy = Tensor(np.array([[34, 1, 55],
|
||||
[3, 3, 63]]).astype(np.float16))
|
||||
expect = np.array([[-0.0034, -0.03998, -2.662],
|
||||
[-300, -12, -63]]).astype(np.float16)
|
||||
net = NetReciprocalGrad()
|
||||
output = net(y, dy)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
Loading…
Reference in New Issue