diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unique_consecutive_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unique_consecutive_helper.h index dd3704f14d2..40dbe738467 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unique_consecutive_helper.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/unique_consecutive_helper.h @@ -68,21 +68,21 @@ class UniqueConsecutiveHelperGpuKernel : public UniqueConsecutiveHelperBase { size_t input_tensor_size = input_size_list_[0]; size_t elements_size = num_elements_ * sizeof(S); size_t elements_plus_one_size = (num_elements_ + 1) * sizeof(S); + // input_index workspace work_size_list_.emplace_back(elements_size); + // sorted_index workspace work_size_list_.emplace_back(elements_size); + // range_data workspace work_size_list_.emplace_back(elements_plus_one_size); + // indices_data workspace work_size_list_.emplace_back(input_tensor_size); + // Transpose scalar workspace + work_size_list_.emplace_back(input_shape_.size() * sizeof(size_t)); + work_size_list_.emplace_back(input_shape_.size() * sizeof(size_t)); + output_size_list_.emplace_back(input_tensor_size); - if (return_idx()) { - output_size_list_.emplace_back(elements_size); - } else { - output_size_list_.emplace_back(0); - } - if (return_counts()) { - output_size_list_.emplace_back(elements_size); - } else { - output_size_list_.emplace_back(0); - } + output_size_list_.emplace_back(elements_size); + output_size_list_.emplace_back(elements_size); return 0; } @@ -93,49 +93,60 @@ class UniqueConsecutiveHelperGpuKernel : public UniqueConsecutiveHelperBase { S *s_sorted_index = nullptr; S *s_range_data = nullptr; T *t_indices_data = nullptr; + size_t *dev_input_shape = nullptr; + size_t *dev_input_axis = nullptr; T *t_output_ptr = nullptr; S *s_output_index = nullptr; S *s_output_counts = nullptr; - int flag = GetDeviceAddress(input_ptrs, 0, kernel_name_, &t_input_ptr); + int flag = GetDeviceAddress(input_ptrs, kIndex0, kernel_name_, &t_input_ptr); if (flag != 0) { return flag; } - flag = GetDeviceAddress(work_ptrs, 0, kernel_name_, &s_input_index); + flag = GetDeviceAddress(work_ptrs, kIndex0, kernel_name_, &s_input_index); if (flag != 0) { return flag; } - flag = GetDeviceAddress(work_ptrs, 1, kernel_name_, &s_sorted_index); + flag = GetDeviceAddress(work_ptrs, kIndex1, kernel_name_, &s_sorted_index); if (flag != 0) { return flag; } - flag = GetDeviceAddress(work_ptrs, 2, kernel_name_, &s_range_data); + flag = GetDeviceAddress(work_ptrs, kIndex2, kernel_name_, &s_range_data); if (flag != 0) { return flag; } - flag = GetDeviceAddress(work_ptrs, 3, kernel_name_, &t_indices_data); + flag = GetDeviceAddress(work_ptrs, kIndex3, kernel_name_, &t_indices_data); if (flag != 0) { return flag; } - flag = GetDeviceAddress(output_ptrs, 0, kernel_name_, &t_output_ptr); + flag = GetDeviceAddress(work_ptrs, kIndex4, kernel_name_, &dev_input_shape); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(work_ptrs, kIndex5, kernel_name_, &dev_input_axis); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(output_ptrs, kIndex0, kernel_name_, &t_output_ptr); if (flag != 0) { return flag; } if (return_idx()) { - flag = GetDeviceAddress(output_ptrs, 1, kernel_name_, &s_output_index); + flag = GetDeviceAddress(output_ptrs, kIndex1, kernel_name_, &s_output_index); if (flag != 0) { return flag; } } if (return_counts()) { - flag = GetDeviceAddress(output_ptrs, 2, kernel_name_, &s_output_counts); + flag = GetDeviceAddress(output_ptrs, kIndex2, kernel_name_, &s_output_counts); if (flag != 0) { return flag; } } - post_output_size_ = CalUniqueConsecutive( - t_input_ptr, num_elements_, input_shape_, is_flattend(), axis(), s_input_index, s_sorted_index, s_range_data, - t_indices_data, t_output_ptr, s_output_index, s_output_counts, reinterpret_cast(cuda_stream)); + post_output_size_ = + CalUniqueConsecutive(t_input_ptr, num_elements_, input_shape_, is_flattend(), axis(), s_input_index, + s_sorted_index, s_range_data, t_indices_data, dev_input_shape, dev_input_axis, t_output_ptr, + s_output_index, s_output_counts, reinterpret_cast(cuda_stream)); return 0; } diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cu index 7cd72a257db..ff77da02dca 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cu @@ -29,6 +29,7 @@ #include #include "unique_consecutive_impl.cuh" #include "include/cuda_fp16.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" template struct BinaryEqual { @@ -68,72 +69,42 @@ struct BinaryNotEqual { } }; -std::vector RemoveElementsIndex(const std::vector &indices, int size, - const std::vector> &pos_map, int64_t axis) { - int num_elements = pos_map.size(); - std::vector output(num_elements, 1); - for (int pos = 0; pos < num_elements; ++pos) { - // Check the axis of the element. - int element_index = pos_map[pos][axis]; - for (int i = 0; i < size; ++i) { - if (element_index == indices[i]) { - output[pos] = 0; +template +struct IndexToAxis { + int num_elements; + int64_t axis; + const size_t *input_shape; + const S *range_data; + int range_size; + + IndexToAxis(int _num_elements, int64_t _axis, const size_t *_input_shape, const S *_range_data, int _range_size) + : num_elements(_num_elements), + axis(_axis), + input_shape(_input_shape), + range_data(_range_data), + range_size(_range_size) {} + + __device__ S operator()(S pos) const { + size_t pos_size = num_elements / input_shape[0]; + size_t last_axis = pos / pos_size; + for (size_t i = 1; i <= axis; ++i) { + pos -= last_axis * pos_size; + pos_size /= input_shape[i]; + last_axis = pos / pos_size; + } + for (size_t k = 0; k < range_size; k++) { + if (last_axis == range_data[k]) { + return 0; } } + return 1; } - return output; -} - -std::vector> GetPositionArray(int num_elements, const std::vector &input_shape) { - std::vector> pos_map(num_elements); - size_t shape_size = input_shape.size(); - size_t temp_pos; - size_t pos_size; - - for (int pos = 0; pos < num_elements; ++pos) { - std::vector array(shape_size, 0); - temp_pos = pos; - pos_size = num_elements / input_shape[0]; - array[0] = temp_pos / pos_size; - for (size_t i = 1; i < shape_size; ++i) { - temp_pos -= array[i - 1] * pos_size; - pos_size = pos_size / input_shape[i]; - array[i] = temp_pos / pos_size; - } - pos_map[pos] = array; - } - return pos_map; -} - -std::vector GetTransposeIndices(int num_elements, const std::vector &input_shape, - const std::vector> &pos_map, int64_t axis) { - // Get transpose axis. - size_t shape_size = input_shape.size(); - int64_t cnt = 0; - std::vector transpose_axis(shape_size, 0); - std::generate(transpose_axis.begin(), transpose_axis.end(), [&] { return cnt++; }); - transpose_axis[0] = axis; - transpose_axis[axis] = 0; - // Do Transpose - std::vector output_indices(num_elements, 0); - for (int pos = 0; pos < num_elements; pos++) { - auto pos_array = pos_map[pos]; - size_t new_pos = pos_array[transpose_axis[shape_size - 1]]; - size_t new_pos_size = 1; - for (int64_t i = shape_size - 2; i >= 0; i--) { - new_pos_size *= input_shape[transpose_axis[i + 1]]; - new_pos += pos_array[transpose_axis[i]] * new_pos_size; - } - output_indices[new_pos] = pos; - } - return output_indices; -} +}; template -std::vector> ComputeUniqueConsecutiveFlattend(const T *input, int num_elements, - const std::vector &input_shape, S *input_index, - S *sorted_index, S *range_data, T *output, S *index, - S *counts, cudaStream_t cuda_stream) { +std::vector> ComputeUniqueConsecutive(const T *input, int num_elements, + const std::vector &input_shape, S *range_data, + T *output, S *index, S *counts, cudaStream_t cuda_stream) { auto policy = thrust::cuda::par.on(cuda_stream); std::vector> out_shapes; // Copy input to output. @@ -141,80 +112,88 @@ std::vector> ComputeUniqueConsecutiveFlattend(const T *input, i thrust::device_pointer_cast(output)); // Inverse indices. - thrust::adjacent_difference(policy, thrust::device_pointer_cast(output), - thrust::device_pointer_cast(output) + num_elements, - thrust::device_pointer_cast(input_index), thrust::not_equal_to()); - thrust::fill(policy, thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + 1, 0); - thrust::inclusive_scan(policy, thrust::device_pointer_cast(input_index), - thrust::device_pointer_cast(input_index) + num_elements, - thrust::device_pointer_cast(input_index)); - if (index != nullptr) { - thrust::sequence(policy, thrust::device_pointer_cast(sorted_index), - thrust::device_pointer_cast(sorted_index) + num_elements); - thrust::scatter(policy, thrust::device_pointer_cast(input_index), - thrust::device_pointer_cast(input_index) + num_elements, thrust::device_pointer_cast(sorted_index), - thrust::device_pointer_cast(index)); - std::vector idx_shape(input_shape.begin(), input_shape.end()); - out_shapes.emplace_back(idx_shape); - } else { + if (index == nullptr || num_elements == 0) { std::vector idx_shape = {0}; out_shapes.emplace_back(idx_shape); + } else { + thrust::adjacent_difference(policy, thrust::device_pointer_cast(output), + thrust::device_pointer_cast(output) + num_elements, thrust::device_pointer_cast(index), + thrust::not_equal_to()); + thrust::fill(policy, thrust::device_pointer_cast(index), thrust::device_pointer_cast(index) + 1, 0); + thrust::inclusive_scan(policy, thrust::device_pointer_cast(index), + thrust::device_pointer_cast(index) + num_elements, thrust::device_pointer_cast(index)); + std::vector idx_shape(input_shape.begin(), input_shape.end()); + out_shapes.emplace_back(idx_shape); } - // Unique. - thrust::sequence(policy, thrust::device_pointer_cast(range_data), - thrust::device_pointer_cast(range_data) + num_elements + 1); - int output_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(output), - thrust::device_pointer_cast(output) + num_elements, - thrust::device_pointer_cast(range_data), thrust::equal_to()) - .first - - thrust::device_pointer_cast(output); - std::vector output_shape = {output_size}; - out_shapes.insert(out_shapes.begin(), output_shape); - // Count. - if (counts != nullptr) { + + // Unique and count. + int output_size = num_elements; + if (counts == nullptr || num_elements == 0) { + output_size = thrust::unique(policy, thrust::device_pointer_cast(output), + thrust::device_pointer_cast(output) + num_elements, thrust::equal_to()) - + thrust::device_pointer_cast(output); + std::vector counts_shape = {output_size}; + out_shapes.emplace_back(counts_shape); + } else { + thrust::sequence(policy, thrust::device_pointer_cast(range_data), + thrust::device_pointer_cast(range_data) + num_elements + 1); + output_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(output), + thrust::device_pointer_cast(output) + num_elements, + thrust::device_pointer_cast(range_data), thrust::equal_to()) + .first - + thrust::device_pointer_cast(output); thrust::fill(policy, thrust::device_pointer_cast(range_data) + output_size, thrust::device_pointer_cast(range_data) + output_size + 1, num_elements); - thrust::sequence(policy, thrust::device_pointer_cast(counts), thrust::device_pointer_cast(counts) + num_elements); thrust::adjacent_difference(policy, thrust::device_pointer_cast(range_data) + 1, thrust::device_pointer_cast(range_data) + output_size + 1, counts); std::vector counts_shape = {output_size}; out_shapes.emplace_back(counts_shape); - } else { - std::vector counts_shape = {0}; - out_shapes.emplace_back(counts_shape); } + std::vector output_shape = {output_size}; + out_shapes.insert(out_shapes.begin(), output_shape); return out_shapes; } template std::vector> ComputeUniqueConsecutiveByAxis(const T *input, int num_elements, - const std::vector &input_shape, bool is_flattend, - int64_t axis, S *input_index, S *sorted_index, - S *range_data, T *indices_data, T *output, S *index, - S *counts, cudaStream_t cuda_stream) { + const std::vector &input_shape, int64_t axis, + S *input_index, S *sorted_index, S *range_data, + T *indices_data, size_t *dev_input_shape, + size_t *dev_input_axis, T *output, S *index, S *counts, + cudaStream_t cuda_stream) { // Compute UniqueConsecutive by axis. auto policy = thrust::cuda::par.on(cuda_stream); std::vector> out_shapes; - // Do transpose. - int64_t num_inp = input_shape[axis]; - int64_t n = num_elements / num_inp; thrust::copy(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + num_elements, thrust::device_pointer_cast(output)); - std::vector> pos_map = GetPositionArray(num_elements, input_shape); - std::vector transpose_indices = GetTransposeIndices(num_elements, input_shape, pos_map, axis); - thrust::device_vector indices_map(transpose_indices.begin(), transpose_indices.end()); - thrust::gather(policy, indices_map.begin(), indices_map.end(), thrust::device_pointer_cast(input), - thrust::device_pointer_cast(indices_data)); + // Do transpose. + size_t shape_size = input_shape.size(); + cudaMemcpyAsync(dev_input_shape, input_shape.data(), sizeof(size_t) * shape_size, cudaMemcpyHostToDevice, + cuda_stream); + // Used for transpose: dev_input_axis={0, 1, ..., axis, ...} -> dev_input_axis[0]=axis, dev_input_axis[axis]=0 + thrust::sequence(policy, thrust::device_pointer_cast(dev_input_axis), + thrust::device_pointer_cast(dev_input_axis) + shape_size); + thrust::fill(policy, thrust::device_pointer_cast(dev_input_axis), thrust::device_pointer_cast(dev_input_axis) + 1, + axis); + thrust::fill(policy, thrust::device_pointer_cast(dev_input_axis) + axis, + thrust::device_pointer_cast(dev_input_axis) + axis + 1, 0); + CalTranspose(num_elements, input, dev_input_shape, dev_input_axis, shape_size, indices_data, cuda_stream); // Inverse indices. + int64_t num_inp = input_shape[axis]; + int64_t n = num_elements / num_inp; thrust::sequence(policy, thrust::device_pointer_cast(range_data), thrust::device_pointer_cast(range_data) + num_inp); - thrust::adjacent_difference(policy, thrust::device_pointer_cast(range_data), - thrust::device_pointer_cast(range_data) + num_inp, - thrust::device_pointer_cast(input_index), BinaryNotEqual(n, indices_data)); - thrust::fill(policy, thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + 1, 0); - thrust::inclusive_scan(policy, thrust::device_pointer_cast(input_index), - thrust::device_pointer_cast(input_index) + num_inp, thrust::device_pointer_cast(input_index)); - if (index != nullptr) { + if (index == nullptr || num_elements == 0) { + std::vector idx_shape = {0}; + out_shapes.emplace_back(idx_shape); + } else { + thrust::adjacent_difference(policy, thrust::device_pointer_cast(range_data), + thrust::device_pointer_cast(range_data) + num_inp, + thrust::device_pointer_cast(input_index), BinaryNotEqual(n, indices_data)); + thrust::fill(policy, thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + 1, 0); + thrust::inclusive_scan(policy, thrust::device_pointer_cast(input_index), + thrust::device_pointer_cast(input_index) + num_inp, + thrust::device_pointer_cast(input_index)); thrust::sequence(policy, thrust::device_pointer_cast(sorted_index), thrust::device_pointer_cast(sorted_index) + num_inp); thrust::scatter(policy, thrust::device_pointer_cast(input_index), @@ -223,33 +202,24 @@ std::vector> ComputeUniqueConsecutiveByAxis(const T *input, int std::vector idx_shape; idx_shape.push_back(num_inp); out_shapes.emplace_back(idx_shape); - } else { - std::vector idx_shape = {0}; - out_shapes.emplace_back(idx_shape); } - // Unique. - thrust::sequence(policy, thrust::device_pointer_cast(sorted_index), - thrust::device_pointer_cast(sorted_index) + num_inp + 1); - int indices_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(range_data), - thrust::device_pointer_cast(range_data) + num_inp, - thrust::device_pointer_cast(sorted_index), BinaryEqual(n, indices_data)) - .first - - thrust::device_pointer_cast(range_data); - - std::vector indices(indices_size); - cudaMemcpyAsync(indices.data(), range_data, indices_size * sizeof(int), cudaMemcpyDeviceToHost, cuda_stream); - cudaStreamSynchronize(cuda_stream); - std::vector elements_index = RemoveElementsIndex(indices, indices_size, pos_map, axis); - thrust::device_vector remove_map(elements_index.begin(), elements_index.end()); - thrust::remove_if(thrust::device_pointer_cast(output), thrust::device_pointer_cast(output) + num_elements, - remove_map.begin(), thrust::identity()); - std::vector output_shape(input_shape.begin(), input_shape.end()); - output_shape[axis] = indices_size; - out_shapes.insert(out_shapes.begin(), output_shape); - - // Count. - if (counts != nullptr) { + // Unique and count. + int indices_size = num_inp; + if (counts == nullptr || num_elements == 0) { + indices_size = thrust::unique(policy, thrust::device_pointer_cast(range_data), + thrust::device_pointer_cast(range_data) + num_inp, BinaryEqual(n, indices_data)) - + thrust::device_pointer_cast(range_data); + std::vector counts_shape = {0}; + out_shapes.emplace_back(counts_shape); + } else { + thrust::sequence(policy, thrust::device_pointer_cast(sorted_index), + thrust::device_pointer_cast(sorted_index) + num_inp + 1); + indices_size = thrust::unique_by_key(policy, thrust::device_pointer_cast(range_data), + thrust::device_pointer_cast(range_data) + num_inp, + thrust::device_pointer_cast(sorted_index), BinaryEqual(n, indices_data)) + .first - + thrust::device_pointer_cast(range_data); thrust::fill(policy, thrust::device_pointer_cast(sorted_index) + indices_size, thrust::device_pointer_cast(sorted_index) + indices_size + 1, num_inp); thrust::sequence(policy, thrust::device_pointer_cast(counts), thrust::device_pointer_cast(counts) + num_inp); @@ -257,40 +227,51 @@ std::vector> ComputeUniqueConsecutiveByAxis(const T *input, int thrust::device_pointer_cast(sorted_index) + num_inp + 1, counts); std::vector counts_shape = {indices_size}; out_shapes.emplace_back(counts_shape); - } else { - std::vector counts_shape = {0}; - out_shapes.emplace_back(counts_shape); } + + // Remove invalid dimensions according to indices, reshape the output. + std::vector output_shape(input_shape.begin(), input_shape.end()); + if (indices_size != num_inp) { + thrust::sequence(policy, thrust::device_pointer_cast(input_index), + thrust::device_pointer_cast(input_index) + num_elements); + thrust::transform(thrust::device_pointer_cast(input_index), thrust::device_pointer_cast(input_index) + num_elements, + thrust::device_pointer_cast(input_index), + IndexToAxis(num_elements, axis, dev_input_shape, range_data, indices_size)); + thrust::remove_if(thrust::device_pointer_cast(output), thrust::device_pointer_cast(output) + num_elements, + input_index, thrust::identity()); + output_shape[axis] = indices_size; + } + out_shapes.insert(out_shapes.begin(), output_shape); return out_shapes; } template std::vector> CalUniqueConsecutive(const T *input, int num_elements, - const std::vector &input_shape, bool is_flattend, + const std::vector &input_shape, bool is_axis_none, int64_t axis, S *input_index, S *sorted_index, S *range_data, - T *indices_data, T *output, S *index, S *counts, - cudaStream_t cuda_stream) { - if (is_flattend) { - return ComputeUniqueConsecutiveFlattend(input, num_elements, input_shape, input_index, sorted_index, range_data, - output, index, counts, cuda_stream); + T *indices_data, size_t *dev_input_shape, size_t *dev_input_axis, + T *output, S *index, S *counts, cudaStream_t cuda_stream) { + if (is_axis_none) { + return ComputeUniqueConsecutive(input, num_elements, input_shape, range_data, output, index, counts, cuda_stream); } - return ComputeUniqueConsecutiveByAxis(input, num_elements, input_shape, is_flattend, axis, input_index, sorted_index, - range_data, indices_data, output, index, counts, cuda_stream); + return ComputeUniqueConsecutiveByAxis(input, num_elements, input_shape, axis, input_index, sorted_index, range_data, + indices_data, dev_input_shape, dev_input_axis, output, index, counts, + cuda_stream); } template CUDA_LIB_EXPORT std::vector> CalUniqueConsecutive( - const float *input, int num_elements, const std::vector &input_shape, bool is_flattend, int64_t axis, - int *input_index, int *sorted_index, int *range_data, float *indices_data, float *output, int *index, int *counts, - cudaStream_t cuda_stream); + const float *input, int num_elements, const std::vector &input_shape, bool is_axis_none, int64_t axis, + int *input_index, int *sorted_index, int *range_data, float *indices_data, size_t *dev_input_shape, + size_t *dev_input_axis, float *output, int *index, int *counts, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT std::vector> CalUniqueConsecutive( - const half *input, int num_elements, const std::vector &input_shape, bool is_flattend, int64_t axis, - int *input_index, int *sorted_index, int *range_data, half *indices_data, half *output, int *index, int *counts, - cudaStream_t cuda_stream); + const half *input, int num_elements, const std::vector &input_shape, bool is_axis_none, int64_t axis, + int *input_index, int *sorted_index, int *range_data, half *indices_data, size_t *dev_input_shape, + size_t *dev_input_axis, half *output, int *index, int *counts, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT std::vector> CalUniqueConsecutive( - const int *input, int num_elements, const std::vector &input_shape, bool is_flattend, int64_t axis, - int *input_index, int *sorted_index, int *range_data, int *indices_data, int *output, int *index, int *counts, - cudaStream_t cuda_stream); + const int *input, int num_elements, const std::vector &input_shape, bool is_axis_none, int64_t axis, + int *input_index, int *sorted_index, int *range_data, int *indices_data, size_t *dev_input_shape, + size_t *dev_input_axis, int *output, int *index, int *counts, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT std::vector> CalUniqueConsecutive( - const int64_t *input, int num_elements, const std::vector &input_shape, bool is_flattend, int64_t axis, - int64_t *input_index, int64_t *sorted_index, int64_t *range_data, int64_t *indices_data, int64_t *output, - int64_t *index, int64_t *counts, cudaStream_t cuda_stream); + const int64_t *input, int num_elements, const std::vector &input_shape, bool is_axis_none, int64_t axis, + int64_t *input_index, int64_t *sorted_index, int64_t *range_data, int64_t *indices_data, size_t *dev_input_shape, + size_t *dev_input_axis, int64_t *output, int64_t *index, int64_t *counts, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cuh index 4ba59957f59..ed41031c2dd 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/unique_consecutive_impl.cuh @@ -20,10 +20,8 @@ #include #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" template -CUDA_LIB_EXPORT std::vector> CalUniqueConsecutive(const T *input, int num_elements, - const std::vector &input_shape, - bool is_flattend, int64_t axis, S *input_index, - S *sorted_index, S *range_data, T *indices_data, - T *output, S *index, S *count, - cudaStream_t cuda_stream); +CUDA_LIB_EXPORT std::vector> CalUniqueConsecutive( + const T *input, int num_elements, const std::vector &input_shape, bool is_axis_none, int64_t axis, + S *input_index, S *sorted_index, S *range_data, T *indices_data, size_t *dev_input_shape, size_t *dev_input_axis, + T *output, S *index, S *count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_UNIQUE_CONSECUTIVE_IMPL_CUH_ diff --git a/mindspore/core/ops/unique_consecutive.cc b/mindspore/core/ops/unique_consecutive.cc index 9f06cea6f33..a8ce6db3bee 100644 --- a/mindspore/core/ops/unique_consecutive.cc +++ b/mindspore/core/ops/unique_consecutive.cc @@ -31,11 +31,24 @@ namespace { constexpr int64_t kUniqueConsecutiveInputNum = 1; // For aicpu, if axis is 1000, that represents None. constexpr int64_t kAxisIsNone = 1000; + +bool CheckNullInput(const std::vector &shape) { + if (shape.size() != 0) { + if (std::any_of(shape.begin(), shape.end(), [](int64_t i) { return i == 0; })) { + return true; + } + } + return false; +} + abstract::BaseShapePtr UniqueConsecutiveInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { auto op_name = primitive->name(); auto input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); auto input_shape_vec = input_shape_map[kShape]; + if (CheckNullInput(input_shape_vec)) { + MS_LOG(EXCEPTION) << "For " << op_name << ", the shape of input cannot contain zero."; + } auto axis_ptr = primitive->GetAttr(kAxis); MS_EXCEPTION_IF_NULL(axis_ptr);