!41710 add fast sort gpu kernel

Merge pull request !41710 from hangq/master
This commit is contained in:
i-robot 2023-01-07 07:42:30 +00:00 committed by Gitee
commit f8bc59aedb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 1227 additions and 103 deletions

View File

@ -0,0 +1,174 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_FAST_SORT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_FAST_SORT_GPU_KERNEL_H_
#include <algorithm>
#include <cstdint>
#include <limits>
#include <utility>
#include <vector>
#include <memory>
#include <map>
#include "mindspore/core/ops/sort.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_layout_helper.cuh"
#include "plugin/device/gpu/kernel/arrays/sort_key_value_inplace.h"
#include "ir/dtype/type_id.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace kernel {
constexpr int kFastSortInputsNum = 1;
constexpr int kFastSortOutputsNum = 2;
template <typename K, typename V>
class FastSortGpuKernelMod : public NativeGpuKernelMod {
public:
FastSortGpuKernelMod() = default;
~FastSortGpuKernelMod() {
delete input_info_;
delete output_index_info_;
delete output_value_info_;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
auto kernel_name = base_operator->GetPrim()->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kFastSortInputsNum, kernel_name);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kFastSortOutputsNum, kernel_name);
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
auto kernel_name = base_operator->GetPrim()->name();
input_shape_ = inputs[0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
if (is_null_input_) {
return KRET_OK;
}
input_rank_ = input_shape_.size();
input_size_ = 1;
for (int64_t i = 0; i < input_rank_; i++) {
input_size_ *= input_shape_[i];
}
auto kernel_ptr = std::make_shared<ops::Sort>(base_operator->GetPrim());
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "Malloc ops::Sort failed while Resizing.";
return KRET_RESIZE_FAILED;
}
descending_ = static_cast<bool>(kernel_ptr->get_descending());
axis_ = static_cast<int64_t>(kernel_ptr->get_axis());
if (axis_ < 0) {
axis_ += input_rank_;
}
if (axis_ >= input_rank_) {
MS_LOG(ERROR) << "For '" << kernel_name << "', the value of 'axis' must be less than the dimension of input"
<< ", but got the dimension of input: "
<< ", got the value of 'axis': ";
return KRET_RESIZE_FAILED;
}
constexpr int kMaxFixedSortSize = 4096;
if (input_shape_[axis_] > kMaxFixedSortSize) {
MS_LOG(ERROR) << "For '" << kernel_name << "', only support sort dim less or equal to 4096, but got: ";
return KRET_RESIZE_FAILED;
}
delete input_info_;
delete output_index_info_;
delete output_value_info_;
int shape[MAX_TENSORINFO_DIMS];
for (int i = 0; i < input_rank_; i++) {
shape[i] = input_shape_[i];
}
input_info_ = new TensorLayoutHelper(shape, input_rank_);
if (input_info_ == nullptr) {
MS_LOG(ERROR) << "Malloc TensorLayoutHelper for input failed while Resizing.";
return KRET_RESIZE_FAILED;
}
output_index_info_ = new TensorLayoutHelper(shape, input_rank_);
if (output_index_info_ == nullptr) {
MS_LOG(ERROR) << "Malloc TensorLayoutHelper for output index failed while Resizing.";
return KRET_RESIZE_FAILED;
}
output_value_info_ = new TensorLayoutHelper(shape, input_rank_);
if (output_value_info_ == nullptr) {
MS_LOG(ERROR) << "Malloc TensorLayoutHelper for output value failed while Resizing.";
return KRET_RESIZE_FAILED;
}
return KRET_OK;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
return LaunchKernel(inputs, workspace, outputs, stream_ptr);
}
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
if (is_null_input_) {
return true;
}
V *input_device = GetDeviceAddress<V>(inputs, 0);
V *output_device = GetDeviceAddress<V>(outputs, 0);
K *indices_device = GetDeviceAddress<K>(outputs, 1);
auto ret = InitIndexBySlice<K>(*output_index_info_, axis_, indices_device, cuda_stream_);
if (!ret) {
MS_LOG(ERROR) << "InitIndexBySlice failed.";
return false;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_device, input_device, input_size_ * sizeof(V), cudaMemcpyDeviceToDevice, cuda_stream_),
"cudaMemcpyAsync for output_device failed");
return SortKeyValueInplace<V, K>(*output_value_info_, output_device, *output_index_info_, indices_device, axis_,
descending_, cuda_stream_);
}
private:
int64_t input_size_{0};
int64_t axis_{0};
bool descending_{false};
bool is_null_input_{false};
std::vector<int64_t> input_shape_;
int64_t input_rank_{0};
TensorLayoutHelper *input_info_{nullptr};
TensorLayoutHelper *output_index_info_{nullptr};
TensorLayoutHelper *output_value_info_{nullptr};
cudaStream_t cuda_stream_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_FAST_SORT_GPU_KERNEL_H_

View File

