!14201 Add new GPU kernel Tile

From: @TFbunny
Reviewed-by: @tom__chen
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-31 21:59:40 +08:00 committed by Gitee
commit f9d1575c5f
4 changed files with 260 additions and 0 deletions

View File

@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/tile_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
TileGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TileGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TileGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), TileGpuKernel,
int16_t)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), TileGpuKernel,
int)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileGpuKernel,
int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,125 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/tile_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class TileGpuKernel : public GpuKernel {
public:
TileGpuKernel() { ResetResource(); }
~TileGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
size_t *input_shape_ptr = GetDeviceAddress<size_t>(workspace, 0);
size_t *output_shape_ptr = GetDeviceAddress<size_t>(workspace, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(input_shape_ptr, &input_shape_[0], input_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape_ failed");
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,
cudaMemcpyAsync(output_shape_ptr, &output_shape_[0], output_shape_.size() * sizeof(size_t),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output_shape_ failed");
CalTile(output_size_, input_size_, shape_size_, input_shape_ptr, output_shape_ptr, input, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but Tile needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but Tile has 1 output.";
return false;
}
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_size_ = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
input_size_ *= input_shape_[i];
}
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = 1;
if (output_shape_.size() > TILE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "Output is " << output_shape_.size() << "-D, but Tile supports up to " << TILE_MAX_DIMENSION
<< "-D.";
}
shape_size_ = output_shape_.size();
for (size_t i = 0; i < output_shape_.size(); i++) {
output_size_ *= output_shape_[i];
}
std::vector<int64_t> multiples = GetAttr<std::vector<int64_t>>(kernel_node, "multiples");
int64_t filling_value = static_cast<int64_t>(multiples.size()) - static_cast<int64_t>(input_shape_.size());
// input_shape_.size() == output_shape_.size() == shape_size_
input_shape_.insert(input_shape_.begin(), LongToSize(filling_value), 1);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
output_size_ = 1;
shape_size_ = 1;
input_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t));
workspace_size_list_.push_back(output_shape_.size() * sizeof(size_t));
output_size_list_.push_back(output_size_ * sizeof(T));
}
private:
size_t input_size_;
size_t output_size_;
size_t shape_size_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/tile_impl.cuh"
template <typename T>
__global__ void Tile(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const T *input, T *output) {
// for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] +
// pos_array[1] * output_shape[2] * output_shape[3] +
// pos_array[2] * output_shape[3] +
// pos_array[3]
size_t pos_array[TILE_MAX_DIMENSION];
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += blockDim.x * gridDim.x) {
size_t tmp_pos = pos;
size_t pos_size = output_size / output_shape[0];
pos_array[0] = tmp_pos / pos_size;
for (size_t i = 1; i < shape_size; i++) {
tmp_pos -= pos_array[i - 1] * pos_size;
pos_size = pos_size / output_shape[i];
pos_array[i] = tmp_pos / pos_size;
}
for (size_t i = 0; i < shape_size; i++) {
pos_array[i] = pos_array[i] % input_shape[i];
}
pos_size = input_size;
size_t input_pos = 0;
for (size_t i = 0; i < shape_size; i++) {
pos_size /= input_shape[i];
input_pos += (pos_array[i] * pos_size);
}
output[pos] = input[input_pos];
}
}
template <typename T>
void CalTile(const size_t output_size, const size_t input_size, const size_t shape_size, const size_t *input_shape,
const size_t *output_shape, const T *input, T *output, cudaStream_t cuda_stream) {
Tile<<<GET_BLOCKS(output_size), GET_THREADS, 0, cuda_stream>>>(output_size, input_size, shape_size, input_shape,
output_shape, input, output);
return;
}
template void CalTile<double>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const double *input,
double *output, cudaStream_t cuda_stream);
template void CalTile<float>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const float *input, float *output,
cudaStream_t cuda_stream);
template void CalTile<half>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const half *input, half *output,
cudaStream_t cuda_stream);
template void CalTile<int16_t>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const int16_t *input,
int16_t *output, cudaStream_t cuda_stream);
template void CalTile<int>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const int *input, int *output,
cudaStream_t cuda_stream);
template void CalTile<int64_t>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const int64_t *input,
int64_t *output, cudaStream_t cuda_stream);

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TILE_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TILE_IMPL_CUH_
#define TILE_MAX_DIMENSION 100
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalTile(const size_t output_size, const size_t input_size, const size_t shape_size, const size_t *input_shape,
const size_t *output_shape, const T *input, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TILE_IMPL_CUH_