!39749 optimize unique_consecutive

Merge pull request !39749 from huangbingjian/opt_unique_consecutive
This commit is contained in:
i-robot 2022-08-08 08:43:40 +00:00 committed by Gitee
commit a028a42c88
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 190 additions and 187 deletions

View File

@ -68,21 +68,21 @@ class UniqueConsecutiveHelperGpuKernel : public UniqueConsecutiveHelperBase {
size_t input_tensor_size = input_size_list_[0];
size_t elements_size = num_elements_ * sizeof(S);
size_t elements_plus_one_size = (num_elements_ + 1) * sizeof(S);
// input_index workspace
work_size_list_.emplace_back(elements_size);
// sorted_index workspace
work_size_list_.emplace_back(elements_size);
// range_data workspace
work_size_list_.emplace_back(elements_plus_one_size);
// indices_data workspace
work_size_list_.emplace_back(input_tensor_size);
// Transpose scalar workspace
work_size_list_.emplace_back(input_shape_.size() * sizeof(size_t));
work_size_list_.emplace_back(input_shape_.size() * sizeof(size_t));
output_size_list_.emplace_back(input_tensor_size);
if (return_idx()) {
output_size_list_.emplace_back(elements_size);
} else {
output_size_list_.emplace_back(0);
}
if (return_counts()) {
output_size_list_.emplace_back(elements_size);
} else {
output_size_list_.emplace_back(0);
}
output_size_list_.emplace_back(elements_size);
output_size_list_.emplace_back(elements_size);
return 0;
}
@ -93,49 +93,60 @@ class UniqueConsecutiveHelperGpuKernel : public UniqueConsecutiveHelperBase {
S *s_sorted_index = nullptr;
S *s_range_data = nullptr;
T *t_indices_data = nullptr;
size_t *dev_input_shape = nullptr;
size_t *dev_input_axis = nullptr;
T *t_output_ptr = nullptr;
S *s_output_index = nullptr;
S *s_output_counts = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &t_input_ptr);
int flag = GetDeviceAddress<T>(input_ptrs, kIndex0, kernel_name_, &t_input_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(work_ptrs, 0, kernel_name_, &s_input_index);
flag = GetDeviceAddress<S>(work_ptrs, kIndex0, kernel_name_, &s_input_index);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(work_ptrs, 1, kernel_name_, &s_sorted_index);
flag = GetDeviceAddress<S>(work_ptrs, kIndex1, kernel_name_, &s_sorted_index);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(work_ptrs, 2, kernel_name_, &s_range_data);
flag = GetDeviceAddress<S>(work_ptrs, kIndex2, kernel_name_, &s_range_data);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(work_ptrs, 3, kernel_name_, &t_indices_data);
flag = GetDeviceAddress<T>(work_ptrs, kIndex3, kernel_name_, &t_indices_data);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &t_output_ptr);
flag = GetDeviceAddress<size_t>(work_ptrs, kIndex4, kernel_name_, &dev_input_shape);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<size_t>(work_ptrs, kIndex5, kernel_name_, &dev_input_axis);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, kIndex0, kernel_name_, &t_output_ptr);
if (flag != 0) {
return flag;
}
if (return_idx()) {
flag = GetDeviceAddress<S>(output_ptrs, 1, kernel_name_, &s_output_index);
flag = GetDeviceAddress<S>(output_ptrs, kIndex1, kernel_name_, &s_output_index);
if (flag != 0) {
return flag;
}
}
if (return_counts()) {
flag = GetDeviceAddress<S>(output_ptrs, 2, kernel_name_, &s_output_counts);
flag = GetDeviceAddress<S>(output_ptrs, kIndex2, kernel_name_, &s_output_counts);
if (flag != 0) {
return flag;
}
}
post_output_size_ = CalUniqueConsecutive(
t_input_ptr, num_elements_, input_shape_, is_flattend(), axis(), s_input_index, s_sorted_index, s_range_data,
t_indices_data, t_output_ptr, s_output_index, s_output_counts, reinterpret_cast<cudaStream_t>(cuda_stream));
post_output_size_ =
CalUniqueConsecutive(t_input_ptr, num_elements_, input_shape_, is_flattend(), axis(), s_input_index,
s_sorted_index, s_range_data, t_indices_data, dev_input_shape, dev_input_axis, t_output_ptr,
s_output_index, s_output_counts, reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}

