forked from OSSInnovation/mindspore
new add sqrt_grad and rsqrt_grad.
This commit is contained in:
parent
32361ae2ed
commit
dda3176fca
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "unary_op_grad_impl.cuh"
|
||||
template <typename T>
|
||||
__global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float input_f = static_cast<float>(input[i]);
|
||||
float dout_f = static_cast<float>(dout[i]);
|
||||
float res_vmul = dout_f / (2.0 * input_f);
|
||||
output[i] = static_cast<T>(res_vmul);
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
|
||||
float input_f = static_cast<float>(input[i]);
|
||||
float dout_f = static_cast<float>(dout[i]);
|
||||
float res_vmul = input_f * input_f * input_f;
|
||||
res_vmul = -0.5 * res_vmul * dout_f;
|
||||
output[i] = static_cast<T>(res_vmul);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
|
||||
RsqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
|
||||
return;
|
||||
}
|
||||
|
||||
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
||||
template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SqrtGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryGradOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SqrtGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
RsqrtGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnaryGradOpGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
RsqrtGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnaryGradOpGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum UnaryGradOptype { UNARY_OP_SQRT_GRAD = 0, UNARY_OP_RSQRT_GRAD, UNARY_OP_GRAD_INVALID_TYPE = 255 };
|
||||
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD},
|
||||
{"RsqrtGrad", UNARY_OP_RSQRT_GRAD}};
|
||||
template <typename T>
|
||||
class UnaryGradOpGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UnaryGradOpGpuKernel()
|
||||
: unary_grad_op_type_(UNARY_OP_GRAD_INVALID_TYPE),
|
||||
input_size_(sizeof(T)),
|
||||
dx_size_(sizeof(T)),
|
||||
output_size_(sizeof(T)),
|
||||
workspace_size_(0),
|
||||
is_null_input_(false) {}
|
||||
~UnaryGradOpGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input_x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *input_dx_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output_y_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
switch (unary_grad_op_type_) {
|
||||
case UNARY_OP_SQRT_GRAD: {
|
||||
SqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
case UNARY_OP_RSQRT_GRAD: {
|
||||
RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Unary grad operation " << unary_grad_op_type_ << " is not supported.";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto iter = kUnaryGradOpTypeMap.find(kernel_name);
|
||||
if (iter == kUnaryGradOpTypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unary grad operation " << kernel_name << " is not supported.";
|
||||
} else {
|
||||
unary_grad_op_type_ = iter->second;
|
||||
}
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but unary grad op needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but unary grad op needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "UnaryGradOpGpuKernel input 0 is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
}
|
||||
auto dx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
is_null_input_ = CHECK_NULL_INPUT(dx_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "UnaryGradOpGpuKernel input 1 is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < dx_shape.size(); i++) {
|
||||
dx_size_ *= dx_shape[i];
|
||||
}
|
||||
if (input_size_ != dx_size_) {
|
||||
MS_LOG(WARNING) << "UnaryGradOpGpuKernel inputs should be same, but got " << input_size_ << " and " << dx_size_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
output_size_ = input_size_;
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(dx_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
UnaryGradOptype unary_grad_op_type_;
|
||||
size_t input_size_;
|
||||
size_t dx_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
bool is_null_input_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNARYOP_GRAD_GPU_KERNEL_H_
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class NetRsqrtGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetRsqrtGrad, self).__init__()
|
||||
self.rsqrt_grad = G.RsqrtGrad()
|
||||
|
||||
def construct(self, x, dx):
|
||||
return self.rsqrt_grad(x, dx)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_rsqrt_grad():
|
||||
x = Tensor(np.array([[[[-1, 1, 10],
|
||||
[5.9, 6.1, 6],
|
||||
[10, 1, -1]]]]).astype(np.float32))
|
||||
dx = Tensor(np.array([[[[1, 1, 1],
|
||||
[2, 2, 2],
|
||||
[3, 3, 3]]]]).astype(np.float32))
|
||||
expect = np.array([[[[0.5, -0.5, -500,],
|
||||
[-205.37901, -226.98099, -216],
|
||||
[-1500, -1.5, 1.5,]]]]).astype(np.float32)
|
||||
error = np.ones(shape=[3, 3]) * 1.0e-6
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
rsqrt_grad = NetRsqrtGrad()
|
||||
output = rsqrt_grad(x, dx)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(np.abs(diff) < error)
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class NetSqrtGrad(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetSqrtGrad, self).__init__()
|
||||
self.sqrt_grad = G.SqrtGrad()
|
||||
|
||||
def construct(self, x, dx):
|
||||
return self.sqrt_grad(x, dx)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sqrt_grad():
|
||||
x = Tensor(np.array([[[[-1, 1, 10],
|
||||
[5.9, 6.1, 6],
|
||||
[10, 1, -1]]]]).astype(np.float32))
|
||||
dx = Tensor(np.array([[[[1, 1, 1],
|
||||
[2, 2, 2],
|
||||
[3, 3, 3]]]]).astype(np.float32))
|
||||
expect = np.array([[[[-0.5, 0.5, 0.05,],
|
||||
[0.16949153, 0.16393442, 0.16666667,],
|
||||
[0.15, 1.5, -1.5,]]]]).astype(np.float32)
|
||||
error = np.ones(shape=[3, 3]) * 1.0e-6
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
sqrt_grad = NetSqrtGrad()
|
||||
output = sqrt_grad(x, dx)
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(np.abs(diff) < error)
|
Loading…
Reference in New Issue