!37109 [feat] [assistant] [ops] [I5EWJI,I5EWJJ,I5EWJK,I5EWJC,I5EWJD,I5EWJG] New GPU operator implementation SparseSegmentOps
Merge pull request !37109 from 路雄博/SSSqG
This commit is contained in:
commit
3324fea63f
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh"
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentPosKernel(const S *indices_ptr, size_t *indices_pos_ptr, size_t idx_seg_size,
|
||||
size_t outer_size) {
|
||||
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) {
|
||||
const S max_size = static_cast<S>(outer_size);
|
||||
const S min_size = S(0);
|
||||
S beg_idx = (id == 0) ? min_size : indices_ptr[id - 1] + 1;
|
||||
S end_idx = (id >= idx_seg_size) ? max_size : indices_ptr[id];
|
||||
beg_idx = max(min_size, min(max_size, beg_idx));
|
||||
end_idx = max(min_size, min(max_size, end_idx));
|
||||
for (S i = beg_idx; i <= end_idx; i++) {
|
||||
indices_pos_ptr[i] = id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
__global__ void SparseSegmentSumGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
|
||||
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
|
||||
size_t output_dim0, R *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
|
||||
size_t beg_pos = indices_pos_ptr[inid];
|
||||
size_t end_pos = indices_pos_ptr[inid + 1];
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
double reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = static_cast<double>(grad_ptr[index * inner_size + inner_idx]);
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx;
|
||||
MsAtomicAdd(out_pos, static_cast<R>(reduce_result));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentSumGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
|
||||
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
|
||||
size_t output_dim0, half *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
|
||||
size_t beg_pos = indices_pos_ptr[inid];
|
||||
size_t end_pos = indices_pos_ptr[inid + 1];
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
double reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = static_cast<double>(__half2float(grad_ptr[index * inner_size + inner_idx]));
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
half *out_pos = y_ptr + segment_ids_ptr[pos]* inner_size + inner_idx;
|
||||
MsAtomicAdd(out_pos, __float2half(static_cast<float>(reduce_result)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
__global__ void SparseSegmentSqrtNGradKernel(const R *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
|
||||
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
|
||||
size_t output_dim0, R *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
|
||||
size_t beg_pos = indices_pos_ptr[inid];
|
||||
size_t end_pos = indices_pos_ptr[inid + 1];
|
||||
double sqrt_segment_len = sqrt(static_cast<double>(end_pos - beg_pos));
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
double reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = static_cast<double>(grad_ptr[index * inner_size + inner_idx]);
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
R *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx;
|
||||
MsAtomicAdd(out_pos, static_cast<R>(reduce_result / sqrt_segment_len));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentSqrtNGradKernel(const half *grad_ptr, const S *indices_ptr, const S *segment_ids_ptr,
|
||||
const size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
|
||||
size_t output_dim0, half *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t inid = blockIdx.y; inid < outer_size; inid += gridDim.y) {
|
||||
size_t beg_pos = indices_pos_ptr[inid];
|
||||
size_t end_pos = indices_pos_ptr[inid + 1];
|
||||
double sqrt_segment_len = sqrt(static_cast<double>(end_pos - beg_pos));
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
double reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = static_cast<double>(__half2float(grad_ptr[index * inner_size + inner_idx]));
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
half *out_pos = y_ptr + segment_ids_ptr[pos] * inner_size + inner_idx;
|
||||
MsAtomicAdd(out_pos, __float2half(static_cast<float>(reduce_result / sqrt_segment_len)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int Log2Floor_M(uint32_t n) {
|
||||
if (n == 0) return -1;
|
||||
int log = 0;
|
||||
for (int i = 4; i >= 0; --i) {
|
||||
int shift = (1 << i);
|
||||
uint32_t x = n >> shift;
|
||||
if (x) {
|
||||
n = x;
|
||||
log += shift;
|
||||
}
|
||||
}
|
||||
return log;
|
||||
}
|
||||
|
||||
inline int Log2Floor64_M(uint64_t n) {
|
||||
// Scan n first high 32 then low 32 bits.
|
||||
const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
|
||||
if (high_32_bit == 0) {
|
||||
return Log2Floor_M(static_cast<uint32_t>(n));
|
||||
} else {
|
||||
return 32 + Log2Floor_M(high_32_bit);
|
||||
}
|
||||
}
|
||||
|
||||
inline int Log2Ceil64_M(uint64_t n) {
|
||||
int floor = Log2Floor64_M(n);
|
||||
if (n == (n & ~(n - 1)))
|
||||
return floor;
|
||||
else
|
||||
return floor + 1;
|
||||
}
|
||||
|
||||
|
||||
template <typename R, typename S>
|
||||
bool CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr, const S *indices_ptr,
|
||||
const S *segment_ids_ptr, size_t *indices_pos_ptr, size_t outer_size,
|
||||
size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr,
|
||||
uint32_t device_id, cudaStream_t cuda_stream) {
|
||||
// Get start position of each segment and set to indices_pos_ptr.
|
||||
// The last element of indices_pos_ptr must equal to idx_seg_size.
|
||||
SparseSegmentPosKernel<<<CUDA_BLOCKS(device_id, idx_seg_size + 1), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
indices_ptr, indices_pos_ptr, idx_seg_size, outer_size);
|
||||
const unsigned int max_grid_x = (1u << 31) - 1;
|
||||
const unsigned int max_grid_y = (1u << 16) - 1;
|
||||
unsigned int block_x = 32;
|
||||
unsigned int block_y = 1;
|
||||
unsigned int grid_x = std::min(static_cast<unsigned int>(UP_DIV(inner_size, block_x)), max_grid_x);
|
||||
unsigned int grid_y = std::min(static_cast<unsigned int>(output_dim0), max_grid_y);
|
||||
dim3 block(block_x, block_y);
|
||||
dim3 grid(grid_x, grid_y);
|
||||
unsigned int shared_memory_size = block_x * block_y * sizeof(R);
|
||||
if (kernel_type == "SparseSegmentSumGrad") {
|
||||
SparseSegmentSumGradKernel<<<grid, block, shared_memory_size, cuda_stream>>>(grad_ptr, indices_ptr,
|
||||
segment_ids_ptr, indices_pos_ptr,
|
||||
outer_size, inner_size, output_dim0,
|
||||
y_ptr);
|
||||
} else if (kernel_type == "SparseSegmentSqrtNGrad") {
|
||||
SparseSegmentSqrtNGradKernel<<<grid, block, shared_memory_size, cuda_stream>>>(grad_ptr, indices_ptr,
|
||||
segment_ids_ptr, indices_pos_ptr,
|
||||
outer_size, inner_size,
|
||||
output_dim0, y_ptr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#define ADD_SPARSE_SEGMENT_GRAD(R, S) \
|
||||
template CUDA_LIB_EXPORT bool CalSparseSegmentGradCombination<R, S>(const std::string kernel_type, \
|
||||
const R *grad_ptr, const S *indices_ptr, \
|
||||
const S *segment_ids_ptr, \
|
||||
size_t *indices_pos_ptr, size_t outer_size, \
|
||||
size_t inner_size, size_t idx_seg_size, \
|
||||
size_t output_dim0, R *y_ptr, \
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
|
||||
ADD_SPARSE_SEGMENT_GRAD(half, int32_t)
|
||||
ADD_SPARSE_SEGMENT_GRAD(half, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT_GRAD(float, int32_t)
|
||||
ADD_SPARSE_SEGMENT_GRAD(float, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT_GRAD(double, int32_t)
|
||||
ADD_SPARSE_SEGMENT_GRAD(double, int64_t)
|
|
@ -0,0 +1,30 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_
|
||||
|
||||
#include <string>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename R, typename S>
|
||||
CUDA_LIB_EXPORT bool CalSparseSegmentGradCombination(const std::string kernel_type, const R *grad_ptr,
|
||||
const S *indices_ptr, const S *segment_ids_ptr,
|
||||
size_t *indices_pos_ptr, size_t outer_size, size_t inner_size,
|
||||
size_t idx_seg_size, size_t output_dim0, R *y_ptr,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_GRAD_IMPL_CUH_
|
|
@ -0,0 +1,307 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh"
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentPosKernel(const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t idx_seg_size,
|
||||
size_t output_dim0) {
|
||||
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id <= idx_seg_size; id += blockDim.x * gridDim.x) {
|
||||
const S max_size = static_cast<S>(output_dim0);
|
||||
const S min_size = S(0);
|
||||
S beg_idx = (id == 0) ? min_size : segment_ids_ptr[id - 1] + 1;
|
||||
S end_idx = (id >= idx_seg_size) ? max_size : segment_ids_ptr[id];
|
||||
beg_idx = max(min_size, min(max_size, beg_idx));
|
||||
end_idx = max(min_size, min(max_size, end_idx));
|
||||
for (S i = beg_idx; i <= end_idx; i++) {
|
||||
segment_pos_ptr[i] = id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
__global__ void SparseSegmentSumKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
|
||||
size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
|
||||
size_t beg_pos = segment_pos_ptr[sid];
|
||||
size_t end_pos = segment_pos_ptr[sid + 1];
|
||||
R segment_sum = 0;
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
R reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = x_ptr[index * inner_size + inner_idx];
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
segment_sum += reduce_result;
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentSumKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
|
||||
size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
|
||||
size_t beg_pos = segment_pos_ptr[sid];
|
||||
size_t end_pos = segment_pos_ptr[sid + 1];
|
||||
double segment_sum = 0;
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
double reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = static_cast<double>(x_ptr[index * inner_size + inner_idx]);
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
segment_sum += reduce_result;
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? static_cast<float>(0) :
|
||||
static_cast<float>(segment_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentSumKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
|
||||
size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
|
||||
size_t beg_pos = segment_pos_ptr[sid];
|
||||
size_t end_pos = segment_pos_ptr[sid + 1];
|
||||
float segment_sum = 0;
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
float reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]);
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
segment_sum += reduce_result;
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) : __float2half(segment_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
__global__ void SparseSegmentSqrtNKernel(const R *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
|
||||
size_t outer_size, size_t inner_size, size_t output_dim0, R *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
|
||||
size_t beg_pos = segment_pos_ptr[sid];
|
||||
size_t end_pos = segment_pos_ptr[sid + 1];
|
||||
R segment_sum = 0;
|
||||
R sqrt_segment_len = R(sqrt(static_cast<double>(end_pos - beg_pos)));
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
R reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = x_ptr[index * inner_size + inner_idx];
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
segment_sum += reduce_result;
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? R(0) : segment_sum / sqrt_segment_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentSqrtNKernel(const float *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
|
||||
size_t outer_size, size_t inner_size, size_t output_dim0, float *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
|
||||
size_t beg_pos = segment_pos_ptr[sid];
|
||||
size_t end_pos = segment_pos_ptr[sid + 1];
|
||||
double segment_sum = 0;
|
||||
double sqrt_segment_len = sqrt(static_cast<double>(end_pos - beg_pos));
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
double reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = static_cast<double>(x_ptr[index * inner_size + inner_idx]);
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
segment_sum += reduce_result;
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? static_cast<float>(0) :
|
||||
static_cast<float>(segment_sum / sqrt_segment_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void SparseSegmentSqrtNKernel(const half *x_ptr, const S *indices_ptr, const size_t *segment_pos_ptr,
|
||||
size_t outer_size, size_t inner_size, size_t output_dim0, half *y_ptr) {
|
||||
size_t num_blocks = (inner_size - 1) / blockDim.x + 1;
|
||||
for (size_t bid = blockIdx.x; bid < num_blocks; bid += gridDim.x) {
|
||||
size_t inner_idx = threadIdx.x + bid * blockDim.x;
|
||||
bool inner_valid = inner_idx < inner_size;
|
||||
for (size_t sid = blockIdx.y; sid < output_dim0; sid += gridDim.y) {
|
||||
size_t beg_pos = segment_pos_ptr[sid];
|
||||
size_t end_pos = segment_pos_ptr[sid + 1];
|
||||
float segment_sum = 0;
|
||||
float sqrt_segment_len = sqrt(static_cast<float>(end_pos - beg_pos));
|
||||
for (size_t pos = beg_pos; pos < end_pos; pos += 1) {
|
||||
float reduce_result = 0;
|
||||
S index = inner_valid ? indices_ptr[pos] : outer_size;
|
||||
if (index >= 0 && index < outer_size) {
|
||||
reduce_result = __half2float(x_ptr[index * inner_size + inner_idx]);
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
segment_sum += reduce_result;
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0 && inner_valid) {
|
||||
y_ptr[sid * inner_size + inner_idx] = beg_pos == end_pos ? half(0) :
|
||||
__float2half(segment_sum / sqrt_segment_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int Log2Floor(uint32_t n) {
|
||||
if (n == 0) return -1;
|
||||
int log = 0;
|
||||
for (int i = 4; i >= 0; --i) {
|
||||
int shift = (1 << i);
|
||||
uint32_t x = n >> shift;
|
||||
if (x) {
|
||||
n = x;
|
||||
log += shift;
|
||||
}
|
||||
}
|
||||
return log;
|
||||
}
|
||||
|
||||
inline int Log2Floor64(uint64_t n) {
|
||||
// Scan n first high 32 then low 32 bits.
|
||||
const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
|
||||
if (high_32_bit == 0) {
|
||||
return Log2Floor(static_cast<uint32_t>(n));
|
||||
} else {
|
||||
return 32 + Log2Floor(high_32_bit);
|
||||
}
|
||||
}
|
||||
|
||||
inline int Log2Ceil64(uint64_t n) {
|
||||
int floor = Log2Floor64(n);
|
||||
if (n == (n & ~(n - 1)))
|
||||
return floor;
|
||||
else
|
||||
return floor + 1;
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr,
|
||||
const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size,
|
||||
size_t inner_size, size_t idx_seg_size, size_t output_dim0, R *y_ptr,
|
||||
uint32_t device_id, cudaStream_t cuda_stream) {
|
||||
// Get start position of each segment and set to segment_pos_ptr.
|
||||
// The last element of segment_pos_ptr must equal to idx_seg_size.
|
||||
SparseSegmentPosKernel<<<CUDA_BLOCKS(device_id, idx_seg_size + 1), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
segment_ids_ptr, segment_pos_ptr, idx_seg_size, output_dim0);
|
||||
|
||||
const unsigned int max_grid_x = (1u << 31) - 1;
|
||||
const unsigned int max_grid_y = (1u << 16) - 1;
|
||||
unsigned int block_x = 32;
|
||||
unsigned int block_y = 1;
|
||||
unsigned int grid_x = std::min(static_cast<unsigned int>(UP_DIV(inner_size, block_x)), max_grid_x);
|
||||
unsigned int grid_y = std::min(static_cast<unsigned int>(output_dim0), max_grid_y);
|
||||
dim3 block(block_x, block_y);
|
||||
dim3 grid(grid_x, grid_y);
|
||||
unsigned int shared_memory_size = block_x * block_y * sizeof(R);
|
||||
if (kernel_type == "SparseSegmentSum" || kernel_type == "SparseSegmentSumWithNumSegments") {
|
||||
SparseSegmentSumKernel<<<grid, block, shared_memory_size, cuda_stream>>>(x_ptr, indices_ptr, segment_pos_ptr,
|
||||
outer_size, inner_size, output_dim0,
|
||||
y_ptr);
|
||||
} else if (kernel_type == "SparseSegmentSqrtN" || kernel_type == "SparseSegmentSqrtNWithNumSegments") {
|
||||
SparseSegmentSqrtNKernel<<<grid, block, shared_memory_size, cuda_stream>>>(x_ptr, indices_ptr, segment_pos_ptr,
|
||||
outer_size, inner_size, output_dim0,
|
||||
y_ptr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#define ADD_SPARSE_SEGMENT(R, S) \
|
||||
template CUDA_LIB_EXPORT bool CalSparseSegmentCombination<R, S>(const std::string kernel_type, const R *x_ptr, \
|
||||
const S *indices_ptr, const S *segment_ids_ptr, \
|
||||
size_t *segment_pos_ptr, size_t outer_size, \
|
||||
size_t inner_size, size_t idx_seg_size, \
|
||||
size_t output_dim0, R *y_ptr, uint32_t device_id, \
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
ADD_SPARSE_SEGMENT(uint8_t, int32_t)
|
||||
ADD_SPARSE_SEGMENT(uint8_t, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(uint16_t, int32_t)
|
||||
ADD_SPARSE_SEGMENT(uint16_t, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(int8_t, int32_t)
|
||||
ADD_SPARSE_SEGMENT(int8_t, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(int16_t, int32_t)
|
||||
ADD_SPARSE_SEGMENT(int16_t, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(int32_t, int32_t)
|
||||
ADD_SPARSE_SEGMENT(int32_t, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(int64_t, int32_t)
|
||||
ADD_SPARSE_SEGMENT(int64_t, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(half, int32_t)
|
||||
ADD_SPARSE_SEGMENT(half, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(float, int32_t)
|
||||
ADD_SPARSE_SEGMENT(float, int64_t)
|
||||
|
||||
ADD_SPARSE_SEGMENT(double, int32_t)
|
||||
ADD_SPARSE_SEGMENT(double, int64_t)
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_
|
||||
|
||||
#include <string>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename R, typename S>
|
||||
CUDA_LIB_EXPORT bool CalSparseSegmentCombination(const std::string kernel_type, const R *x_ptr, const S *indices_ptr,
|
||||
const S *segment_ids_ptr, size_t *segment_pos_ptr, size_t outer_size,
|
||||
size_t inner_size, size_t indices_size, size_t output_dim0, R *y_ptr,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_SEGMENT_IMPL_CUH_
|
|
@ -0,0 +1,506 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/sparse/sparse_segment_ops_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr auto Sparse_Segment_Sum = "SparseSegmentSum";
|
||||
constexpr auto Sparse_Segment_Sum_With_Num_Segments = "SparseSegmentSumWithNumSegments";
|
||||
constexpr auto Sparse_Segment_Sqrt_N = "SparseSegmentSqrtN";
|
||||
constexpr auto Sparse_Segment_Sqrt_N_With_Num_Segments = "SparseSegmentSqrtNWithNumSegments";
|
||||
constexpr size_t kNumber1 = 1;
|
||||
constexpr size_t kNumber3 = 3;
|
||||
constexpr size_t kNumber4 = 4;
|
||||
} // namespace
|
||||
|
||||
bool SparseSegmentOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
if (kernel_type_ == "SparseSegmentSum" || kernel_type_ == "SparseSegmentSqrtN") {
|
||||
flag_ = true;
|
||||
} else {
|
||||
flag_ = false;
|
||||
}
|
||||
size_t inputs_num = flag_ ? kNumber3 : kNumber4;
|
||||
size_t outputs_num = kNumber1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second;
|
||||
unit_x_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
int SparseSegmentOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
for (const auto &output : outputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto output_shape = output->GetShapeVector();
|
||||
if (!IsValidShape(output_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
|
||||
if (output_elements_ == 0) {
|
||||
is_null_input_ = true;
|
||||
}
|
||||
std::vector<int64_t> x_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
x_shape_0_ = x_shape[0];
|
||||
x_elements_ = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies{});
|
||||
outer_size_ = x_shape.front();
|
||||
inner_size_ = x_elements_ / x_shape.front();
|
||||
std::vector<int64_t> indices_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{});
|
||||
output_dim0_ = LongToSize(output_shape.front());
|
||||
|
||||
size_t input_x_size = x_elements_ * unit_x_size_;
|
||||
size_t input_idx_seg_size = idx_seg_elements_ * unit_idx_seg_size_;
|
||||
size_t output_size = output_elements_ * unit_x_size_;
|
||||
input_size_list_.push_back(input_x_size);
|
||||
input_size_list_.push_back(input_idx_seg_size);
|
||||
input_size_list_.push_back(input_idx_seg_size);
|
||||
if (flag_) {
|
||||
input_size_list_.push_back(unit_idx_seg_size_);
|
||||
}
|
||||
output_size_list_.push_back(output_size);
|
||||
workspace_size_list_.push_back((output_dim0_ + 1) * sizeof(size_t));
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
bool SparseSegmentOpsGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
R *x_ptr = GetDeviceAddress<R>(inputs, kIndex0);
|
||||
S *indices_ptr = GetDeviceAddress<S>(inputs, kIndex1);
|
||||
S *segment_ids_ptr = GetDeviceAddress<S>(inputs, kIndex2);
|
||||
R *y_ptr = GetDeviceAddress<R>(outputs, kIndex0);
|
||||
size_t *segment_pos_ptr = GetDeviceAddress<size_t>(workspace, kIndex0);
|
||||
auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); };
|
||||
if (any(x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) {
|
||||
return false;
|
||||
}
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
|
||||
std::vector<S> indices_host;
|
||||
std::vector<S> segment_ids_host;
|
||||
std::vector<S> num_segments_host;
|
||||
indices_host.resize(idx_seg_elements_);
|
||||
segment_ids_host.resize(idx_seg_elements_);
|
||||
num_segments_host.resize(kNumber1);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpy failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr,
|
||||
idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpy failed.");
|
||||
if (!flag_) {
|
||||
auto num_segments_ptr = GetDeviceAddress<S>(inputs, kIndex3);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(num_segments_host.data(), num_segments_ptr, sizeof(S), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpy failed.");
|
||||
}
|
||||
if (segment_ids_host[0] != 0 && flag_) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', indices in 'segment_ids' should be contiguous and start from 0.";
|
||||
}
|
||||
for (size_t i = 1; i < idx_seg_elements_; i++) {
|
||||
if (segment_ids_host[i] < segment_ids_host[i - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
|
||||
}
|
||||
if (segment_ids_host[i] - segment_ids_host[i - 1] > 1 && flag_) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', indices in 'segment_ids' should be contiguous and start from 0.";
|
||||
}
|
||||
}
|
||||
if (segment_ids_host[idx_seg_elements_ - 1] >= num_segments_host[kIndex0] && !flag_) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_
|
||||
<< "', num_segments must bigger than the last number of segment_ids.";
|
||||
}
|
||||
for (size_t i = 0; i < idx_seg_elements_; i++) {
|
||||
if (indices_host[i] >= static_cast<S>(x_shape_0_)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of x's first shape.";
|
||||
}
|
||||
}
|
||||
CalSparseSegmentCombination(kernel_type_, x_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, outer_size_,
|
||||
inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, SparseSegmentOpsGpuKernelMod::SSLaunchFunc>>>
|
||||
SparseSegmentOpsGpuKernelMod::kernel_attr_map_ = {
|
||||
{Sparse_Segment_Sum,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
|
||||
{Sparse_Segment_Sum_With_Num_Segments,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
|
||||
{Sparse_Segment_Sqrt_N,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
|
||||
{Sparse_Segment_Sqrt_N_With_Num_Segments,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentOpsGpuKernelMod::LaunchKernel<double, int64_t>}}}}; // kernel_attr_map_
|
||||
|
||||
std::vector<KernelAttr> SparseSegmentOpsGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_type_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'SparseSegmentOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SparseSegmentOpsGpuKernelMod::SSLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSum,
|
||||
[]() { return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sum); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumWithNumSegments, []() {
|
||||
return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sum_With_Num_Segments);
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtN, []() {
|
||||
return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sqrt_N);
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNWithNumSegments, []() {
|
||||
return std::make_shared<SparseSegmentOpsGpuKernelMod>(Sparse_Segment_Sqrt_N_With_Num_Segments);
|
||||
});
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentOpsGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
explicit SparseSegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~SparseSegmentOpsGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
outer_size_ = 0;
|
||||
inner_size_ = 0;
|
||||
x_elements_ = 0;
|
||||
x_shape_0_ = 0;
|
||||
idx_seg_elements_ = 0;
|
||||
output_dim0_ = 0;
|
||||
output_elements_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename R, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
using SSLaunchFunc =
|
||||
std::function<bool(SparseSegmentOpsGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t outer_size_{0};
|
||||
size_t inner_size_{0};
|
||||
size_t x_elements_{0};
|
||||
size_t x_shape_0_{0};
|
||||
size_t idx_seg_elements_{0};
|
||||
size_t output_dim0_{0};
|
||||
size_t output_elements_{0};
|
||||
size_t unit_x_size_{1};
|
||||
size_t unit_idx_seg_size_{1};
|
||||
std::string kernel_type_{"Unknown"};
|
||||
bool is_null_input_{false};
|
||||
size_t flag_{0};
|
||||
void *cuda_stream_{nullptr};
|
||||
SSLaunchFunc kernel_func_{};
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, SSLaunchFunc>>> kernel_attr_map_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_SEGMENT_GPU_KERNEL_H_
|
|
@ -0,0 +1,245 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/sparse_grad/sparse_segment_grad_ops_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr auto Sparse_Segment_Sum_Grad = "SparseSegmentSumGrad";
|
||||
constexpr auto Sparse_Segment_Sqrt_N_Grad = "SparseSegmentSqrtNGrad";
|
||||
constexpr size_t kNumber1 = 1;
|
||||
constexpr size_t kNumber4 = 4;
|
||||
} // namespace
|
||||
|
||||
bool SparseSegmentGradOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
size_t inputs_num = kNumber4;
|
||||
size_t outputs_num = kNumber1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = kernel_attr_map_.at(kernel_type_)[index].second;
|
||||
unit_grad_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
unit_idx_seg_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
int SparseSegmentGradOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
for (const auto &output : outputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto output_shape = output->GetShapeVector();
|
||||
if (!IsValidShape(output_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
|
||||
if (output_elements_ == 0) {
|
||||
is_null_input_ = true;
|
||||
}
|
||||
std::vector<int64_t> grad_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
grad_shape_0_ = grad_shape[0];
|
||||
grad_elements_ = std::accumulate(grad_shape.begin(), grad_shape.end(), 1, std::multiplies{});
|
||||
outer_size_ = grad_shape.front();
|
||||
inner_size_ = grad_elements_ / outer_size_;
|
||||
std::vector<int64_t> indices_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
idx_seg_elements_ = std::accumulate(indices_shape.begin(), indices_shape.end(), 1, std::multiplies{});
|
||||
output_dim0_ = LongToSize(output_shape.front());
|
||||
|
||||
size_t input_grad_size = grad_elements_ * unit_grad_size_;
|
||||
size_t input_idx_seg_size = idx_seg_elements_ * unit_idx_seg_size_;
|
||||
size_t output_size = output_elements_ * unit_grad_size_;
|
||||
input_size_list_.push_back(input_grad_size);
|
||||
input_size_list_.push_back(input_idx_seg_size);
|
||||
input_size_list_.push_back(input_idx_seg_size);
|
||||
input_size_list_.push_back(unit_idx_seg_size_);
|
||||
output_size_list_.push_back(output_size);
|
||||
workspace_size_list_.push_back((outer_size_ + 1) * sizeof(size_t));
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename R, typename S>
|
||||
bool SparseSegmentGradOpsGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
R *grad_ptr = GetDeviceAddress<R>(inputs, kIndex0);
|
||||
S *indices_ptr = GetDeviceAddress<S>(inputs, kIndex1);
|
||||
S *segment_ids_ptr = GetDeviceAddress<S>(inputs, kIndex2);
|
||||
R *y_ptr = GetDeviceAddress<R>(outputs, kIndex0);
|
||||
size_t *segment_pos_ptr = GetDeviceAddress<size_t>(workspace, kIndex0);
|
||||
auto any = [](auto... args) -> bool { return ((args == nullptr) || ...); };
|
||||
if (any(grad_ptr, indices_ptr, segment_ids_ptr, segment_pos_ptr, y_ptr)) {
|
||||
return false;
|
||||
}
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
|
||||
std::vector<S> indices_host;
|
||||
std::vector<S> segment_ids_host;
|
||||
indices_host.resize(idx_seg_elements_);
|
||||
segment_ids_host.resize(idx_seg_elements_);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(indices_host.data(), indices_ptr, idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpy failed.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(segment_ids_host.data(), segment_ids_ptr,
|
||||
idx_seg_elements_ * sizeof(S), cudaMemcpyDeviceToHost, stream),
|
||||
"cudaMemcpy failed.");
|
||||
for (size_t i = 1; i < idx_seg_elements_; i++) {
|
||||
if (segment_ids_host[i] < segment_ids_host[i - 1]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids should be sorted.";
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < idx_seg_elements_; i++) {
|
||||
if (indices_host[i] >= static_cast<S>(output_dim0_)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', indices out of range of output_dim0.";
|
||||
}
|
||||
if (segment_ids_host[i] >= static_cast<S>(grad_shape_0_)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', segment_ids out of range of grad's first shape.";
|
||||
}
|
||||
}
|
||||
cudaMemset(y_ptr, 0, output_elements_ * unit_grad_size_);
|
||||
CalSparseSegmentGradCombination(kernel_type_, grad_ptr, segment_ids_ptr, indices_ptr, segment_pos_ptr, outer_size_,
|
||||
inner_size_, idx_seg_elements_, output_dim0_, y_ptr, device_id_, stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, SparseSegmentGradOpsGpuKernelMod::SSGLaunchFunc>>>
|
||||
SparseSegmentGradOpsGpuKernelMod::kernel_attr_map_ = {
|
||||
{Sparse_Segment_Sum_Grad,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int64_t>}}},
|
||||
{Sparse_Segment_Sqrt_N_Grad,
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseSegmentGradOpsGpuKernelMod::LaunchKernel<double, int64_t>}}}}; // kernel_attr_map_
|
||||
|
||||
std::vector<KernelAttr> SparseSegmentGradOpsGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_type_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'SparseSegmentGradOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SparseSegmentGradOpsGpuKernelMod::SSGLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSumGrad, []() {
|
||||
return std::make_shared<SparseSegmentGradOpsGpuKernelMod>(Sparse_Segment_Sum_Grad);
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SparseSegmentSqrtNGrad, []() {
|
||||
return std::make_shared<SparseSegmentGradOpsGpuKernelMod>(Sparse_Segment_Sqrt_N_Grad);
|
||||
});
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_segment_grad_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseSegmentGradOpsGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
explicit SparseSegmentGradOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~SparseSegmentGradOpsGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
outer_size_ = 0;
|
||||
inner_size_ = 0;
|
||||
grad_elements_ = 0;
|
||||
idx_seg_elements_ = 0;
|
||||
output_dim0_ = 0;
|
||||
output_elements_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename R, typename S>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
using SSGLaunchFunc =
|
||||
std::function<bool(SparseSegmentGradOpsGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t outer_size_{0};
|
||||
size_t inner_size_{0};
|
||||
size_t grad_elements_{0};
|
||||
size_t grad_shape_0_{0};
|
||||
size_t idx_seg_elements_{0};
|
||||
size_t output_dim0_{0};
|
||||
size_t output_elements_{0};
|
||||
size_t unit_grad_size_{1};
|
||||
size_t unit_idx_seg_size_{1};
|
||||
std::string kernel_type_{"Unknown"};
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
SSGLaunchFunc kernel_func_{};
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, SSGLaunchFunc>>> kernel_attr_map_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_GRAD_SPARSE_SEGMENT_GRAD_GPU_KERNEL_H_
|
|
@ -136,9 +136,6 @@ constexpr auto kEditDistance = "EditDistance";
|
|||
constexpr auto kNextAfter = "NextAfter";
|
||||
constexpr auto kMaximumGradGrad = "MaximumGradGrad";
|
||||
constexpr auto kSparseSegmentMean = "SparseSegmentMean";
|
||||
constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN";
|
||||
constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad";
|
||||
constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments";
|
||||
constexpr auto kTridiagonalMatMul = "TridiagonalMatMul";
|
||||
constexpr auto kFFTWithSize = "FFTWithSize";
|
||||
constexpr auto kTrace = "Trace";
|
||||
|
@ -348,6 +345,12 @@ constexpr auto kSparseMatrixSoftmax = "SparseMatrixSoftmax";
|
|||
constexpr auto kSparseMatrixMatMul = "SparseMatrixMatMul";
|
||||
constexpr auto kSparseMatrixSparseMatMul = "SparseMatrixSparseMatMul";
|
||||
constexpr auto kSparseMatrixOrderingAMD = "SparseMatrixOrderingAMD";
|
||||
constexpr auto kSparseSegmentSum = "SparseSegmentSum";
|
||||
constexpr auto kSparseSegmentSumGrad = "SparseSegmentSumGrad";
|
||||
constexpr auto kSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments";
|
||||
constexpr auto kSparseSegmentSqrtN = "SparseSegmentSqrtN";
|
||||
constexpr auto kSparseSegmentSqrtNGrad = "SparseSegmentSqrtNGrad";
|
||||
constexpr auto kSparseSegmentSqrtNWithNumSegments = "SparseSegmentSqrtNWithNumSegments";
|
||||
|
||||
// Sparse Grad ops
|
||||
constexpr auto kSparseAddGrad = "SparseAddGrad";
|
||||
|
@ -1043,6 +1046,14 @@ GVAR_DEF(PrimitivePtr, kPrimSparseMatrixSparseMatMul, std::make_shared<Primitive
|
|||
GVAR_DEF(PrimitivePtr, kPrimCSRSparseMatrixToDense, std::make_shared<Primitive>("CSRSparseMatrixToDense"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixTranspose, std::make_shared<Primitive>(kSparseMatrixTranspose));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseMatrixOrderingAMD, std::make_shared<Primitive>(kSparseMatrixOrderingAMD));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSum, std::make_shared<Primitive>("SparseSegmentSum"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSumGrad, std::make_shared<Primitive>("SparseSegmentSumGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSumWithNumSegments,
|
||||
std::make_shared<Primitive>("SparseSegmentSumWithNumSegments"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared<Primitive>("SparseSegmentSqrtN"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared<Primitive>("SparseSegmentSqrtNGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments,
|
||||
std::make_shared<Primitive>("SparseSegmentSqrtNWithNumSegments"));
|
||||
|
||||
// Sparse Grad ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseAddGrad, std::make_shared<Primitive>(kSparseAddGrad));
|
||||
|
@ -1195,10 +1206,6 @@ GVAR_DEF(PrimitivePtr, kPrimBucketize, std::make_shared<Primitive>("Bucketize"))
|
|||
GVAR_DEF(PrimitivePtr, kPrimEinsum, std::make_shared<Primitive>("Einsum"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimEinsumGrad, std::make_shared<Primitive>("EinsumGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentMean, std::make_shared<Primitive>(kSparseSegmentMean));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtN, std::make_shared<Primitive>("SparseSegmentSqrtN"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNGrad, std::make_shared<Primitive>("SparseSegmentSqrtNGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseSegmentSqrtNWithNumSegments,
|
||||
std::make_shared<Primitive>("SparseSegmentSqrtNWithNumSegments"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTrace, std::make_shared<Primitive>("Trace"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTraceGrad, std::make_shared<Primitive>("TraceGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTridiagonalMatMul, std::make_shared<Primitive>(kTridiagonalMatMul));
|
||||
|
|
|
@ -35,14 +35,22 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
|
|||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto output_dim0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
|
||||
kInputIndex1, prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor x's rank is less than 1.";
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "tensor x's rank must be greater than 1, but got [" << x_shape.size() << "].";
|
||||
}
|
||||
if (output_dim0_shape.size() != kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor outputdim0 should be a scalar.";
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, "
|
||||
<< "but got [" << output_dim0_shape.size() << "].";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor indices & segment_ids's ranks mismatch.";
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
|
||||
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
|
||||
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
|
||||
}
|
||||
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
|
@ -54,7 +62,8 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
|
|||
CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name);
|
||||
size_t dim_zero = static_cast<size_t>(output_dim0_value_ptr_tensor[kInputIndex0]);
|
||||
if (dim_zero <= kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "Input output_dim0 must > 0!";
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "' , tensor output_dim0 must > 0, "
|
||||
<< "but got [" << dim_zero << "].";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
|
||||
|
@ -62,7 +71,9 @@ abstract::ShapePtr SparseSegmentSqrtNGradInferShape(const PrimitivePtr &prim,
|
|||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,12 +83,14 @@ TypePtr SparseSegmentSqrtNGradInferType(const PrimitivePtr &prim, const std::vec
|
|||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
auto output_dim0_type = input_args[kInputIndex3]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, {kFloat16, kFloat32, kFloat64}, prim->name());
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
(void)types.emplace("output_dim0", output_dim0_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32}, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/grad/sparse_segment_sum_grad.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseSegmentSumGradInferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto output_dim0_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
|
||||
kInputIndex1, prim->name());
|
||||
if (grad_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "tensor grad's rank must be greater than 1, but got [" << grad_shape.size() << "].";
|
||||
}
|
||||
if (output_dim0_shape.size() != kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', tensor output_dim0 should be a scalar, "
|
||||
<< "but got [" << output_dim0_shape.size() << "].";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
|
||||
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
|
||||
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
|
||||
}
|
||||
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
auto output_dim0_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_dim0_value);
|
||||
auto output_dim0_value_ptr = output_dim0_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(output_dim0_value_ptr);
|
||||
auto output_dim0_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("output_dim0", output_dim0_value_ptr, prim_name);
|
||||
size_t dim_zero = output_dim0_value_ptr_tensor[kInputIndex0];
|
||||
if (dim_zero <= kInputIndex0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "' , tensor output_dim0 must > 0, "
|
||||
<< "but got [" << dim_zero << "].";
|
||||
} else {
|
||||
ShapeVector y_shape = grad_shape;
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SparseSegmentSumGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto grad_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
auto output_dim0_type = input_args[kInputIndex3]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("grad", grad_type, valid_types, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
(void)types.emplace("output_dim0", output_dim0_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseSegmentSumGrad, BaseOperator);
|
||||
AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = kInputIndex4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSumGradInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSumGradInferShape(prim, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameSparseSegmentSumGrad, {3});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSumGrad, prim::kPrimSparseSegmentSumGrad, SparseSegmentSumGradInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_GRAD_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseSegmentSumGrad = "SparseSegmentSumGrad";
|
||||
class MIND_API SparseSegmentSumGrad : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseSegmentSumGrad);
|
||||
SparseSegmentSumGrad() : BaseOperator(kNameSparseSegmentSumGrad) {
|
||||
InitIOName({"grad", "indices", "segment_ids", "output_dim0"}, {"output"});
|
||||
}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr SparseSegmentSumGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSparseSegmentSumGradPtr = std::shared_ptr<SparseSegmentSumGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SQRT_N_GRAD_H_
|
|
@ -37,22 +37,22 @@ abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
constexpr size_t kRankOne = 1;
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
if (indices_shape.size() != kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1.";
|
||||
}
|
||||
if (segment_ids_shape.size() != kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1.";
|
||||
}
|
||||
if (x_shape.size() < kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', x's rank is less than 1.";
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
|
||||
kInputIndex1, prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', ranks of indices and segment_ids mismatch.";
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
|
||||
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
|
||||
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
|
||||
}
|
||||
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
|
||||
|
@ -60,17 +60,20 @@ abstract::ShapePtr SparseSegmentSqrtNInferShape(const PrimitivePtr &prim,
|
|||
MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
|
||||
auto segment_ids_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name());
|
||||
size_t dim_zero = static_cast<size_t>(segment_ids_value_ptr_tensor.back()) + kRankOne;
|
||||
if (dim_zero < kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must >= 0!";
|
||||
size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1;
|
||||
if (dim_zero < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, "
|
||||
<< "but got [" << dim_zero << "].";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,9 +84,11 @@ TypePtr SparseSegmentSqrtNInferType(const PrimitivePtr &prim, const std::vector<
|
|||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, common_valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("segment_ids", segment_ids_type, common_valid_types, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -93,7 +98,7 @@ AbstractBasePtr SparseSegmentSqrtNInfer(const abstract::AnalysisEnginePtr &, con
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = static_cast<int64_t>(kInputIndex3);
|
||||
const int64_t input_num = kInputIndex3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSqrtNInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSqrtNInferShape(prim, input_args);
|
||||
|
|
|
@ -36,53 +36,57 @@ abstract::ShapePtr SparseSegmentSqrtNWithNumSegmentsInferShape(const PrimitivePt
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
constexpr size_t kRankOne = 1;
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto num_segments_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
if (indices_shape.size() != kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices should be 1.";
|
||||
}
|
||||
if (segment_ids_shape.size() != kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of segment_ids should be 1.";
|
||||
}
|
||||
if (x_shape.size() < kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of x cannot be less than 1.";
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", rank of indices and segment_ids mismatch.";
|
||||
}
|
||||
if (num_segments_shape.size() > kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D.";
|
||||
}
|
||||
if (num_segments_shape.size() == kRankOne) {
|
||||
if (num_segments_shape[kInputIndex0] != kRankOne) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1.";
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
|
||||
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
|
||||
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
|
||||
}
|
||||
if (num_segments_shape.size() > kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D, but got ["
|
||||
<< num_segments_shape.size() << "].";
|
||||
}
|
||||
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
if (num_segments_shape.size() == kInputIndex1) {
|
||||
if (num_segments_shape[kInputIndex0] != kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1, but got ["
|
||||
<< num_segments_shape[kInputIndex0] << "].";
|
||||
}
|
||||
}
|
||||
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value);
|
||||
auto num_segments_value_ptr = num_segments_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name());
|
||||
size_t dim_zero = static_cast<size_t>(num_segments_value_ptr_tensor.back());
|
||||
if (dim_zero < kRankOne) {
|
||||
size_t dim_zero = num_segments_value_ptr_tensor.back();
|
||||
if (dim_zero < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||
<< ", num_segments must be bigger than the largest id of segment_ids.";
|
||||
<< ", num_segments must bigger than the last number of segment_ids, "
|
||||
<< "but got " << dim_zero << ".";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = static_cast<int64_t>(dim_zero);
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -94,12 +98,13 @@ TypePtr SparseSegmentSqrtNWithNumSegmentsInferType(const PrimitivePtr &prim,
|
|||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
auto num_segments_type = input_args[kInputIndex3]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim->name());
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
(void)types.emplace("num_segments", num_segments_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kInt32, kInt64}, prim->name());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
@ -109,7 +114,7 @@ AbstractBasePtr SparseSegmentSqrtNWithNumSegmentsInfer(const abstract::AnalysisE
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = static_cast<size_t>(kInputIndex4);
|
||||
const int64_t input_num = kInputIndex4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSqrtNWithNumSegmentsInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSqrtNWithNumSegmentsInferShape(prim, input_args);
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/sparse_segment_sum.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseSegmentSumInferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", SizeToLong(indices_shape.size()), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", SizeToLong(segment_ids_shape.size()), kEqual,
|
||||
kInputIndex1, prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
|
||||
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
|
||||
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
|
||||
}
|
||||
if (!input_args[kInputIndex2]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex2]->BuildValue()->isa<None>()) {
|
||||
auto segment_ids_value_ptr = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(segment_ids_value_ptr);
|
||||
auto segment_ids_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("segment_ids", segment_ids_value_ptr, prim->name());
|
||||
size_t dim_zero = segment_ids_value_ptr_tensor.back() + kInputIndex1;
|
||||
if (dim_zero < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', segment_ids must be greater or equal to 0, "
|
||||
<< "but got [" << dim_zero << "].";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SparseSegmentSumInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseSegmentSum, BaseOperator);
|
||||
AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t input_num = kInputIndex3;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto types = SparseSegmentSumInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSumInferShape(prim, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameSparseSegmentSum, {2});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSum, prim::kPrimSparseSegmentSum, SparseSegmentSumInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseSegmentSum = "SparseSegmentSum";
|
||||
/// \brief Computes the sum along sparse segments of a tensor.
|
||||
/// Refer to Python API @ref mindspore.ops.SparseSegmentSum for more details.
|
||||
class MIND_API SparseSegmentSum : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseSegmentSum);
|
||||
/// \brief Constructor.
|
||||
SparseSegmentSum() : BaseOperator(kNameSparseSegmentSum) { InitIOName({"x", "indices", "segment_ids"}, {"y"}); }
|
||||
};
|
||||
|
||||
AbstractBasePtr SparseSegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_H_
|
|
@ -0,0 +1,120 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "ops/sparse_segment_sum_with_num_segments.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseSegmentSumWithNumSegmentsInferShape(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto segment_ids_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto num_segments_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("indices_shape", indices_shape.size(), kEqual, kInputIndex1, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("segment_ids_shape", segment_ids_shape.size(), kEqual, kInputIndex1,
|
||||
prim->name());
|
||||
if (x_shape.size() < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', "
|
||||
<< "x's rank must be greater than 1, but got [" << x_shape.size() << "].";
|
||||
}
|
||||
if (indices_shape[kInputIndex0] != segment_ids_shape[kInputIndex0]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the rank of indices and segment_ids must be the same, "
|
||||
<< "but got indices [" << indices_shape[kInputIndex0] << "] "
|
||||
<< "and segment_ids [" << segment_ids_shape[kInputIndex0] << "].";
|
||||
}
|
||||
if (num_segments_shape.size() > kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", num_segments should be at most 1-D, but got ["
|
||||
<< num_segments_shape.size() << "].";
|
||||
}
|
||||
if (!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
if (num_segments_shape.size() == kInputIndex1) {
|
||||
if (num_segments_shape[kInputIndex0] != kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the num element of num_segments should be 1, but got ["
|
||||
<< num_segments_shape[kInputIndex0] << "].";
|
||||
}
|
||||
}
|
||||
auto num_segments_value = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value);
|
||||
auto num_segments_value_ptr = num_segments_value->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_value_ptr_tensor =
|
||||
CheckAndConvertUtils::CheckTensorIntValue("num_segments", num_segments_value_ptr, prim->name());
|
||||
size_t dim_zero = num_segments_value_ptr_tensor.back();
|
||||
if (dim_zero < kInputIndex1) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name
|
||||
<< ", num_segments must bigger than the last number of segment_ids, "
|
||||
<< "but got " << dim_zero << ".";
|
||||
} else {
|
||||
ShapeVector y_shape = x_shape;
|
||||
y_shape[kInputIndex0] = dim_zero;
|
||||
return std::make_shared<abstract::Shape>(y_shape);
|
||||
}
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {-2};
|
||||
std::vector<int64_t> min_shape = {1};
|
||||
std::vector<int64_t> max_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr SparseSegmentSumWithNumSegmentsInferType(const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
auto indices_type = input_args[kInputIndex1]->BuildType();
|
||||
auto segment_ids_type = input_args[kInputIndex2]->BuildType();
|
||||
auto num_segments_type = input_args[kInputIndex3]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kFloat16, kFloat32, kFloat64};
|
||||
const std::set<TypePtr> common_valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("indices", indices_type);
|
||||
(void)types.emplace("segment_ids", segment_ids_type);
|
||||
(void)types.emplace("num_segments", num_segments_type);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SparseSegmentSumWithNumSegments, BaseOperator);
|
||||
AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &prim,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
constexpr size_t kInputsNum = kInputIndex4;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, prim_name);
|
||||
auto types = SparseSegmentSumWithNumSegmentsInferType(prim, input_args);
|
||||
auto shapes = SparseSegmentSumWithNumSegmentsInferShape(prim, input_args);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameSparseSegmentSumWithNumSegments, {3});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseSegmentSumWithNumSegments, prim::kPrimSparseSegmentSumWithNumSegments,
|
||||
SparseSegmentSumWithNumSegmentsInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseSegmentSumWithNumSegments = "SparseSegmentSumWithNumSegments";
|
||||
/// \brief Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids.
|
||||
/// Refer to Python API @ref mindspore.ops.SparseSegmentSumWithNumSegments for more details.
|
||||
class MIND_API SparseSegmentSumWithNumSegments : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseSegmentSumWithNumSegments);
|
||||
/// \brief Constructor.
|
||||
SparseSegmentSumWithNumSegments() : BaseOperator(kNameSparseSegmentSumWithNumSegments) {
|
||||
InitIOName({"x", "indices", "segment_ids", "num_segments"}, {"y"});
|
||||
}
|
||||
};
|
||||
AbstractBasePtr SparseSegmentSumWithNumSegmentsInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimSparseSegmentSumWithNumSegmentsPtr = std::shared_ptr<SparseSegmentSumWithNumSegments>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SPARSE_SEGMENT_SUM_WITH_NUM_SEGMENTS_H_
|
|
@ -20,6 +20,8 @@ from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
|
|||
from mindspore.ops.operations.sparse_ops import SparseToDenseV2
|
||||
from mindspore.ops.operations.sparse_ops import SparseSoftmax
|
||||
from mindspore.ops.operations.sparse_ops import SparseDenseCwiseAdd
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSum
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSumWithNumSegments
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtNWithNumSegments
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentMeanWithNumSegments
|
||||
|
@ -31,6 +33,7 @@ from .. import operations as P
|
|||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations import _grad_ops as G
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from .._utils.utils import is_shape_unknown
|
||||
|
||||
# Unused parameters are placeholders.
|
||||
|
||||
|
@ -103,13 +106,18 @@ def get_bprop_sparse_segment_sqrt_n(self):
|
|||
"""Grad definition for `SparseSegmentSqrtN` operation."""
|
||||
input_grad = G.SparseSegmentSqrtNGrad()
|
||||
shape = P.Shape()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
||||
def bprop(x, indices, segment_ids, out, dout):
|
||||
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
|
||||
shape_x = shape(x)
|
||||
if is_shape_unknown(shape_x):
|
||||
shape_x = dyn_shape_op(x)
|
||||
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
|
||||
indices = F.cast(indices, mstype.int32)
|
||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||
return dx, zeros_like(indices), zeros_like(segment_ids)
|
||||
all_d = (dx, zeros_like(indices), zeros_like(segment_ids))
|
||||
return all_d
|
||||
|
||||
return bprop
|
||||
|
||||
|
@ -119,9 +127,55 @@ def get_bprop_sparse_segment_sqrt_n_with_num_segments(self):
|
|||
"""Grad definition for `SparseSegmentSqrtNWithNumSegments` operation."""
|
||||
input_grad = G.SparseSegmentSqrtNGrad()
|
||||
shape = P.Shape()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
||||
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
||||
output_dim0 = F.scalar_to_tensor(shape(x)[0], mstype.int32)
|
||||
shape_x = shape(x)
|
||||
if is_shape_unknown(shape_x):
|
||||
shape_x = dyn_shape_op(x)
|
||||
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
|
||||
indices = F.cast(indices, mstype.int32)
|
||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||
all_d = (dx, zeros_like(indices), zeros_like(segment_ids), zeros_like(num_segments))
|
||||
return all_d
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(SparseSegmentSum)
|
||||
def get_bprop_sparse_segment_sum(self):
|
||||
"""Grad definition for `SparseSegmentSum` operation."""
|
||||
input_grad = G.SparseSegmentSumGrad()
|
||||
shape = P.Shape()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
||||
def bprop(x, indices, segment_ids, out, dout):
|
||||
shape_x = shape(x)
|
||||
if is_shape_unknown(shape_x):
|
||||
shape_x = dyn_shape_op(x)
|
||||
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
|
||||
indices = F.cast(indices, mstype.int32)
|
||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||
all_d = (dx, zeros_like(indices), zeros_like(segment_ids))
|
||||
return all_d
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(SparseSegmentSumWithNumSegments)
|
||||
def get_bprop_sparse_segment_sum_with_num_segments(self):
|
||||
"""Grad definition for `SparseSegmentSumWithNumSegments` operation."""
|
||||
input_grad = G.SparseSegmentSumGrad()
|
||||
shape = P.Shape()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
||||
def bprop(x, indices, segment_ids, num_segments, out, dout):
|
||||
shape_x = shape(x)
|
||||
if is_shape_unknown(shape_x):
|
||||
shape_x = dyn_shape_op(x)
|
||||
output_dim0 = P.Cast()(shape_x[0], mstype.int32)
|
||||
indices = F.cast(indices, mstype.int32)
|
||||
segment_ids = F.cast(segment_ids, mstype.int32)
|
||||
dx = input_grad(dout, indices, segment_ids, output_dim0)
|
||||
|
|
|
@ -2043,6 +2043,7 @@ class UpsampleNearest3DGrad(Primitive):
|
|||
The scale array along each dimension, contain 3 elements: scale_depth, scale_height, scale_width.
|
||||
The number of elements of 'scales' should be the same as the rank of input 'grad_output'.
|
||||
One of 'scales' and 'output_size' MUST be specified and it is an error if both are specified.
|
||||
|
||||
Inputs:
|
||||
- **grad_output** (Tensor) - Tensor of shape [N, C, D, H, W], Must be one of the following types:
|
||||
float16, float32, float64.
|
||||
|
@ -3499,6 +3500,50 @@ class MedianGrad(Primitive):
|
|||
self.init_prim_io_names(inputs=['y_grad', 'x', 'y', 'indices'], outputs=['x_grad'])
|
||||
|
||||
|
||||
class SparseSegmentSumGrad(Primitive):
|
||||
"""
|
||||
Computes gradients for SparseSegmentSumGrad operation.
|
||||
|
||||
Inputs:
|
||||
- **grad** (Tensor) - A tensor.
|
||||
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
||||
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
||||
- **output_dim0** (Tensor) - Output_dim0 is a 0-D tensor. Dimension 0 of `x` passed to SparseSegmentSum op.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as `grad` .
|
||||
Has same shape as `grad`, except for dimension 0 which is the value of `output_dim0`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `grad` or `indices` or `segment_ids` or `output_dim0` is not a tensor.
|
||||
TypeError: If the dtype of `grad` is not any of the following data types: {float16, float32, float64}.
|
||||
TypeError: If the dtype of `indices` and `segment_ids` and `output_dim0` is not int32 or int64.
|
||||
ValueError: If dimension size of `grad` less than 1.
|
||||
ValueError: If rank of `indices` or `segment_ids` is not 1.
|
||||
ValueError: If dimension size of `output_dim0` is not 0.
|
||||
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If the last number of `segment_ids` is out of range of grad's first shape.
|
||||
ValueError: If `indices` is bigger than or equal to `output_dim0`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('grad', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('output_dim0', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseSegmentSumGrad"""
|
||||
self.init_prim_io_names(inputs=['grad', 'indices', 'segment_ids', 'output_dim0'], outputs=['y'])
|
||||
|
||||
|
||||
class SparseSegmentSqrtNGrad(Primitive):
|
||||
"""
|
||||
Computes gradients for SparseSegmentSqrtNGrad operation.
|
||||
|
@ -3526,12 +3571,12 @@ class SparseSegmentSqrtNGrad(Primitive):
|
|||
ValueError: If rank of `indices` or `segment_ids` is not 1.
|
||||
ValueError: If dimension size of `output_dim0` is not 0.
|
||||
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
||||
ValueError: If indices in `segment_ids` are not contiguous or do not start from 0.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If `indices` is out of range of `output_dim0`.
|
||||
ValueError: If the last number of `segment_ids` is out of range of x's first shape.
|
||||
ValueError: If `indices` is bigger than or equal to `output_dim0`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
|
||||
"""Operators for sparse operators."""
|
||||
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from .. import signature as sig
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
|
||||
|
||||
|
||||
class SparseDenseCwiseAdd(Primitive):
|
||||
|
@ -1122,6 +1123,115 @@ class SparseConcat(Primitive):
|
|||
validator.check_value_type("concat_dim", concat_dim, [int], self.name)
|
||||
|
||||
|
||||
class SparseSegmentSum(Primitive):
|
||||
"""
|
||||
Computes the sum along sparse segments of a tensor.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A tensor.
|
||||
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
||||
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as `x` .
|
||||
Has same shape as `x`, except for dimension 0 which is the number of segments.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `indices` or `segment_ids` is not a tensor.
|
||||
TypeError: If the dtype of `indices` and `segment_ids` is not int32 or int64.
|
||||
ValueError: If dimension size of `x` less than 1.
|
||||
ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor.
|
||||
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
||||
ValueError: If indices in `segment_ids` are not contiguous or do not start from 0.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If `indices` is out of range of x's first shape.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[0, 1, 2], [1, 2, 3], [3, 6, 7]], dtype=ms.float32)
|
||||
>>> indices = Tensor([0, 1, 2], dtype=ms.int32)
|
||||
>>> segment_ids = Tensor([0, 1, 1], dtype=ms.int32)
|
||||
>>> sparse_segment_sum = ops.SparseSegmentSum()
|
||||
>>> out = sparse_segment_sum(x, indices, segment_ids)
|
||||
>>> print(out)
|
||||
[[ 0. 1. 2.]
|
||||
[ 4. 8. 10.]]
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseSegmentSum"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids'], outputs=['y'])
|
||||
self.add_prim_attr("cust_aicpu", self.name)
|
||||
|
||||
|
||||
class SparseSegmentSumWithNumSegments(Primitive):
|
||||
"""
|
||||
Computes the sum along sparse segments of a tensor, but it is allowed to miss id in segment_ids.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - A Tensor.
|
||||
- **indices** (Tensor) - Indices is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Has same rank as segment_ids. The shape should be :math:`(N,)`.
|
||||
- **segment_ids** (Tensor) - Segment_ids is a 1-D tensor. Must be one of the following types: int32, int64.
|
||||
Values should be sorted and can be repeated. The shape should be :math:`(N,)`.
|
||||
- **num_segments** (Tensor) - Num_segments should equal the number of distinct segment_ids.
|
||||
|
||||
Outputs:
|
||||
A Tensor. Has the same type as `x` .
|
||||
Has same shape as `x`, except for dimension 0 which is the value of `num_segments`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `indices` or `segment_ids` or `num_segments` is not a tensor.
|
||||
TypeError: If the dtype of `indices` and `segment_ids` and `num_segments` is not int32 or int64.
|
||||
ValueError: If dimension size of `x` less than 1.
|
||||
ValueError: If any of `indices` and `segment_ids` is not a 1-D tensor.
|
||||
ValueError: If rank of `num_segments` is bigger than 1.
|
||||
ValueError: If numelements of `num_segments` is not 1.
|
||||
ValueError: If shape[0] of `indices` is not corresponding to shape[0] of `segment_ids`.
|
||||
ValueError: If `segment_ids` is not sorted.
|
||||
ValueError: If the last number of `segment_ids` is bigger than or equal to `num_segments`.
|
||||
ValueError: If `indices` is out of range of x's first shape.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)
|
||||
>>> indices = Tensor([0, 2, 1], dtype=ms.int32)
|
||||
>>> segment_ids = Tensor([0, 0, 2], dtype=ms.int32)
|
||||
>>> num_segments = Tensor([4], dtype=ms.int32)
|
||||
>>> sparse_segment_sum_with_num_segments = ops.SparseSegmentSumWithNumSegments()
|
||||
>>> output = sparse_segment_sum_with_num_segments(x, indices, segment_ids, num_segments)
|
||||
>>> print(output)
|
||||
[[1. 1. 1. 0.]
|
||||
[0. 0. 0. 0.]
|
||||
[0. 1. 1. 0.]
|
||||
[0. 0. 0. 0.]]
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('num_segemnts', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseSegmentSumWithNumSegments"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'segment_ids', 'num_segments'], outputs=['y'])
|
||||
self.add_prim_attr("cust_aicpu", self.name)
|
||||
|
||||
|
||||
class SparseSegmentSqrtN(Primitive):
|
||||
"""
|
||||
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
|
||||
|
@ -1151,7 +1261,7 @@ class SparseSegmentSqrtN(Primitive):
|
|||
ValueError: If `indices` is out of range of x's first dimension.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]]).astype(np.float32))
|
||||
|
@ -1164,6 +1274,11 @@ class SparseSegmentSqrtN(Primitive):
|
|||
[ 5. 6. 7. 8.]
|
||||
[ 9. 10. 11. 12.]]
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('indices', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('segment_ids', dtype=sig.sig_dtype.T)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
|
@ -1208,7 +1323,7 @@ class SparseSegmentSqrtNWithNumSegments(Primitive):
|
|||
ValueError: If `indices` is out of range of x's first dimension.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[0, 1, 0, 0], [0, 1, 1, 0], [1, 0, 1, 0]], dtype=ms.float16)
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class SparseSegmentSqrtNGradNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseSegmentSqrtNGradNet, self).__init__()
|
||||
self.net = P.SparseSegmentSqrtNGrad()
|
||||
|
||||
@ms_function
|
||||
def construct(self, grad, indices, segment_ids, output_dim0):
|
||||
return self.net(grad, indices, segment_ids, output_dim0)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n_grad(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
|
||||
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
|
||||
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
|
||||
output_dim0_np = np.array(8, dtype=np.int32)
|
||||
grad_ms = Tensor(grad_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
output_dim0_ms = Tensor(output_dim0_np)
|
||||
net_ms = SparseSegmentSqrtNGradNet()
|
||||
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
|
||||
expected = np.array([[6, 8, 10, 12],
|
||||
[6.363961, 7.071068, 7.7781744, 8.485281],
|
||||
[15.55635, 16.970562, 18.384777, 19.798988],
|
||||
[9.192389, 9.899495, 10.606602, 11.313708],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n_grad_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
|
||||
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
|
||||
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
|
||||
output_dim0_np = np.array(8, dtype=np.int64)
|
||||
grad_ms = Tensor(grad_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
output_dim0_ms = Tensor(output_dim0_np)
|
||||
net_ms = SparseSegmentSqrtNGradNet()
|
||||
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
|
||||
expected = np.array([[0.70710678, 1.41421356, 2.12132034, 2.82842712],
|
||||
[5.70710678, 7.41421356, 9.12132034, 10.82842712],
|
||||
[6.36396103, 7.07106781, 7.77817459, 8.48528137],
|
||||
[19.36396103, 21.07106781, 22.77817459, 24.48528137],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_grad_graph_float32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSqrtNGrad
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sqrt_n_grad(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_grad_pynative_float64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSqrtNGrad
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sqrt_n_grad_pynative(loss=1.0e-5)
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.sparse_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class SparseSegmentSqrtNNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseSegmentSqrtNNet, self).__init__()
|
||||
self.net = P.SparseSegmentSqrtN()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, indices, segment_ids):
|
||||
return self.net(x, indices, segment_ids)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
|
||||
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
|
||||
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
net_ms = SparseSegmentSqrtNNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
|
||||
expected = np.array([[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[9.899495, 11.313708, 12.727922, 14.142136],
|
||||
[15.556349, 16.970562, 18.384777, 19.79899]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
|
||||
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
|
||||
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
net_ms = SparseSegmentSqrtNNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
|
||||
expected = np.array([[4.24264069, 5.65685425, 7.07106781, 8.48528137],
|
||||
[5, 6, 7, 8],
|
||||
[15.55634919, 16.97056275, 18.38477631, 19.79898987],
|
||||
[13, 14, 15, 16]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_graph_float32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSqrtN
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sqrt_n(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_pynative_float64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSqrtN
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sqrt_n_pynative(loss=1.0e-5)
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.sparse_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class SparseSegmentSqrtNWithNumSegmentsNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseSegmentSqrtNWithNumSegmentsNet, self).__init__()
|
||||
self.net = P.SparseSegmentSqrtNWithNumSegments()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, indices, segment_ids, num_segments):
|
||||
return self.net(x, indices, segment_ids, num_segments)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n_with_num_segments(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
|
||||
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
|
||||
segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32)
|
||||
num_segments_np = np.array(8, dtype=np.int32)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
num_segments_ms = Tensor(num_segments_np)
|
||||
net_ms = SparseSegmentSqrtNWithNumSegmentsNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
|
||||
expected = np.array([[1, 2, 3, 4],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[4.2426405, 5.656854, 7.071068, 8.485281],
|
||||
[0, 0, 0, 0],
|
||||
[9, 10, 11, 12],
|
||||
[0, 0, 0, 0],
|
||||
[15.556349, 16.970562, 18.384777, 19.79899]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def sparse_segment_sqrt_n_with_num_segments_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
|
||||
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
|
||||
segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64)
|
||||
num_segments_np = np.array(8, dtype=np.int64)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
num_segments_ms = Tensor(num_segments_np)
|
||||
net_ms = SparseSegmentSqrtNWithNumSegmentsNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
|
||||
expected = np.array([[4.24264069, 5.65685425, 7.07106781, 8.48528137],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[5, 6, 7, 8],
|
||||
[0, 0, 0, 0],
|
||||
[15.55634919, 16.97056275, 18.38477631, 19.79898987],
|
||||
[0, 0, 0, 0],
|
||||
[13, 14, 15, 16]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_with_num_segments_graph_float32_int32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSqrtNWithNumSegments
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sqrt_n_with_num_segments(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sqrt_n_with_num_segments_pynative_float64_int64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSqrtNWithNumSegments
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sqrt_n_with_num_segments_pynative(loss=1.0e-5)
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class SparseSegmentSumGradNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseSegmentSumGradNet, self).__init__()
|
||||
self.net = P.SparseSegmentSumGrad()
|
||||
|
||||
@ms_function
|
||||
def construct(self, grad, indices, segment_ids, output_dim0):
|
||||
return self.net(grad, indices, segment_ids, output_dim0)
|
||||
|
||||
|
||||
def sparse_segment_sum_grad(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
|
||||
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
|
||||
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
|
||||
output_dim0_np = np.array(8, dtype=np.int32)
|
||||
grad_ms = Tensor(grad_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
output_dim0_ms = Tensor(output_dim0_np)
|
||||
net_ms = SparseSegmentSumGradNet()
|
||||
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
|
||||
expected = np.array([[6, 8, 10, 12],
|
||||
[9, 10, 11, 12],
|
||||
[22, 24, 26, 28],
|
||||
[13, 14, 15, 16],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def sparse_segment_sum_grad_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
grad_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
|
||||
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
|
||||
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
|
||||
output_dim0_np = np.array(8, dtype=np.int64)
|
||||
grad_ms = Tensor(grad_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
output_dim0_ms = Tensor(output_dim0_np)
|
||||
net_ms = SparseSegmentSumGradNet()
|
||||
out_ms = net_ms(grad_ms, indices_ms, segment_ids_ms, output_dim0_ms)
|
||||
expected = np.array([[1, 2, 3, 4.],
|
||||
[6, 8, 10, 12],
|
||||
[9, 10, 11, 12],
|
||||
[22, 24, 26, 28],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sum_grad_graph_float32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSumGrad
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sum_grad(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sum_grad_pynative_float64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSumGrad
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sum_grad_pynative(loss=1.0e-5)
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.sparse_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class SparseSegmentSumNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseSegmentSumNet, self).__init__()
|
||||
self.net = P.SparseSegmentSum()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, indices, segment_ids):
|
||||
return self.net(x, indices, segment_ids)
|
||||
|
||||
|
||||
def sparse_segment_sum(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
|
||||
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
|
||||
segment_ids_np = np.array([0, 1, 2, 2, 3, 3], dtype=np.int32)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
net_ms = SparseSegmentSumNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
|
||||
expected = np.array([[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[14, 16, 18, 20],
|
||||
[22, 24, 26, 28]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def sparse_segment_sum_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
|
||||
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
|
||||
segment_ids_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int64)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
net_ms = SparseSegmentSumNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms)
|
||||
expected = np.array([[6, 8, 10, 12],
|
||||
[5, 6, 7, 8],
|
||||
[22, 24, 26, 28],
|
||||
[13, 14, 15, 16]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sum_graph_float32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSum
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sum(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sum_pynative_float64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSum
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sum_pynative(loss=1.0e-5)
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.sparse_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class SparseSegmentSumWithNumSegmentsNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseSegmentSumWithNumSegmentsNet, self).__init__()
|
||||
self.net = P.SparseSegmentSumWithNumSegments()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, indices, segment_ids, num_segments):
|
||||
return self.net(x, indices, segment_ids, num_segments)
|
||||
|
||||
|
||||
def sparse_segment_sum_with_num_segments(loss):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32)
|
||||
indices_np = np.array([0, 0, 1, 2, 2, 3], dtype=np.int32)
|
||||
segment_ids_np = np.array([0, 3, 3, 5, 7, 7], dtype=np.int32)
|
||||
num_segments_np = np.array(8, dtype=np.int32)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
num_segments_ms = Tensor(num_segments_np)
|
||||
net_ms = SparseSegmentSumWithNumSegmentsNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
|
||||
expected = np.array([[1, 2, 3, 4],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[6, 8, 10, 12],
|
||||
[0, 0, 0, 0],
|
||||
[9, 10, 11, 12],
|
||||
[0, 0, 0, 0],
|
||||
[22, 24, 26, 28]], dtype=np.float32)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
def sparse_segment_sum_with_num_segments_pynative(loss):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
x_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float64)
|
||||
indices_np = np.array([0, 1, 1, 2, 3, 3], dtype=np.int64)
|
||||
segment_ids_np = np.array([0, 0, 3, 5, 5, 7], dtype=np.int64)
|
||||
num_segments_np = np.array(8, dtype=np.int64)
|
||||
x_ms = Tensor(x_np)
|
||||
indices_ms = Tensor(indices_np)
|
||||
segment_ids_ms = Tensor(segment_ids_np)
|
||||
num_segments_ms = Tensor(num_segments_np)
|
||||
net_ms = SparseSegmentSumWithNumSegmentsNet()
|
||||
out_ms = net_ms(x_ms, indices_ms, segment_ids_ms, num_segments_ms)
|
||||
expected = np.array([[6, 8, 10, 12],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[5, 6, 7, 8],
|
||||
[0, 0, 0, 0],
|
||||
[22, 24, 26, 28],
|
||||
[0, 0, 0, 0],
|
||||
[13, 14, 15, 16]], dtype=np.float64)
|
||||
assert np.allclose(out_ms.asnumpy(), expected, loss, loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sum_with_num_segments_graph_float32_int32_int32_int32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSumWithNumSegments
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sum_with_num_segments(loss=1.0e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_sparse_segment_sum_with_num_segments_pynative_float64_int64_int64_int64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for SparseSegmentSumWithNumSegments
|
||||
Expectation: the result match to tensorflow
|
||||
"""
|
||||
sparse_segment_sum_with_num_segments_pynative(loss=1.0e-5)
|
Loading…
Reference in New Issue