View File

@ -29,6 +29,7 @@
#include <vector>
#include "unique_consecutive_impl.cuh"
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
template <typename T>
struct BinaryEqual {
@ -68,72 +69,42 @@ struct BinaryNotEqual {
}
};
std::vector<int> RemoveElementsIndex(const std::vector<int> &indices, int size,
const std::vector<std::vector<size_t>> &pos_map, int64_t axis) {
int num_elements = pos_map.size();
std::vector<int> output(num_elements, 1);
for (int pos = 0; pos < num_elements; ++pos) {
// Check the axis of the element.
int element_index = pos_map[pos][axis];
for (int i = 0; i < size; ++i) {
if (element_index == indices[i]) {
output[pos] = 0;
template <typename S>
struct IndexToAxis {
int num_elements;
int64_t axis;
const size_t *input_shape;
const S *range_data;
int range_size;
IndexToAxis(int _num_elements, int64_t _axis, const size_t *_input_shape, const S *_range_data, int _range_size)
: num_elements(_num_elements),
axis(_axis),
input_shape(_input_shape),
range_data(_range_data),
range_size(_range_size) {}
__device__ S operator()(S pos) const {
size_t pos_size = num_elements / input_shape[0];
size_t last_axis = pos / pos_size;
for (size_t i = 1; i <= axis; ++i) {
pos -= last_axis * pos_size;
pos_size /= input_shape[i];
last_axis = pos / pos_size;
}
for (size_t k = 0; k < range_size; k++) {
if (last_axis == range_data[k]) {
return 0;
}
}
return 1;
}
return output;
}
std::vector<std::vector<size_t>> GetPositionArray(int num_elements, const std::vector<int64_t> &input_shape) {
std::vector<std::vector<size_t>> pos_map(num_elements);
size_t shape_size = input_shape.size();
size_t temp_pos;
size_t pos_size;
for (int pos = 0; pos < num_elements; ++pos) {
std::vector<size_t> array(shape_size, 0);
temp_pos = pos;
pos_size = num_elements / input_shape[0];
array[0] = temp_pos / pos_size;
for (size_t i = 1; i < shape_size; ++i) {
temp_pos -= array[i - 1] * pos_size;
pos_size = pos_size / input_shape[i];
array[i] = temp_pos / pos_size;
}
pos_map[pos] = array;
}
return pos_map;
}
std::vector<int64_t> GetTransposeIndices(int num_elements, const std::vector<int64_t> &input_shape,
const std::vector<std::vector<size_t>> &pos_map, int64_t axis) {
// Get transpose axis.
size_t shape_size = input_shape.size();
int64_t cnt = 0;
std::vector<int64_t> transpose_axis(shape_size, 0);
std::generate(transpose_axis.begin(), transpose_axis.end(), [&] { return cnt++; });
transpose_axis[0] = axis;
transpose_axis[axis] = 0;
// Do Transpose
std::vector<int64_t> output_indices(num_elements, 0);
for (int pos = 0; pos < num_elements; pos++) {
auto pos_array = pos_map[pos];
size_t new_pos = pos_array[transpose_axis[shape_size - 1]];
size_t new_pos_size = 1;
for (int64_t i = shape_size - 2; i >= 0; i--) {
new_pos_size *= input_shape[transpose_axis[i + 1]];
new_pos += pos_array[transpose_axis[i]] * new_pos_size;
}
output_indices[new_pos] = pos;
}
return output_indices;
}
};
template <typename T, typename S>
std::vector<std::vector<int>> ComputeUniqueConsecutiveFlattend(const T *input, int num_elements,
const std::vector<int64_t> &input_shape, S *input_index,
S *sorted_index, S *range_data, T *output, S *index,
S *counts, cudaStream_t cuda_stream) {
std::vector<std::vector<int>> ComputeUniqueConsecutive(const T *input, int num_elements,
const std::vector<int64_t> &input_shape, S *range_data,
T *output, S *index, S *counts, cudaStream_t cuda_stream) {
auto policy = thrust::cuda::par.on(cuda_stream);
std::vector<std::vector<int>> out_shapes;
// Copy input to output.
@ -141,80 +112,88 @@ std::vector<std::vector<int>> ComputeUniqueConsecutiveFlattend(const T *input, i
thrust::device_pointer_cast(output));
// Inverse indices.
thrust::adjacent_difference(policy, thrust::device_pointer_cast(output),
thrust::device_pointer_cast(output) + num_elements,
thrust::device_pointer_cast(input_index), thrust::not_equal_to<T>());
thrust::fill(policy, thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + 1, 0);
thrust::inclusive_scan(policy, thrust::device_pointer_cast(input_index),
thrust::device_pointer_cast(input_index) + num_elements,
thrust::device_pointer_cast(input_index));
if (index != nullptr) {
thrust::sequence(policy, thrust::device_pointer_cast(sorted_index),
thrust::device_pointer_cast(sorted_index) + num_elements);
thrust::scatter(policy, thrust::device_pointer_cast(input_index),
thrust::device_pointer_cast(input_index) + num_elements, thrust::device_pointer_cast(sorted_index),
thrust::device_pointer_cast(index));
std::vector<int> idx_shape(input_shape.begin(), input_shape.end());
out_shapes.emplace_back(idx_shape);
} else {
if (index == nullptr || num_elements == 0) {
std::vector<int> idx_shape = {0};
out_shapes.emplace_back(idx_shape);
} else {
thrust::adjacent_difference(policy, thrust::device_pointer_cast(output),
thrust::device_pointer_cast(output) + num_elements, thrust::device_pointer_cast(index),
thrust::not_equal_to<T>());
thrust::fill(policy, thrust::device_pointer_cast(index), thrust::device_pointer_cast(index) + 1, 0);
thrust::inclusive_scan(policy, thrust::device_pointer_cast(index),
thrust::device_pointer_cast(index) + num_elements, thrust::device_pointer_cast(index));
std::vector<int> idx_shape(input_shape.begin(), input_shape.end());
out_shapes.emplace_back(idx_shape);
}
// Unique.
thrust::sequence(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_elements + 1);
int output_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(output),
thrust::device_pointer_cast(output) + num_elements,
thrust::device_pointer_cast(range_data), thrust::equal_to<T>())
.first -
thrust::device_pointer_cast(output);
std::vector<int> output_shape = {output_size};
out_shapes.insert(out_shapes.begin(), output_shape);
// Count.
if (counts != nullptr) {
// Unique and count.
int output_size = num_elements;
if (counts == nullptr || num_elements == 0) {
output_size = thrust::unique(policy, thrust::device_pointer_cast(output),
thrust::device_pointer_cast(output) + num_elements, thrust::equal_to<T>()) -
thrust::device_pointer_cast(output);
std::vector<int> counts_shape = {output_size};
out_shapes.emplace_back(counts_shape);
} else {
thrust::sequence(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_elements + 1);
output_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(output),
thrust::device_pointer_cast(output) + num_elements,
thrust::device_pointer_cast(range_data), thrust::equal_to<T>())
.first -
thrust::device_pointer_cast(output);
thrust::fill(policy, thrust::device_pointer_cast(range_data) + output_size,
thrust::device_pointer_cast(range_data) + output_size + 1, num_elements);
thrust::sequence(policy, thrust::device_pointer_cast(counts), thrust::device_pointer_cast(counts) + num_elements);
thrust::adjacent_difference(policy, thrust::device_pointer_cast(range_data) + 1,
thrust::device_pointer_cast(range_data) + output_size + 1, counts);
std::vector<int> counts_shape = {output_size};
out_shapes.emplace_back(counts_shape);
} else {
std::vector<int> counts_shape = {0};
out_shapes.emplace_back(counts_shape);
}
std::vector<int> output_shape = {output_size};
out_shapes.insert(out_shapes.begin(), output_shape);
return out_shapes;
}
template <typename T, typename S>
std::vector<std::vector<int>> ComputeUniqueConsecutiveByAxis(const T *input, int num_elements,
const std::vector<int64_t> &input_shape, bool is_flattend,
int64_t axis, S *input_index, S *sorted_index,
S *range_data, T *indices_data, T *output, S *index,
S *counts, cudaStream_t cuda_stream) {
const std::vector<int64_t> &input_shape, int64_t axis,
S *input_index, S *sorted_index, S *range_data,
T *indices_data, size_t *dev_input_shape,
size_t *dev_input_axis, T *output, S *index, S *counts,
cudaStream_t cuda_stream) {
// Compute UniqueConsecutive by axis.
auto policy = thrust::cuda::par.on(cuda_stream);
std::vector<std::vector<int>> out_shapes;
// Do transpose.
int64_t num_inp = input_shape[axis];
int64_t n = num_elements / num_inp;
thrust::copy(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + num_elements,
thrust::device_pointer_cast(output));
std::vector<std::vector<size_t>> pos_map = GetPositionArray(num_elements, input_shape);
std::vector<int64_t> transpose_indices = GetTransposeIndices(num_elements, input_shape, pos_map, axis);
thrust::device_vector<int> indices_map(transpose_indices.begin(), transpose_indices.end());
thrust::gather(policy, indices_map.begin(), indices_map.end(), thrust::device_pointer_cast(input),
thrust::device_pointer_cast(indices_data));
// Do transpose.
size_t shape_size = input_shape.size();
cudaMemcpyAsync(dev_input_shape, input_shape.data(), sizeof(size_t) * shape_size, cudaMemcpyHostToDevice,
cuda_stream);
// Used for transpose: dev_input_axis={0, 1, ..., axis, ...} -> dev_input_axis[0]=axis, dev_input_axis[axis]=0
thrust::sequence(policy, thrust::device_pointer_cast(dev_input_axis),
thrust::device_pointer_cast(dev_input_axis) + shape_size);
thrust::fill(policy, thrust::device_pointer_cast(dev_input_axis), thrust::device_pointer_cast(dev_input_axis) + 1,
axis);
thrust::fill(policy, thrust::device_pointer_cast(dev_input_axis) + axis,
thrust::device_pointer_cast(dev_input_axis) + axis + 1, 0);
CalTranspose(num_elements, input, dev_input_shape, dev_input_axis, shape_size, indices_data, cuda_stream);
// Inverse indices.
int64_t num_inp = input_shape[axis];
int64_t n = num_elements / num_inp;
thrust::sequence(policy, thrust::device_pointer_cast(range_data), thrust::device_pointer_cast(range_data) + num_inp);
thrust::adjacent_difference(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_inp,
thrust::device_pointer_cast(input_index), BinaryNotEqual<T>(n, indices_data));
thrust::fill(policy, thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + 1, 0);
thrust::inclusive_scan(policy, thrust::device_pointer_cast(input_index),
thrust::device_pointer_cast(input_index) + num_inp, thrust::device_pointer_cast(input_index));
if (index != nullptr) {
if (index == nullptr || num_elements == 0) {
std::vector<int> idx_shape = {0};
out_shapes.emplace_back(idx_shape);
} else {
thrust::adjacent_difference(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_inp,
thrust::device_pointer_cast(input_index), BinaryNotEqual<T>(n, indices_data));
thrust::fill(policy, thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + 1, 0);
thrust::inclusive_scan(policy, thrust::device_pointer_cast(input_index),
thrust::device_pointer_cast(input_index) + num_inp,
thrust::device_pointer_cast(input_index));
thrust::sequence(policy, thrust::device_pointer_cast(sorted_index),
thrust::device_pointer_cast(sorted_index) + num_inp);
thrust::scatter(policy, thrust::device_pointer_cast(input_index),
@ -223,33 +202,24 @@ std::vector<std::vector<int>> ComputeUniqueConsecutiveByAxis(const T *input, int
std::vector<int> idx_shape;
idx_shape.push_back(num_inp);
out_shapes.emplace_back(idx_shape);
} else {
std::vector<int> idx_shape = {0};
out_shapes.emplace_back(idx_shape);
}
// Unique.
thrust::sequence(policy, thrust::device_pointer_cast(sorted_index),
thrust::device_pointer_cast(sorted_index) + num_inp + 1);
int indices_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_inp,
thrust::device_pointer_cast(sorted_index), BinaryEqual<T>(n, indices_data))
.first -
thrust::device_pointer_cast(range_data);
std::vector<int> indices(indices_size);
cudaMemcpyAsync(indices.data(), range_data, indices_size * sizeof(int), cudaMemcpyDeviceToHost, cuda_stream);
cudaStreamSynchronize(cuda_stream);
std::vector<int> elements_index = RemoveElementsIndex(indices, indices_size, pos_map, axis);
thrust::device_vector<int> remove_map(elements_index.begin(), elements_index.end());
thrust::remove_if(thrust::device_pointer_cast(output), thrust::device_pointer_cast(output) + num_elements,
remove_map.begin(), thrust::identity<T>());
std::vector<int> output_shape(input_shape.begin(), input_shape.end());
output_shape[axis] = indices_size;
out_shapes.insert(out_shapes.begin(), output_shape);
// Count.
if (counts != nullptr) {
// Unique and count.
int indices_size = num_inp;
if (counts == nullptr || num_elements == 0) {
indices_size = thrust::unique(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_inp, BinaryEqual<T>(n, indices_data)) -
thrust::device_pointer_cast(range_data);
std::vector<int> counts_shape = {0};
out_shapes.emplace_back(counts_shape);
} else {
thrust::sequence(policy, thrust::device_pointer_cast(sorted_index),
thrust::device_pointer_cast(sorted_index) + num_inp + 1);
indices_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(range_data),
thrust::device_pointer_cast(range_data) + num_inp,
thrust::device_pointer_cast(sorted_index), BinaryEqual<T>(n, indices_data))
.first -
thrust::device_pointer_cast(range_data);
thrust::fill(policy, thrust::device_pointer_cast(sorted_index) + indices_size,
thrust::device_pointer_cast(sorted_index) + indices_size + 1, num_inp);
thrust::sequence(policy, thrust::device_pointer_cast(counts), thrust::device_pointer_cast(counts) + num_inp);
@ -257,40 +227,51 @@ std::vector<std::vector<int>> ComputeUniqueConsecutiveByAxis(const T *input, int
thrust::device_pointer_cast(sorted_index) + num_inp + 1, counts);
std::vector<int> counts_shape = {indices_size};
out_shapes.emplace_back(counts_shape);
} else {
std::vector<int> counts_shape = {0};
out_shapes.emplace_back(counts_shape);
}
// Remove invalid dimensions according to indices, reshape the output.
std::vector<int> output_shape(input_shape.begin(), input_shape.end());
if (indices_size != num_inp) {
thrust::sequence(policy, thrust::device_pointer_cast(input_index),
thrust::device_pointer_cast(input_index) + num_elements);
thrust::transform(thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + num_elements,
thrust::device_pointer_cast(input_index),
IndexToAxis<S>(num_elements, axis, dev_input_shape, range_data, indices_size));
thrust::remove_if(thrust::device_pointer_cast(output), thrust::device_pointer_cast(output) + num_elements,
input_index, thrust::identity<T>());
output_shape[axis] = indices_size;
}
out_shapes.insert(out_shapes.begin(), output_shape);
return out_shapes;
}
template <typename T, typename S>
std::vector<std::vector<int>> CalUniqueConsecutive(const T *input, int num_elements,
const std::vector<int64_t> &input_shape, bool is_flattend,
const std::vector<int64_t> &input_shape, bool is_axis_none,
int64_t axis, S *input_index, S *sorted_index, S *range_data,
T *indices_data, T *output, S *index, S *counts,
cudaStream_t cuda_stream) {
if (is_flattend) {
return ComputeUniqueConsecutiveFlattend(input, num_elements, input_shape, input_index, sorted_index, range_data,
output, index, counts, cuda_stream);
T *indices_data, size_t *dev_input_shape, size_t *dev_input_axis,
T *output, S *index, S *counts, cudaStream_t cuda_stream) {
if (is_axis_none) {
return ComputeUniqueConsecutive(input, num_elements, input_shape, range_data, output, index, counts, cuda_stream);
}
return ComputeUniqueConsecutiveByAxis(input, num_elements, input_shape, is_flattend, axis, input_index, sorted_index,
range_data, indices_data, output, index, counts, cuda_stream);
return ComputeUniqueConsecutiveByAxis(input, num_elements, input_shape, axis, input_index, sorted_index, range_data,
indices_data, dev_input_shape, dev_input_axis, output, index, counts,
cuda_stream);
}
template CUDA_LIB_EXPORT std::vector<std::vector<int>> CalUniqueConsecutive<float, int>(
const float *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_flattend, int64_t axis,
int *input_index, int *sorted_index, int *range_data, float *indices_data, float *output, int *index, int *counts,
cudaStream_t cuda_stream);
const float *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_axis_none, int64_t axis,
int *input_index, int *sorted_index, int *range_data, float *indices_data, size_t *dev_input_shape,
size_t *dev_input_axis, float *output, int *index, int *counts, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT std::vector<std::vector<int>> CalUniqueConsecutive<half, int>(
const half *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_flattend, int64_t axis,
int *input_index, int *sorted_index, int *range_data, half *indices_data, half *output, int *index, int *counts,
cudaStream_t cuda_stream);
const half *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_axis_none, int64_t axis,
int *input_index, int *sorted_index, int *range_data, half *indices_data, size_t *dev_input_shape,
size_t *dev_input_axis, half *output, int *index, int *counts, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT std::vector<std::vector<int>> CalUniqueConsecutive<int, int>(
const int *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_flattend, int64_t axis,
int *input_index, int *sorted_index, int *range_data, int *indices_data, int *output, int *index, int *counts,
cudaStream_t cuda_stream);
const int *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_axis_none, int64_t axis,
int *input_index, int *sorted_index, int *range_data, int *indices_data, size_t *dev_input_shape,
size_t *dev_input_axis, int *output, int *index, int *counts, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT std::vector<std::vector<int>> CalUniqueConsecutive<int64_t, int64_t>(
const int64_t *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_flattend, int64_t axis,
int64_t *input_index, int64_t *sorted_index, int64_t *range_data, int64_t *indices_data, int64_t *output,
int64_t *index, int64_t *counts, cudaStream_t cuda_stream);
const int64_t *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_axis_none, int64_t axis,
int64_t *input_index, int64_t *sorted_index, int64_t *range_data, int64_t *indices_data, size_t *dev_input_shape,
size_t *dev_input_axis, int64_t *output, int64_t *index, int64_t *counts, cudaStream_t cuda_stream);

View File

@ -20,10 +20,8 @@
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T, typename S>
CUDA_LIB_EXPORT std::vector<std::vector<int>> CalUniqueConsecutive(const T *input, int num_elements,
const std::vector<int64_t> &input_shape,
bool is_flattend, int64_t axis, S *input_index,
S *sorted_index, S *range_data, T *indices_data,
T *output, S *index, S *count,
cudaStream_t cuda_stream);
CUDA_LIB_EXPORT std::vector<std::vector<int>> CalUniqueConsecutive(
const T *input, int num_elements, const std::vector<int64_t> &input_shape, bool is_axis_none, int64_t axis,
S *input_index, S *sorted_index, S *range_data, T *indices_data, size_t *dev_input_shape, size_t *dev_input_axis,
T *output, S *index, S *count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNIQUE_CONSECUTIVE_IMPL_CUH_

View File

@ -31,11 +31,24 @@ namespace {
constexpr int64_t kUniqueConsecutiveInputNum = 1;
// For aicpu, if axis is 1000, that represents None.
constexpr int64_t kAxisIsNone = 1000;
bool CheckNullInput(const std::vector<int64_t> &shape) {
if (shape.size() != 0) {
if (std::any_of(shape.begin(), shape.end(), [](int64_t i) { return i == 0; })) {
return true;
}
}
return false;
}
abstract::BaseShapePtr UniqueConsecutiveInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
auto input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_shape_vec = input_shape_map[kShape];
if (CheckNullInput(input_shape_vec)) {
MS_LOG(EXCEPTION) << "For " << op_name << ", the shape of input cannot contain zero.";
}
auto axis_ptr = primitive->GetAttr(kAxis);
MS_EXCEPTION_IF_NULL(axis_ptr);