@ -20,19 +20,29 @@
namespace mindspore {
namespace kernel {
constexpr double kMinValue = -65504.;
template <typename T>
bool SortGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
template <typename K, typename V>
bool SortGpuKernelMod<K, V>::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
MS_LOG(EXCEPTION) << "Only support input datatype in [float16, float32] for sort kernel";
return false;
}
template <>
bool SortGpuKernelMod<int32_t, half>::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
if (is_null_input_) {
return true;
}
T *input_device = GetDeviceAddress<T>(inputs, kIndex0);
half *input_device = GetDeviceAddress<half>(inputs, kIndex0);
T *output_device = GetDeviceAddress<T>(outputs, kIndex0);
half *output_device = GetDeviceAddress<half>(outputs, kIndex0);
int32_t *indices_device = GetDeviceAddress<int32_t>(outputs, kIndex1);
T *temp_output_device = GetDeviceAddress<T>(workspace, kIndex0);
half *temp_output_device = GetDeviceAddress<half>(workspace, kIndex0);
int32_t *temp_indices_device = GetDeviceAddress<int32_t>(workspace, kIndex1);
size_t *input_shape_device = GetDeviceAddress<size_t>(workspace, kIndex2);
size_t *perm_device = GetDeviceAddress<size_t>(workspace, kIndex3);
@ -53,14 +63,10 @@ bool SortGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
// between using temp_output_device and output_device for intermediate calculations,
// this way only a constant number of allocations is needed instead of needing to
// allocate once for each intermediate calculation.
T *intermediate_input_device = input_device;
T *intermediate_output_device = output_device;
half *intermediate_input_device = input_device;
half *intermediate_output_device = output_device;
T topk_init_ = std::numeric_limits<T>::lowest();
if (std::is_same<T, half>::value) {
// min value representable by float16, std::numeric_limits doesn't support half
topk_init_ = static_cast<half>(kMinValue);
}
half topk_init_ = static_cast<half>(kMinValue);
// if sort not in descending order, negate input and negate back after sorting
if (!descending_) {
@ -104,59 +110,113 @@ bool SortGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
return true;
}
int SortGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
input_shape_ = inputs[0]->GetShapeVector();
auto kernel_name = base_operator->GetPrim()->name();
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
template <>
bool SortGpuKernelMod<int32_t, float>::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
if (is_null_input_) {
return KRET_RESIZE_FAILED;
return true;
}
float *input_device = GetDeviceAddress<float>(inputs, kIndex0);
float *output_device = GetDeviceAddress<float>(outputs, kIndex0);
int32_t *indices_device = GetDeviceAddress<int32_t>(outputs, kIndex1);
float *temp_output_device = GetDeviceAddress<float>(workspace, kIndex0);
int32_t *temp_indices_device = GetDeviceAddress<int32_t>(workspace, kIndex1);
size_t *input_shape_device = GetDeviceAddress<size_t>(workspace, kIndex2);
size_t *perm_device = GetDeviceAddress<size_t>(workspace, kIndex3);
size_t *transposed_shape_device = GetDeviceAddress<size_t>(workspace, kIndex4);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(input_shape_device, &input_shape_[0], workspace_size_list_[kIndex2], cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync for input_shape_ failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(perm_device, &perm_[0], workspace_size_list_[kIndex3], cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync for perm_ failed");
// Sort is implemented using a combination of Neg, Transpose, and TopK. It's
// Not safe to treat Transpose and TopK as inplace operators, so we alternate
// between using temp_output_device and output_device for intermediate calculations,
// this way only a constant number of allocations is needed instead of needing to
// allocate once for each intermediate calculation.
float *intermediate_input_device = input_device;
float *intermediate_output_device = output_device;
float topk_init_ = std::numeric_limits<float>::lowest();
// if sort not in descending order, negate input and negate back after sorting
if (!descending_) {
NegOpt(intermediate_input_device, intermediate_output_device, input_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
intermediate_input_device = output_device;
intermediate_output_device = temp_output_device;
}
input_size_ = 1;
for (size_t i = 0; i < input_rank_; i++) {
input_size_ *= static_cast<size_t>(input_shape_[i]);
// transpose so that desired dimension to sort along becomes the last one
CalTranspose(input_size_, intermediate_input_device, input_shape_device, perm_device, input_rank_,
intermediate_output_device, reinterpret_cast<cudaStream_t>(stream_ptr));
intermediate_input_device = intermediate_output_device;
intermediate_output_device = intermediate_input_device == output_device ? temp_output_device : output_device;
// topk sorts the input along the last dimension
FastTopK(outer_size_, inner_size_, intermediate_input_device, static_cast<int32_t>(input_shape_[axis_]),
intermediate_output_device, temp_indices_device, topk_init_, reinterpret_cast<cudaStream_t>(stream_ptr));
std::swap(intermediate_input_device, intermediate_output_device);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(transposed_shape_device, &transposed_shape_[0], workspace_size_list_[kIndex4],
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync for transposed_shape_ failed");
// transpose the sorted output back to the original input shape
CalTranspose(input_size_, intermediate_input_device, transposed_shape_device, perm_device, input_rank_,
intermediate_output_device, reinterpret_cast<cudaStream_t>(stream_ptr));
// transpose the indices back to the original input shape
CalTranspose(input_size_, temp_indices_device, transposed_shape_device, perm_device, input_rank_, indices_device,
reinterpret_cast<cudaStream_t>(stream_ptr));
// negate back the sorted values if we negated prior to sorting
if (!descending_) {
std::swap(intermediate_input_device, intermediate_output_device);
NegOpt(intermediate_input_device, intermediate_output_device, input_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
transposed_shape_ = input_shape_;
std::swap(transposed_shape_[input_rank_ - 1], transposed_shape_[axis_]);
inner_size_ = static_cast<size_t>(input_shape_[axis_]);
outer_size_ = input_size_ / inner_size_;
MS_LOG(DEBUG) << "In gpu kernel sort Resize, axis_=" << axis_ << " descending_=" << descending_
<< " input_rank_=" << input_rank_ << " input_size_=" << input_size_ << " inner_size_=" << inner_size_
<< " outer_size_=" << outer_size_;
if (input_size_list_.size() > 0) {
size_t input_bytes = input_size_list_.at(kIndex0);
size_t indices_bytes = input_size_ * sizeof(int32_t);
workspace_size_list_.push_back(input_bytes);
workspace_size_list_.push_back(indices_bytes);
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
}
return KRET_OK;
return true;
}
std::vector<std::pair<KernelAttr, SortGpuKernelMod::SortLaunchFunc>> SortGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
&SortGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
&SortGpuKernelMod::LaunchKernel<float>}};
std::vector<KernelAttr> SortGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SortGpuKernelMod::SortLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sort, SortGpuKernelMod);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, bool);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, int8_t);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, int16_t);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, int32_t);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, int64_t);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, uint8_t);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, half);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, float);
MS_REG_GPU_KERNEL_TWO(
Sort, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
SortGpuKernelMod, int32_t, double);
} // namespace kernel
} // namespace mindspore

View File

