!47679 fix bug for sort kernel while compile windows

Merge pull request !47679 from hangq/master
This commit is contained in:
i-robot 2023-01-10 02:37:21 +00:00 committed by Gitee
commit 169bbd1639
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 27 additions and 18 deletions

View File

@ -89,7 +89,7 @@ bool SegSort(const TensorLayoutHelper &key_info, K *key_data, int64_t key_slices
}
template <typename K>
CUDA_LIB_EXPORT bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaStream_t cuda_stream) {
bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaStream_t cuda_stream) {
if (t.shape_size_ <= 0) {
return true;
}
@ -134,22 +134,31 @@ CUDA_LIB_EXPORT bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis,
out_size[i] = 1;
}
BroadcastTo<K>(in_size[0], in_size[1], in_size[2], in_size[3], in_size[4], in_size[5], in_size[6], in_size[7],
out_size[0], out_size[1], out_size[2], out_size[3], out_size[4], out_size[5], out_size[6], out_size[7],
slice_data_device, data, cuda_stream);
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
constexpr size_t kIndex6 = 6;
constexpr size_t kIndex7 = 7;
BroadcastTo<K>(in_size[kIndex0], in_size[kIndex1], in_size[kIndex2], in_size[kIndex3], in_size[kIndex4],
in_size[kIndex5], in_size[kIndex6], in_size[kIndex7], out_size[kIndex0], out_size[kIndex1],
out_size[kIndex2], out_size[kIndex3], out_size[kIndex4], out_size[kIndex5], out_size[kIndex6],
out_size[kIndex7], slice_data_device, data, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(slice_data_device), "Free slice data failed.");
return true;
}
template CUDA_LIB_EXPORT bool InitIndexBySlice<int64_t>(const TensorLayoutHelper &t, int64_t axis, int64_t *data,
cudaStream_t cuda_stream);
template bool InitIndexBySlice<int64_t>(const TensorLayoutHelper &t, int64_t axis, int64_t *data,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT bool InitIndexBySlice<int32_t>(const TensorLayoutHelper &t, int64_t axis, int32_t *data,
cudaStream_t cuda_stream);
template bool InitIndexBySlice<int32_t>(const TensorLayoutHelper &t, int64_t axis, int32_t *data,
cudaStream_t cuda_stream);
template <typename K, typename V>
CUDA_LIB_EXPORT bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value,
V *value_data, int64_t axis, bool descending, cudaStream_t cuda_stream) {
bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value, V *value_data,
int64_t axis, bool descending, cudaStream_t cuda_stream) {
if (key.dim_size_ != value.dim_size_) {
MS_LOG(ERROR) << "dim_size of key(" << key.dim_size_ << ") should be equal to dim_size of value(" << value.dim_size_
<< ").";
@ -194,8 +203,9 @@ CUDA_LIB_EXPORT bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_d
if (key_info.IsContiguous()) {
HANDLE_SORT_CASE(int64_t, kFixedSizeSortKeyDimsLastSecond);
} else {
constexpr int kDimSize = 2;
switch (key_info.dim_size_) {
case 2: // if sort dim == -1:
case kDimSize: // if sort dim == -1:
HANDLE_SORT_CASE(unsigned int, kFixedSizeSortKeyDimsSecond);
default: // if sort dim != -1:
HANDLE_SORT_CASE(unsigned int, kFixedSizeSortKeyDimsLast);
@ -204,10 +214,9 @@ CUDA_LIB_EXPORT bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_d
#undef HANDLE_SORT_CASE
}
#define SortKeyValueInplace(K, V) \
template CUDA_LIB_EXPORT bool SortKeyValueInplace<K, V>(const TensorLayoutHelper &key, K *key_data, \
const TensorLayoutHelper &value, V *value_data, \
int64_t axis, bool descending, cudaStream_t cuda_stream);
#define SortKeyValueInplace(K, V) \
template bool SortKeyValueInplace<K, V>(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value, \
V *value_data, int64_t axis, bool descending, cudaStream_t cuda_stream);
SortKeyValueInplace(bool, int64_t);
SortKeyValueInplace(int8_t, int64_t);

View File

@ -22,10 +22,10 @@
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename K>
CUDA_LIB_EXPORT bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaStream_t cuda_stream);
bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaStream_t cuda_stream);
template <typename K, typename V>
CUDA_LIB_EXPORT bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value,
V *value_data, int64_t axis, bool descending, cudaStream_t cuda_stream);
bool SortKeyValueInplace(const TensorLayoutHelper &key, K *key_data, const TensorLayoutHelper &value, V *value_data,
int64_t axis, bool descending, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SORT_KEY_VALUE_INPLACE_CUH_