!37109 [feat] [assistant] [ops] [I5EWJI,I5EWJJ,I5EWJK,I5EWJC,I5EWJD,I5EWJG] New GPU operator implementation SparseSegmentOps

Merge pull request !37109 from 路雄博/SSSqG
This commit is contained in:
i-robot 2022-09-13 01:12:59 +00:00 committed by Gitee
commit 3324fea63f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
27 changed files with 3093 additions and 254 deletions

View File

@ -0,0 +1,225 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh"
template <typename S>
__global__ void SparseSegmentPosKernel(const S *indices_ptr, size_t *indices_pos_ptr, size_t idx_seg_size,
size_t outer_size) {
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) {
const S max_size = static_cast<S>(outer_size);
const S min_size = S(0);
S beg_idx = (id == 0) ? min_size : indices_ptr[id - 1] + 1;
S end_idx = (id >= idx_seg_size) ? max_size : indices_ptr[id];
beg_idx = max(min_size, min(max_size, beg_idx));
end_idx = max(min_size, min(max_size, end_idx));
for (S i = beg_idx; i <= end_idx; i++) {
indices_pos_ptr[i] = id;
}
}
}
template <typename R, typename S>
__global__ void SparseSegmentSumGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
size_t output_dim0, R *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
size_t beg_pos = indices_pos_ptr[inid];
size_t end_pos = indices_pos_ptr[inid + 1];
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
double reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = static_cast<double>(grad_ptr[index * inner_size + inner_idx]);
}
if (threadIdx.y == 0 && inner_valid) {
R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx;
MsAtomicAdd(out_pos, static_cast<R>(reduce_result));
}
}
}
}
}
template <typename S>
__global__ void SparseSegmentSumGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
size_t output_dim0, half *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
size_t beg_pos = indices_pos_ptr[inid];
size_t end_pos = indices_pos_ptr[inid + 1];
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
double reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = static_cast<double>(__half2float(grad_ptr[index * inner_size + inner_idx]));
}
if (threadIdx.y == 0 && inner_valid) {
half *out_pos = y_ptr + segment_ids_ptr[pos]* inner_size + inner_idx;
MsAtomicAdd(out_pos, __float2half(static_cast<float>(reduce_result)));
}
}
}
}
}
template <typename R, typename S>
__global__ void SparseSegmentSqrtNGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
size_t output_dim0, R *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
size_t beg_pos = indices_pos_ptr[inid];
size_t end_pos = indices_pos_ptr[inid + 1];
double sqrt_segment_len = sqrt(static_cast<double>(end_pos - beg_pos));
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
double reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = static_cast<double>(grad_ptr[index * inner_size + inner_idx]);
}
if (threadIdx.y == 0 && inner_valid) {
R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx;
MsAtomicAdd(out_pos, static_cast<R>(reduce_result / sqrt_segment_len));
}
}
}
}
}
template <typename S>
__global__ void SparseSegmentSqrtNGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
size_t output_dim0, half *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
size_t beg_pos = indices_pos_ptr[inid];
size_t end_pos = indices_pos_ptr[inid + 1];
double sqrt_segment_len = sqrt(static_cast<double>(end_pos - beg_pos));
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
double reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = static_cast<double>(__half2float(grad_ptr[index * inner_size + inner_idx]));
}
if (threadIdx.y == 0 && inner_valid) {
half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx;
MsAtomicAdd(out_pos, __float2half(static_cast<float>(reduce_result / sqrt_segment_len)));
}
}
}
}
}
inline int Log2Floor_M(uint32_t n) {
if (n == 0) return -1;
int log = 0;
for (int i = 4; i >= 0; --i) {
int shift = (1 << i);
uint32_t x = n >> shift;
if (x) {
n = x;
log += shift;
}
}
return log;
}
inline int Log2Floor64_M(uint64_t n) {
// Scan n first high 32 then low 32 bits.
const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
if (high_32_bit == 0) {
return Log2Floor_M(static_cast<uint32_t>(n));
} else {
return 32 + Log2Floor_M(high_32_bit);
}
}
inline int Log2Ceil64_M(uint64_t n) {
int floor = Log2Floor64_M(n);
if (n == (n & ~(n - 1)))
return floor;
else
return floor + 1;
}
template <typename R, typename S>
bool CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, const S *indices_ptr,
const S *segment_ids_ptr, size_t *indices_pos_ptr, size_t outer_size,
size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr,
uint32_t device_id, cudaStream_t cuda_stream) {
// Get start position of each segment and set to indices_pos_ptr.
// The last element of indices_pos_ptr must equal to idx_seg_size.
SparseSegmentPosKernel<<<CUDA_BLOCKS(device_id, idx_seg_size + 1), CUDA_THREADS(device_id), 0, cuda_stream>>>(
indices_ptr, indices_pos_ptr, idx_seg_size, outer_size);
const unsigned int max_grid_x = (1u << 31) - 1;
const unsigned int max_grid_y = (1u << 16) - 1;
unsigned int block_x = 32;
unsigned int block_y = 1;
unsigned int grid_x = std::min(static_cast<unsigned int>(UP_DIV(inner_size, block_x)), max_grid_x);
unsigned int grid_y = std::min(static_cast<unsigned int>(output_dim0), max_grid_y);
dim3 block(block_x, block_y);
dim3 grid(grid_x, grid_y);
unsigned int shared_memory_size = block_x * block_y * sizeof(R);
if (kernel_type == "SparseSegmentSumGrad") {
SparseSegmentSumGradKernel<<<grid, block, shared_memory_size, cuda_stream>>>(grad_ptr, indices_ptr,
segment_ids_ptr, indices_pos_ptr,
outer_size, inner_size, output_dim0,
y_ptr);
} else if (kernel_type == "SparseSegmentSqrtNGrad") {
SparseSegmentSqrtNGradKernel<<<grid, block, shared_memory_size, cuda_stream>>>(grad_ptr, indices_ptr,
segment_ids_ptr, indices_pos_ptr,
outer_size, inner_size,
output_dim0, y_ptr);
}
return true;
}
#define ADD_SPARSE_SEGMENT_GRAD(R, S) \
template CUDA_LIB_EXPORT bool CalSparseSegmentGradCombination<R, S>(const std::string kernel_type, \
const R *grad_ptr, const S *indices_ptr, \
const S *segment_ids_ptr, \
size_t *indices_pos_ptr, size_t outer_size, \
size_t inner_size, size_t idx_seg_size, \
size_t output_dim0, R *y_ptr, \
uint32_t device_id, cudaStream_t cuda_stream);
ADD_SPARSE_SEGMENT_GRAD(half, int32_t)
ADD_SPARSE_SEGMENT_GRAD(half, int64_t)
ADD_SPARSE_SEGMENT_GRAD(float, int32_t)
ADD_SPARSE_SEGMENT_GRAD(float, int64_t)
ADD_SPARSE_SEGMENT_GRAD(double, int32_t)
ADD_SPARSE_SEGMENT_GRAD(double, int64_t)

View File

@ -0,0 +1,30 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_
#include <string>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename R, typename S>
CUDA_LIB_EXPORT bool CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr,
const S *indices_ptr, const S *segment_ids_ptr,
size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
size_t idx_seg_size, size_t output_dim0, R *y_ptr,
uint32_t device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_

View File