@ -30,77 +30,88 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/topk_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/elementwise_op_impl.cuh"
#include "plugin/device/gpu/kernel/arrays/fast_sort_gpu_kernel.h"
namespace mindspore {
namespace kernel {
constexpr int kSortInputsNum = 1;
constexpr int kSortOutputsNum = 2;
template <typename K, typename V>
class SortGpuKernelMod : public NativeGpuKernelMod {
public:
SortGpuKernelMod() { ResetResource(); }
~SortGpuKernelMod() = default;
~SortGpuKernelMod() { delete fast_sort_kernel_; }
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
auto kernel_name = base_operator->GetPrim()->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSortInputsNum, kernel_name);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSortOutputsNum, kernel_name);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
return false;
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
input_shape_ = inputs[0]->GetShapeVector();
use_fast_ = input_shape_[axis_] <= sort_dim_thres_;
if (use_fast_) {
return fast_sort_kernel_->Resize(base_operator, inputs, outputs, inputsOnHost);
} else {
if (!old_kernel_support_) {
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
MS_LOG(ERROR) << "Only support input datatype in [float16, float32] for sort kernel, but got "
<< kernel_attr.GetInputAttr(0).dtype << " in KernelAttr.";
return KRET_RESIZE_FAILED;
}
}
auto kernel_name = base_operator->GetPrim()->name();
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
if (is_null_input_) {
return true;
return KRET_OK;
}
input_rank_ = input_shape_.size();
if (input_rank_ > TRANSPOSE_MAX_DIMENSION || input_rank_ < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input cannot be greater than "
<< TRANSPOSE_MAX_DIMENSION << ", or less than 1"
<< ", but got " << input_rank_;
MS_LOG(ERROR) << "For '" << kernel_name << "', the dimension of input cannot be greater than "
<< TRANSPOSE_MAX_DIMENSION << ", or less than 1"
<< ", but got " << input_rank_;
return KRET_RESIZE_FAILED;
}
input_size_ = 1;
auto kernel_ptr = std::make_shared<ops::Sort>(base_operator->GetPrim());
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "Malloc ops::Sort failed while Resizing.";
return KRET_RESIZE_FAILED;
}
descending_ = static_cast<bool>(kernel_ptr->get_descending());
axis_ = static_cast<int64_t>(kernel_ptr->get_axis());
if (axis_ < 0) {
axis_ += input_rank_;
}
if ((size_t)axis_ >= input_rank_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the value of 'axis' must be less than the dimension of input"
<< ", but got the dimension of input: " << input_rank_
<< ", got the value of 'axis': " << (size_t)axis_;
MS_LOG(ERROR) << "For '" << kernel_name << "', the value of 'axis' must be less than the dimension of input"
<< ", but got the dimension of input: " << input_rank_
<< ", got the value of 'axis': " << (size_t)axis_;
return KRET_RESIZE_FAILED;
}
perm_.resize(input_rank_);
std::iota(perm_.begin(), perm_.end(), 0);
std::swap(perm_[input_rank_ - 1], perm_[axis_]);
input_size_ = 1;
for (size_t i = 0; i < input_rank_; i++) {
input_size_ *= static_cast<size_t>(input_shape_[i]);
}
transposed_shape_ = input_shape_;
std::swap(transposed_shape_[input_rank_ - 1], transposed_shape_[axis_]);
inner_size_ = static_cast<size_t>(input_shape_[axis_]);
outer_size_ = input_size_ / inner_size_;
MS_LOG(DEBUG) << "In gpu kernel sort Init, axis_=" << axis_ << " descending_=" << descending_
MS_LOG(DEBUG) << "In gpu kernel sort Resize, axis_=" << axis_ << " descending_=" << descending_
<< " input_rank_=" << input_rank_ << " input_size_=" << input_size_ << " inner_size_=" << inner_size_
<< " outer_size_=" << outer_size_;
(void)KernelMod::Resize(base_operator, inputs, outputs);
if (input_size_list_.size() > 0) {
size_t input_bytes = input_size_list_.at(kIndex0);
size_t indices_bytes = input_size_ * sizeof(int32_t);
@ -110,9 +121,46 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
workspace_size_list_.push_back(input_rank_ * sizeof(size_t));
}
return KRET_OK;
}
kernel_func_ = func_list_[index].second;
return true;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
auto kernel_name = base_operator->GetPrim()->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSortInputsNum, kernel_name);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSortOutputsNum, kernel_name);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
KernelAttr fp16_kernel_attr;
fp16_kernel_attr.AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32);
KernelAttr fp32_kernel_attr;
fp32_kernel_attr.AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32);
std::vector<KernelAttr> support_list;
support_list.emplace_back(fp16_kernel_attr);
support_list.emplace_back(fp32_kernel_attr);
old_kernel_support_ = MatchKernelAttr(kernel_attr, support_list).first;
MS_LOG(DEBUG) << "In gpu kernel sort Init, axis_=" << axis_ << " descending_=" << descending_
<< " input_rank_=" << input_rank_ << " input_size_=" << input_size_ << " inner_size_=" << inner_size_
<< " outer_size_=" << outer_size_;
(void)KernelMod::Resize(base_operator, inputs, outputs);
fast_sort_kernel_ = new FastSortGpuKernelMod<K, V>();
if (fast_sort_kernel_ == nullptr) {
MS_LOG(ERROR) << "Malloc FastSortGpuKernelMod failed while Init.";
return false;
}
return fast_sort_kernel_->Init(base_operator, inputs, outputs);
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
if (use_fast_) {
return fast_sort_kernel_->Launch(inputs, workspace, outputs, stream_ptr);
}
return LaunchKernel(inputs, workspace, outputs, stream_ptr);
}
void ResetResource() noexcept {
@ -131,11 +179,6 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
workspace_size_list_.clear();
}
protected:
std::vector<KernelTensorPtr> outputs_{};
std::vector<KernelAttr> GetOpSupport() override;
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
private:
size_t input_size_;
int64_t axis_;
@ -152,14 +195,15 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
size_t outer_size_;
size_t inner_size_;
template <typename T>
// fast sort
FastSortGpuKernelMod<K, V> *fast_sort_kernel_{nullptr};
bool use_fast_{false};
constexpr static int64_t sort_dim_thres_ = 4096;
bool old_kernel_support_{false};
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using SortLaunchFunc = std::function<bool(SortGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, SortLaunchFunc>> func_list_;
SortLaunchFunc kernel_func_;
cudaStream_t cuda_stream_;
};
} // namespace kernel

