From 2ef17714e955d78d8f8ec168904cac7bcdf85e83 Mon Sep 17 00:00:00 2001 From: Super_Wzb <1114120549@qq.com> Date: Thu, 4 Aug 2022 00:17:09 +0800 Subject: [PATCH] [feat] [assistant] [ops] [I5EWJI,I5EWJJ,I5EWJK,I5EWJC,I5EWJD,I5EWJG] New GPU operator implementation SparseSegmentSum,SparseSegmentSumGrad,SparseSegmentSumWithNumSegments,SparseSegmentSqrtN, SparseSegmentSqrtNGrad, SparseSegmentSqrtNWithNumSegments --- .../cuda_ops/sparse_segment_grad_impl.cu | 225 ++++++++ .../cuda_ops/sparse_segment_grad_impl.cuh | 30 ++ .../cuda_impl/cuda_ops/sparse_segment_impl.cu | 307 +++++++++++ .../cuda_ops/sparse_segment_impl.cuh | 29 + .../sparse/sparse_segment_ops_gpu_kernel.cc | 506 ++++++++++++++++++ .../sparse/sparse_segment_ops_gpu_kernel.h | 101 ++++ .../sparse_segment_grad_ops_gpu_kernel.cc | 245 +++++++++ .../sparse_segment_grad_ops_gpu_kernel.h | 99 ++++ mindspore/core/ops/core_ops.h | 21 +- .../ops/grad/sparse_segment_sqrt_n_grad.cc | 27 +- .../core/ops/grad/sparse_segment_sum_grad.cc | 113 ++++ .../core/ops/grad/sparse_segment_sum_grad.h | 46 ++ mindspore/core/ops/sparse_segment_sqrt_n.cc | 215 ++++---- ...sparse_segment_sqrt_n_with_num_segments.cc | 249 ++++----- mindspore/core/ops/sparse_segment_sum.cc | 100 ++++ mindspore/core/ops/sparse_segment_sum.h | 45 ++ .../sparse_segment_sum_with_num_segments.cc | 120 +++++ .../sparse_segment_sum_with_num_segments.h | 47 ++ .../ops/_grad_experimental/grad_sparse_ops.py | 60 ++- .../mindspore/ops/operations/_grad_ops.py | 55 +- .../mindspore/ops/operations/sparse_ops.py | 125 ++++- .../gpu/test_sparse_segment_sqrt_n_grad_op.py | 101 ++++ .../ops/gpu/test_sparse_segment_sqrt_n_op.py | 89 +++ ...rse_segment_sqrt_n_with_num_segments_op.py | 101 ++++ .../gpu/test_sparse_segment_sum_grad_op.py | 101 ++++ .../st/ops/gpu/test_sparse_segment_sum_op.py | 89 +++ ...sparse_segment_sum_with_num_segments_op.py | 101 ++++ 27 files changed, 3093 insertions(+), 254 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h create mode 100644 mindspore/core/ops/grad/sparse_segment_sum_grad.cc create mode 100644 mindspore/core/ops/grad/sparse_segment_sum_grad.h create mode 100644 mindspore/core/ops/sparse_segment_sum.cc create mode 100644 mindspore/core/ops/sparse_segment_sum.h create mode 100644 mindspore/core/ops/sparse_segment_sum_with_num_segments.cc create mode 100644 mindspore/core/ops/sparse_segment_sum_with_num_segments.h create mode 100644 tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py create mode 100644 tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py create mode 100644 tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py create mode 100644 tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py create mode 100644 tests/st/ops/gpu/test_sparse_segment_sum_op.py create mode 100644 tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu new file mode 100644 index 00000000000..a13eff55a44 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cu @@ -0,0 +1,225 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh" + +template +__global__ void SparseSegmentPosKernel(const S *indices_ptr, size_t *indices_pos_ptr, size_t idx_seg_size, + size_t outer_size) { + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) { + const S max_size = static_cast(outer_size); + const S min_size = S(0); + S beg_idx = (id == 0) ? min_size : indices_ptr[id - 1] + 1; + S end_idx = (id >= idx_seg_size) ? max_size : indices_ptr[id]; + beg_idx = max(min_size, min(max_size, beg_idx)); + end_idx = max(min_size, min(max_size, end_idx)); + for (S i = beg_idx; i <= end_idx; i++) { + indices_pos_ptr[i] = id; + } + } +} + +template +__global__ void SparseSegmentSumGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(grad_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, static_cast(reduce_result)); + } + } + } + } +} + +template +__global__ void SparseSegmentSumGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(__half2float(grad_ptr[index * inner_size + inner_idx])); + } + if (threadIdx.y == 0 && inner_valid) { + half *out_pos = y_ptr + segment_ids_ptr[pos]* inner_size + inner_idx; + MsAtomicAdd(out_pos, __float2half(static_cast(reduce_result))); + } + } + } + } +} + +template +__global__ void SparseSegmentSqrtNGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(grad_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, static_cast(reduce_result / sqrt_segment_len)); + } + } + } + } +} + +template +__global__ void SparseSegmentSqrtNGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr, + const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) { + size_t beg_pos = indices_pos_ptr[inid]; + size_t end_pos = indices_pos_ptr[inid + 1]; + double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(__half2float(grad_ptr[index * inner_size + inner_idx])); + } + if (threadIdx.y == 0 && inner_valid) { + half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx; + MsAtomicAdd(out_pos, __float2half(static_cast(reduce_result / sqrt_segment_len))); + } + } + } + } +} + +inline int Log2Floor_M(uint32_t n) { + if (n == 0) return -1; + int log = 0; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32_t x = n >> shift; + if (x) { + n = x; + log += shift; + } + } + return log; +} + +inline int Log2Floor64_M(uint64_t n) { + // Scan n first high 32 then low 32 bits. + const uint32_t high_32_bit = static_cast(n >> 32); + if (high_32_bit == 0) { + return Log2Floor_M(static_cast(n)); + } else { + return 32 + Log2Floor_M(high_32_bit); + } +} + +inline int Log2Ceil64_M(uint64_t n) { + int floor = Log2Floor64_M(n); + if (n == (n & ~(n - 1))) + return floor; + else + return floor + 1; +} + + +template +bool CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, const S *indices_ptr, + const S *segment_ids_ptr, size_t *indices_pos_ptr, size_t outer_size, + size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream) { + // Get start position of each segment and set to indices_pos_ptr. + // The last element of indices_pos_ptr must equal to idx_seg_size. + SparseSegmentPosKernel<<>>( + indices_ptr, indices_pos_ptr, idx_seg_size, outer_size); + const unsigned int max_grid_x = (1u << 31) - 1; + const unsigned int max_grid_y = (1u << 16) - 1; + unsigned int block_x = 32; + unsigned int block_y = 1; + unsigned int grid_x = std::min(static_cast(UP_DIV(inner_size, block_x)), max_grid_x); + unsigned int grid_y = std::min(static_cast(output_dim0), max_grid_y); + dim3 block(block_x, block_y); + dim3 grid(grid_x, grid_y); + unsigned int shared_memory_size = block_x * block_y * sizeof(R); + if (kernel_type == "SparseSegmentSumGrad") { + SparseSegmentSumGradKernel<<>>(grad_ptr, indices_ptr, + segment_ids_ptr, indices_pos_ptr, + outer_size, inner_size, output_dim0, + y_ptr); + } else if (kernel_type == "SparseSegmentSqrtNGrad") { + SparseSegmentSqrtNGradKernel<<>>(grad_ptr, indices_ptr, + segment_ids_ptr, indices_pos_ptr, + outer_size, inner_size, + output_dim0, y_ptr); + } + return true; +} + +#define ADD_SPARSE_SEGMENT_GRAD(R, S) \ + template CUDA_LIB_EXPORT bool CalSparseSegmentGradCombination(const std::string kernel_type, \ + const R *grad_ptr, const S *indices_ptr, \ + const S *segment_ids_ptr, \ + size_t *indices_pos_ptr, size_t outer_size, \ + size_t inner_size, size_t idx_seg_size, \ + size_t output_dim0, R *y_ptr, \ + uint32_t device_id, cudaStream_t cuda_stream); + +ADD_SPARSE_SEGMENT_GRAD(half, int32_t) +ADD_SPARSE_SEGMENT_GRAD(half, int64_t) + +ADD_SPARSE_SEGMENT_GRAD(float, int32_t) +ADD_SPARSE_SEGMENT_GRAD(float, int64_t) + +ADD_SPARSE_SEGMENT_GRAD(double, int32_t) +ADD_SPARSE_SEGMENT_GRAD(double, int64_t) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh new file mode 100644 index 00000000000..53823e86e44 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT bool CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, + const S *indices_ptr, const S *segment_ids_ptr, + size_t *indices_pos_ptr, size_t outer_size, size_t inner_size, + size_t idx_seg_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu new file mode 100644 index 00000000000..b2d2094da93 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cu @@ -0,0 +1,307 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "plugin/device/cpu/kernel/nnacl/op_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh" + +template +__global__ void SparseSegmentPosKernel(const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t idx_seg_size, + size_t output_dim0) { + for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) { + const S max_size = static_cast(output_dim0); + const S min_size = S(0); + S beg_idx = (id == 0) ? min_size : segment_ids_ptr[id - 1] + 1; + S end_idx = (id >= idx_seg_size) ? max_size : segment_ids_ptr[id]; + beg_idx = max(min_size, min(max_size, beg_idx)); + end_idx = max(min_size, min(max_size, end_idx)); + for (S i = beg_idx; i <= end_idx; i++) { + segment_pos_ptr[i] = id; + } + } +} + +template +__global__ void SparseSegmentSumKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + R segment_sum = 0; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + R reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = x_ptr[index * inner_size + inner_idx]; + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum; + } + } + } +} + +template +__global__ void SparseSegmentSumKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + double segment_sum = 0; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? static_cast(0) : + static_cast(segment_sum); + } + } + } +} + +template +__global__ void SparseSegmentSumKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + float segment_sum = 0; + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + float reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) : __float2half(segment_sum); + } + } + } +} + +template +__global__ void SparseSegmentSqrtNKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + R segment_sum = 0; + R sqrt_segment_len = R(sqrt(static_cast(end_pos - beg_pos))); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + R reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = x_ptr[index * inner_size + inner_idx]; + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum / sqrt_segment_len; + } + } + } +} + +template +__global__ void SparseSegmentSqrtNKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + double segment_sum = 0; + double sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + double reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = static_cast(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? static_cast(0) : + static_cast(segment_sum / sqrt_segment_len); + } + } + } +} + +template +__global__ void SparseSegmentSqrtNKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr, + size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) { + size_t num_blocks = (inner_size - 1) / blockDim.x + 1; + for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) { + size_t inner_idx = threadIdx.x + bid * blockDim.x; + bool inner_valid = inner_idx < inner_size; + for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) { + size_t beg_pos = segment_pos_ptr[sid]; + size_t end_pos = segment_pos_ptr[sid + 1]; + float segment_sum = 0; + float sqrt_segment_len = sqrt(static_cast(end_pos - beg_pos)); + for (size_t pos = beg_pos; pos < end_pos; pos += 1) { + float reduce_result = 0; + S index = inner_valid ? indices_ptr[pos] : outer_size; + if (index >= 0 && index < outer_size) { + reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]); + } + if (threadIdx.y == 0 && inner_valid) { + segment_sum += reduce_result; + } + } + if (threadIdx.y == 0 && inner_valid) { + y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) : + __float2half(segment_sum / sqrt_segment_len); + } + } + } +} + +inline int Log2Floor(uint32_t n) { + if (n == 0) return -1; + int log = 0; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32_t x = n >> shift; + if (x) { + n = x; + log += shift; + } + } + return log; +} + +inline int Log2Floor64(uint64_t n) { + // Scan n first high 32 then low 32 bits. + const uint32_t high_32_bit = static_cast(n >> 32); + if (high_32_bit == 0) { + return Log2Floor(static_cast(n)); + } else { + return 32 + Log2Floor(high_32_bit); + } +} + +inline int Log2Ceil64(uint64_t n) { + int floor = Log2Floor64(n); + if (n == (n & ~(n - 1))) + return floor; + else + return floor + 1; +} + +template +bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr, + const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size, + size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream) { + // Get start position of each segment and set to segment_pos_ptr. + // The last element of segment_pos_ptr must equal to idx_seg_size. + SparseSegmentPosKernel<<>>( + segment_ids_ptr, segment_pos_ptr, idx_seg_size, output_dim0); + + const unsigned int max_grid_x = (1u << 31) - 1; + const unsigned int max_grid_y = (1u << 16) - 1; + unsigned int block_x = 32; + unsigned int block_y = 1; + unsigned int grid_x = std::min(static_cast(UP_DIV(inner_size, block_x)), max_grid_x); + unsigned int grid_y = std::min(static_cast(output_dim0), max_grid_y); + dim3 block(block_x, block_y); + dim3 grid(grid_x, grid_y); + unsigned int shared_memory_size = block_x * block_y * sizeof(R); + if (kernel_type == "SparseSegmentSum" || kernel_type == "SparseSegmentSumWithNumSegments") { + SparseSegmentSumKernel<<>>(x_ptr, indices_ptr, segment_pos_ptr, + outer_size, inner_size, output_dim0, + y_ptr); + } else if (kernel_type == "SparseSegmentSqrtN" || kernel_type == "SparseSegmentSqrtNWithNumSegments") { + SparseSegmentSqrtNKernel<<>>(x_ptr, indices_ptr, segment_pos_ptr, + outer_size, inner_size, output_dim0, + y_ptr); + } + return true; +} + +#define ADD_SPARSE_SEGMENT(R, S) \ + template CUDA_LIB_EXPORT bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, \ + const S *indices_ptr, const S *segment_ids_ptr, \ + size_t *segment_pos_ptr, size_t outer_size, \ + size_t inner_size, size_t idx_seg_size, \ + size_t output_dim0, R *y_ptr, uint32_t device_id, \ + cudaStream_t cuda_stream); + +ADD_SPARSE_SEGMENT(uint8_t, int32_t) +ADD_SPARSE_SEGMENT(uint8_t, int64_t) + +ADD_SPARSE_SEGMENT(uint16_t, int32_t) +ADD_SPARSE_SEGMENT(uint16_t, int64_t) + +ADD_SPARSE_SEGMENT(int8_t, int32_t) +ADD_SPARSE_SEGMENT(int8_t, int64_t) + +ADD_SPARSE_SEGMENT(int16_t, int32_t) +ADD_SPARSE_SEGMENT(int16_t, int64_t) + +ADD_SPARSE_SEGMENT(int32_t, int32_t) +ADD_SPARSE_SEGMENT(int32_t, int64_t) + +ADD_SPARSE_SEGMENT(int64_t, int32_t) +ADD_SPARSE_SEGMENT(int64_t, int64_t) + +ADD_SPARSE_SEGMENT(half, int32_t) +ADD_SPARSE_SEGMENT(half, int64_t) + +ADD_SPARSE_SEGMENT(float, int32_t) +ADD_SPARSE_SEGMENT(float, int64_t) + +ADD_SPARSE_SEGMENT(double, int32_t) +ADD_SPARSE_SEGMENT(double, int64_t) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh new file mode 100644 index 00000000000..c19aa170052 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ + +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr, + const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size, + size_t inner_size, size_t indices_size, size_t output_dim0, R *y_ptr, + uint32_t device_id, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc new file mode 100644 index 00000000000..388255b4a77 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.cc @@ -0,0 +1,506 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto Sparse_Segment_Sum = "SparseSegmentSum"; +constexpr auto Sparse_Segment_Sum_With_Num_Segments = "SparseSegmentSumWithNumSegments"; +constexpr auto Sparse_Segment_Sqrt_N = "SparseSegmentSqrtN"; +constexpr auto Sparse_Segment_Sqrt_N_With_Num_Segments = "SparseSegmentSqrtNWithNumSegments"; +constexpr size_t kNumber1 = 1; +constexpr size_t kNumber3 = 3; +constexpr size_t kNumber4 = 4; +} // namespace + +bool SparseSegmentOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + if (kernel_type_ == "SparseSegmentSum" || kernel_type_ == "SparseSegmentSqrtN") { + flag_ = true; + } else { + flag_ = false; + } + size_t inputs_num = flag_ ? kNumber3 : kNumber4; + size_t outputs_num = kNumber1; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); + kernel_name_ = base_operator->GetPrim()->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << "."; + return false; + } + kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second; + unit_x_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first); + return true; +} + +int SparseSegmentOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + std::vector x_shape = inputs.at(kIndex0)->GetShapeVector(); + x_shape_0_ = x_shape[0]; + x_elements_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies{}); + outer_size_ = x_shape.front(); + inner_size_ = x_elements_ / x_shape.front(); + std::vector indices_shape = inputs.at(kIndex1)->GetShapeVector(); + idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{}); + output_dim0_ = LongToSize(output_shape.front()); + + size_t input_x_size = x_elements_ * unit_x_size_; + size_t input_idx_seg_size = idx_seg_elements_ * unit_idx_seg_size_; + size_t output_size = output_elements_ * unit_x_size_; + input_size_list_.push_back(input_x_size); + input_size_list_.push_back(input_idx_seg_size); + input_size_list_.push_back(input_idx_seg_size); + if (flag_) { + input_size_list_.push_back(unit_idx_seg_size_); + } + output_size_list_.push_back(output_size); + workspace_size_list_.push_back((output_dim0_ + 1) * sizeof(size_t)); + return KRET_OK; +} + +template +bool SparseSegmentOpsGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + R *x_ptr = GetDeviceAddress(inputs, kIndex0); + S *indices_ptr = GetDeviceAddress(inputs, kIndex1); + S *segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); + R *y_ptr = GetDeviceAddress(outputs, kIndex0); + size_t *segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); + auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; + if (any(x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { + return false; + } + cudaStream_t stream = reinterpret_cast(cuda_stream_); + std::vector indices_host; + std::vector segment_ids_host; + std::vector num_segments_host; + indices_host.resize(idx_seg_elements_); + segment_ids_host.resize(idx_seg_elements_); + num_segments_host.resize(kNumber1); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "cudaMemcpy failed."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr, + idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "cudaMemcpy failed."); + if (!flag_) { + auto num_segments_ptr = GetDeviceAddress(inputs, kIndex3); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(num_segments_host.data(), num_segments_ptr, sizeof(S), cudaMemcpyDeviceToHost, stream), + "cudaMemcpy failed."); + } + if (segment_ids_host[0] != 0 && flag_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', indices in 'segment_ids' should be contiguous and start from 0."; + } + for (size_t i = 1; i < idx_seg_elements_; i++) { + if (segment_ids_host[i] < segment_ids_host[i - 1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; + } + if (segment_ids_host[i] - segment_ids_host[i - 1] > 1 && flag_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', indices in 'segment_ids' should be contiguous and start from 0."; + } + } + if (segment_ids_host[idx_seg_elements_ - 1] >= num_segments_host[kIndex0] && !flag_) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ + << "', num_segments must bigger than the last number of segment_ids."; + } + for (size_t i = 0; i < idx_seg_elements_; i++) { + if (indices_host[i] >= static_cast(x_shape_0_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape."; + } + } + CalSparseSegmentCombination(kernel_type_, x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, outer_size_, + inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream); + return true; +} + +std::map>> + SparseSegmentOpsGpuKernelMod::kernel_attr_map_ = { + {Sparse_Segment_Sum, + {{KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sum_With_Num_Segments, + {{KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeUInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt8), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sqrt_N, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sqrt_N_With_Num_Segments, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentOpsGpuKernelMod::LaunchKernel}}}}; // kernel_attr_map_ + +std::vector SparseSegmentOpsGpuKernelMod::GetOpSupport() { + auto iter = kernel_attr_map_.find(kernel_type_); + if (iter == kernel_attr_map_.end()) { + MS_LOG(ERROR) << "For 'SparseSegmentOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map_) + << " currently, but got " << kernel_name_; + } + std::vector support_list; + (void)std::transform( + iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSum, + []() { return std::make_shared(Sparse_Segment_Sum); }); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumWithNumSegments, []() { + return std::make_shared(Sparse_Segment_Sum_With_Num_Segments); +}); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtN, []() { + return std::make_shared(Sparse_Segment_Sqrt_N); +}); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNWithNumSegments, []() { + return std::make_shared(Sparse_Segment_Sqrt_N_With_Num_Segments); +}); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h new file mode 100644 index 00000000000..e2902ca630e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h @@ -0,0 +1,101 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSegmentOpsGpuKernelMod : public NativeGpuKernelMod { + public: + explicit SparseSegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~SparseSegmentOpsGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + protected: + void ResetResource() noexcept { + outer_size_ = 0; + inner_size_ = 0; + x_elements_ = 0; + x_shape_0_ = 0; + idx_seg_elements_ = 0; + output_dim0_ = 0; + output_elements_ = 0; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using SSLaunchFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t outer_size_{0}; + size_t inner_size_{0}; + size_t x_elements_{0}; + size_t x_shape_0_{0}; + size_t idx_seg_elements_{0}; + size_t output_dim0_{0}; + size_t output_elements_{0}; + size_t unit_x_size_{1}; + size_t unit_idx_seg_size_{1}; + std::string kernel_type_{"Unknown"}; + bool is_null_input_{false}; + size_t flag_{0}; + void *cuda_stream_{nullptr}; + SSLaunchFunc kernel_func_{}; + static std::map>> kernel_attr_map_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc new file mode 100644 index 00000000000..719c1306b50 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.cc @@ -0,0 +1,245 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto Sparse_Segment_Sum_Grad = "SparseSegmentSumGrad"; +constexpr auto Sparse_Segment_Sqrt_N_Grad = "SparseSegmentSqrtNGrad"; +constexpr size_t kNumber1 = 1; +constexpr size_t kNumber4 = 4; +} // namespace + +bool SparseSegmentGradOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + size_t inputs_num = kNumber4; + size_t outputs_num = kNumber1; + CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_); + kernel_name_ = base_operator->GetPrim()->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << "."; + return false; + } + kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second; + unit_grad_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first); + return true; +} + +int SparseSegmentGradOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + for (const auto &output : outputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto output_shape = output->GetShapeVector(); + if (!IsValidShape(output_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); + output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()); + if (output_elements_ == 0) { + is_null_input_ = true; + } + std::vector grad_shape = inputs.at(kIndex0)->GetShapeVector(); + grad_shape_0_ = grad_shape[0]; + grad_elements_ = std::accumulate(grad_shape.begin(), grad_shape.end(), 1, std::multiplies{}); + outer_size_ = grad_shape.front(); + inner_size_ = grad_elements_ / outer_size_; + std::vector indices_shape = inputs.at(kIndex1)->GetShapeVector(); + idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{}); + output_dim0_ = LongToSize(output_shape.front()); + + size_t input_grad_size = grad_elements_ * unit_grad_size_; + size_t input_idx_seg_size = idx_seg_elements_ * unit_idx_seg_size_; + size_t output_size = output_elements_ * unit_grad_size_; + input_size_list_.push_back(input_grad_size); + input_size_list_.push_back(input_idx_seg_size); + input_size_list_.push_back(input_idx_seg_size); + input_size_list_.push_back(unit_idx_seg_size_); + output_size_list_.push_back(output_size); + workspace_size_list_.push_back((outer_size_ + 1) * sizeof(size_t)); + return KRET_OK; +} + +template +bool SparseSegmentGradOpsGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + R *grad_ptr = GetDeviceAddress(inputs, kIndex0); + S *indices_ptr = GetDeviceAddress(inputs, kIndex1); + S *segment_ids_ptr = GetDeviceAddress(inputs, kIndex2); + R *y_ptr = GetDeviceAddress(outputs, kIndex0); + size_t *segment_pos_ptr = GetDeviceAddress(workspace, kIndex0); + auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); }; + if (any(grad_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) { + return false; + } + cudaStream_t stream = reinterpret_cast(cuda_stream_); + std::vector indices_host; + std::vector segment_ids_host; + indices_host.resize(idx_seg_elements_); + segment_ids_host.resize(idx_seg_elements_); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "cudaMemcpy failed."); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr, + idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream), + "cudaMemcpy failed."); + for (size_t i = 1; i < idx_seg_elements_; i++) { + if (segment_ids_host[i] < segment_ids_host[i - 1]) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted."; + } + } + for (size_t i = 0; i < idx_seg_elements_; i++) { + if (indices_host[i] >= static_cast(output_dim0_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of output_dim0."; + } + if (segment_ids_host[i] >= static_cast(grad_shape_0_)) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids out of range of grad's first shape."; + } + } + cudaMemset(y_ptr, 0, output_elements_ * unit_grad_size_); + CalSparseSegmentGradCombination(kernel_type_, grad_ptr, segment_ids_ptr, indices_ptr, segment_pos_ptr, outer_size_, + inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream); + return true; +} + +std::map>> + SparseSegmentGradOpsGpuKernelMod::kernel_attr_map_ = { + {Sparse_Segment_Sum_Grad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}}}, + {Sparse_Segment_Sqrt_N_Grad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat16), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat32), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseSegmentGradOpsGpuKernelMod::LaunchKernel}}}}; // kernel_attr_map_ + +std::vector SparseSegmentGradOpsGpuKernelMod::GetOpSupport() { + auto iter = kernel_attr_map_.find(kernel_type_); + if (iter == kernel_attr_map_.end()) { + MS_LOG(ERROR) << "For 'SparseSegmentGradOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map_) + << " currently, but got " << kernel_name_; + } + std::vector support_list; + (void)std::transform( + iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumGrad, []() { + return std::make_shared(Sparse_Segment_Sum_Grad); +}); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNGrad, []() { + return std::make_shared(Sparse_Segment_Sqrt_N_Grad); +}); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h new file mode 100644 index 00000000000..acaf09db541 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h @@ -0,0 +1,99 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseSegmentGradOpsGpuKernelMod : public NativeGpuKernelMod { + public: + explicit SparseSegmentGradOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {} + ~SparseSegmentGradOpsGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + protected: + void ResetResource() noexcept { + outer_size_ = 0; + inner_size_ = 0; + grad_elements_ = 0; + idx_seg_elements_ = 0; + output_dim0_ = 0; + output_elements_ = 0; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + std::vector GetOpSupport() override; + + private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + using SSGLaunchFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t outer_size_{0}; + size_t inner_size_{0}; + size_t grad_elements_{0}; + size_t grad_shape_0_{0}; + size_t idx_seg_elements_{0}; + size_t output_dim0_{0}; + size_t output_elements_{0}; + size_t unit_grad_size_{1}; + size_t unit_idx_seg_size_{1}; + std::string kernel_type_{"Unknown"}; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + SSGLaunchFunc kernel_func_{}; + static std::map>> kernel_attr_map_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 00c1685be3d..058224c5dce 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -136,9 +136,6 @@ constexpr auto kEditDistance = "EditDistance"; constexpr auto kNextAfter = "NextAfter"; constexpr auto kMaximumGradGrad = "MaximumGradGrad"; constexpr auto kSparseSegmentMean = "SparseSegmentMean"; -constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN"; -constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad"; -constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments"; constexpr auto kTridiagonalMatMul = "TridiagonalMatMul"; constexpr auto kFFTWithSize = "FFTWithSize"; constexpr auto kTrace = "Trace"; @@ -348,6 +345,12 @@ constexpr auto kSparseMatrixSoftmax = "SparseMatrixSoftmax"; constexpr auto kSparseMatrixMatMul = "SparseMatrixMatMul"; constexpr auto kSparseMatrixSparseMatMul = "SparseMatrixSparseMatMul"; constexpr auto kSparseMatrixOrderingAMD = "SparseMatrixOrderingAMD"; +constexpr auto kSparseSegmentSum = "SparseSegmentSum"; +constexpr auto kSparseSegmentSumGrad = "SparseSegmentSumGrad"; +constexpr auto kSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments"; +constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN"; +constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad"; +constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments"; // Sparse Grad ops constexpr auto kSparseAddGrad = "SparseAddGrad"; @@ -1040,6 +1043,14 @@ GVAR_DEF(PrimitivePtr, kPrimSparseMatrixSparseMatMul, std::make_shared("CSRSparseMatrixToDense")); GVAR_DEF(PrimitivePtr, kPrimSparseMatrixTranspose, std::make_shared(kSparseMatrixTranspose)); GVAR_DEF(PrimitivePtr, kPrimSparseMatrixOrderingAMD, std::make_shared(kSparseMatrixOrderingAMD)); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSum, std::make_shared("SparseSegmentSum")); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSumGrad, std::make_shared("SparseSegmentSumGrad")); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSumWithNumSegments, + std::make_shared("SparseSegmentSumWithNumSegments")); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared("SparseSegmentSqrtN")); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared("SparseSegmentSqrtNGrad")); +GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments, + std::make_shared("SparseSegmentSqrtNWithNumSegments")); // Sparse Grad ops GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared(kSparseAddGrad)); @@ -1192,10 +1203,6 @@ GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared("Bucketize")) GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared("Einsum")); GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared("EinsumGrad")); GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMean, std::make_shared(kSparseSegmentMean)); -GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared("SparseSegmentSqrtN")); -GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared("SparseSegmentSqrtNGrad")); -GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments, - std::make_shared("SparseSegmentSqrtNWithNumSegments")); GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared("Trace")); GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared("TraceGrad")); GVAR_DEF(PrimitivePtr, kPrimTridiagonalMatMul, std::make_shared(kTridiagonalMatMul)); diff --git a/mindspore/core/ops/grad/sparse_segment_sqrt_n_grad.cc b/mindspore/core/ops/grad/sparse_segment_sqrt_n_grad.cc index 73eb06142fd..b5763003464 100644 --- a/mindspore/core/ops/grad/sparse_segment_sqrt_n_grad.cc +++ b/mindspore/core/ops/grad/sparse_segment_sqrt_n_grad.cc @@ -35,14 +35,22 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim, CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; auto output_dim0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1, + prim->name()); + (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual, + kInputIndex1, prim->name()); if (x_shape.size() < kInputIndex1) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor x's rank is less than 1."; + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', " + << "tensor x's rank must be greater than 1, but got [" << x_shape.size() << "]."; } if (output_dim0_shape.size() != kInputIndex0) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor outputdim0 should be a scalar."; + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, " + << "but got [" << output_dim0_shape.size() << "]."; } if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor indices & segment_ids's ranks mismatch."; + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, " + << "but got indices [" << indices_shape[kInputIndex0] << "] " + << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "]."; } if (!input_args[kInputIndex3]->BuildValue()->isa() && !input_args[kInputIndex3]->BuildValue()->isa()) { @@ -54,7 +62,8 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim, CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name); size_t dim_zero = static_cast(output_dim0_value_ptr_tensor[kInputIndex0]); if (dim_zero <= kInputIndex0) { - MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!"; + MS_EXCEPTION(ValueError) << "For '" << prim_name << "' , tensor output_dim0 must > 0, " + << "but got [" << dim_zero << "]."; } else { ShapeVector y_shape = x_shape; y_shape[kInputIndex0] = static_cast(dim_zero); @@ -62,7 +71,9 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim, } } else { std::vector output_shape = {-2}; - return std::make_shared(output_shape); + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); } } @@ -72,12 +83,14 @@ TypePtr SparseSegmentSqrtNGradInferType(const PrimitivePtr &prim, const std::vec auto indices_type = input_args[kInputIndex1]->BuildType(); auto segment_ids_type = input_args[kInputIndex2]->BuildType(); auto output_dim0_type = input_args[kInputIndex3]->BuildType(); - (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, {kFloat16, kFloat32, kFloat64}, prim->name()); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name()); std::map types; (void)types.emplace("indices", indices_type); (void)types.emplace("segment_ids", segment_ids_type); (void)types.emplace("output_dim0", output_dim0_type); - (void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32}, prim->name()); + CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); return input_args[kInputIndex0]->BuildType(); } } // namespace diff --git a/mindspore/core/ops/grad/sparse_segment_sum_grad.cc b/mindspore/core/ops/grad/sparse_segment_sum_grad.cc new file mode 100644 index 00000000000..5241e40d3eb --- /dev/null +++ b/mindspore/core/ops/grad/sparse_segment_sum_grad.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/sparse_segment_sum_grad.h" +#include "abstract/dshape.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentSumGradInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto output_dim0_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1, + prim->name()); + (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual, + kInputIndex1, prim->name()); + if (grad_shape.size() < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', " + << "tensor grad's rank must be greater than 1, but got [" << grad_shape.size() << "]."; + } + if (output_dim0_shape.size() != kInputIndex0) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, " + << "but got [" << output_dim0_shape.size() << "]."; + } + if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, " + << "but got indices [" << indices_shape[kInputIndex0] << "] " + << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "]."; + } + if (!input_args[kInputIndex3]->BuildValue()->isa() && + !input_args[kInputIndex3]->BuildValue()->isa()) { + auto output_dim0_value = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(output_dim0_value); + auto output_dim0_value_ptr = output_dim0_value->BuildValue(); + MS_EXCEPTION_IF_NULL(output_dim0_value_ptr); + auto output_dim0_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name); + size_t dim_zero = output_dim0_value_ptr_tensor[kInputIndex0]; + if (dim_zero <= kInputIndex0) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "' , tensor output_dim0 must > 0, " + << "but got [" << dim_zero << "]."; + } else { + ShapeVector y_shape = grad_shape; + y_shape[kInputIndex0] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentSumGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto grad_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + auto output_dim0_type = input_args[kInputIndex3]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("grad", grad_type, valid_types, prim->name()); + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + (void)types.emplace("output_dim0", output_dim0_type); + CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseSegmentSumGrad, BaseOperator); +AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + const int64_t input_num = kInputIndex4; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto types = SparseSegmentSumGradInferType(prim, input_args); + auto shapes = SparseSegmentSumGradInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} +REGISTER_HOST_DEPENDS(kNameSparseSegmentSumGrad, {3}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSumGrad, prim::kPrimSparseSegmentSumGrad, SparseSegmentSumGradInfer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/sparse_segment_sum_grad.h b/mindspore/core/ops/grad/sparse_segment_sum_grad.h new file mode 100644 index 00000000000..ded6694e21f --- /dev/null +++ b/mindspore/core/ops/grad/sparse_segment_sum_grad.h @@ -0,0 +1,46 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_ + +#include +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentSumGrad = "SparseSegmentSumGrad"; +class MIND_API SparseSegmentSumGrad : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentSumGrad); + SparseSegmentSumGrad() : BaseOperator(kNameSparseSegmentSumGrad) { + InitIOName({"grad", "indices", "segment_ids", "output_dim0"}, {"output"}); + } +}; + +abstract::AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSegmentSumGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_ diff --git a/mindspore/core/ops/sparse_segment_sqrt_n.cc b/mindspore/core/ops/sparse_segment_sqrt_n.cc index de88f4245f3..d2e77b569b3 100644 --- a/mindspore/core/ops/sparse_segment_sqrt_n.cc +++ b/mindspore/core/ops/sparse_segment_sqrt_n.cc @@ -1,105 +1,110 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "ops/sparse_segment_sqrt_n.h" -#include "abstract/dshape.h" -#include "ops/op_utils.h" -#include "utils/check_convert_utils.h" -#include "utils/tensor_construct_utils.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -namespace { -abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); - constexpr size_t kRankOne = 1; - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; - auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; - auto segment_ids_shape = - CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; - if (indices_shape.size() != kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1."; - } - if (segment_ids_shape.size() != kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1."; - } - if (x_shape.size() < kRankOne) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', x's rank is less than 1."; - } - if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', ranks of indices and segment_ids mismatch."; - } - if (!input_args[kInputIndex2]->BuildValue()->isa() && - !input_args[kInputIndex2]->BuildValue()->isa()) { - auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue(); - MS_EXCEPTION_IF_NULL(segment_ids_value_ptr); - auto segment_ids_value_ptr_tensor = - CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name()); - size_t dim_zero = static_cast(segment_ids_value_ptr_tensor.back()) + kRankOne; - if (dim_zero < kRankOne) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must >= 0!"; - } else { - ShapeVector y_shape = x_shape; - y_shape[kInputIndex0] = static_cast(dim_zero); - return std::make_shared(y_shape); - } - } else { - std::vector output_shape = {-2}; - return std::make_shared(output_shape); - } -} - -TypePtr SparseSegmentSqrtNInferType(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto x_type = input_args[kInputIndex0]->BuildType(); - auto indices_type = input_args[kInputIndex1]->BuildType(); - auto segment_ids_type = input_args[kInputIndex2]->BuildType(); - const std::set valid_types = {kFloat16, kFloat32, kFloat64}; - const std::set common_valid_types = {kInt32, kInt64}; - (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name()); - (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, common_valid_types, prim->name()); - (void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids", segment_ids_type, common_valid_types, prim->name()); - return input_args[kInputIndex0]->BuildType(); -} -} // namespace - -MIND_API_OPERATOR_IMPL(SparseSegmentSqrtN, BaseOperator); -AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); - const int64_t input_num = static_cast(kInputIndex3); - CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); - auto types = SparseSegmentSqrtNInferType(prim, input_args); - auto shapes = SparseSegmentSqrtNInferShape(prim, input_args); - return abstract::MakeAbstract(shapes, types); -} -REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtN, {2}); -REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtN, prim::kPrimSparseSegmentSqrtN, SparseSegmentSqrtNInfer, nullptr, true); -} // namespace ops -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "ops/sparse_segment_sqrt_n.h" +#include "abstract/dshape.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1, + prim->name()); + (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual, + kInputIndex1, prim->name()); + if (x_shape.size() < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', " + << "x's rank must be greater than 1, but got [" << x_shape.size() << "]."; + } + if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, " + << "but got indices [" << indices_shape[kInputIndex0] << "] " + << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "]."; + } + if (!input_args[kInputIndex2]->BuildValue()->isa() && + !input_args[kInputIndex2]->BuildValue()->isa()) { + auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue(); + MS_EXCEPTION_IF_NULL(segment_ids_value_ptr); + auto segment_ids_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name()); + size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1; + if (dim_zero < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, " + << "but got [" << dim_zero << "]."; + } else { + ShapeVector y_shape = x_shape; + y_shape[kInputIndex0] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentSqrtNInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name()); + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseSegmentSqrtN, BaseOperator); +AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + const int64_t input_num = kInputIndex3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto types = SparseSegmentSqrtNInferType(prim, input_args); + auto shapes = SparseSegmentSqrtNInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} +REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtN, {2}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtN, prim::kPrimSparseSegmentSqrtN, SparseSegmentSqrtNInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_segment_sqrt_n_with_num_segments.cc b/mindspore/core/ops/sparse_segment_sqrt_n_with_num_segments.cc index b4cbc9dd4af..653a3fcccb0 100644 --- a/mindspore/core/ops/sparse_segment_sqrt_n_with_num_segments.cc +++ b/mindspore/core/ops/sparse_segment_sqrt_n_with_num_segments.cc @@ -1,122 +1,127 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include "ops/sparse_segment_sqrt_n_with_num_segments.h" -#include "abstract/dshape.h" -#include "ops/op_utils.h" -#include "utils/check_convert_utils.h" -#include "utils/tensor_construct_utils.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindapi/src/helper.h" - -namespace mindspore { -namespace ops { -namespace { -abstract::ShapePtr SparseSegmentSqrtNWithNumSegmentsInferShape(const PrimitivePtr &prim, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); - constexpr size_t kRankOne = 1; - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; - auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; - auto segment_ids_shape = - CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; - auto num_segments_shape = - CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; - if (indices_shape.size() != kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1."; - } - if (segment_ids_shape.size() != kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1."; - } - if (x_shape.size() < kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of x cannot be less than 1."; - } - if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices and segment_ids mismatch."; - } - if (num_segments_shape.size() > kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D."; - } - if (num_segments_shape.size() == kRankOne) { - if (num_segments_shape[kInputIndex0] != kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1."; - } - } - if (!input_args[kInputIndex3]->BuildValue()->isa() && - !input_args[kInputIndex3]->BuildValue()->isa()) { - auto num_segments_value = input_args[kInputIndex3]->cast(); - MS_EXCEPTION_IF_NULL(num_segments_value); - auto num_segments_value_ptr = num_segments_value->BuildValue(); - MS_EXCEPTION_IF_NULL(num_segments_value_ptr); - auto num_segments_value_ptr_tensor = - CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name()); - size_t dim_zero = static_cast(num_segments_value_ptr_tensor.back()); - if (dim_zero < kRankOne) { - MS_EXCEPTION(ValueError) << "For " << prim_name - << ", num_segments must be bigger than the largest id of segment_ids."; - } else { - ShapeVector y_shape = x_shape; - y_shape[kInputIndex0] = static_cast(dim_zero); - return std::make_shared(y_shape); - } - } else { - std::vector output_shape = {-2}; - return std::make_shared(output_shape); - } -} - -TypePtr SparseSegmentSqrtNWithNumSegmentsInferType(const PrimitivePtr &prim, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto x_type = input_args[kInputIndex0]->BuildType(); - auto indices_type = input_args[kInputIndex1]->BuildType(); - auto segment_ids_type = input_args[kInputIndex2]->BuildType(); - auto num_segments_type = input_args[kInputIndex3]->BuildType(); - const std::set valid_types = {kFloat16, kFloat32, kFloat64}; - std::map types; - (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name()); - (void)types.emplace("indices", indices_type); - (void)types.emplace("segment_ids", segment_ids_type); - (void)types.emplace("num_segments", num_segments_type); - (void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32, kInt64}, prim->name()); - return input_args[kInputIndex0]->BuildType(); -} -} // namespace - -MIND_API_OPERATOR_IMPL(SparseSegmentSqrtNWithNumSegments, BaseOperator); -AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - auto prim_name = prim->name(); - const int64_t input_num = static_cast(kInputIndex4); - CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); - auto types = SparseSegmentSqrtNWithNumSegmentsInferType(prim, input_args); - auto shapes = SparseSegmentSqrtNWithNumSegmentsInferShape(prim, input_args); - return abstract::MakeAbstract(shapes, types); -} -REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtNWithNumSegments, {3}); -REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtNWithNumSegments, prim::kPrimSparseSegmentSqrtNWithNumSegments, - SparseSegmentSqrtNWithNumSegmentsInfer, nullptr, true); -} // namespace ops -} // namespace mindspore +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "ops/sparse_segment_sqrt_n_with_num_segments.h" +#include "abstract/dshape.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentSqrtNWithNumSegmentsInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto num_segments_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name()); + (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1, + prim->name()); + if (x_shape.size() < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', " + << "x's rank must be greater than 1, but got [" << x_shape.size() << "]."; + } + if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, " + << "but got indices [" << indices_shape[kInputIndex0] << "] " + << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "]."; + } + if (num_segments_shape.size() > kInputIndex1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D, but got [" + << num_segments_shape.size() << "]."; + } + if (!input_args[kInputIndex3]->BuildValue()->isa() && + !input_args[kInputIndex3]->BuildValue()->isa()) { + if (num_segments_shape.size() == kInputIndex1) { + if (num_segments_shape[kInputIndex0] != kInputIndex1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1, but got [" + << num_segments_shape[kInputIndex0] << "]."; + } + } + auto num_segments_value = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(num_segments_value); + auto num_segments_value_ptr = num_segments_value->BuildValue(); + MS_EXCEPTION_IF_NULL(num_segments_value_ptr); + auto num_segments_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name()); + size_t dim_zero = num_segments_value_ptr_tensor.back(); + if (dim_zero < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", num_segments must bigger than the last number of segment_ids, " + << "but got " << dim_zero << "."; + } else { + ShapeVector y_shape = x_shape; + y_shape[kInputIndex0] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentSqrtNWithNumSegmentsInferType(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + auto num_segments_type = input_args[kInputIndex3]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name()); + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + (void)types.emplace("num_segments", num_segments_type); + CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseSegmentSqrtNWithNumSegments, BaseOperator); +AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + const int64_t input_num = kInputIndex4; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto types = SparseSegmentSqrtNWithNumSegmentsInferType(prim, input_args); + auto shapes = SparseSegmentSqrtNWithNumSegmentsInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} +REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtNWithNumSegments, {3}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtNWithNumSegments, prim::kPrimSparseSegmentSqrtNWithNumSegments, + SparseSegmentSqrtNWithNumSegmentsInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_segment_sum.cc b/mindspore/core/ops/sparse_segment_sum.cc new file mode 100644 index 00000000000..7683486c3f4 --- /dev/null +++ b/mindspore/core/ops/sparse_segment_sum.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/sparse_segment_sum.h" +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentSumInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1, + prim->name()); + (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual, + kInputIndex1, prim->name()); + if (x_shape.size() < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', " + << "x's rank must be greater than 1, but got [" << x_shape.size() << "]."; + } + if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, " + << "but got indices [" << indices_shape[kInputIndex0] << "] " + << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "]."; + } + if (!input_args[kInputIndex2]->BuildValue()->isa() && + !input_args[kInputIndex2]->BuildValue()->isa()) { + auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue(); + MS_EXCEPTION_IF_NULL(segment_ids_value_ptr); + auto segment_ids_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name()); + size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1; + if (dim_zero < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, " + << "but got [" << dim_zero << "]."; + } else { + ShapeVector y_shape = x_shape; + y_shape[kInputIndex0] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentSumInferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name); + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseSegmentSum, BaseOperator); +AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + const int64_t input_num = kInputIndex3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto types = SparseSegmentSumInferType(prim, input_args); + auto shapes = SparseSegmentSumInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} +REGISTER_HOST_DEPENDS(kNameSparseSegmentSum, {2}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSum, prim::kPrimSparseSegmentSum, SparseSegmentSumInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_segment_sum.h b/mindspore/core/ops/sparse_segment_sum.h new file mode 100644 index 00000000000..a70b8ada748 --- /dev/null +++ b/mindspore/core/ops/sparse_segment_sum.h @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentSum = "SparseSegmentSum"; +/// \brief Computes the sum along sparse segments of a tensor. +/// Refer to Python API @ref mindspore.ops.SparseSegmentSum for more details. +class MIND_API SparseSegmentSum : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentSum); + /// \brief Constructor. + SparseSegmentSum() : BaseOperator(kNameSparseSegmentSum) { InitIOName({"x", "indices", "segment_ids"}, {"y"}); } +}; + +AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_ diff --git a/mindspore/core/ops/sparse_segment_sum_with_num_segments.cc b/mindspore/core/ops/sparse_segment_sum_with_num_segments.cc new file mode 100644 index 00000000000..6d5e162415f --- /dev/null +++ b/mindspore/core/ops/sparse_segment_sum_with_num_segments.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/sparse_segment_sum_with_num_segments.h" +#include "abstract/ops/primitive_infer_map.h" +#include "ops/op_utils.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr SparseSegmentSumWithNumSegmentsInferShape(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto segment_ids_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto num_segments_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name()); + (void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1, + prim->name()); + if (x_shape.size() < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', " + << "x's rank must be greater than 1, but got [" << x_shape.size() << "]."; + } + if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, " + << "but got indices [" << indices_shape[kInputIndex0] << "] " + << "and segment_ids [" << segment_ids_shape[kInputIndex0] << "]."; + } + if (num_segments_shape.size() > kInputIndex1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D, but got [" + << num_segments_shape.size() << "]."; + } + if (!input_args[kInputIndex3]->BuildValue()->isa() && + !input_args[kInputIndex3]->BuildValue()->isa()) { + if (num_segments_shape.size() == kInputIndex1) { + if (num_segments_shape[kInputIndex0] != kInputIndex1) { + MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1, but got [" + << num_segments_shape[kInputIndex0] << "]."; + } + } + auto num_segments_value = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(num_segments_value); + auto num_segments_value_ptr = num_segments_value->BuildValue(); + MS_EXCEPTION_IF_NULL(num_segments_value_ptr); + auto num_segments_value_ptr_tensor = + CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name()); + size_t dim_zero = num_segments_value_ptr_tensor.back(); + if (dim_zero < kInputIndex1) { + MS_EXCEPTION(ValueError) << "For " << prim_name + << ", num_segments must bigger than the last number of segment_ids, " + << "but got " << dim_zero << "."; + } else { + ShapeVector y_shape = x_shape; + y_shape[kInputIndex0] = dim_zero; + return std::make_shared(y_shape); + } + } else { + std::vector output_shape = {-2}; + std::vector min_shape = {1}; + std::vector max_shape = {1}; + return std::make_shared(output_shape, min_shape, max_shape); + } +} + +TypePtr SparseSegmentSumWithNumSegmentsInferType(const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + auto x_type = input_args[kInputIndex0]->BuildType(); + auto indices_type = input_args[kInputIndex1]->BuildType(); + auto segment_ids_type = input_args[kInputIndex2]->BuildType(); + auto num_segments_type = input_args[kInputIndex3]->BuildType(); + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64}; + const std::set common_valid_types = {kInt32, kInt64}; + CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name); + std::map types; + (void)types.emplace("indices", indices_type); + (void)types.emplace("segment_ids", segment_ids_type); + (void)types.emplace("num_segments", num_segments_type); + CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); +} +} // namespace + +MIND_API_OPERATOR_IMPL(SparseSegmentSumWithNumSegments, BaseOperator); +AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + constexpr size_t kInputsNum = kInputIndex4; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, prim_name); + auto types = SparseSegmentSumWithNumSegmentsInferType(prim, input_args); + auto shapes = SparseSegmentSumWithNumSegmentsInferShape(prim, input_args); + return abstract::MakeAbstract(shapes, types); +} +REGISTER_HOST_DEPENDS(kNameSparseSegmentSumWithNumSegments, {3}); +REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSumWithNumSegments, prim::kPrimSparseSegmentSumWithNumSegments, + SparseSegmentSumWithNumSegmentsInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sparse_segment_sum_with_num_segments.h b/mindspore/core/ops/sparse_segment_sum_with_num_segments.h new file mode 100644 index 00000000000..56e444055d3 --- /dev/null +++ b/mindspore/core/ops/sparse_segment_sum_with_num_segments.h @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ +#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ +#include +#include +#include +#include +#include +#include "ops/base_operator.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments"; +/// \brief Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids. +/// Refer to Python API @ref mindspore.ops.SparseSegmentSumWithNumSegments for more details. +class MIND_API SparseSegmentSumWithNumSegments : public BaseOperator { + public: + MIND_API_BASE_MEMBER(SparseSegmentSumWithNumSegments); + /// \brief Constructor. + SparseSegmentSumWithNumSegments() : BaseOperator(kNameSparseSegmentSumWithNumSegments) { + InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"}); + } +}; +AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimSparseSegmentSumWithNumSegmentsPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_ diff --git a/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py b/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py index 37225ec2c19..8b368938669 100644 --- a/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py +++ b/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py @@ -20,6 +20,8 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix from mindspore.ops.operations.sparse_ops import SparseToDenseV2 from mindspore.ops.operations.sparse_ops import SparseSoftmax from mindspore.ops.operations.sparse_ops import SparseDenseCwiseAdd +from mindspore.ops.operations.sparse_ops import SparseSegmentSum +from mindspore.ops.operations.sparse_ops import SparseSegmentSumWithNumSegments from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments @@ -31,6 +33,7 @@ from .. import operations as P from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations import _grad_ops as G from .._grad.grad_base import bprop_getters +from .._utils.utils import is_shape_unknown # Unused parameters are placeholders. @@ -103,13 +106,18 @@ def get_bprop_sparse_segment_sqrt_n(self): """Grad definition for `SparseSegmentSqrtN` operation.""" input_grad = G.SparseSegmentSqrtNGrad() shape = P.Shape() + dyn_shape_op = P.TensorShape() def bprop(x, indices, segment_ids, out, dout): - output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32) + shape_x = shape(x) + if is_shape_unknown(shape_x): + shape_x = dyn_shape_op(x) + output_dim0 = P.Cast()(shape_x[0], mstype.int32) indices = F.cast(indices, mstype.int32) segment_ids = F.cast(segment_ids, mstype.int32) dx = input_grad(dout, indices, segment_ids, output_dim0) - return dx, zeros_like(indices), zeros_like(segment_ids) + all_d = (dx, zeros_like(indices), zeros_like(segment_ids)) + return all_d return bprop @@ -119,9 +127,55 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self): """Grad definition for `SparseSegmentSqrtNWithNumSegments` operation.""" input_grad = G.SparseSegmentSqrtNGrad() shape = P.Shape() + dyn_shape_op = P.TensorShape() def bprop(x, indices, segment_ids, num_segments, out, dout): - output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32) + shape_x = shape(x) + if is_shape_unknown(shape_x): + shape_x = dyn_shape_op(x) + output_dim0 = P.Cast()(shape_x[0], mstype.int32) + indices = F.cast(indices, mstype.int32) + segment_ids = F.cast(segment_ids, mstype.int32) + dx = input_grad(dout, indices, segment_ids, output_dim0) + all_d = (dx, zeros_like(indices), zeros_like(segment_ids), zeros_like(num_segments)) + return all_d + + return bprop + + +@bprop_getters.register(SparseSegmentSum) +def get_bprop_sparse_segment_sum(self): + """Grad definition for `SparseSegmentSum` operation.""" + input_grad = G.SparseSegmentSumGrad() + shape = P.Shape() + dyn_shape_op = P.TensorShape() + + def bprop(x, indices, segment_ids, out, dout): + shape_x = shape(x) + if is_shape_unknown(shape_x): + shape_x = dyn_shape_op(x) + output_dim0 = P.Cast()(shape_x[0], mstype.int32) + indices = F.cast(indices, mstype.int32) + segment_ids = F.cast(segment_ids, mstype.int32) + dx = input_grad(dout, indices, segment_ids, output_dim0) + all_d = (dx, zeros_like(indices), zeros_like(segment_ids)) + return all_d + + return bprop + + +@bprop_getters.register(SparseSegmentSumWithNumSegments) +def get_bprop_sparse_segment_sum_with_num_segments(self): + """Grad definition for `SparseSegmentSumWithNumSegments` operation.""" + input_grad = G.SparseSegmentSumGrad() + shape = P.Shape() + dyn_shape_op = P.TensorShape() + + def bprop(x, indices, segment_ids, num_segments, out, dout): + shape_x = shape(x) + if is_shape_unknown(shape_x): + shape_x = dyn_shape_op(x) + output_dim0 = P.Cast()(shape_x[0], mstype.int32) indices = F.cast(indices, mstype.int32) segment_ids = F.cast(segment_ids, mstype.int32) dx = input_grad(dout, indices, segment_ids, output_dim0) diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 31d8112bf29..d95a563f763 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -2017,7 +2017,7 @@ class UpsampleNearest3DGrad(Primitive): """ Upsample the 3-D gradient data with the nearest neighbor interpolation algorithm. - Args: + Args: input_size (listInt): An required listInt. contain 5 elements: [min_batch, channels, depth, height, width]. Must: input_size[0] == grad_output_tensor_size[0], input_size[1] == grad_output_tensor_size[1]. output_size (listInt): An optional listInt. Defaults to none. @@ -2030,7 +2030,8 @@ class UpsampleNearest3DGrad(Primitive): The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width. The number of elements of 'scales' should be the same as the rank of input 'grad_output'. One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified. - Inputs: + + Inputs: - **grad_output** (Tensor) - Tensor of shape [N, C, D, H, W], Must be one of the following types: float16, float32, float64. @@ -3486,6 +3487,50 @@ class MedianGrad(Primitive): self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad']) +class SparseSegmentSumGrad(Primitive): + """ + Computes gradients for SparseSegmentSumGrad operation. + + Inputs: + - **grad** (Tensor) - A tensor. + - **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64. + Has same rank as segment_ids. The shape should be :math:`(N,)`. + - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64. + Values should be sorted and can be repeated. The shape should be :math:`(N,)`. + - **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSum op. + + Outputs: + A Tensor. Has the same type as `grad` . + Has same shape as `grad`, except for dimension 0 which is the value of `output_dim0`. + + Raises: + TypeError: If `grad` or `indices` or `segment_ids` or `output_dim0` is not a tensor. + TypeError: If the dtype of `grad` is not any of the following data types: {float16, float32, float64}. + TypeError: If the dtype of `indices` and `segment_ids` and `output_dim0` is not int32 or int64. + ValueError: If dimension size of `grad` less than 1. + ValueError: If rank of `indices` or `segment_ids` is not 1. + ValueError: If dimension size of `output_dim0` is not 0. + ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`. + ValueError: If `segment_ids` is not sorted. + ValueError: If the last number of `segment_ids` is out of range of grad's first shape. + ValueError: If `indices` is bigger than or equal to `output_dim0`. + + Supported Platforms: + ``GPU`` + """ + __mindspore_signature__ = ( + sig.make_sig('grad', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T), + sig.make_sig('segment_ids', dtype=sig.sig_dtype.T), + sig.make_sig('output_dim0', dtype=sig.sig_dtype.T) + ) + + @prim_attr_register + def __init__(self): + """Initialize SparseSegmentSumGrad""" + self.init_prim_io_names(inputs=['grad', 'indices', 'segment_ids', 'output_dim0'], outputs=['y']) + + class SparseSegmentSqrtNGrad(Primitive): """ Computes gradients for SparseSegmentSqrtNGrad operation. @@ -3513,12 +3558,12 @@ class SparseSegmentSqrtNGrad(Primitive): ValueError: If rank of `indices` or `segment_ids` is not 1. ValueError: If dimension size of `output_dim0` is not 0. ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`. - ValueError: If indices in `segment_ids` are not contiguous or do not start from 0. ValueError: If `segment_ids` is not sorted. - ValueError: If `indices` is out of range of `output_dim0`. + ValueError: If the last number of `segment_ids` is out of range of x's first shape. + ValueError: If `indices` is bigger than or equal to `output_dim0`. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` """ @prim_attr_register diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index f010652b3bd..4e9f8657a1b 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -18,9 +18,10 @@ """Operators for sparse operators.""" -from mindspore._checkparam import Validator as validator -from mindspore.common import dtype as mstype -from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register, Primitive +from ..._checkparam import Validator as validator +from ...common import dtype as mstype +from .. import signature as sig +from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive class SparseDenseCwiseAdd(Primitive): @@ -1122,6 +1123,115 @@ class SparseConcat(Primitive): validator.check_value_type("concat_dim", concat_dim, [int], self.name) +class SparseSegmentSum(Primitive): + """ + Computes the sum along sparse segments of a tensor. + + Inputs: + - **x** (Tensor) - A tensor. + - **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64. + Has same rank as segment_ids. The shape should be :math:`(N,)`. + - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64. + Values should be sorted and can be repeated. The shape should be :math:`(N,)`. + + Outputs: + A Tensor. Has the same type as `x` . + Has same shape as `x`, except for dimension 0 which is the number of segments. + + Raises: + TypeError: If `x` or `indices` or `segment_ids` is not a tensor. + TypeError: If the dtype of `indices` and `segment_ids` is not int32 or int64. + ValueError: If dimension size of `x` less than 1. + ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor. + ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`. + ValueError: If indices in `segment_ids` are not contiguous or do not start from 0. + ValueError: If `segment_ids` is not sorted. + ValueError: If `indices` is out of range of x's first shape. + + Supported Platforms: + ``GPU`` + + Examples: + >>> x = Tensor([[0, 1, 2], [1, 2, 3], [3, 6, 7]], dtype=ms.float32) + >>> indices = Tensor([0, 1, 2], dtype=ms.int32) + >>> segment_ids = Tensor([0, 1, 1], dtype=ms.int32) + >>> sparse_segment_sum = ops.SparseSegmentSum() + >>> out = sparse_segment_sum(x, indices, segment_ids) + >>> print(out) + [[ 0. 1. 2.] + [ 4. 8. 10.]] + """ + __mindspore_signature__ = ( + sig.make_sig('x', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T), + sig.make_sig('segment_ids', dtype=sig.sig_dtype.T) + ) + + @prim_attr_register + def __init__(self): + """Initialize SparseSegmentSum""" + self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids'], outputs=['y']) + self.add_prim_attr("cust_aicpu", self.name) + + +class SparseSegmentSumWithNumSegments(Primitive): + """ + Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids. + + Inputs: + - **x** (Tensor) - A Tensor. + - **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64. + Has same rank as segment_ids. The shape should be :math:`(N,)`. + - **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64. + Values should be sorted and can be repeated. The shape should be :math:`(N,)`. + - **num_segments** (Tensor) - Num_segments should equal the number of distinct segment_ids. + + Outputs: + A Tensor. Has the same type as `x` . + Has same shape as `x`, except for dimension 0 which is the value of `num_segments`. + + Raises: + TypeError: If `x` or `indices` or `segment_ids` or `num_segments` is not a tensor. + TypeError: If the dtype of `indices` and `segment_ids` and `num_segments` is not int32 or int64. + ValueError: If dimension size of `x` less than 1. + ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor. + ValueError: If rank of `num_segments` is bigger than 1. + ValueError: If numelements of `num_segments` is not 1. + ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`. + ValueError: If `segment_ids` is not sorted. + ValueError: If the last number of `segment_ids` is bigger than or equal to `num_segments`. + ValueError: If `indices` is out of range of x's first shape. + + Supported Platforms: + ``GPU`` + + Examples: + >>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16) + >>> indices = Tensor([0, 2, 1], dtype=ms.int32) + >>> segment_ids = Tensor([0, 0, 2], dtype=ms.int32) + >>> num_segments = Tensor([4], dtype=ms.int32) + >>> sparse_segment_sum_with_num_segments = ops.SparseSegmentSumWithNumSegments() + >>> output = sparse_segment_sum_with_num_segments(x, indices, segment_ids, num_segments) + >>> print(output) + [[1. 1. 1. 0.] + [0. 0. 0. 0.] + [0. 1. 1. 0.] + [0. 0. 0. 0.]] + """ + __mindspore_signature__ = ( + sig.make_sig('x', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T), + sig.make_sig('segment_ids', dtype=sig.sig_dtype.T), + sig.make_sig('num_segemnts', dtype=sig.sig_dtype.T) + ) + + @prim_attr_register + def __init__(self): + """Initialize SparseSegmentSumWithNumSegments""" + self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'num_segments'], outputs=['y']) + self.add_prim_attr("cust_aicpu", self.name) + + class SparseSegmentSqrtN(Primitive): """ Computes the sum along sparse segments of a tensor divided by the sqrt of N. @@ -1151,7 +1261,7 @@ class SparseSegmentSqrtN(Primitive): ValueError: If `indices` is out of range of x's first dimension. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]]).astype(np.float32)) @@ -1164,6 +1274,11 @@ class SparseSegmentSqrtN(Primitive): [ 5. 6. 7. 8.] [ 9. 10. 11. 12.]] """ + __mindspore_signature__ = ( + sig.make_sig('x', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T), + sig.make_sig('segment_ids', dtype=sig.sig_dtype.T) + ) @prim_attr_register def __init__(self): @@ -1208,7 +1323,7 @@ class SparseSegmentSqrtNWithNumSegments(Primitive): ValueError: If `indices` is out of range of x's first dimension. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16) diff --git a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py new file mode 100644 index 00000000000..416476898f4 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_grad_op.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations._grad_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class SparseSegmentSqrtNGradNet(nn.Cell): + def __init__(self): + super(SparseSegmentSqrtNGradNet, self).__init__() + self.net = P.SparseSegmentSqrtNGrad() + + @ms_function + def construct(self, grad, indices, segment_ids, output_dim0): + return self.net(grad, indices, segment_ids, output_dim0) + + +def sparse_segment_sqrt_n_grad(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + output_dim0_np = np.array(8, dtype=np.int32) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSqrtNGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[6, 8, 10, 12], + [6.363961, 7.071068, 7.7781744, 8.485281], + [15.55635, 16.970562, 18.384777, 19.798988], + [9.192389, 9.899495, 10.606602, 11.313708], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sqrt_n_grad_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + output_dim0_np = np.array(8, dtype=np.int64) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSqrtNGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[0.70710678, 1.41421356, 2.12132034, 2.82842712], + [5.70710678, 7.41421356, 9.12132034, 10.82842712], + [6.36396103, 7.07106781, 7.77817459, 8.48528137], + [19.36396103, 21.07106781, 22.77817459, 24.48528137], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_grad(loss=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_grad_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py new file mode 100644 index 00000000000..44469d35b80 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_op.py @@ -0,0 +1,89 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class SparseSegmentSqrtNNet(nn.Cell): + def __init__(self): + super(SparseSegmentSqrtNNet, self).__init__() + self.net = P.SparseSegmentSqrtN() + + @ms_function + def construct(self, x, indices, segment_ids): + return self.net(x, indices, segment_ids) + + +def sparse_segment_sqrt_n(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSqrtNNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array([[1, 2, 3, 4], + [1, 2, 3, 4], + [9.899495, 11.313708, 12.727922, 14.142136], + [15.556349, 16.970562, 18.384777, 19.79899]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sqrt_n_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSqrtNNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array([[4.24264069, 5.65685425, 7.07106781, 8.48528137], + [5, 6, 7, 8], + [15.55634919, 16.97056275, 18.38477631, 19.79898987], + [13, 14, 15, 16]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sqrt_n_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtN + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n(loss=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sqrt_n_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtN + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py new file mode 100644 index 00000000000..7c0228474b6 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_segment_sqrt_n_with_num_segments_op.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class SparseSegmentSqrtNWithNumSegmentsNet(nn.Cell): + def __init__(self): + super(SparseSegmentSqrtNWithNumSegmentsNet, self).__init__() + self.net = P.SparseSegmentSqrtNWithNumSegments() + + @ms_function + def construct(self, x, indices, segment_ids, num_segments): + return self.net(x, indices, segment_ids, num_segments) + + +def sparse_segment_sqrt_n_with_num_segments(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32) + num_segments_np = np.array(8, dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSqrtNWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[1, 2, 3, 4], + [0, 0, 0, 0], + [0, 0, 0, 0], + [4.2426405, 5.656854, 7.071068, 8.485281], + [0, 0, 0, 0], + [9, 10, 11, 12], + [0, 0, 0, 0], + [15.556349, 16.970562, 18.384777, 19.79899]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sqrt_n_with_num_segments_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64) + num_segments_np = np.array(8, dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSqrtNWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[4.24264069, 5.65685425, 7.07106781, 8.48528137], + [0, 0, 0, 0], + [0, 0, 0, 0], + [5, 6, 7, 8], + [0, 0, 0, 0], + [15.55634919, 16.97056275, 18.38477631, 19.79898987], + [0, 0, 0, 0], + [13, 14, 15, 16]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sqrt_n_with_num_segments_graph_float32_int32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_with_num_segments(loss=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sqrt_n_with_num_segments_pynative_float64_int64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSqrtNWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sqrt_n_with_num_segments_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py b/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py new file mode 100644 index 00000000000..e40ca158145 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_segment_sum_grad_op.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations._grad_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class SparseSegmentSumGradNet(nn.Cell): + def __init__(self): + super(SparseSegmentSumGradNet, self).__init__() + self.net = P.SparseSegmentSumGrad() + + @ms_function + def construct(self, grad, indices, segment_ids, output_dim0): + return self.net(grad, indices, segment_ids, output_dim0) + + +def sparse_segment_sum_grad(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + output_dim0_np = np.array(8, dtype=np.int32) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSumGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[6, 8, 10, 12], + [9, 10, 11, 12], + [22, 24, 26, 28], + [13, 14, 15, 16], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sum_grad_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + output_dim0_np = np.array(8, dtype=np.int64) + grad_ms = Tensor(grad_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + output_dim0_ms = Tensor(output_dim0_np) + net_ms = SparseSegmentSumGradNet() + out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms) + expected = np.array([[1, 2, 3, 4.], + [6, 8, 10, 12], + [9, 10, 11, 12], + [22, 24, 26, 28], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sum_grad_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sum_grad(loss=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sum_grad_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumGrad + Expectation: the result match to tensorflow + """ + sparse_segment_sum_grad_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sum_op.py b/tests/st/ops/gpu/test_sparse_segment_sum_op.py new file mode 100644 index 00000000000..2125a825fa3 --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_segment_sum_op.py @@ -0,0 +1,89 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class SparseSegmentSumNet(nn.Cell): + def __init__(self): + super(SparseSegmentSumNet, self).__init__() + self.net = P.SparseSegmentSum() + + @ms_function + def construct(self, x, indices, segment_ids): + return self.net(x, indices, segment_ids) + + +def sparse_segment_sum(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSumNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array([[1, 2, 3, 4], + [1, 2, 3, 4], + [14, 16, 18, 20], + [22, 24, 26, 28]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sum_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + net_ms = SparseSegmentSumNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms) + expected = np.array([[6, 8, 10, 12], + [5, 6, 7, 8], + [22, 24, 26, 28], + [13, 14, 15, 16]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sum_graph_float32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSum + Expectation: the result match to tensorflow + """ + sparse_segment_sum(loss=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sum_pynative_float64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSum + Expectation: the result match to tensorflow + """ + sparse_segment_sum_pynative(loss=1.0e-5) diff --git a/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py b/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py new file mode 100644 index 00000000000..5c480538e7c --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_segment_sum_with_num_segments_op.py @@ -0,0 +1,101 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +import mindspore.ops.operations.sparse_ops as P +from mindspore import Tensor +from mindspore.common.api import ms_function + + +class SparseSegmentSumWithNumSegmentsNet(nn.Cell): + def __init__(self): + super(SparseSegmentSumWithNumSegmentsNet, self).__init__() + self.net = P.SparseSegmentSumWithNumSegments() + + @ms_function + def construct(self, x, indices, segment_ids, num_segments): + return self.net(x, indices, segment_ids, num_segments) + + +def sparse_segment_sum_with_num_segments(loss): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32) + segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32) + num_segments_np = np.array(8, dtype=np.int32) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSumWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[1, 2, 3, 4], + [0, 0, 0, 0], + [0, 0, 0, 0], + [6, 8, 10, 12], + [0, 0, 0, 0], + [9, 10, 11, 12], + [0, 0, 0, 0], + [22, 24, 26, 28]], dtype=np.float32) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +def sparse_segment_sum_with_num_segments_pynative(loss): + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64) + indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64) + segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64) + num_segments_np = np.array(8, dtype=np.int64) + x_ms = Tensor(x_np) + indices_ms = Tensor(indices_np) + segment_ids_ms = Tensor(segment_ids_np) + num_segments_ms = Tensor(num_segments_np) + net_ms = SparseSegmentSumWithNumSegmentsNet() + out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms) + expected = np.array([[6, 8, 10, 12], + [0, 0, 0, 0], + [0, 0, 0, 0], + [5, 6, 7, 8], + [0, 0, 0, 0], + [22, 24, 26, 28], + [0, 0, 0, 0], + [13, 14, 15, 16]], dtype=np.float64) + assert np.allclose(out_ms.asnumpy(), expected, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sum_with_num_segments_graph_float32_int32_int32_int32(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sum_with_num_segments(loss=1.0e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparse_segment_sum_with_num_segments_pynative_float64_int64_int64_int64(): + """ + Feature: ALL To ALL + Description: test cases for SparseSegmentSumWithNumSegments + Expectation: the result match to tensorflow + """ + sparse_segment_sum_with_num_segments_pynative(loss=1.0e-5)