@ -0,0 +1,307 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh"
template <typename S>
__global__ void SparseSegmentPosKernel(const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t idx_seg_size,
size_t output_dim0) {
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) {
const S max_size = static_cast<S>(output_dim0);
const S min_size = S(0);
S beg_idx = (id == 0) ? min_size : segment_ids_ptr[id - 1] + 1;
S end_idx = (id >= idx_seg_size) ? max_size : segment_ids_ptr[id];
beg_idx = max(min_size, min(max_size, beg_idx));
end_idx = max(min_size, min(max_size, end_idx));
for (S i = beg_idx; i <= end_idx; i++) {
segment_pos_ptr[i] = id;
}
}
}
template <typename R, typename S>
__global__ void SparseSegmentSumKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
size_t beg_pos = segment_pos_ptr[sid];
size_t end_pos = segment_pos_ptr[sid + 1];
R segment_sum = 0;
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
R reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = x_ptr[index * inner_size + inner_idx];
}
if (threadIdx.y == 0 && inner_valid) {
segment_sum += reduce_result;
}
}
if (threadIdx.y == 0 && inner_valid) {
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum;
}
}
}
}
template <typename S>
__global__ void SparseSegmentSumKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
size_t beg_pos = segment_pos_ptr[sid];
size_t end_pos = segment_pos_ptr[sid + 1];
double segment_sum = 0;
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
double reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = static_cast<double>(x_ptr[index * inner_size + inner_idx]);
}
if (threadIdx.y == 0 && inner_valid) {
segment_sum += reduce_result;
}
}
if (threadIdx.y == 0 && inner_valid) {
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? static_cast<float>(0) :
static_cast<float>(segment_sum);
}
}
}
}
template <typename S>
__global__ void SparseSegmentSumKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
size_t beg_pos = segment_pos_ptr[sid];
size_t end_pos = segment_pos_ptr[sid + 1];
float segment_sum = 0;
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
float reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]);
}
if (threadIdx.y == 0 && inner_valid) {
segment_sum += reduce_result;
}
}
if (threadIdx.y == 0 && inner_valid) {
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) : __float2half(segment_sum);
}
}
}
}
template <typename R, typename S>
__global__ void SparseSegmentSqrtNKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
size_t beg_pos = segment_pos_ptr[sid];
size_t end_pos = segment_pos_ptr[sid + 1];
R segment_sum = 0;
R sqrt_segment_len = R(sqrt(static_cast<double>(end_pos - beg_pos)));
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
R reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = x_ptr[index * inner_size + inner_idx];
}
if (threadIdx.y == 0 && inner_valid) {
segment_sum += reduce_result;
}
}
if (threadIdx.y == 0 && inner_valid) {
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum / sqrt_segment_len;
}
}
}
}
template <typename S>
__global__ void SparseSegmentSqrtNKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
size_t beg_pos = segment_pos_ptr[sid];
size_t end_pos = segment_pos_ptr[sid + 1];
double segment_sum = 0;
double sqrt_segment_len = sqrt(static_cast<double>(end_pos - beg_pos));
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
double reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = static_cast<double>(x_ptr[index * inner_size + inner_idx]);
}
if (threadIdx.y == 0 && inner_valid) {
segment_sum += reduce_result;
}
}
if (threadIdx.y == 0 && inner_valid) {
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? static_cast<float>(0) :
static_cast<float>(segment_sum / sqrt_segment_len);
}
}
}
}
template <typename S>
__global__ void SparseSegmentSqrtNKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) {
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
size_t inner_idx = threadIdx.x + bid * blockDim.x;
bool inner_valid = inner_idx < inner_size;
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
size_t beg_pos = segment_pos_ptr[sid];
size_t end_pos = segment_pos_ptr[sid + 1];
float segment_sum = 0;
float sqrt_segment_len = sqrt(static_cast<float>(end_pos - beg_pos));
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
float reduce_result = 0;
S index = inner_valid ? indices_ptr[pos] : outer_size;
if (index >= 0 && index < outer_size) {
reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]);
}
if (threadIdx.y == 0 && inner_valid) {
segment_sum += reduce_result;
}
}
if (threadIdx.y == 0 && inner_valid) {
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) :
__float2half(segment_sum / sqrt_segment_len);
}
}
}
}
inline int Log2Floor(uint32_t n) {
if (n == 0) return -1;
int log = 0;
for (int i = 4; i >= 0; --i) {
int shift = (1 << i);
uint32_t x = n >> shift;
if (x) {
n = x;
log += shift;
}
}
return log;
}
inline int Log2Floor64(uint64_t n) {
// Scan n first high 32 then low 32 bits.
const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
if (high_32_bit == 0) {
return Log2Floor(static_cast<uint32_t>(n));
} else {
return 32 + Log2Floor(high_32_bit);
}
}
inline int Log2Ceil64(uint64_t n) {
int floor = Log2Floor64(n);
if (n == (n & ~(n - 1)))
return floor;
else
return floor + 1;
}
template <typename R, typename S>
bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr,
const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size,
size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr,
uint32_t device_id, cudaStream_t cuda_stream) {
// Get start position of each segment and set to segment_pos_ptr.
// The last element of segment_pos_ptr must equal to idx_seg_size.
SparseSegmentPosKernel<<<CUDA_BLOCKS(device_id, idx_seg_size + 1), CUDA_THREADS(device_id), 0, cuda_stream>>>(
segment_ids_ptr, segment_pos_ptr, idx_seg_size, output_dim0);
const unsigned int max_grid_x = (1u << 31) - 1;
const unsigned int max_grid_y = (1u << 16) - 1;
unsigned int block_x = 32;
unsigned int block_y = 1;
unsigned int grid_x = std::min(static_cast<unsigned int>(UP_DIV(inner_size, block_x)), max_grid_x);
unsigned int grid_y = std::min(static_cast<unsigned int>(output_dim0), max_grid_y);
dim3 block(block_x, block_y);
dim3 grid(grid_x, grid_y);
unsigned int shared_memory_size = block_x * block_y * sizeof(R);
if (kernel_type == "SparseSegmentSum" || kernel_type == "SparseSegmentSumWithNumSegments") {
SparseSegmentSumKernel<<<grid, block, shared_memory_size, cuda_stream>>>(x_ptr, indices_ptr, segment_pos_ptr,
outer_size, inner_size, output_dim0,
y_ptr);
} else if (kernel_type == "SparseSegmentSqrtN" || kernel_type == "SparseSegmentSqrtNWithNumSegments") {
SparseSegmentSqrtNKernel<<<grid, block, shared_memory_size, cuda_stream>>>(x_ptr, indices_ptr, segment_pos_ptr,
outer_size, inner_size, output_dim0,
y_ptr);
}
return true;
}
#define ADD_SPARSE_SEGMENT(R, S) \
template CUDA_LIB_EXPORT bool CalSparseSegmentCombination<R, S>(const std::string kernel_type, const R *x_ptr, \
const S *indices_ptr, const S *segment_ids_ptr, \
size_t *segment_pos_ptr, size_t outer_size, \
size_t inner_size, size_t idx_seg_size, \
size_t output_dim0, R *y_ptr, uint32_t device_id, \
cudaStream_t cuda_stream);
ADD_SPARSE_SEGMENT(uint8_t, int32_t)
ADD_SPARSE_SEGMENT(uint8_t, int64_t)
ADD_SPARSE_SEGMENT(uint16_t, int32_t)
ADD_SPARSE_SEGMENT(uint16_t, int64_t)
ADD_SPARSE_SEGMENT(int8_t, int32_t)
ADD_SPARSE_SEGMENT(int8_t, int64_t)
ADD_SPARSE_SEGMENT(int16_t, int32_t)
ADD_SPARSE_SEGMENT(int16_t, int64_t)
ADD_SPARSE_SEGMENT(int32_t, int32_t)
ADD_SPARSE_SEGMENT(int32_t, int64_t)
ADD_SPARSE_SEGMENT(int64_t, int32_t)
ADD_SPARSE_SEGMENT(int64_t, int64_t)
ADD_SPARSE_SEGMENT(half, int32_t)
ADD_SPARSE_SEGMENT(half, int64_t)
ADD_SPARSE_SEGMENT(float, int32_t)
ADD_SPARSE_SEGMENT(float, int64_t)
ADD_SPARSE_SEGMENT(double, int32_t)
ADD_SPARSE_SEGMENT(double, int64_t)

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_
#include <string>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename R, typename S>
CUDA_LIB_EXPORT bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr,
const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size,
size_t inner_size, size_t indices_size, size_t output_dim0, R *y_ptr,
uint32_t device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_