View File

@ -0,0 +1,230 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/sort_key_value_inplace.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sort_fixed_size.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/hal/device/gpu_common.h"
constexpr int MAX_DIMS = 8;
// Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
static uint64_t NextHighestPowerOf2(uint64_t n) {
const int pow0of2 = 1;
const int pow1of2 = 2;
const int pow2of2 = 4;
const int pow3of2 = 8;
const int pow4of2 = 16;
const int pow5of2 = 32;
n--;
n |= n >> pow0of2;
n |= n >> pow1of2;
n |= n >> pow2of2;
n |= n >> pow3of2;
n |= n >> pow4of2;
#ifndef _MSC_VER
n |= n >> pow5of2;
#endif
n++;
return n;
}
template <int A, typename K, typename V>
bool SegSort(const TensorLayoutHelper &key_info, K *key_data, int64_t key_slices, int64_t key_slice_size,
int64_t key_slice_stride, const TensorLayoutHelper &value_info, V *value_data, int64_t value_slice_stride,
bool descending, cudaStream_t stream) {
int64_t ceil_power_of2 = NextHighestPowerOf2(key_slice_size);
#define HANDLE_CASE(SIZE, ITEMS_PER_THREAD) \
return SortFixedSize<A, SIZE, ITEMS_PER_THREAD, K, V>(key_info, key_data, key_slices, key_slice_size, \
key_slice_stride, value_info, value_data, value_slice_stride, \
descending, stream)
constexpr int kFixedSizeLevel3SubThreshold1 = 512;
constexpr int kFixedSizeLevel3SubThreshold2 = 256;
constexpr int kFixedSizeLevel4SubThreshold = 64;
constexpr int kFixedSizeLevel5SubThreshold1 = 16;
constexpr int kFixedSizeLevel5SubThreshold2 = 8;
constexpr int kFixedSizeLevel5SubThreshold3 = 4;
constexpr int kFixedSizeLevel5SubThreshold4 = 2;
switch (ceil_power_of2) {
case kFixedSizeLevel1:
HANDLE_CASE(kFixedSizeLevel1, kFixedSizeLevel1ItemPreThread);
case kFixedSizeLevel2:
HANDLE_CASE(kFixedSizeLevel2, kFixedSizeLevel2ItemPreThread);
case kFixedSizeLevel3:
case kFixedSizeLevel3SubThreshold1:
case kFixedSizeLevel3SubThreshold2:
HANDLE_CASE(kFixedSizeLevel3, kFixedSizeLevel3ItemPreThread);
case kFixedSizeLevel4:
case kFixedSizeLevel4SubThreshold:
HANDLE_CASE(kFixedSizeLevel4, kFixedSizeLevel4ItemPreThread);
case kFixedSizeLevel5:
case kFixedSizeLevel5SubThreshold1:
case kFixedSizeLevel5SubThreshold2:
case kFixedSizeLevel5SubThreshold3:
case kFixedSizeLevel5SubThreshold4:
HANDLE_CASE(kFixedSizeLevel5, kFixedSizeLevel5ItemPreThread);
case 1:
return true;
default:
MS_LOG(ERROR) << "SortKeyValueInplace only support sort size less than or equal to 4096, but got "
<< key_slice_size;
return false;
}
#undef HANDLE_CASE
}
template <typename K>
CUDA_LIB_EXPORT bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaStream_t cuda_stream) {
if (t.shape_size_ <= 0) {
return true;
}
if (axis < 0) {
axis += t.dim_size_;
}
if (axis >= t.dim_size_ || axis < 0) {
MS_LOG(ERROR) << "axis out of range of dim_size_.";
return false;
}
// implement cuda method to init slice data and avoiding temp-data malloc and cudaMemcpy in future.
int64_t slice_size = t.sizes_[axis];
K *slice_data_host = reinterpret_cast<K *>(malloc(slice_size * sizeof(K)));
if (slice_data_host == nullptr) {
MS_LOG(ERROR) << "Malloc slice index data failed.";
return false;
}
for (int64_t i = 0; i < slice_size; i++) {
slice_data_host[i] = i;
}
K *slice_data_device = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(reinterpret_cast<void **>(&slice_data_device), slice_size * sizeof(K)),
"Malloc slice data failed.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(slice_data_device, slice_data_host, slice_size * sizeof(K), cudaMemcpyHostToDevice, cuda_stream),
"Memcpy slice data from host to device failed.");
free(slice_data_host);
int in_size[MAX_DIMS];
int out_size[MAX_DIMS];
for (int i = 0; i < MAX_DIMS; i++) {
in_size[i] = 1;
}
in_size[MAX_DIMS - t.dim_size_ + axis] = t.sizes_[axis];
for (int i = t.dim_size_ - 1; i >= 0; i--) {
out_size[i + MAX_DIMS - t.dim_size_] = t.sizes_[i];
}
for (int i = MAX_DIMS - t.dim_size_ - 1; i >= 0; i--) {
out_size[i] = 1;
}
BroadcastTo<K>(in_size[0], in_size[1], in_size[2], in_size[3], in_size[4], in_size[5], in_size[6], in_size[7],
out_size[0], out_size[1], out_size[2], out_size[3], out_size[4], out_size[5], out_size[6], out_size[7],
slice_data_device, data, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(slice_data_device), "Free slice data failed.");
return true;
}
template CUDA_LIB_EXPORT bool InitIndexBySlice<int64_t>(const TensorLayoutHelper &t, int64_t axis, int64_t *data,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT bool InitIndexBySlice<int32_t>(const TensorLayoutHelper &t, int64_t axis, int32_t *data,
cudaStream_t cuda_stream);
template <typename K, typename V>
CUDA_LIB_EXPORT bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value,
V *value_data, int64_t axis, bool descending, cudaStream_t cuda_stream) {
if (key.dim_size_ != value.dim_size_) {
MS_LOG(ERROR) << "dim_size of key(" << key.dim_size_ << ") should be equal to dim_size of value(" << value.dim_size_
<< ").";
return false;
}
int dims = value.dim_size_;
if (dims > MAX_DIMS) {
MS_LOG(ERROR) << "dim_size should be less than or equal to " << MAX_DIMS << ", but got " << dims << ".";
return false;
}
int in_elements = key.shape_size_;
if (in_elements == 0) {
return true;
}
int key_slice_size = key.sizes_[axis];
int key_slices = in_elements / key_slice_size;
// The constructed key/value tensor info is used to select the slice
// we are sorting on a per-block basis
// The constructed key/value tensor info is used to select the slice
// we are sorting on a per-block basis
TensorLayoutHelper key_info(key.sizes_, key.dim_size_);
TensorLayoutHelper value_info(value.sizes_, value.dim_size_);
auto stride_key = key_info.strides_[axis];
key_info.sizes_[axis] = 1;
int collapse_key_dim = key_info.CollapseDims(axis);
key_info.strides_[collapse_key_dim] = stride_key;
auto stride_value = value_info.strides_[axis];
value_info.sizes_[axis] = 1;
int collapse_value_dim = value_info.CollapseDims(axis);
value_info.strides_[collapse_value_dim] = stride_value;
#define HANDLE_SORT_CASE(TYPE, A) \
return SegSort<A, K, V>(key_info, key_data, (TYPE)key_slices, (TYPE)key_slice_size, \
(TYPE)key_info.strides_[collapse_key_dim], value_info, value_data, \
(TYPE)value_info.strides_[collapse_value_dim], descending, cuda_stream)
if (key_info.IsContiguous()) {
HANDLE_SORT_CASE(int64_t, kFixedSizeSortKeyDimsLastSecond);
} else {
switch (key_info.dim_size_) {
case 2: // if sort dim == -1:
HANDLE_SORT_CASE(unsigned int, kFixedSizeSortKeyDimsSecond);
default: // if sort dim != -1:
HANDLE_SORT_CASE(unsigned int, kFixedSizeSortKeyDimsLast);
}
}
#undef HANDLE_SORT_CASE
}
#define SortKeyValueInplace(K, V) \
template CUDA_LIB_EXPORT bool SortKeyValueInplace<K, V>(const TensorLayoutHelper &key, K *key_data, \
const TensorLayoutHelper &value, V *value_data, \
int64_t axis, bool descending, cudaStream_t cuda_stream);
SortKeyValueInplace(bool, int64_t);
SortKeyValueInplace(int8_t, int64_t);
SortKeyValueInplace(int16_t, int64_t);
SortKeyValueInplace(int32_t, int64_t);
SortKeyValueInplace(int64_t, int64_t);
SortKeyValueInplace(uint8_t, int64_t);
SortKeyValueInplace(half, int64_t);
SortKeyValueInplace(float, int64_t);
SortKeyValueInplace(double, int64_t);
SortKeyValueInplace(bool, int32_t);
SortKeyValueInplace(int8_t, int32_t);
SortKeyValueInplace(int16_t, int32_t);
SortKeyValueInplace(int32_t, int32_t);
SortKeyValueInplace(int64_t, int32_t);
SortKeyValueInplace(uint8_t, int32_t);
SortKeyValueInplace(half, int32_t);
SortKeyValueInplace(float, int32_t);
SortKeyValueInplace(double, int32_t);

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_KEY_VALUE_INPLACE_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_KEY_VALUE_INPLACE_CUH_
#include <cuda_runtime.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_layout_helper.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename K>
CUDA_LIB_EXPORT bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaStream_t cuda_stream);
template <typename K, typename V>
CUDA_LIB_EXPORT bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value,
V *value_data, int64_t axis, bool descending, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_KEY_VALUE_INPLACE_CUH_

