Cast/ReLU dynamic shapelsgpu op supports int32 and int64

This commit is contained in:
Jonathan Yan 2020-11-27 07:03:50 -05:00
parent 3874160faf
commit 9f70ebac64
9 changed files with 199 additions and 13 deletions

View File

@ -27,7 +27,7 @@ namespace kernel {
template <typename S, typename T>
class CastGpuKernel : public GpuKernel {
public:
CastGpuKernel() : input_size_(1), output_size_(1) {}
CastGpuKernel() { ResetResource(); }
~CastGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -42,6 +42,7 @@ class CastGpuKernel : public GpuKernel {
Cast(input_size_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
@ -62,6 +63,14 @@ class CastGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
output_size_ = 1;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void CalReLUKernel(int size, T *input_addr, T *output_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : static_cast<T>(0);
}
}
template <typename T>
void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) {
CalReLUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, output_addr);
return;
}
template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream);
template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalReLU(int input_size, T *input_addr, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_

View File

@ -45,6 +45,7 @@ static constexpr float kSignedMinFloat = -3.402823466e+38F;
// Used by mixprecision, cudnn dtype select
static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {{"kNumberTypeFloat32", CUDNN_DATA_FLOAT},
{"kNumberTypeFloat16", CUDNN_DATA_HALF},
{"kNumberTypeInt64", CUDNN_DATA_DOUBLE},
{"kNumberTypeInt32", CUDNN_DATA_INT32}};
// Used by mixprecision, cuda dtype select
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},

View File

@ -22,6 +22,10 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
ActivationGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ActivationGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ActivationGpuFwdKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ActivationGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ActivationGpuFwdKernel, float)

View File

@ -23,6 +23,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
namespace mindspore {
namespace kernel {
@ -36,18 +37,23 @@ class ActivationGpuFwdKernel : public GpuKernel {
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *) override {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input,
&beta, data_descriptor_, output),
"cudnnActivationForward failed");
if (mode_ == CUDNN_ACTIVATION_RELU) {
const int size = input_size_ / sizeof(T);
CalReLU(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_,
input, &beta, data_descriptor_, output),
"cudnnActivationForward failed");
}
return true;
}

View File

@ -291,7 +291,7 @@ class Softsign(PrimitiveWithInfer):
return input_x
class ReLU(PrimitiveWithInfer):
class ReLU(PrimitiveWithCheck):
r"""
Computes ReLU (Rectified Linear Unit) of input tensors element-wise.
@ -320,12 +320,11 @@ class ReLU(PrimitiveWithInfer):
"""Initialize ReLU"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, input_x):
return input_x
def check_shape(self, input_x):
pass
def infer_dtype(self, input_x):
def check_dtype(self, input_x):
validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
return input_x
class ReLU6(PrimitiveWithInfer):

View File

@ -21,6 +21,7 @@ import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
class Net(Cell):
@ -36,6 +37,22 @@ class Net(Cell):
return output
class NetDynamic(Cell):
def __init__(self, type0, type1):
super(NetDynamic, self).__init__()
self.conv = inner.GpuConvertToDynamicShape()
self.Cast = P.Cast()
self.type0 = type0
self.type1 = type1
def construct(self, x0, x1):
x0_conv = self.conv(x0)
x1_conv = self.conv(x1)
output = (self.Cast(x0_conv, self.type0),
self.Cast(x1_conv, self.type1))
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -563,3 +580,20 @@ def test_cast30():
assert type0 == 'uint16'
type1 = output[1].asnumpy().dtype
assert type1 == 'uint32'
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cast31():
x0 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32))
t0 = mstype.uint16
x1 = Tensor(np.arange(24).reshape((4, 3, 2)).astype(np.float32))
t1 = mstype.uint32
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = NetDynamic(t0, t1)
output = net(x0, x1)
type0 = output[0].asnumpy().dtype
assert type0 == 'uint16'
type1 = output[1].asnumpy().dtype
assert type1 == 'uint32'

View File

@ -20,6 +20,7 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
class NetRelu(nn.Cell):
@ -31,10 +32,21 @@ class NetRelu(nn.Cell):
return self.relu(x)
class NetReluDynamic(nn.Cell):
def __init__(self):
super(NetReluDynamic, self).__init__()
self.conv = inner.GpuConvertToDynamicShape()
self.relu = P.ReLU()
def construct(self, x):
x_conv = self.conv(x)
return self.relu(x_conv)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu():
def test_relu_float32():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
@ -51,3 +63,65 @@ def test_relu():
relu = NetRelu()
output = relu(x)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_int32():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.int32))
expect = np.array([[[[0, 1, 10,],
[1, 0, 1,],
[10, 1, 0.]]]]).astype(np.int32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
relu = NetRelu()
output = relu(x)
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
relu = NetRelu()
output = relu(x)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_int64():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.int64))
expect = np.array([[[[0, 1, 10,],
[1, 0, 1,],
[10, 1, 0.]]]]).astype(np.int64)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
relu = NetRelu()
output = relu(x)
print(output.asnumpy(), expect)
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
relu = NetRelu()
output = relu(x)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_relu_int64_dynamic_shape():
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.int64))
expect = np.array([[[[0, 1, 10,],
[1, 0, 1,],
[10, 1, 0.]]]]).astype(np.int64)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
relu_dynamic = NetReluDynamic()
output = relu_dynamic(x)
assert (output.asnumpy() == expect).all()