View File

@ -0,0 +1,506 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto Sparse_Segment_Sum = "SparseSegmentSum";
constexpr auto Sparse_Segment_Sum_With_Num_Segments = "SparseSegmentSumWithNumSegments";
constexpr auto Sparse_Segment_Sqrt_N = "SparseSegmentSqrtN";
constexpr auto Sparse_Segment_Sqrt_N_With_Num_Segments = "SparseSegmentSqrtNWithNumSegments";
constexpr size_t kNumber1 = 1;
constexpr size_t kNumber3 = 3;
constexpr size_t kNumber4 = 4;
} // namespace
bool SparseSegmentOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
if (kernel_type_ == "SparseSegmentSum" || kernel_type_ == "SparseSegmentSqrtN") {
flag_ = true;
} else {
flag_ = false;
}
size_t inputs_num = flag_ ? kNumber3 : kNumber4;
size_t outputs_num = kNumber1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_);
kernel_name_ = base_operator->GetPrim()->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << ".";
return false;
}
kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second;
unit_x_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
return true;
}
int SparseSegmentOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
for (const auto &output : outputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto output_shape = output->GetShapeVector();
if (!IsValidShape(output_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
if (output_elements_ == 0) {
is_null_input_ = true;
}
std::vector<int64_t> x_shape = inputs.at(kIndex0)->GetShapeVector();
x_shape_0_ = x_shape[0];
x_elements_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies{});
outer_size_ = x_shape.front();
inner_size_ = x_elements_ / x_shape.front();
std::vector<int64_t> indices_shape = inputs.at(kIndex1)->GetShapeVector();
idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{});
output_dim0_ = LongToSize(output_shape.front());
size_t input_x_size = x_elements_ * unit_x_size_;
size_t input_idx_seg_size = idx_seg_elements_ * unit_idx_seg_size_;
size_t output_size = output_elements_ * unit_x_size_;
input_size_list_.push_back(input_x_size);
input_size_list_.push_back(input_idx_seg_size);
input_size_list_.push_back(input_idx_seg_size);
if (flag_) {
input_size_list_.push_back(unit_idx_seg_size_);
}
output_size_list_.push_back(output_size);
workspace_size_list_.push_back((output_dim0_ + 1) * sizeof(size_t));
return KRET_OK;
}
template <typename R, typename S>
bool SparseSegmentOpsGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
R *x_ptr = GetDeviceAddress<R>(inputs, kIndex0);
S *indices_ptr = GetDeviceAddress<S>(inputs, kIndex1);
S *segment_ids_ptr = GetDeviceAddress<S>(inputs, kIndex2);
R *y_ptr = GetDeviceAddress<R>(outputs, kIndex0);
size_t *segment_pos_ptr = GetDeviceAddress<size_t>(workspace, kIndex0);
auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); };
if (any(x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) {
return false;
}
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
std::vector<S> indices_host;
std::vector<S> segment_ids_host;
std::vector<S> num_segments_host;
indices_host.resize(idx_seg_elements_);
segment_ids_host.resize(idx_seg_elements_);
num_segments_host.resize(kNumber1);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
"cudaMemcpy failed.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr,
idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
"cudaMemcpy failed.");
if (!flag_) {
auto num_segments_ptr = GetDeviceAddress<S>(inputs, kIndex3);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(num_segments_host.data(), num_segments_ptr, sizeof(S), cudaMemcpyDeviceToHost, stream),
"cudaMemcpy failed.");
}
if (segment_ids_host[0] != 0 && flag_) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "', indices in 'segment_ids' should be contiguous and start from 0.";
}
for (size_t i = 1; i < idx_seg_elements_; i++) {
if (segment_ids_host[i] < segment_ids_host[i - 1]) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
}
if (segment_ids_host[i] - segment_ids_host[i - 1] > 1 && flag_) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "', indices in 'segment_ids' should be contiguous and start from 0.";
}
}
if (segment_ids_host[idx_seg_elements_ - 1] >= num_segments_host[kIndex0] && !flag_) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
<< "', num_segments must bigger than the last number of segment_ids.";
}
for (size_t i = 0; i < idx_seg_elements_; i++) {
if (indices_host[i] >= static_cast<S>(x_shape_0_)) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape.";
}
}
CalSparseSegmentCombination(kernel_type_, x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, outer_size_,
inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream);
return true;
}
std::map<std::string, std::vector<std::pair<KernelAttr, SparseSegmentOpsGpuKernelMod::SSLaunchFunc>>>
SparseSegmentOpsGpuKernelMod::kernel_attr_map_ = {
{Sparse_Segment_Sum,
{{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
{Sparse_Segment_Sum_With_Num_Segments,
{{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeUInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeUInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt8),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
{Sparse_Segment_Sqrt_N,
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
{Sparse_Segment_Sqrt_N_With_Num_Segments,
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}}}; // kernel_attr_map_
std::vector<KernelAttr> SparseSegmentOpsGpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map_.find(kernel_type_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'SparseSegmentOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map_)
<< " currently, but got " << kernel_name_;
}
std::vector<KernelAttr> support_list;
(void)std::transform(
iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SparseSegmentOpsGpuKernelMod::SSLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSum,
[]() { return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sum); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumWithNumSegments, []() {
return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sum_With_Num_Segments);
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtN, []() {
return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sqrt_N);
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNWithNumSegments, []() {
return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sqrt_N_With_Num_Segments);
});
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,101 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_
#include <vector>
#include <utility>
#include <string>
#include <memory>
#include <map>
#include <algorithm>
#include <functional>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class SparseSegmentOpsGpuKernelMod : public NativeGpuKernelMod {
public:
explicit SparseSegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~SparseSegmentOpsGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
outer_size_ = 0;
inner_size_ = 0;
x_elements_ = 0;
x_shape_0_ = 0;
idx_seg_elements_ = 0;
output_dim0_ = 0;
output_elements_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename R, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using SSLaunchFunc =
std::function<bool(SparseSegmentOpsGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
size_t outer_size_{0};
size_t inner_size_{0};
size_t x_elements_{0};
size_t x_shape_0_{0};
size_t idx_seg_elements_{0};
size_t output_dim0_{0};
size_t output_elements_{0};
size_t unit_x_size_{1};
size_t unit_idx_seg_size_{1};
std::string kernel_type_{"Unknown"};
bool is_null_input_{false};
size_t flag_{0};
void *cuda_stream_{nullptr};
SSLaunchFunc kernel_func_{};
static std::map<std::string, std::vector<std::pair<KernelAttr, SSLaunchFunc>>> kernel_attr_map_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_

View File

@ -0,0 +1,245 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto Sparse_Segment_Sum_Grad = "SparseSegmentSumGrad";
constexpr auto Sparse_Segment_Sqrt_N_Grad = "SparseSegmentSqrtNGrad";
constexpr size_t kNumber1 = 1;
constexpr size_t kNumber4 = 4;
} // namespace
bool SparseSegmentGradOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
size_t inputs_num = kNumber4;
size_t outputs_num = kNumber1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_);
kernel_name_ = base_operator->GetPrim()->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << ".";
return false;
}
kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second;
unit_grad_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
return true;
}
int SparseSegmentGradOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
for (const auto &output : outputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto output_shape = output->GetShapeVector();
if (!IsValidShape(output_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
if (output_elements_ == 0) {
is_null_input_ = true;
}
std::vector<int64_t> grad_shape = inputs.at(kIndex0)->GetShapeVector();
grad_shape_0_ = grad_shape[0];
grad_elements_ = std::accumulate(grad_shape.begin(), grad_shape.end(), 1, std::multiplies{});
outer_size_ = grad_shape.front();
inner_size_ = grad_elements_ / outer_size_;
std::vector<int64_t> indices_shape = inputs.at(kIndex1)->GetShapeVector();
idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{});
output_dim0_ = LongToSize(output_shape.front());
size_t input_grad_size = grad_elements_ * unit_grad_size_;
size_t input_idx_seg_size = idx_seg_elements_ * unit_idx_seg_size_;
size_t output_size = output_elements_ * unit_grad_size_;
input_size_list_.push_back(input_grad_size);
input_size_list_.push_back(input_idx_seg_size);
input_size_list_.push_back(input_idx_seg_size);
input_size_list_.push_back(unit_idx_seg_size_);
output_size_list_.push_back(output_size);
workspace_size_list_.push_back((outer_size_ + 1) * sizeof(size_t));
return KRET_OK;
}
template <typename R, typename S>
bool SparseSegmentGradOpsGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
R *grad_ptr = GetDeviceAddress<R>(inputs, kIndex0);
S *indices_ptr = GetDeviceAddress<S>(inputs, kIndex1);
S *segment_ids_ptr = GetDeviceAddress<S>(inputs, kIndex2);
R *y_ptr = GetDeviceAddress<R>(outputs, kIndex0);
size_t *segment_pos_ptr = GetDeviceAddress<size_t>(workspace, kIndex0);
auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); };
if (any(grad_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) {
return false;
}
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
std::vector<S> indices_host;
std::vector<S> segment_ids_host;
indices_host.resize(idx_seg_elements_);
segment_ids_host.resize(idx_seg_elements_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
"cudaMemcpy failed.");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr,
idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
"cudaMemcpy failed.");
for (size_t i = 1; i < idx_seg_elements_; i++) {
if (segment_ids_host[i] < segment_ids_host[i - 1]) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
}
}
for (size_t i = 0; i < idx_seg_elements_; i++) {
if (indices_host[i] >= static_cast<S>(output_dim0_)) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of output_dim0.";
}
if (segment_ids_host[i] >= static_cast<S>(grad_shape_0_)) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids out of range of grad's first shape.";
}
}
cudaMemset(y_ptr, 0, output_elements_ * unit_grad_size_);
CalSparseSegmentGradCombination(kernel_type_, grad_ptr, segment_ids_ptr, indices_ptr, segment_pos_ptr, outer_size_,
inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream);
return true;
}
std::map<std::string, std::vector<std::pair<KernelAttr, SparseSegmentGradOpsGpuKernelMod::SSGLaunchFunc>>>
SparseSegmentGradOpsGpuKernelMod::kernel_attr_map_ = {
{Sparse_Segment_Sum_Grad,
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
{Sparse_Segment_Sqrt_N_Grad,
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int64_t>}}}}; // kernel_attr_map_
std::vector<KernelAttr> SparseSegmentGradOpsGpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map_.find(kernel_type_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'SparseSegmentGradOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map_)
<< " currently, but got " << kernel_name_;
}
std::vector<KernelAttr> support_list;
(void)std::transform(
iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SparseSegmentGradOpsGpuKernelMod::SSGLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumGrad, []() {
return std::make_shared<SparseSegmentGradOpsGpuKernelMod>(Sparse_Segment_Sum_Grad);
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNGrad, []() {
return std::make_shared<SparseSegmentGradOpsGpuKernelMod>(Sparse_Segment_Sqrt_N_Grad);
});
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,99 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_
#include <vector>
#include <utility>
#include <string>
#include <memory>
#include <map>
#include <algorithm>
#include <functional>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class SparseSegmentGradOpsGpuKernelMod : public NativeGpuKernelMod {
public:
explicit SparseSegmentGradOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~SparseSegmentGradOpsGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
outer_size_ = 0;
inner_size_ = 0;
grad_elements_ = 0;
idx_seg_elements_ = 0;
output_dim0_ = 0;
output_elements_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename R, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using SSGLaunchFunc =
std::function<bool(SparseSegmentGradOpsGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
size_t outer_size_{0};
size_t inner_size_{0};
size_t grad_elements_{0};
size_t grad_shape_0_{0};
size_t idx_seg_elements_{0};
size_t output_dim0_{0};
size_t output_elements_{0};
size_t unit_grad_size_{1};
size_t unit_idx_seg_size_{1};
std::string kernel_type_{"Unknown"};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
SSGLaunchFunc kernel_func_{};
static std::map<std::string, std::vector<std::pair<KernelAttr, SSGLaunchFunc>>> kernel_attr_map_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_

View File

@ -136,9 +136,6 @@ constexpr auto kEditDistance = "EditDistance";
constexpr auto kNextAfter = "NextAfter";
constexpr auto kMaximumGradGrad = "MaximumGradGrad";
constexpr auto kSparseSegmentMean = "SparseSegmentMean";
constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN";
constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad";
constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments";
constexpr auto kTridiagonalMatMul = "TridiagonalMatMul";
constexpr auto kFFTWithSize = "FFTWithSize";
constexpr auto kTrace = "Trace";
@ -348,6 +345,12 @@ constexpr auto kSparseMatrixSoftmax = "SparseMatrixSoftmax";
constexpr auto kSparseMatrixMatMul = "SparseMatrixMatMul";
constexpr auto kSparseMatrixSparseMatMul = "SparseMatrixSparseMatMul";
constexpr auto kSparseMatrixOrderingAMD = "SparseMatrixOrderingAMD";
constexpr auto kSparseSegmentSum = "SparseSegmentSum";
constexpr auto kSparseSegmentSumGrad = "SparseSegmentSumGrad";
constexpr auto kSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments";
constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN";
constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad";
constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments";
// Sparse Grad ops
constexpr auto kSparseAddGrad = "SparseAddGrad";
@ -1043,6 +1046,14 @@ GVAR_DEF(PrimitivePtr, kPrimSparseMatrixSparseMatMul, std::make_shared<Primitive
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToDense, std::make_shared<Primitive>("CSRSparseMatrixToDense"));
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixTranspose, std::make_shared<Primitive>(kSparseMatrixTranspose));
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixOrderingAMD, std::make_shared<Primitive>(kSparseMatrixOrderingAMD));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSum, std::make_shared<Primitive>("SparseSegmentSum"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSumGrad, std::make_shared<Primitive>("SparseSegmentSumGrad"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSumWithNumSegments,
std::make_shared<Primitive>("SparseSegmentSumWithNumSegments"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared<Primitive>("SparseSegmentSqrtN"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared<Primitive>("SparseSegmentSqrtNGrad"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments,
std::make_shared<Primitive>("SparseSegmentSqrtNWithNumSegments"));
// Sparse Grad ops
GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared<Primitive>(kSparseAddGrad));
@ -1195,10 +1206,6 @@ GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared<Primitive>("Bucketize"))
GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared<Primitive>("Einsum"));
GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMean, std::make_shared<Primitive>(kSparseSegmentMean));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared<Primitive>("SparseSegmentSqrtN"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared<Primitive>("SparseSegmentSqrtNGrad"));
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments,
std::make_shared<Primitive>("SparseSegmentSqrtNWithNumSegments"));
GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared<Primitive>("Trace"));
GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared<Primitive>("TraceGrad"));
GVAR_DEF(PrimitivePtr, kPrimTridiagonalMatMul, std::make_shared<Primitive>(kTridiagonalMatMul));