View File

@ -0,0 +1,241 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sort_fixed_size.cuh"
#include <cuda_runtime.h>
#include <cub/cub.cuh>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/strided_pointer.cuh"
#if __CUDA_ARCH__ == 750
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;
#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1536;
#else
constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 2048;
#endif
constexpr uint32_t CUDA_THREADS_PRE_BLOCK_FALLBACK = 256;
#define MS_MAX_THREAD_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PRE_BLOCK_FALLBACK)
// Maximum size per grid dimension that we assume (compute capability >= 2.0)
constexpr int64_t MAX_GRID_SIZE = 65535LL;
#define ceil_div(x, y) (((x) + (y) - 1) / (y))
static bool GetGridFromTiles(int64_t grid_tiles, dim3 *grid) {
if (grid_tiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
return false;
}
int64_t grid_x = grid_tiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : grid_tiles;
int64_t grid_y = 1;
int64_t grid_z = 1;
if (grid_tiles > MAX_GRID_SIZE) {
grid_tiles = ceil_div(grid_tiles, MAX_GRID_SIZE);
grid_y = grid_tiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : grid_tiles;
if (grid_tiles > MAX_GRID_SIZE) {
grid_tiles = ceil_div(grid_tiles, MAX_GRID_SIZE);
grid_z = grid_tiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : grid_tiles;
}
}
*grid = dim3(grid_x, grid_y, grid_z);
return true;
}
template <typename index_t>
__device__ __forceinline__ index_t GetLinearBlockId() {
return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x;
}
template <typename T, int Dims>
struct IndexToOffset {
static __device__ int64_t get(int64_t linear_id, const TensorLayoutHelper &info) {
int64_t offset = 0;
// Uses static dims
for (int i = Dims - 1; i > 0; --i) {
int64_t cur_dim_index = linear_id % info.sizes_[i];
int64_t cur_dim_offset = cur_dim_index * info.strides_[i];
offset += cur_dim_offset;
linear_id /= info.sizes_[i];
}
return offset + linear_id * info.strides_[0];
}
};
template <typename T>
struct IndexToOffset<T, -1> {
static inline __device__ int64_t get(int64_t linear_id, const TensorLayoutHelper &info) {
int64_t offset = 0;
for (int i = info.dim_size_ - 1; i > 0; --i) {
int64_t cur_dim_index = linear_id % info.sizes_[i];
int64_t cur_dim_offset = cur_dim_index * info.strides_[i];
offset += cur_dim_offset;
linear_id /= info.sizes_[i];
}
return offset + linear_id * info.strides_[0];
}
};
template <int kKeyDims, int kValueDims, int kBlockSize, int kItemsPerThread, typename K, typename V>
__global__ void __launch_bounds__(MS_MAX_THREAD_PER_BLOCK(kBlockSize))
RadixSortKVInPlace(TensorLayoutHelper keys, K *key_data, int64_t key_slices, int64_t key_slice_size,
int64_t key_slice_stride, TensorLayoutHelper values, V *value_data, int64_t value_slice_stride,
bool descending) {
static_assert(kBlockSize > 0, "");
// Find the slice of the tensor that we are sorting
const int64_t linearIndex = GetLinearBlockId<int64_t>();
// Tiling the slices could have us be out of bounds, if there are a
// lot of slices to sort
if (linearIndex >= key_slices) {
return;
}
const int64_t keyStartOffset = IndexToOffset<K, kKeyDims>::get(linearIndex, keys);
const int64_t valueStartOffset = IndexToOffset<V, kValueDims>::get(linearIndex, values);
K *keys_slice = &(key_data[keyStartOffset]);
V *values_slice = &(value_data[valueStartOffset]);
StridedPointer<K, int64_t> keys_iter(keys_slice, key_slice_stride);
StridedPointer<V, int64_t> values_iter(values_slice, value_slice_stride);
using LoadKeys = cub::BlockLoad<K, kBlockSize, kItemsPerThread, cub::BlockLoadAlgorithm::BLOCK_LOAD_TRANSPOSE>;
using LoadValues = cub::BlockLoad<V, kBlockSize, kItemsPerThread, cub::BlockLoadAlgorithm::BLOCK_LOAD_TRANSPOSE>;
using Sort = cub::BlockRadixSort<K, kBlockSize, kItemsPerThread, V>;
using StoreKeys = cub::BlockStore<K, kBlockSize, kItemsPerThread, cub::BLOCK_STORE_TRANSPOSE>;
using StoreValues = cub::BlockStore<V, kBlockSize, kItemsPerThread, cub::BLOCK_STORE_TRANSPOSE>;
__shared__ union {
typename LoadKeys::TempStorage load_keys;
typename LoadValues::TempStorage load_values;
typename Sort::TempStorage sort;
typename StoreKeys::TempStorage store_keys;
typename StoreValues::TempStorage store_values;
} tmp_storage;
// cub's Block operations operate on a fixed number of items, but the
// actual slice we are sorting might be smaller. So, we need to make
// up the difference with keys that will always sort higher.
const K invalid_key = [descending] {
using radix_t = typename cub::Traits<K>::UnsignedBits;
union {
K key;
radix_t radix;
} tmp;
tmp.radix = descending ? cub::Traits<K>::LOWEST_KEY : cub::Traits<K>::MAX_KEY;
return tmp.key;
}();
const V invalid_value = static_cast<V>(0);
// Load inputs
K local_keys[kItemsPerThread];
V local_values[kItemsPerThread];
LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, key_slice_size, invalid_key);
__syncthreads();
LoadValues(tmp_storage.load_values).Load(values_iter, local_values, key_slice_size, invalid_value);
__syncthreads();
// Sort!
if (descending) {
auto sorter = Sort(tmp_storage.sort);
sorter.SortDescending(reinterpret_cast<K(&)[kItemsPerThread]>(local_keys), local_values);
} else {
Sort(tmp_storage.sort).Sort(reinterpret_cast<K(&)[kItemsPerThread]>(local_keys), local_values);
}
__syncthreads();
// Store outputs
StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, key_slice_size);
__syncthreads();
StoreValues(tmp_storage.store_values).Store(values_iter, local_values, key_slice_size);
}
template <int A, int kSortSize, int kItemsPerThread, typename K, typename V>
CUDA_LIB_EXPORT bool SortFixedSize(const TensorLayoutHelper &key_info, K *key_data, int64_t key_slices,
int64_t key_slice_size, int64_t key_slice_stride,
const TensorLayoutHelper &value_info, V *value_data, int64_t value_slice_stride,
bool descending, cudaStream_t cuda_stream) {
static_assert(kSortSize % kItemsPerThread == 0, "SortSize mod ItemsPerThread should be equal to zero.");
constexpr int block = kSortSize / kItemsPerThread;
dim3 grid;
if (!GetGridFromTiles(key_slices, &grid)) {
fprintf(stderr, "GetGridFromTiles failed\n");
return false;
}
RadixSortKVInPlace<A, -1, block, kItemsPerThread><<<grid, block, 0, cuda_stream>>>(
key_info, key_data, key_slices, key_slice_size, key_slice_stride,
value_info, value_data, value_slice_stride, descending);
const cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) {
fprintf(stderr, "Call RadixSortKVInPlace failed\n");
return false;
}
return true;
}
#define SortFixedSizeSpec(A, kSortSize, kItemsPerThread, K, V) \
template CUDA_LIB_EXPORT bool SortFixedSize<A, kSortSize, kItemsPerThread, K, V>( \
const TensorLayoutHelper &key_info, K *key_data, int64_t key_slices, int64_t key_slice_size, \
int64_t key_slice_stride, const TensorLayoutHelper &value_info, V *value_data, int64_t value_slice_stride, \
bool descending, cudaStream_t cuda_stream)
#define SortFixedSizeSpecKV(K, V) \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLast, kFixedSizeLevel1, kFixedSizeLevel1ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLast, kFixedSizeLevel2, kFixedSizeLevel2ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLast, kFixedSizeLevel3, kFixedSizeLevel3ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLast, kFixedSizeLevel4, kFixedSizeLevel4ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLast, kFixedSizeLevel5, kFixedSizeLevel5ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLastSecond, kFixedSizeLevel1, kFixedSizeLevel1ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLastSecond, kFixedSizeLevel2, kFixedSizeLevel2ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLastSecond, kFixedSizeLevel3, kFixedSizeLevel3ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLastSecond, kFixedSizeLevel4, kFixedSizeLevel4ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsLastSecond, kFixedSizeLevel5, kFixedSizeLevel5ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsSecond, kFixedSizeLevel1, kFixedSizeLevel1ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsSecond, kFixedSizeLevel2, kFixedSizeLevel2ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsSecond, kFixedSizeLevel3, kFixedSizeLevel3ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsSecond, kFixedSizeLevel4, kFixedSizeLevel4ItemPreThread, K, V); \
SortFixedSizeSpec(kFixedSizeSortKeyDimsSecond, kFixedSizeLevel5, kFixedSizeLevel5ItemPreThread, K, V);
SortFixedSizeSpecKV(bool, int64_t);
SortFixedSizeSpecKV(int8_t, int64_t);
SortFixedSizeSpecKV(int16_t, int64_t);
SortFixedSizeSpecKV(int32_t, int64_t);
SortFixedSizeSpecKV(int64_t, int64_t);
SortFixedSizeSpecKV(uint8_t, int64_t);
SortFixedSizeSpecKV(half, int64_t);
SortFixedSizeSpecKV(float, int64_t);
SortFixedSizeSpecKV(double, int64_t);
SortFixedSizeSpecKV(bool, int32_t);
SortFixedSizeSpecKV(int8_t, int32_t);
SortFixedSizeSpecKV(int16_t, int32_t);
SortFixedSizeSpecKV(int32_t, int32_t);
SortFixedSizeSpecKV(int64_t, int32_t);
SortFixedSizeSpecKV(uint8_t, int32_t);
SortFixedSizeSpecKV(half, int32_t);
SortFixedSizeSpecKV(float, int32_t);
SortFixedSizeSpecKV(double, int32_t);

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_FIXED_SIZE_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_FIXED_SIZE_CUH_
#include <cuda_runtime.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tensor_layout_helper.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
constexpr int kFixedSizeLevel1 = 4096;
constexpr int kFixedSizeLevel2 = 2048;
constexpr int kFixedSizeLevel3 = 1024;
constexpr int kFixedSizeLevel4 = 128;
constexpr int kFixedSizeLevel5 = 32;
constexpr int kFixedSizeLevel1ItemPreThread = 32;
constexpr int kFixedSizeLevel2ItemPreThread = 32;
constexpr int kFixedSizeLevel3ItemPreThread = 32;
constexpr int kFixedSizeLevel4ItemPreThread = 4;
constexpr int kFixedSizeLevel5ItemPreThread = 2;
constexpr int kFixedSizeSortKeyDimsLast = -1;
constexpr int kFixedSizeSortKeyDimsSecond = 2;
constexpr int kFixedSizeSortKeyDimsLastSecond = -2;
template <int A, int sort_size, int items_per_thread, typename K, typename V>
CUDA_LIB_EXPORT bool SortFixedSize(const TensorLayoutHelper &key_info, K *key_data, int64_t key_slices,
int64_t key_slice_size, int64_t key_slice_stride,
const TensorLayoutHelper &value_info, V *value_data, int64_t value_slice_stride,
bool descending, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_FIXED_SIZE_CUH_

View File

@ -0,0 +1,167 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDE_POINTER_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDE_POINTER_CUH_
#include <limits.h>
#include <cuda_runtime.h>
#include <algorithm>
// ConstStridedPointer is a const random access iterator defined over a strided array.
template <typename T, typename index_t = int64_t>
class ConstStridedPointer {
public:
__device__ ConstStridedPointer() : ptr{nullptr}, stride{static_cast<index_t>(1)} {}
__device__ explicit ConstStridedPointer(T *ptr) : ptr{ptr}, stride{static_cast<index_t>(1)} {}
__device__ ConstStridedPointer(T *ptr, index_t stride) : ptr{ptr}, stride{stride} {}
// Pointer-like operations
__device__ const T &operator[](index_t idx) const { return ptr[idx * stride]; }
__device__ const T &operator*() const { return *ptr; }
__device__ const T *operator->() const { return reinterpret_cast<const T *>(ptr); }
// Prefix/postfix increment/decrement
__device__ ConstStridedPointer operator++(int) {
ConstStridedPointer copy(*this);
++*this;
return copy;
}
__device__ ConstStridedPointer &operator++() {
ptr += stride;
return *this;
}
__device__ ConstStridedPointer operator--(int) {
ConstStridedPointer copy(*this);
--*this;
return copy;
}
__device__ ConstStridedPointer &operator--() {
ptr -= stride;
return *this;
}
// Arithmetic operations
__device__ friend ConstStridedPointer operator+(index_t offset,
const ConstStridedPointer &accessor) {
return accessor + offset;
}
__device__ ConstStridedPointer operator+(index_t offset) const {
return ConstStridedPointer(ptr + offset * stride, stride);
}
__device__ ConstStridedPointer &operator+=(index_t offset) {
ptr += offset * stride;
return *this;
}
__device__ ConstStridedPointer operator-(index_t offset) const {
return ConstStridedPointer(ptr - offset * stride, stride);
}
__device__ ConstStridedPointer &operator-=(index_t offset) {
ptr -= offset * stride;
return *this;
}
__device__ index_t operator-(const ConstStridedPointer &other) const {
return (ptr - other.ptr) / stride;
}
// Comparison operators
__device__ bool operator>=(const ConstStridedPointer &other) const { return !(*this < other); }
__device__ bool operator>(const ConstStridedPointer &other) const { return !(*this <= other); }
__device__ bool operator<=(const ConstStridedPointer &other) const { return (*this < other) || (*this == other); }
__device__ bool operator<(const ConstStridedPointer &other) const { return ptr < other.ptr; }
__device__ bool operator!=(const ConstStridedPointer &other) const { return !(*this == other); }
__device__ bool operator==(const ConstStridedPointer &other) const {
return (ptr == other.ptr) && (stride == other.stride);
}
protected:
index_t stride;
T *ptr;
};
// StridedPointer is a random access iterator defined over a strided array.
template <typename T, typename index_t = int64_t>
class StridedPointer : public ConstStridedPointer<T, index_t> {
public:
__device__ explicit StridedPointer(T *ptr) : ConstStridedPointer<T, index_t>(ptr) {}
__device__ StridedPointer(T *ptr, index_t stride) : ConstStridedPointer<T, index_t>(ptr, stride) {}
__device__ StridedPointer() : ConstStridedPointer<T, index_t>() {}
// Pointer-like operations
__device__ T &operator[](index_t idx) const { return this->ptr[idx * this->stride]; }
__device__ T *operator->() const { return reinterpret_cast<T *>(this->ptr); }
__device__ T &operator*() const { return *this->ptr; }
// Prefix/postfix increment/decrement
__device__ StridedPointer operator++(int) {
StridedPointer copy(*this);
++*this;
return copy;
}
__device__ StridedPointer &operator++() {
this->ptr += this->stride;
return *this;
}
__device__ StridedPointer operator--(int) {
StridedPointer copy(*this);
--*this;
return copy;
}
__device__ StridedPointer &operator--() {
this->ptr -= this->stride;
return *this;
}
// Arithmetic operations
__device__ StridedPointer &operator-=(index_t offset) {
this->ptr -= offset * this->stride;
return *this;
}
__device__ StridedPointer operator-(index_t offset) const {
return StridedPointer(this->ptr - offset * this->stride, this->stride);
}
__device__ friend StridedPointer operator+(index_t offset, const StridedPointer &accessor) {
return accessor + offset;
}
__device__ StridedPointer operator+(index_t offset) const {
return StridedPointer(this->ptr + offset * this->stride, this->stride);
}
__device__ StridedPointer &operator+=(index_t offset) {
this->ptr += offset * this->stride;
return *this;
}
__device__ index_t operator-(const ConstStridedPointer<T, index_t> &other) const {
return (static_cast<const ConstStridedPointer<T, index_t> &>(*this) - other);
}
};
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDE_POINTER_CUH_

View File

@ -0,0 +1,128 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_LAYOUT_HELPER_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_LAYOUT_HELPER_CUH_
#include <cuda_runtime.h>
#include <climits>
#include <string>
#include <utility>
#include "ir/dtype/type_id.h"
#define MAX_TENSORINFO_DIMS 8
// CUDA kernel argument that defines tensor layout
struct TensorLayoutHelper {
TensorLayoutHelper(const int shape[MAX_TENSORINFO_DIMS], int dim_size) {
dim_size_ = dim_size;
if (dim_size_ >= MAX_TENSORINFO_DIMS) {
printf("dim_size_ >= MAX_TENSORINFO_DIMS.\n");
exit(1);
}
for (int i = 0; i < dim_size_; ++i) {
sizes_[i] = shape[i];
}
int64_t stride = 1;
for (int i = dim_size_ - 1; i >= 0; i--) {
strides_[i] = stride;
stride *= shape[i];
}
shape_size_ = stride;
}
static std::pair<int64_t, int64_t> CollapseDimsInner(int *sizes, int64_t *strides, int dim_size, int exclude_dim) {
int64_t stop_dim = (exclude_dim == -1) ? dim_size : exclude_dim;
int64_t new_index = -1;
int64_t old_index = 0;
int64_t remapped_excluded_dim = -1;
while (old_index < dim_size) {
for (; old_index < stop_dim; ++old_index) {
if (sizes[old_index] == 1) {
continue;
}
++new_index;
sizes[new_index] = sizes[old_index];
strides[new_index] = strides[old_index];
++old_index;
break;
}
for (; old_index < stop_dim; ++old_index) {
if (sizes[old_index] == 1) {
continue;
}
if (strides[new_index] == sizes[old_index] * strides[old_index]) {
sizes[new_index] *= sizes[old_index];
strides[new_index] = strides[old_index];
} else {
++new_index;
sizes[new_index] = sizes[old_index];
strides[new_index] = strides[old_index];
}
}
if (old_index != dim_size) {
++new_index;
sizes[new_index] = sizes[old_index];
strides[new_index] = strides[old_index];
remapped_excluded_dim = new_index;
++old_index;
stop_dim = dim_size;
}
}
if (new_index == -1 || (new_index == 0 && sizes[0] == 1)) {
dim_size = 1;
sizes[0] = 1;
strides[0] = 1;
return std::pair<int64_t, int64_t>(0, 1);
}
dim_size = new_index + 1;
return std::pair<int64_t, int64_t>(remapped_excluded_dim, dim_size);
}
inline int CollapseDims(int exclude_dim = -1) {
if (exclude_dim < 0) {
exclude_dim += dim_size_;
}
if (exclude_dim >= dim_size_ || exclude_dim < 0) {
printf("dim out of range of dim_size_.\n");
exit(1);
}
auto result = CollapseDimsInner(sizes_, strides_, dim_size_, exclude_dim);
dim_size_ = result.second;
return result.first;
}
// Contiguous tensors of more than one dimension are collapsed down to one tensor
inline bool IsContiguous() const {
return (dim_size_ == 1 && strides_[0] == 1);
}
int sizes_[MAX_TENSORINFO_DIMS];
int64_t strides_[MAX_TENSORINFO_DIMS];
int dim_size_{0};
int64_t shape_size_{0};
};
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_LAYOUT_HELPER_CUH_

View File

@ -52,6 +52,11 @@ abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std:
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (IsDynamicRank(x_shape)) {
auto unknown_shape_ptr = std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{unknown_shape_ptr, unknown_shape_ptr});
}
auto x_rank = SizeToLong(x_shape.size());
auto axis = GetValue<int64_t>(primitive->GetAttr("axis"));
CheckAndConvertUtils::CheckInRange<int64_t>("axis", axis, kIncludeBoth, {-x_rank, x_rank - 1}, prim_name);
@ -66,7 +71,7 @@ abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std:
TuplePtr SortInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kUInt8, kInt8, kInt16, kInt32, kInt64};
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kInt8, kInt16, kInt32, kInt64, kBool};
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, primitive->name());
std::vector<TypePtr> type_tuple;
type_tuple.push_back(type);