!17656 ReverseV2 gpu kernel

From: @peilin-wang
Reviewed-by: @robingrosman,@nsyca
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-06-04 07:57:35 +08:00 committed by Gitee
commit 473bfff854
6 changed files with 389 additions and 0 deletions

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/reverse_v2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ReverseV2GpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReverseV2GpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ReverseV2GpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
ReverseV2GpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReverseV2GpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ReverseV2GpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,142 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_V2_GPU_KERNEL_H_
#include <algorithm>
#include <cstdint>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/reverse_v2_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class ReverseV2GpuKernel : public GpuKernel {
public:
ReverseV2GpuKernel() { ResetResource(); }
~ReverseV2GpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_device = GetDeviceAddress<T>(inputs, 0);
T *output_device = GetDeviceAddress<T>(outputs, 0);
size_t *input_shape_device = GetDeviceAddress<size_t>(workspace, 0);
int64_t *strides_device = GetDeviceAddress<int64_t>(workspace, 1);
int64_t *axis_device = GetDeviceAddress<int64_t>(workspace, 2);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(input_shape_device, &input_shape_[0], workspace_size_list_[0],
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync for input_shape_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(strides_device, &strides_[0], workspace_size_list_[1],
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync for strides_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(axis_device, &axis_[0], workspace_size_list_[2], cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync for axis_ failed");
CalReverseV2(input_device, output_device, input_shape_device, strides_device, axis_device, input_size_,
axis_.size(), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(ERROR) << input_count << " inputs were provided, but ReverseV2GpuKernel expects 1.";
return false;
}
size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_count != 1) {
MS_LOG(ERROR) << "Number of outputs is " << output_count << ", but should be 1 for ReverseV2GpuKernel.";
return false;
}
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_rank_ = input_shape_.size();
input_size_ = 1;
for (size_t i = 0; i < input_rank_; i++) {
input_size_ *= input_shape_[i];
}
strides_.resize(input_rank_);
strides_[input_rank_ - 1] = 1;
for (int32_t i = input_rank_ - 2; i >= 0; i--) {
strides_[i] = static_cast<int64_t>(input_shape_[i + 1]) * strides_[i + 1];
}
axis_ = GetAttr<std::vector<int64_t>>(kernel_node, "axis");
for (int64_t &dimension : axis_) {
if (dimension < 0) {
dimension += input_rank_;
}
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
input_rank_ = 0;
input_shape_.clear();
strides_.clear();
axis_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
size_t input_bytes = input_size_ * sizeof(T);
input_size_list_.push_back(input_bytes);
output_size_list_.push_back(input_bytes);
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
workspace_size_list_.push_back(input_rank_ * sizeof(int64_t));
workspace_size_list_.push_back(axis_.size() * sizeof(int64_t));
}
private:
size_t input_size_;
size_t input_rank_;
std::vector<size_t> input_shape_;
std::vector<int64_t> strides_;
std::vector<int64_t> axis_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_V2_GPU_KERNEL_H_

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include "reverse_v2_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void ReverseV2(const T* input, T* output, const size_t* input_shape, const int64_t* strides,
const int64_t* axis, size_t input_size, size_t axis_size) {
for (int64_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < input_size; gt_id += blockDim.x * gridDim.x) {
int64_t intermediate_index = gt_id;
for (size_t i = 0; i < axis_size; i++) {
int64_t d = axis[i];
int64_t pre_reverse_position = (gt_id / strides[d]) % input_shape[d];
int64_t reversed_position = input_shape[d] - pre_reverse_position - 1;
intermediate_index += ((reversed_position - pre_reverse_position) * strides[d]);
}
output[intermediate_index] = input[gt_id];
}
return;
}
template <typename T>
void CalReverseV2(const T* input, T* output, const size_t* input_shape, const int64_t* strides, const int64_t* axis,
size_t input_size, size_t axis_size, cudaStream_t cuda_stream) {
ReverseV2<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, output, input_shape, strides, axis,
input_size, axis_size);
return;
}
template void CalReverseV2<half>(const half* input, half* output, const size_t* input_shape, const int64_t* strides,
const int64_t* axis, size_t input_size, size_t axis_size, cudaStream_t cuda_stream);
template void CalReverseV2<float>(const float* input, float* output, const size_t* input_shape, const int64_t* strides,
const int64_t* axis, size_t input_size, size_t axis_size, cudaStream_t cuda_stream);
template void CalReverseV2<uint8_t>(const uint8_t* input, uint8_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
cudaStream_t cuda_stream);
template void CalReverseV2<int16_t>(const int16_t* input, int16_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
cudaStream_t cuda_stream);
template void CalReverseV2<int32_t>(const int32_t* input, int32_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
cudaStream_t cuda_stream);
template void CalReverseV2<int64_t>(const int64_t* input, int64_t* output, const size_t* input_shape,
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,21 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_V2_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_V2_CUH_
template <typename T>
void CalReverseV2(const T* input, T* output, const size_t* input_shape, const int64_t* strides, const int64_t* axis,
size_t input_size, size_t axis_size, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_V2_CUH_

View File

@ -2672,10 +2672,21 @@ class ReverseV2(PrimitiveWithInfer):
self.axis = axis
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
dim = len(x_shape)
for i, each in enumerate(self.axis):
validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
normalized_axis = []
for i, v in enumerate(self.axis):
if v < 0:
normalized_axis.append(v + dim)
else:
normalized_axis.append(v)
if len(normalized_axis) != len(set(normalized_axis)):
raise ValueError('axis cannot contain duplicate dimensions.')
return x_shape
def infer_dtype(self, x_dtype):

View File

@ -0,0 +1,114 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class ReverseV2Net(nn.Cell):
def __init__(self, axis):
super(ReverseV2Net, self).__init__()
self.reverse_v2 = P.ReverseV2(axis)
def construct(self, x):
return self.reverse_v2(x)
def reverse_v2(x_numpy, axis):
x = Tensor(x_numpy)
reverse_v2_net = ReverseV2Net(axis)
output = reverse_v2_net(x).asnumpy()
expected_output = np.flip(x_numpy, axis)
np.testing.assert_array_equal(output, expected_output)
def reverse_v2_3d(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_numpy = np.arange(60).reshape(3, 4, 5).astype(nptype)
reverse_v2(x_numpy, (0,))
reverse_v2(x_numpy, (1,))
reverse_v2(x_numpy, (2,))
reverse_v2(x_numpy, (2, -2))
reverse_v2(x_numpy, (-3, 1, 2))
def reverse_v2_1d(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_numpy = np.arange(4).astype(nptype)
reverse_v2(x_numpy, (0,))
reverse_v2(x_numpy, (-1,))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_float16():
reverse_v2_1d(np.float16)
reverse_v2_3d(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_float32():
reverse_v2_1d(np.float32)
reverse_v2_3d(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_uint8():
reverse_v2_1d(np.uint8)
reverse_v2_3d(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_int16():
reverse_v2_1d(np.int16)
reverse_v2_3d(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_int32():
reverse_v2_1d(np.int32)
reverse_v2_3d(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_int64():
reverse_v2_1d(np.int64)
reverse_v2_3d(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_reverse_v2_invalid_axis():
x = Tensor(np.arange(60).reshape(1, 2, 3, 2, 5).astype(np.int32))
with pytest.raises(ValueError) as info:
reverse_v2_net = ReverseV2Net((0, 1, 2, 1))
_ = reverse_v2_net(x)
assert "axis cannot contain duplicate dimensions" in str(info.value)
with pytest.raises(ValueError) as info:
reverse_v2_net = ReverseV2Net((-2, -1, 3))
_ = reverse_v2_net(x)
assert "axis cannot contain duplicate dimensions" in str(info.value)