forked from mindspore-Ecosystem/mindspore
!17656 ReverseV2 gpu kernel
From: @peilin-wang Reviewed-by: @robingrosman,@nsyca Signed-off-by: @robingrosman
This commit is contained in:
commit
473bfff854
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/arrays/reverse_v2_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ReverseV2GpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReverseV2GpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
ReverseV2GpuKernel, uint8_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
ReverseV2GpuKernel, int16_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ReverseV2GpuKernel, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(ReverseV2, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ReverseV2GpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_V2_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_V2_GPU_KERNEL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/reverse_v2_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ReverseV2GpuKernel : public GpuKernel {
|
||||
public:
|
||||
ReverseV2GpuKernel() { ResetResource(); }
|
||||
~ReverseV2GpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_device = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_device = GetDeviceAddress<T>(outputs, 0);
|
||||
size_t *input_shape_device = GetDeviceAddress<size_t>(workspace, 0);
|
||||
int64_t *strides_device = GetDeviceAddress<int64_t>(workspace, 1);
|
||||
int64_t *axis_device = GetDeviceAddress<int64_t>(workspace, 2);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape_device, &input_shape_[0], workspace_size_list_[0],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync for input_shape_ failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(strides_device, &strides_[0], workspace_size_list_[1],
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync for strides_ failed");
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(axis_device, &axis_[0], workspace_size_list_[2], cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync for axis_ failed");
|
||||
|
||||
CalReverseV2(input_device, output_device, input_shape_device, strides_device, axis_device, input_size_,
|
||||
axis_.size(), reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 1) {
|
||||
MS_LOG(ERROR) << input_count << " inputs were provided, but ReverseV2GpuKernel expects 1.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_count = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_count != 1) {
|
||||
MS_LOG(ERROR) << "Number of outputs is " << output_count << ", but should be 1 for ReverseV2GpuKernel.";
|
||||
return false;
|
||||
}
|
||||
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
input_rank_ = input_shape_.size();
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_rank_; i++) {
|
||||
input_size_ *= input_shape_[i];
|
||||
}
|
||||
|
||||
strides_.resize(input_rank_);
|
||||
strides_[input_rank_ - 1] = 1;
|
||||
for (int32_t i = input_rank_ - 2; i >= 0; i--) {
|
||||
strides_[i] = static_cast<int64_t>(input_shape_[i + 1]) * strides_[i + 1];
|
||||
}
|
||||
|
||||
axis_ = GetAttr<std::vector<int64_t>>(kernel_node, "axis");
|
||||
for (int64_t &dimension : axis_) {
|
||||
if (dimension < 0) {
|
||||
dimension += input_rank_;
|
||||
}
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
input_rank_ = 0;
|
||||
input_shape_.clear();
|
||||
strides_.clear();
|
||||
axis_.clear();
|
||||
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t input_bytes = input_size_ * sizeof(T);
|
||||
input_size_list_.push_back(input_bytes);
|
||||
output_size_list_.push_back(input_bytes);
|
||||
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
|
||||
workspace_size_list_.push_back(input_rank_ * sizeof(int64_t));
|
||||
workspace_size_list_.push_back(axis_.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t input_rank_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<int64_t> strides_;
|
||||
std::vector<int64_t> axis_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_REVERSE_V2_GPU_KERNEL_H_
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cuda_runtime.h>
|
||||
#include "reverse_v2_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
__global__ void ReverseV2(const T* input, T* output, const size_t* input_shape, const int64_t* strides,
|
||||
const int64_t* axis, size_t input_size, size_t axis_size) {
|
||||
for (int64_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < input_size; gt_id += blockDim.x * gridDim.x) {
|
||||
int64_t intermediate_index = gt_id;
|
||||
for (size_t i = 0; i < axis_size; i++) {
|
||||
int64_t d = axis[i];
|
||||
int64_t pre_reverse_position = (gt_id / strides[d]) % input_shape[d];
|
||||
int64_t reversed_position = input_shape[d] - pre_reverse_position - 1;
|
||||
intermediate_index += ((reversed_position - pre_reverse_position) * strides[d]);
|
||||
}
|
||||
|
||||
output[intermediate_index] = input[gt_id];
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T>
|
||||
void CalReverseV2(const T* input, T* output, const size_t* input_shape, const int64_t* strides, const int64_t* axis,
|
||||
size_t input_size, size_t axis_size, cudaStream_t cuda_stream) {
|
||||
ReverseV2<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, output, input_shape, strides, axis,
|
||||
input_size, axis_size);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalReverseV2<half>(const half* input, half* output, const size_t* input_shape, const int64_t* strides,
|
||||
const int64_t* axis, size_t input_size, size_t axis_size, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalReverseV2<float>(const float* input, float* output, const size_t* input_shape, const int64_t* strides,
|
||||
const int64_t* axis, size_t input_size, size_t axis_size, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalReverseV2<uint8_t>(const uint8_t* input, uint8_t* output, const size_t* input_shape,
|
||||
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalReverseV2<int16_t>(const int16_t* input, int16_t* output, const size_t* input_shape,
|
||||
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalReverseV2<int32_t>(const int32_t* input, int32_t* output, const size_t* input_shape,
|
||||
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalReverseV2<int64_t>(const int64_t* input, int64_t* output, const size_t* input_shape,
|
||||
const int64_t* strides, const int64_t* axis, size_t input_size, size_t axis_size,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,21 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_V2_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_V2_CUH_
|
||||
template <typename T>
|
||||
void CalReverseV2(const T* input, T* output, const size_t* input_shape, const int64_t* strides, const int64_t* axis,
|
||||
size_t input_size, size_t axis_size, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REVERSE_V2_CUH_
|
|
@ -2672,10 +2672,21 @@ class ReverseV2(PrimitiveWithInfer):
|
|||
self.axis = axis
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
||||
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
dim = len(x_shape)
|
||||
for i, each in enumerate(self.axis):
|
||||
validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name)
|
||||
normalized_axis = []
|
||||
for i, v in enumerate(self.axis):
|
||||
if v < 0:
|
||||
normalized_axis.append(v + dim)
|
||||
else:
|
||||
normalized_axis.append(v)
|
||||
|
||||
if len(normalized_axis) != len(set(normalized_axis)):
|
||||
raise ValueError('axis cannot contain duplicate dimensions.')
|
||||
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class ReverseV2Net(nn.Cell):
|
||||
def __init__(self, axis):
|
||||
super(ReverseV2Net, self).__init__()
|
||||
self.reverse_v2 = P.ReverseV2(axis)
|
||||
|
||||
def construct(self, x):
|
||||
return self.reverse_v2(x)
|
||||
|
||||
|
||||
def reverse_v2(x_numpy, axis):
|
||||
x = Tensor(x_numpy)
|
||||
reverse_v2_net = ReverseV2Net(axis)
|
||||
output = reverse_v2_net(x).asnumpy()
|
||||
expected_output = np.flip(x_numpy, axis)
|
||||
np.testing.assert_array_equal(output, expected_output)
|
||||
|
||||
def reverse_v2_3d(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
x_numpy = np.arange(60).reshape(3, 4, 5).astype(nptype)
|
||||
|
||||
reverse_v2(x_numpy, (0,))
|
||||
reverse_v2(x_numpy, (1,))
|
||||
reverse_v2(x_numpy, (2,))
|
||||
reverse_v2(x_numpy, (2, -2))
|
||||
reverse_v2(x_numpy, (-3, 1, 2))
|
||||
|
||||
def reverse_v2_1d(nptype):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
x_numpy = np.arange(4).astype(nptype)
|
||||
|
||||
reverse_v2(x_numpy, (0,))
|
||||
reverse_v2(x_numpy, (-1,))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_float16():
|
||||
reverse_v2_1d(np.float16)
|
||||
reverse_v2_3d(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_float32():
|
||||
reverse_v2_1d(np.float32)
|
||||
reverse_v2_3d(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_uint8():
|
||||
reverse_v2_1d(np.uint8)
|
||||
reverse_v2_3d(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_int16():
|
||||
reverse_v2_1d(np.int16)
|
||||
reverse_v2_3d(np.int16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_int32():
|
||||
reverse_v2_1d(np.int32)
|
||||
reverse_v2_3d(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_int64():
|
||||
reverse_v2_1d(np.int64)
|
||||
reverse_v2_3d(np.int64)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_reverse_v2_invalid_axis():
|
||||
x = Tensor(np.arange(60).reshape(1, 2, 3, 2, 5).astype(np.int32))
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
reverse_v2_net = ReverseV2Net((0, 1, 2, 1))
|
||||
_ = reverse_v2_net(x)
|
||||
assert "axis cannot contain duplicate dimensions" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
reverse_v2_net = ReverseV2Net((-2, -1, 3))
|
||||
_ = reverse_v2_net(x)
|
||||
assert "axis cannot contain duplicate dimensions" in str(info.value)
|
Loading…
Reference in New Issue