View File

@ -35,14 +35,22 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto output_dim0_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
kInputIndex1, prim->name());
if (x_shape.size() < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor x's rank is less than 1.";
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "tensor x's rank must be greater than 1, but got [" << x_shape.size() << "].";
}
if (output_dim0_shape.size() != kInputIndex0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor outputdim0 should be a scalar.";
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, "
<< "but got [" << output_dim0_shape.size() << "].";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor indices & segment_ids's ranks mismatch.";
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
@ -54,7 +62,8 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name);
size_t dim_zero = static_cast<size_t>(output_dim0_value_ptr_tensor[kInputIndex0]);
if (dim_zero <= kInputIndex0) {
MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!";
MS_EXCEPTION(ValueError) << "For '" << prim_name << "' , tensor output_dim0 must > 0, "
<< "but got [" << dim_zero << "].";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
@ -62,7 +71,9 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
}
} else {
std::vector<int64_t> output_shape = {-2};
return std::make_shared<abstract::Shape>(output_shape);
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
@ -72,12 +83,14 @@ TypePtr SparseSegmentSqrtNGradInferType(const PrimitivePtr &prim, const std::vec
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto output_dim0_type = input_args[kInputIndex3]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, {kFloat16, kFloat32, kFloat64}, prim->name());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("output_dim0", output_dim0_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32}, prim->name());
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace

