forked from mindspore-Ecosystem/mindspore
gpu support Gelu & GeluGrad kernels
This commit is contained in:
parent
d9dd6aa0b8
commit
a304304c30
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template<typename T>
|
||||
__global__ void GeluKernel(size_t size, T* input_addr, T* output_addr) {
|
||||
// formula:
|
||||
// gelu(x) = 0.5 * x * (1.0 + tanh(y))
|
||||
// tanh(y) = 2 / (1 + exp(-2y)) - 1)
|
||||
// y = sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
float x = input_addr[pos];
|
||||
float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x));
|
||||
output_addr[pos] = 0.5 * x * (1.0 + tanh_res);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Gelu(size_t size, T* input_addr, T* output_addr, cudaStream_t cuda_stream) {
|
||||
GeluKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, output_addr);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
__global__ void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr) {
|
||||
// formula:
|
||||
// dx = dy * y'
|
||||
// y' = 0.5 * (1 + tanh(tanh_para)) +
|
||||
// 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
||||
// tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3)
|
||||
// mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2))
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
T x = x_addr[pos];
|
||||
T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x));
|
||||
T mul_right = 0.7978845608 + 0.1070322244 * x * x;
|
||||
T y_res = 0.5 * (1 + tanh_res) + 0.5 * x * (1 - tanh_res * tanh_res) * mul_right;
|
||||
dx_addr[pos] = dy_addr[pos] * y_res;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream) {
|
||||
GeluGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy_addr, x_addr, dx_addr);
|
||||
}
|
||||
|
||||
|
||||
template void Gelu(size_t size, float* input_addr, float* output_addr, cudaStream_t cuda_stream);
|
||||
template void GeluGradKernel(size_t size, float* dy_addr, float* x_addr, float* dx_addr, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
template<typename T>
|
||||
void Gelu(size_t input_size, T* input_addr, T* output_addr, cudaStream_t cuda_stream);
|
||||
|
||||
template<typename T>
|
||||
void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/gelu_grad_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(GeluGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
GeLUGpuGradKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/kernel_constants.h"
|
||||
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class GeLUGpuGradKernel : public GpuKernel {
|
||||
public:
|
||||
GeLUGpuGradKernel() : input_size_(0) {}
|
||||
~GeLUGpuGradKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
GeluGradKernel(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
input_size_ = sizeof(T);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (auto dim : input_shape) {
|
||||
input_size_ *= dim;
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(input_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
size_t input_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel/gpu/nn/gelu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GeluGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "kernel/gpu/gpu_kernel.h"
|
||||
#include "kernel/gpu/gpu_kernel_factory.h"
|
||||
#include "kernel/gpu/kernel_constants.h"
|
||||
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class GeluGpuKernel : public GpuKernel {
|
||||
public:
|
||||
GeluGpuKernel() : input_size_(0) {}
|
||||
~GeluGpuKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
input_size_ = sizeof(T);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (auto dim : input_shape) {
|
||||
input_size_ *= dim;
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(input_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
size_t input_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
class GeluNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GeluNet, self).__init__()
|
||||
self.gelu = P.Gelu()
|
||||
|
||||
def construct(self, x):
|
||||
return self.gelu(x)
|
||||
|
||||
class Grad(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
|
||||
def construct(self, input_data, sens):
|
||||
gout = self.grad(self.network)(input_data, sens)
|
||||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gelugrad():
|
||||
x_ms = Tensor(np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501,
|
||||
0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32))
|
||||
dy_ms = Tensor(np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048,
|
||||
0.55681044, 0.966908, 0.06015943, 0.6099489 ]).astype(np.float32))
|
||||
|
||||
net = GeluNet()
|
||||
grad = Grad(net)
|
||||
|
||||
output = grad(x_ms, dy_ms)
|
||||
print(output)
|
||||
expect = [0.50963277, 0.9414753, 0.2667653, 0.21358444, 0.25243032, 0.0352667,
|
||||
0.34266686, 0.57757664, 0.04707306, 0.51536125]
|
||||
assert np.allclose(output[0].asnumpy(), expect)
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
class GeluNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GeluNet, self).__init__()
|
||||
self.gelu = P.Gelu()
|
||||
|
||||
def construct(self, x):
|
||||
return self.gelu(x)
|
||||
|
||||
|
||||
def GeluCompute(x):
|
||||
return 0.5 * x * (1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x * x * x)))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gelu_1d():
|
||||
x_np = np.random.random((50,)).astype(np.float32)
|
||||
y_np = GeluCompute(x_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
net = GeluNet()
|
||||
y_ms = net(x_ms)
|
||||
|
||||
assert np.allclose(y_np, y_ms.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gelu_2d():
|
||||
x_np = np.random.random((50, 40)).astype(np.float32)
|
||||
y_np = GeluCompute(x_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
net = GeluNet()
|
||||
y_ms = net(x_ms)
|
||||
|
||||
assert np.allclose(y_np, y_ms.asnumpy())
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gelu_4d():
|
||||
x_np = np.random.random((32, 3, 224, 224)).astype(np.float32)
|
||||
y_np = GeluCompute(x_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
net = GeluNet()
|
||||
y_ms = net(x_ms)
|
||||
|
||||
assert np.allclose(y_np, y_ms.asnumpy())
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gelu_neg():
|
||||
x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1
|
||||
y_np = GeluCompute(x_np)
|
||||
|
||||
x_ms = Tensor(x_np)
|
||||
net = GeluNet()
|
||||
y_ms = net(x_ms)
|
||||
|
||||
assert np.allclose(y_np, y_ms.asnumpy())
|
Loading…
Reference in New Issue