support PReLU for GPU platform

This commit is contained in:
xuguoyang 2021-04-13 17:13:37 +08:00
parent cac91018ad
commit 7d1a35bedb
7 changed files with 238 additions and 3 deletions

View File

@ -95,3 +95,19 @@ template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *ma
cudaStream_t cuda_stream);
template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
cudaStream_t cuda_stream);
template <typename T>
__global__ void CalPReLUKernel(int size, T *input_addr, T *weight_addr, T *output_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : *weight_addr * input_addr[pos];
}
}
template <typename T>
void CalPReLU(int size, T *input_addr, T *weight_addr, T *output_addr, cudaStream_t cuda_stream) {
CalPReLUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, weight_addr, output_addr);
return;
}
template void CalPReLU(int size, float *input_addr, float *weight_addr, float *output_addr, cudaStream_t cuda_stream);
template void CalPReLU(int size, half *input_addr, half *weight_addr, half *output_addr, cudaStream_t cuda_stream);

View File

@ -25,4 +25,7 @@ template <typename T>
void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream);
template <typename T>
void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
template <typename T>
void CalPReLU(int input_size, T *input_addr, T *weight_addr, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
PReLU,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PReLUGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
PReLU,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PReLUGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class PReLUGpuKernel : public GpuKernel {
public:
PReLUGpuKernel() { ResetResource(); }
~PReLUGpuKernel() override {}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *weight = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
const int size = input_size_ / sizeof(T);
CalPReLU(size, input, weight, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 2.";
return false;
}
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "PReLUGpuFwdKernel input is null.";
}
size_t size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
size *= input_shape[i];
}
input_size_ = size * sizeof(T);
auto weight_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(weight_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "PReLUGpuFwdKernel weight is null.";
}
size = 1;
for (size_t i = 0; i < weight_shape.size(); i++) {
size *= weight_shape[i];
}
weight_size_ = size * sizeof(T);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
input_size_ = 0;
workspace_size_ = 0;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
workspace_size_list_.push_back(workspace_size_);
}
private:
bool is_null_input_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t weight_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GPU_KERNEL_H_

View File

@ -548,7 +548,7 @@ class PReLU(Cell):
ValueError: If length of shape of `input_data` is equal to 1.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU``
Examples:
>>> input_x = Tensor(np.array([[[[0.1, 0.6], [0.9, 0.9]]]]), mindspore.float32)

View File

@ -1145,7 +1145,7 @@ class BatchNorm(PrimitiveWithInfer):
TypeError: If dtype of `input_x`, `scale` or `mean` is neither float16 nor float32.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
@ -3533,7 +3533,7 @@ class PReLU(PrimitiveWithInfer):
ValueError: If length of shape of `weight` is not equal to 1.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU``
Examples:
>>> import mindspore

View File

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetPReLU(nn.Cell):
def __init__(self):
super(NetPReLU, self).__init__()
self.prelu = P.PReLU()
def construct(self, x, weight):
return self.prelu(x, weight)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_float16():
weight = Tensor(np.array([0.25]).astype(np.float16))
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float16))
expect = np.array([[[[-0.25, 1, 10,],
[1, -0.25, 1,],
[10, 1, -0.25]]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_float32():
weight = Tensor(np.array([0.25]).astype(np.float32))
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
expect = np.array([[[[-0.25, 1, 10,],
[1, -0.25, 1,],
[10, 1, -0.25]]]]).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()