View File

@ -0,0 +1,113 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/sparse_segment_sum_grad.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSumGradInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto output_dim0_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
kInputIndex1, prim->name());
if (grad_shape.size() < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "tensor grad's rank must be greater than 1, but got [" << grad_shape.size() << "].";
}
if (output_dim0_shape.size() != kInputIndex0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, "
<< "but got [" << output_dim0_shape.size() << "].";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
auto output_dim0_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(output_dim0_value);
auto output_dim0_value_ptr = output_dim0_value->BuildValue();
MS_EXCEPTION_IF_NULL(output_dim0_value_ptr);
auto output_dim0_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name);
size_t dim_zero = output_dim0_value_ptr_tensor[kInputIndex0];
if (dim_zero <= kInputIndex0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "' , tensor output_dim0 must > 0, "
<< "but got [" << dim_zero << "].";
} else {
ShapeVector y_shape = grad_shape;
y_shape[kInputIndex0] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentSumGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto grad_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto output_dim0_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("grad", grad_type, valid_types, prim->name());
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("output_dim0", output_dim0_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSumGrad, BaseOperator);
AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = kInputIndex4;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentSumGradInferType(prim, input_args);
auto shapes = SparseSegmentSumGradInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSumGrad, {3});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSumGrad, prim::kPrimSparseSegmentSumGrad, SparseSegmentSumGradInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSegmentSumGrad = "SparseSegmentSumGrad";
class MIND_API SparseSegmentSumGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSegmentSumGrad);
SparseSegmentSumGrad() : BaseOperator(kNameSparseSegmentSumGrad) {
InitIOName({"grad", "indices", "segment_ids", "output_dim0"}, {"output"});
}
};
abstract::AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSparseSegmentSumGradPtr = std::shared_ptr<SparseSegmentSumGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_

View File

@ -1,105 +1,110 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <iostream>
#include "ops/sparse_segment_sqrt_n.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
constexpr size_t kRankOne = 1;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
if (indices_shape.size() != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1.";
}
if (segment_ids_shape.size() != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1.";
}
if (x_shape.size() < kRankOne) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', x's rank is less than 1.";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', ranks of indices and segment_ids mismatch.";
}
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue();
MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
auto segment_ids_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name());
size_t dim_zero = static_cast<size_t>(segment_ids_value_ptr_tensor.back()) + kRankOne;
if (dim_zero < kRankOne) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must >= 0!";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
return std::make_shared<abstract::Shape>(output_shape);
}
}
TypePtr SparseSegmentSqrtNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids", segment_ids_type, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtN, BaseOperator);
AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = static_cast<int64_t>(kInputIndex3);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentSqrtNInferType(prim, input_args);
auto shapes = SparseSegmentSqrtNInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtN, {2});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtN, prim::kPrimSparseSegmentSqrtN, SparseSegmentSqrtNInfer, nullptr, true);
} // namespace ops
} // namespace mindspore
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include <iostream>
#include "ops/sparse_segment_sqrt_n.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
kInputIndex1, prim->name());
if (x_shape.size() < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue();
MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
auto segment_ids_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name());
size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1;
if (dim_zero < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, "
<< "but got [" << dim_zero << "].";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentSqrtNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtN, BaseOperator);
AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = kInputIndex3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentSqrtNInferType(prim, input_args);
auto shapes = SparseSegmentSqrtNInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtN, {2});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtN, prim::kPrimSparseSegmentSqrtN, SparseSegmentSqrtNInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,122 +1,127 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "ops/sparse_segment_sqrt_n_with_num_segments.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSqrtNWithNumSegmentsInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
constexpr size_t kRankOne = 1;
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto num_segments_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
if (indices_shape.size() != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1.";
}
if (segment_ids_shape.size() != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1.";
}
if (x_shape.size() < kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of x cannot be less than 1.";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices and segment_ids mismatch.";
}
if (num_segments_shape.size() > kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D.";
}
if (num_segments_shape.size() == kRankOne) {
if (num_segments_shape[kInputIndex0] != kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1.";
}
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_value);
auto num_segments_value_ptr = num_segments_value->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name());
size_t dim_zero = static_cast<size_t>(num_segments_value_ptr_tensor.back());
if (dim_zero < kRankOne) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", num_segments must be bigger than the largest id of segment_ids.";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
return std::make_shared<abstract::Shape>(output_shape);
}
}
TypePtr SparseSegmentSqrtNWithNumSegmentsInferType(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto num_segments_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
std::map<std::string, TypePtr> types;
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("num_segments", num_segments_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32, kInt64}, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtNWithNumSegments, BaseOperator);
AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = static_cast<size_t>(kInputIndex4);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentSqrtNWithNumSegmentsInferType(prim, input_args);
auto shapes = SparseSegmentSqrtNWithNumSegmentsInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtNWithNumSegments, {3});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtNWithNumSegments, prim::kPrimSparseSegmentSqrtNWithNumSegments,
SparseSegmentSqrtNWithNumSegmentsInfer, nullptr, true);
} // namespace ops
} // namespace mindspore
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "ops/sparse_segment_sqrt_n_with_num_segments.h"
#include "abstract/dshape.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSqrtNWithNumSegmentsInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto num_segments_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1,
prim->name());
if (x_shape.size() < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (num_segments_shape.size() > kInputIndex1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D, but got ["
<< num_segments_shape.size() << "].";
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
if (num_segments_shape.size() == kInputIndex1) {
if (num_segments_shape[kInputIndex0] != kInputIndex1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1, but got ["
<< num_segments_shape[kInputIndex0] << "].";
}
}
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_value);
auto num_segments_value_ptr = num_segments_value->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name());
size_t dim_zero = num_segments_value_ptr_tensor.back();
if (dim_zero < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", num_segments must bigger than the last number of segment_ids, "
<< "but got " << dim_zero << ".";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentSqrtNWithNumSegmentsInferType(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto num_segments_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("num_segments", num_segments_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSqrtNWithNumSegments, BaseOperator);
AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = kInputIndex4;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentSqrtNWithNumSegmentsInferType(prim, input_args);
auto shapes = SparseSegmentSqrtNWithNumSegmentsInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSqrtNWithNumSegments, {3});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSqrtNWithNumSegments, prim::kPrimSparseSegmentSqrtNWithNumSegments,
SparseSegmentSqrtNWithNumSegmentsInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,100 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sparse_segment_sum.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSumInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
kInputIndex1, prim->name());
if (x_shape.size() < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue();
MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
auto segment_ids_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name());
size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1;
if (dim_zero < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, "
<< "but got [" << dim_zero << "].";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentSumInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSum, BaseOperator);
AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t input_num = kInputIndex3;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = SparseSegmentSumInferType(prim, input_args);
auto shapes = SparseSegmentSumInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSum, {2});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSum, prim::kPrimSparseSegmentSum, SparseSegmentSumInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_
#include <set>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSegmentSum = "SparseSegmentSum";
/// \brief Computes the sum along sparse segments of a tensor.
/// Refer to Python API @ref mindspore.ops.SparseSegmentSum for more details.
class MIND_API SparseSegmentSum : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSegmentSum);
/// \brief Constructor.
SparseSegmentSum() : BaseOperator(kNameSparseSegmentSum) { InitIOName({"x", "indices", "segment_ids"}, {"y"}); }
};
AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_

View File

@ -0,0 +1,120 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "ops/sparse_segment_sum_with_num_segments.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseSegmentSumWithNumSegmentsInferShape(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto segment_ids_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto num_segments_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name());
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1,
prim->name());
if (x_shape.size() < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
}
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
}
if (num_segments_shape.size() > kInputIndex1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D, but got ["
<< num_segments_shape.size() << "].";
}
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
if (num_segments_shape.size() == kInputIndex1) {
if (num_segments_shape[kInputIndex0] != kInputIndex1) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1, but got ["
<< num_segments_shape[kInputIndex0] << "].";
}
}
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(num_segments_value);
auto num_segments_value_ptr = num_segments_value->BuildValue();
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
auto num_segments_value_ptr_tensor =
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name());
size_t dim_zero = num_segments_value_ptr_tensor.back();
if (dim_zero < kInputIndex1) {
MS_EXCEPTION(ValueError) << "For " << prim_name
<< ", num_segments must bigger than the last number of segment_ids, "
<< "but got " << dim_zero << ".";
} else {
ShapeVector y_shape = x_shape;
y_shape[kInputIndex0] = dim_zero;
return std::make_shared<abstract::Shape>(y_shape);
}
} else {
std::vector<int64_t> output_shape = {-2};
std::vector<int64_t> min_shape = {1};
std::vector<int64_t> max_shape = {1};
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
}
}
TypePtr SparseSegmentSumWithNumSegmentsInferType(const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
auto x_type = input_args[kInputIndex0]->BuildType();
auto indices_type = input_args[kInputIndex1]->BuildType();
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
auto num_segments_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
std::map<std::string, TypePtr> types;
(void)types.emplace("indices", indices_type);
(void)types.emplace("segment_ids", segment_ids_type);
(void)types.emplace("num_segments", num_segments_type);
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[kInputIndex0]->BuildType();
}
} // namespace
MIND_API_OPERATOR_IMPL(SparseSegmentSumWithNumSegments, BaseOperator);
AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
constexpr size_t kInputsNum = kInputIndex4;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, prim_name);
auto types = SparseSegmentSumWithNumSegmentsInferType(prim, input_args);
auto shapes = SparseSegmentSumWithNumSegmentsInferShape(prim, input_args);
return abstract::MakeAbstract(shapes, types);
}
REGISTER_HOST_DEPENDS(kNameSparseSegmentSumWithNumSegments, {3});
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSumWithNumSegments, prim::kPrimSparseSegmentSumWithNumSegments,
SparseSegmentSumWithNumSegmentsInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_
#include <set>
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments";
/// \brief Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids.
/// Refer to Python API @ref mindspore.ops.SparseSegmentSumWithNumSegments for more details.
class MIND_API SparseSegmentSumWithNumSegments : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseSegmentSumWithNumSegments);
/// \brief Constructor.
SparseSegmentSumWithNumSegments() : BaseOperator(kNameSparseSegmentSumWithNumSegments) {
InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"});
}
};
AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSparseSegmentSumWithNumSegmentsPtr = std::shared_ptr<SparseSegmentSumWithNumSegments>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_

View File

@ -20,6 +20,8 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
from mindspore.ops.operations.sparse_ops import SparseToDenseV2
from mindspore.ops.operations.sparse_ops import SparseSoftmax
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseAdd
from mindspore.ops.operations.sparse_ops import SparseSegmentSum
from mindspore.ops.operations.sparse_ops import SparseSegmentSumWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
@ -31,6 +33,7 @@ from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations import _grad_ops as G
from .._grad.grad_base import bprop_getters
from .._utils.utils import is_shape_unknown
# Unused parameters are placeholders.
@ -103,13 +106,18 @@ def get_bprop_sparse_segment_sqrt_n(self):
"""Grad definition for `SparseSegmentSqrtN` operation."""
input_grad = G.SparseSegmentSqrtNGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, out, dout):
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
shape_x = shape(x)
if is_shape_unknown(shape_x):
shape_x = dyn_shape_op(x)
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
indices = F.cast(indices, mstype.int32)
segment_ids = F.cast(segment_ids, mstype.int32)
dx = input_grad(dout, indices, segment_ids, output_dim0)
return dx, zeros_like(indices), zeros_like(segment_ids)
all_d = (dx, zeros_like(indices), zeros_like(segment_ids))
return all_d
return bprop
@ -119,9 +127,55 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self):
"""Grad definition for `SparseSegmentSqrtNWithNumSegments` operation."""
input_grad = G.SparseSegmentSqrtNGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, num_segments, out, dout):
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
shape_x = shape(x)
if is_shape_unknown(shape_x):
shape_x = dyn_shape_op(x)
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
indices = F.cast(indices, mstype.int32)
segment_ids = F.cast(segment_ids, mstype.int32)
dx = input_grad(dout, indices, segment_ids, output_dim0)
all_d = (dx, zeros_like(indices), zeros_like(segment_ids), zeros_like(num_segments))
return all_d
return bprop
@bprop_getters.register(SparseSegmentSum)
def get_bprop_sparse_segment_sum(self):
"""Grad definition for `SparseSegmentSum` operation."""
input_grad = G.SparseSegmentSumGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, out, dout):
shape_x = shape(x)
if is_shape_unknown(shape_x):
shape_x = dyn_shape_op(x)
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
indices = F.cast(indices, mstype.int32)
segment_ids = F.cast(segment_ids, mstype.int32)
dx = input_grad(dout, indices, segment_ids, output_dim0)
all_d = (dx, zeros_like(indices), zeros_like(segment_ids))
return all_d
return bprop
@bprop_getters.register(SparseSegmentSumWithNumSegments)
def get_bprop_sparse_segment_sum_with_num_segments(self):
"""Grad definition for `SparseSegmentSumWithNumSegments` operation."""
input_grad = G.SparseSegmentSumGrad()
shape = P.Shape()
dyn_shape_op = P.TensorShape()
def bprop(x, indices, segment_ids, num_segments, out, dout):
shape_x = shape(x)
if is_shape_unknown(shape_x):
shape_x = dyn_shape_op(x)
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
indices = F.cast(indices, mstype.int32)
segment_ids = F.cast(segment_ids, mstype.int32)
dx = input_grad(dout, indices, segment_ids, output_dim0)

View File

@ -2030,7 +2030,7 @@ class UpsampleNearest3DGrad(Primitive):
"""
Upsample the 3-D gradient data with the nearest neighbor interpolation algorithm.
Args:
Args:
input_size (listInt): An required listInt. contain 5 elements: [min_batch, channels, depth, height, width].
Must: input_size[0] == grad_output_tensor_size[0], input_size[1] == grad_output_tensor_size[1].
output_size (listInt): An optional listInt. Defaults to none.
@ -2043,7 +2043,8 @@ class UpsampleNearest3DGrad(Primitive):
The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width.
The number of elements of 'scales' should be the same as the rank of input 'grad_output'.
One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified.
Inputs:
Inputs:
- **grad_output** (Tensor) - Tensor of shape [N, C, D, H, W], Must be one of the following types:
float16, float32, float64.
@ -3499,6 +3500,50 @@ class MedianGrad(Primitive):
self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad'])
class SparseSegmentSumGrad(Primitive):
"""
Computes gradients for SparseSegmentSumGrad operation.
Inputs:
- **grad** (Tensor) - A tensor.
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
Has same rank as segment_ids. The shape should be :math:`(N,)`.
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSum op.
Outputs:
A Tensor. Has the same type as `grad` .
Has same shape as `grad`, except for dimension 0 which is the value of `output_dim0`.
Raises:
TypeError: If `grad` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
TypeError: If the dtype of `grad` is not any of the following data types: {float16, float32, float64}.
TypeError: If the dtype of `indices` and `segment_ids` and `output_dim0` is not int32 or int64.
ValueError: If dimension size of `grad` less than 1.
ValueError: If rank of `indices` or `segment_ids` is not 1.
ValueError: If dimension size of `output_dim0` is not 0.
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
ValueError: If `segment_ids` is not sorted.
ValueError: If the last number of `segment_ids` is out of range of grad's first shape.
ValueError: If `indices` is bigger than or equal to `output_dim0`.
Supported Platforms:
``GPU``
"""
__mindspore_signature__ = (
sig.make_sig('grad', dtype=sig.sig_dtype.T1),
sig.make_sig('indices', dtype=sig.sig_dtype.T),
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
sig.make_sig('output_dim0', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
"""Initialize SparseSegmentSumGrad"""
self.init_prim_io_names(inputs=['grad', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
class SparseSegmentSqrtNGrad(Primitive):
"""
Computes gradients for SparseSegmentSqrtNGrad operation.
@ -3526,12 +3571,12 @@ class SparseSegmentSqrtNGrad(Primitive):
ValueError: If rank of `indices` or `segment_ids` is not 1.
ValueError: If dimension size of `output_dim0` is not 0.
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
ValueError: If indices in `segment_ids` are not contiguous or do not start from 0.
ValueError: If `segment_ids` is not sorted.
ValueError: If `indices` is out of range of `output_dim0`.
ValueError: If the last number of `segment_ids` is out of range of x's first shape.
ValueError: If `indices` is bigger than or equal to `output_dim0`.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
"""
@prim_attr_register

View File

@ -18,9 +18,10 @@
"""Operators for sparse operators."""
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register, Primitive
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from .. import signature as sig
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
class SparseDenseCwiseAdd(Primitive):
@ -1122,6 +1123,115 @@ class SparseConcat(Primitive):
validator.check_value_type("concat_dim", concat_dim, [int], self.name)
class SparseSegmentSum(Primitive):
"""
Computes the sum along sparse segments of a tensor.
Inputs:
- **x** (Tensor) - A tensor.
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
Has same rank as segment_ids. The shape should be :math:`(N,)`.
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
Outputs:
A Tensor. Has the same type as `x` .
Has same shape as `x`, except for dimension 0 which is the number of segments.
Raises:
TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
TypeError: If the dtype of `indices` and `segment_ids` is not int32 or int64.
ValueError: If dimension size of `x` less than 1.
ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor.
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
ValueError: If indices in `segment_ids` are not contiguous or do not start from 0.
ValueError: If `segment_ids` is not sorted.
ValueError: If `indices` is out of range of x's first shape.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor([[0, 1, 2], [1, 2, 3], [3, 6, 7]], dtype=ms.float32)
>>> indices = Tensor([0, 1, 2], dtype=ms.int32)
>>> segment_ids = Tensor([0, 1, 1], dtype=ms.int32)
>>> sparse_segment_sum = ops.SparseSegmentSum()
>>> out = sparse_segment_sum(x, indices, segment_ids)
>>> print(out)
[[ 0. 1. 2.]
[ 4. 8. 10.]]
"""
__mindspore_signature__ = (
sig.make_sig('x', dtype=sig.sig_dtype.T1),
sig.make_sig('indices', dtype=sig.sig_dtype.T),
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
"""Initialize SparseSegmentSum"""
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids'], outputs=['y'])
self.add_prim_attr("cust_aicpu", self.name)
class SparseSegmentSumWithNumSegments(Primitive):
"""
Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids.
Inputs:
- **x** (Tensor) - A Tensor.
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
Has same rank as segment_ids. The shape should be :math:`(N,)`.
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
- **num_segments** (Tensor) - Num_segments should equal the number of distinct segment_ids.
Outputs:
A Tensor. Has the same type as `x` .
Has same shape as `x`, except for dimension 0 which is the value of `num_segments`.
Raises:
TypeError: If `x` or `indices` or `segment_ids` or `num_segments` is not a tensor.
TypeError: If the dtype of `indices` and `segment_ids` and `num_segments` is not int32 or int64.
ValueError: If dimension size of `x` less than 1.
ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor.
ValueError: If rank of `num_segments` is bigger than 1.
ValueError: If numelements of `num_segments` is not 1.
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
ValueError: If `segment_ids` is not sorted.
ValueError: If the last number of `segment_ids` is bigger than or equal to `num_segments`.
ValueError: If `indices` is out of range of x's first shape.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)
>>> indices = Tensor([0, 2, 1], dtype=ms.int32)
>>> segment_ids = Tensor([0, 0, 2], dtype=ms.int32)
>>> num_segments = Tensor([4], dtype=ms.int32)
>>> sparse_segment_sum_with_num_segments = ops.SparseSegmentSumWithNumSegments()
>>> output = sparse_segment_sum_with_num_segments(x, indices, segment_ids, num_segments)
>>> print(output)
[[1. 1. 1. 0.]
[0. 0. 0. 0.]
[0. 1. 1. 0.]
[0. 0. 0. 0.]]
"""
__mindspore_signature__ = (
sig.make_sig('x', dtype=sig.sig_dtype.T1),
sig.make_sig('indices', dtype=sig.sig_dtype.T),
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
sig.make_sig('num_segemnts', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
"""Initialize SparseSegmentSumWithNumSegments"""
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'num_segments'], outputs=['y'])
self.add_prim_attr("cust_aicpu", self.name)
class SparseSegmentSqrtN(Primitive):
"""
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
@ -1151,7 +1261,7 @@ class SparseSegmentSqrtN(Primitive):
ValueError: If `indices` is out of range of x's first dimension.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]]).astype(np.float32))
@ -1164,6 +1274,11 @@ class SparseSegmentSqrtN(Primitive):
[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]]
"""
__mindspore_signature__ = (
sig.make_sig('x', dtype=sig.sig_dtype.T1),
sig.make_sig('indices', dtype=sig.sig_dtype.T),
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
@ -1208,7 +1323,7 @@ class SparseSegmentSqrtNWithNumSegments(Primitive):
ValueError: If `indices` is out of range of x's first dimension.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)

View File

@ -0,0 +1,101 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations._grad_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class SparseSegmentSqrtNGradNet(nn.Cell):
def __init__(self):
super(SparseSegmentSqrtNGradNet, self).__init__()
self.net = P.SparseSegmentSqrtNGrad()
@ms_function
def construct(self, grad, indices, segment_ids, output_dim0):
return self.net(grad, indices, segment_ids, output_dim0)
def sparse_segment_sqrt_n_grad(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
output_dim0_np = np.array(8, dtype=np.int32)
grad_ms = Tensor(grad_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
output_dim0_ms = Tensor(output_dim0_np)
net_ms = SparseSegmentSqrtNGradNet()
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
expected = np.array([[6, 8, 10, 12],
[6.363961, 7.071068, 7.7781744, 8.485281],
[15.55635, 16.970562, 18.384777, 19.798988],
[9.192389, 9.899495, 10.606602, 11.313708],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def sparse_segment_sqrt_n_grad_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
output_dim0_np = np.array(8, dtype=np.int64)
grad_ms = Tensor(grad_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
output_dim0_ms = Tensor(output_dim0_np)
net_ms = SparseSegmentSqrtNGradNet()
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
expected = np.array([[0.70710678, 1.41421356, 2.12132034, 2.82842712],
[5.70710678, 7.41421356, 9.12132034, 10.82842712],
[6.36396103, 7.07106781, 7.77817459, 8.48528137],
[19.36396103, 21.07106781, 22.77817459, 24.48528137],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSqrtNGrad
Expectation: the result match to tensorflow
"""
sparse_segment_sqrt_n_grad(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSqrtNGrad
Expectation: the result match to tensorflow
"""
sparse_segment_sqrt_n_grad_pynative(loss=1.0e-5)

View File

@ -0,0 +1,89 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.sparse_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class SparseSegmentSqrtNNet(nn.Cell):
def __init__(self):
super(SparseSegmentSqrtNNet, self).__init__()
self.net = P.SparseSegmentSqrtN()
@ms_function
def construct(self, x, indices, segment_ids):
return self.net(x, indices, segment_ids)
def sparse_segment_sqrt_n(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
net_ms = SparseSegmentSqrtNNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
expected = np.array([[1, 2, 3, 4],
[1, 2, 3, 4],
[9.899495, 11.313708, 12.727922, 14.142136],
[15.556349, 16.970562, 18.384777, 19.79899]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def sparse_segment_sqrt_n_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
net_ms = SparseSegmentSqrtNNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
expected = np.array([[4.24264069, 5.65685425, 7.07106781, 8.48528137],
[5, 6, 7, 8],
[15.55634919, 16.97056275, 18.38477631, 19.79898987],
[13, 14, 15, 16]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_graph_float32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSqrtN
Expectation: the result match to tensorflow
"""
sparse_segment_sqrt_n(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_pynative_float64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSqrtN
Expectation: the result match to tensorflow
"""
sparse_segment_sqrt_n_pynative(loss=1.0e-5)

View File

@ -0,0 +1,101 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.sparse_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class SparseSegmentSqrtNWithNumSegmentsNet(nn.Cell):
def __init__(self):
super(SparseSegmentSqrtNWithNumSegmentsNet, self).__init__()
self.net = P.SparseSegmentSqrtNWithNumSegments()
@ms_function
def construct(self, x, indices, segment_ids, num_segments):
return self.net(x, indices, segment_ids, num_segments)
def sparse_segment_sqrt_n_with_num_segments(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32)
num_segments_np = np.array(8, dtype=np.int32)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
num_segments_ms = Tensor(num_segments_np)
net_ms = SparseSegmentSqrtNWithNumSegmentsNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
expected = np.array([[1, 2, 3, 4],
[0, 0, 0, 0],
[0, 0, 0, 0],
[4.2426405, 5.656854, 7.071068, 8.485281],
[0, 0, 0, 0],
[9, 10, 11, 12],
[0, 0, 0, 0],
[15.556349, 16.970562, 18.384777, 19.79899]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def sparse_segment_sqrt_n_with_num_segments_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64)
num_segments_np = np.array(8, dtype=np.int64)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
num_segments_ms = Tensor(num_segments_np)
net_ms = SparseSegmentSqrtNWithNumSegmentsNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
expected = np.array([[4.24264069, 5.65685425, 7.07106781, 8.48528137],
[0, 0, 0, 0],
[0, 0, 0, 0],
[5, 6, 7, 8],
[0, 0, 0, 0],
[15.55634919, 16.97056275, 18.38477631, 19.79898987],
[0, 0, 0, 0],
[13, 14, 15, 16]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_with_num_segments_graph_float32_int32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSqrtNWithNumSegments
Expectation: the result match to tensorflow
"""
sparse_segment_sqrt_n_with_num_segments(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sqrt_n_with_num_segments_pynative_float64_int64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSqrtNWithNumSegments
Expectation: the result match to tensorflow
"""
sparse_segment_sqrt_n_with_num_segments_pynative(loss=1.0e-5)

View File

@ -0,0 +1,101 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations._grad_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class SparseSegmentSumGradNet(nn.Cell):
def __init__(self):
super(SparseSegmentSumGradNet, self).__init__()
self.net = P.SparseSegmentSumGrad()
@ms_function
def construct(self, grad, indices, segment_ids, output_dim0):
return self.net(grad, indices, segment_ids, output_dim0)
def sparse_segment_sum_grad(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
output_dim0_np = np.array(8, dtype=np.int32)
grad_ms = Tensor(grad_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
output_dim0_ms = Tensor(output_dim0_np)
net_ms = SparseSegmentSumGradNet()
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
expected = np.array([[6, 8, 10, 12],
[9, 10, 11, 12],
[22, 24, 26, 28],
[13, 14, 15, 16],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def sparse_segment_sum_grad_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
output_dim0_np = np.array(8, dtype=np.int64)
grad_ms = Tensor(grad_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
output_dim0_ms = Tensor(output_dim0_np)
net_ms = SparseSegmentSumGradNet()
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
expected = np.array([[1, 2, 3, 4.],
[6, 8, 10, 12],
[9, 10, 11, 12],
[22, 24, 26, 28],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sum_grad_graph_float32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSumGrad
Expectation: the result match to tensorflow
"""
sparse_segment_sum_grad(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sum_grad_pynative_float64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSumGrad
Expectation: the result match to tensorflow
"""
sparse_segment_sum_grad_pynative(loss=1.0e-5)

View File

@ -0,0 +1,89 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.sparse_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class SparseSegmentSumNet(nn.Cell):
def __init__(self):
super(SparseSegmentSumNet, self).__init__()
self.net = P.SparseSegmentSum()
@ms_function
def construct(self, x, indices, segment_ids):
return self.net(x, indices, segment_ids)
def sparse_segment_sum(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
net_ms = SparseSegmentSumNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
expected = np.array([[1, 2, 3, 4],
[1, 2, 3, 4],
[14, 16, 18, 20],
[22, 24, 26, 28]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def sparse_segment_sum_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
net_ms = SparseSegmentSumNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
expected = np.array([[6, 8, 10, 12],
[5, 6, 7, 8],
[22, 24, 26, 28],
[13, 14, 15, 16]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sum_graph_float32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSum
Expectation: the result match to tensorflow
"""
sparse_segment_sum(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sum_pynative_float64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSum
Expectation: the result match to tensorflow
"""
sparse_segment_sum_pynative(loss=1.0e-5)

View File

@ -0,0 +1,101 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.sparse_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class SparseSegmentSumWithNumSegmentsNet(nn.Cell):
def __init__(self):
super(SparseSegmentSumWithNumSegmentsNet, self).__init__()
self.net = P.SparseSegmentSumWithNumSegments()
@ms_function
def construct(self, x, indices, segment_ids, num_segments):
return self.net(x, indices, segment_ids, num_segments)
def sparse_segment_sum_with_num_segments(loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32)
num_segments_np = np.array(8, dtype=np.int32)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
num_segments_ms = Tensor(num_segments_np)
net_ms = SparseSegmentSumWithNumSegmentsNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
expected = np.array([[1, 2, 3, 4],
[0, 0, 0, 0],
[0, 0, 0, 0],
[6, 8, 10, 12],
[0, 0, 0, 0],
[9, 10, 11, 12],
[0, 0, 0, 0],
[22, 24, 26, 28]], dtype=np.float32)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
def sparse_segment_sum_with_num_segments_pynative(loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64)
num_segments_np = np.array(8, dtype=np.int64)
x_ms = Tensor(x_np)
indices_ms = Tensor(indices_np)
segment_ids_ms = Tensor(segment_ids_np)
num_segments_ms = Tensor(num_segments_np)
net_ms = SparseSegmentSumWithNumSegmentsNet()
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
expected = np.array([[6, 8, 10, 12],
[0, 0, 0, 0],
[0, 0, 0, 0],
[5, 6, 7, 8],
[0, 0, 0, 0],
[22, 24, 26, 28],
[0, 0, 0, 0],
[13, 14, 15, 16]], dtype=np.float64)
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sum_with_num_segments_graph_float32_int32_int32_int32():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSumWithNumSegments
Expectation: the result match to tensorflow
"""
sparse_segment_sum_with_num_segments(loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparse_segment_sum_with_num_segments_pynative_float64_int64_int64_int64():
"""
Feature: ALL To ALL
Description: test cases for SparseSegmentSumWithNumSegments
Expectation: the result match to tensorflow
"""
sparse_segment_sum_with_num_segments_pynative(loss=1.0e-5)