This commit is contained in:
parent
f02867e3b8
commit
09432c95f8
|
@ -0,0 +1,171 @@
|
|||
/* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SEGMENT_OPS_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SEGMENT_OPS_HELPER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/segment_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
enum SegmentOpsOptype {
|
||||
SEGMENT_OP_MAX = 0,
|
||||
SEGMENT_OP_MIN = 1,
|
||||
SEGMENT_OP_MEAN = 2,
|
||||
SEGMENT_OP_SUM = 3,
|
||||
SEGMENT_OP_PROD = 4,
|
||||
SEGMENT_OP_INVALID_TYPE = 5
|
||||
};
|
||||
|
||||
static const std::map<std::string, SegmentOpsOptype> kSegmentOpsOpTypeMap = {{"SegmentMax", SEGMENT_OP_MAX},
|
||||
{"SegmentMin", SEGMENT_OP_MIN},
|
||||
{"SegmentMean", SEGMENT_OP_MEAN},
|
||||
{"SegmentSum", SEGMENT_OP_SUM},
|
||||
{"SegmentProd", SEGMENT_OP_PROD}};
|
||||
template <typename T, typename S>
|
||||
class SegmentOpsHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit SegmentOpsHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
Segment_ops_op_type_ = SEGMENT_OP_INVALID_TYPE;
|
||||
}
|
||||
|
||||
virtual ~SegmentOpsHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
ResetResource();
|
||||
auto iter = kSegmentOpsOpTypeMap.find(kernel_name_);
|
||||
if (iter == kSegmentOpsOpTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "For 'SegmentOps', only support these types: " << kernel::Map2Str(kSegmentOpsOpTypeMap)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
return -1;
|
||||
}
|
||||
Segment_ops_op_type_ = iter->second;
|
||||
constexpr size_t INPUT_NUM = 2;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
int inp_flag = CalShapesNum(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
|
||||
if (inp_flag == 1 || out_flag == 1) {
|
||||
is_null_input_ = true;
|
||||
return 0;
|
||||
}
|
||||
if (input_shapes[0].size() < 1) {
|
||||
MS_LOG(ERROR) << "For 'SegmentOps', data shape must be more than 1D, but got " << input_shapes[0].size();
|
||||
return -1;
|
||||
}
|
||||
if (input_shapes[1].size() != 1) {
|
||||
MS_LOG(ERROR) << "For 'SegmentOps', segment_ids' shape only support 1D, but got " << input_shapes[1].size();
|
||||
return -1;
|
||||
}
|
||||
outer_class_ = output_shapes[0][0];
|
||||
outer_size_ = input_shapes[0][0];
|
||||
inner_size_ = input_size_list_[0] / outer_size_;
|
||||
size_t segment_id_num = static_cast<size_t>(input_shapes[1][0]);
|
||||
if (segment_id_num != outer_size_) {
|
||||
MS_LOG(ERROR) << "For 'SegmentOps', the length of segment_id must be equal to input_shape[0],"
|
||||
" but got the length of segment_id : "
|
||||
<< segment_id_num << ", and input_shape[0] " << outer_size_;
|
||||
return -1;
|
||||
}
|
||||
input_size_list_[0] *= sizeof(T);
|
||||
input_size_list_[1] *= sizeof(S);
|
||||
work_size_list_.emplace_back((outer_size_ + 1) * sizeof(size_t));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int CalShapesNum(const std::vector<std::vector<int64_t>> &shapes, const size_t shape_num,
|
||||
const std::string kernel_name, const std::string param_name,
|
||||
std::vector<size_t> *shapes_size) {
|
||||
if (shape_num != shapes.size()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name << "', the number of " << param_name << "should be equal to " << shape_num
|
||||
<< ", but got " << shapes.size();
|
||||
return -1;
|
||||
}
|
||||
int return_flag = 0;
|
||||
for (size_t idx = 0; idx < shape_num; ++idx) {
|
||||
size_t cur_size = 1;
|
||||
if (shapes[idx].size() == 0) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name << "', the shapes[" << idx << "] is ( )";
|
||||
shapes_size->emplace_back(cur_size);
|
||||
continue;
|
||||
}
|
||||
for (const auto &val : shapes[idx]) {
|
||||
cur_size *= val;
|
||||
}
|
||||
if (cur_size == 0) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name << "', got shapes[" << idx << "] is "
|
||||
<< ConvertVectorToString(shapes[idx]);
|
||||
return_flag = 1;
|
||||
}
|
||||
shapes_size->emplace_back(cur_size);
|
||||
}
|
||||
return return_flag;
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
T *input_addr;
|
||||
S *seg_id_addr;
|
||||
size_t *seg_pos_addr;
|
||||
T *output_addr;
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_addr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<S>(input_ptrs, 1, kernel_name_, &seg_id_addr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<size_t>(work_ptrs, 0, kernel_name_, &seg_pos_addr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_addr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
CalSegmentCombination(input_addr, output_addr, seg_id_addr, seg_pos_addr, Segment_ops_op_type_, inner_size_,
|
||||
outer_size_, outer_class_, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
inner_size_ = 1;
|
||||
outer_size_ = 1;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
work_size_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
SegmentOpsOptype Segment_ops_op_type_;
|
||||
size_t inner_size_;
|
||||
size_t outer_size_;
|
||||
size_t outer_class_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SEGMENT_OPS_HELPER_H_
|
|
@ -0,0 +1,396 @@
|
|||
/*copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/segment_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
|
||||
|
||||
// Basic function
|
||||
|
||||
template <typename DataType>
|
||||
struct MinFunc {
|
||||
__device__ __host__ __forceinline__ DataType operator()(const DataType &lhs, const DataType &rhs) {
|
||||
return lhs < rhs ? lhs : rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinFunc<Complex<float>> {
|
||||
__device__ __host__ __forceinline__ Complex<float> operator()(const Complex<float> &lhs, const Complex<float> &rhs) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinFunc<Complex<double>> {
|
||||
__device__ __host__ __forceinline__
|
||||
Complex<double> operator()(const Complex<double> &lhs, const Complex<double> &rhs) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename DataType>
|
||||
struct MaxFunc {
|
||||
__device__ __host__ __forceinline__ DataType operator()(const DataType &lhs, const DataType &rhs) {
|
||||
return lhs > rhs ? lhs : rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxFunc<Complex<float>> {
|
||||
__device__ __host__ __forceinline__ Complex<float> operator()(const Complex<float> &lhs, const Complex<float> &rhs) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxFunc<Complex<double>> {
|
||||
__device__ __host__ __forceinline__
|
||||
Complex<double> operator()(const Complex<double> &lhs, const Complex<double> &rhs) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct AddFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs + rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MulFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
|
||||
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
|
||||
return (lhs * rhs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
DataType max_val_init() {
|
||||
return std::numeric_limits<DataType>::max();
|
||||
}
|
||||
|
||||
template <>
|
||||
half max_val_init() {
|
||||
return 65504; // Max value for Half
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
DataType min_val_init() {
|
||||
return std::numeric_limits<DataType>::lowest();
|
||||
}
|
||||
|
||||
template <>
|
||||
half min_val_init() {
|
||||
return -65504; // Max value for Half
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
DataType get_default_value(size_t op) {
|
||||
return static_cast<DataType>(0);
|
||||
}
|
||||
|
||||
template <>
|
||||
half get_default_value(size_t op) {
|
||||
return op == 0 ? -65504 : 65504;
|
||||
}
|
||||
|
||||
template <>
|
||||
float get_default_value(size_t op) {
|
||||
return op == 0 ? std::numeric_limits<double>::lowest() : -std::numeric_limits<float>::lowest();
|
||||
}
|
||||
|
||||
template <>
|
||||
double get_default_value(size_t op) {
|
||||
return op == 0 ? std::numeric_limits<double>::lowest() : -std::numeric_limits<double>::lowest();
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
__global__ void CalSegmentPos(const IndexType *segment_ids_ptr, size_t *segment_pos_ptr, const size_t segment_size) {
|
||||
for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos <= segment_size; pos += blockDim.x * gridDim.x) {
|
||||
IndexType temp =
|
||||
(segment_size > (segment_ids_ptr[segment_size - 1]) + 1) ? segment_size : (segment_ids_ptr[segment_size - 1] + 1);
|
||||
IndexType begin_pos = (pos == 0) ? 0 : (segment_ids_ptr[pos - 1] + 1);
|
||||
IndexType end_pos = (pos != segment_size) ? segment_ids_ptr[pos] : temp;
|
||||
const IndexType max_size = static_cast<IndexType>(temp);
|
||||
const IndexType min_size = IndexType(0);
|
||||
begin_pos = max(min_size, min(max_size, begin_pos));
|
||||
end_pos = max(min_size, min(max_size, end_pos));
|
||||
for (IndexType j = begin_pos; j <= end_pos; ++j) {
|
||||
segment_pos_ptr[j] = pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType, typename Func>
|
||||
__device__ DataType ReduceWithinBlock(const DataType &value) {
|
||||
// Refer to reduce3 from Mark Harris, et al. 'Optimizing Parallel Reduction in CUDA'.
|
||||
extern __shared__ __align__(16) char share_data[];
|
||||
DataType *share_data_ptr = reinterpret_cast<DataType *>(share_data);
|
||||
const unsigned int x = threadIdx.x;
|
||||
const unsigned int y = threadIdx.y;
|
||||
const unsigned int tid = y * blockDim.x + x;
|
||||
share_data_ptr[tid] = value;
|
||||
__syncthreads();
|
||||
// Reduce over the y dimension of the block.
|
||||
for (unsigned k = blockDim.y / 2; k > 0; k /= 2) {
|
||||
if (y < k) {
|
||||
share_data_ptr[tid] = Func()(share_data_ptr[tid], share_data_ptr[(y + k) * blockDim.x + x]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
return share_data_ptr[tid];
|
||||
}
|
||||
|
||||
template <typename DataType, typename Func, typename IndexType>
|
||||
__global__ void SegmentProcess(DataType *inp_ptr, DataType *out_ptr, size_t *seg_pos_ptr, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, size_t op, DataType init_K,
|
||||
DataType default_value, IndexType *seg_id_ptr) {
|
||||
for (size_t thread_x = threadIdx.x + blockIdx.x * blockDim.x; thread_x < inner_size;
|
||||
thread_x += blockDim.x * gridDim.x) {
|
||||
for (size_t block_idx_y = blockIdx.y; block_idx_y < outer_class; block_idx_y += gridDim.y) {
|
||||
size_t begin_pos = seg_pos_ptr[block_idx_y];
|
||||
size_t end_pos = seg_pos_ptr[block_idx_y + 1];
|
||||
DataType res = init_K;
|
||||
DataType cur_data = init_K;
|
||||
for (size_t pos = begin_pos; pos < end_pos; pos += blockDim.y) {
|
||||
size_t thread_y = pos + threadIdx.y;
|
||||
cur_data = (thread_y < end_pos) ? inp_ptr[thread_y * inner_size + thread_x] : static_cast<DataType>(0);
|
||||
cur_data = ReduceWithinBlock<DataType, Func>(cur_data);
|
||||
if (threadIdx.y == 0) {
|
||||
res = Func()(res, cur_data);
|
||||
}
|
||||
}
|
||||
if (threadIdx.y == 0) {
|
||||
if (op == 2) {
|
||||
DataType segment_len = DataType(static_cast<double>(end_pos - begin_pos));
|
||||
out_ptr[block_idx_y * inner_size + thread_x] =
|
||||
(begin_pos >= end_pos) ? static_cast<DataType>(0) : res / segment_len;
|
||||
} else if (op == 3) {
|
||||
out_ptr[block_idx_y * inner_size + thread_x] = (begin_pos >= end_pos) ? static_cast<DataType>(0) : res;
|
||||
} else if (op == 4) {
|
||||
out_ptr[block_idx_y * inner_size + thread_x] = (begin_pos >= end_pos) ? static_cast<DataType>(1) : res;
|
||||
} else {
|
||||
out_ptr[block_idx_y * inner_size + thread_x] = (begin_pos >= end_pos) ? default_value : res;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int Log2Floor(uint32_t n) {
|
||||
if (n == 0) return -1;
|
||||
int log = 0;
|
||||
for (int i = 4; i >= 0; --i) {
|
||||
int shift = (1 << i);
|
||||
uint32_t x = n >> shift;
|
||||
if (x) {
|
||||
n = x;
|
||||
log += shift;
|
||||
}
|
||||
}
|
||||
return log;
|
||||
}
|
||||
|
||||
inline int Log2Floor64(uint64_t n) {
|
||||
// Scan n first high 32 then low 32 bits.
|
||||
const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
|
||||
if (high_32_bit == 0) {
|
||||
return Log2Floor(static_cast<uint32_t>(n));
|
||||
} else {
|
||||
return 32 + Log2Floor(high_32_bit);
|
||||
}
|
||||
}
|
||||
|
||||
inline int Log2Ceil64(uint64_t n) {
|
||||
int floor = Log2Floor64(n);
|
||||
if (n == (n & ~(n - 1)))
|
||||
return floor;
|
||||
else
|
||||
return floor + 1;
|
||||
}
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
CUDA_LIB_EXPORT void CalSegmentCombination(DataType *inp_ptr, DataType *out_ptr, IndexType *seg_id_ptr,
|
||||
size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
// Get start position of each segment and set to segment_pos_ptr.
|
||||
// The last element of segment_pos_ptr must equal to indices_size.
|
||||
const unsigned int segment_size = outer_size + 1;
|
||||
// size_t segment_pos_length[1] = {0};
|
||||
CalSegmentPos<<<CUDA_BLOCKS(device_id, segment_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
seg_id_ptr, seg_pos_ptr, outer_size);
|
||||
const unsigned int max_grid_x = (1u << 31) - 1;
|
||||
const unsigned int max_grid_y = (1u << 16) - 1;
|
||||
const unsigned int max_block_x = 1024;
|
||||
const unsigned int max_block_y = 64;
|
||||
unsigned int inner_power2 = 1u << Log2Ceil64(inner_size);
|
||||
unsigned int avg_reduce_size = UP_DIV(outer_size, outer_size);
|
||||
unsigned int avg_reduce_size_power2 = 1u << Log2Ceil64(avg_reduce_size);
|
||||
unsigned int block_x = std::min(inner_power2, max_block_x);
|
||||
unsigned int block_y = std::min(avg_reduce_size_power2, UP_DIV(max_block_y, block_x));
|
||||
unsigned int grid_x = std::min(static_cast<unsigned int>(UP_DIV(inner_size, block_x)), max_grid_x);
|
||||
unsigned int grid_y = std::min(segment_size, max_grid_y);
|
||||
dim3 block(block_x, block_y);
|
||||
dim3 grid(grid_x, grid_y);
|
||||
unsigned int shared_memory_size = block_x * block_y * sizeof(DataType);
|
||||
DataType init_K = std::numeric_limits<DataType>::lowest();
|
||||
DataType default_value = get_default_value<DataType>(op);
|
||||
switch (op) {
|
||||
case 0:
|
||||
init_K = min_val_init<DataType>();
|
||||
return SegmentProcess<DataType, MaxFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
|
||||
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
|
||||
case 1:
|
||||
init_K = max_val_init<DataType>();
|
||||
return SegmentProcess<DataType, MinFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
|
||||
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
|
||||
case 2:
|
||||
init_K = 0.0;
|
||||
return SegmentProcess<DataType, AddFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
|
||||
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
|
||||
case 3:
|
||||
init_K = 0.0;
|
||||
return SegmentProcess<DataType, AddFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
|
||||
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
|
||||
case 4:
|
||||
init_K = 1.0;
|
||||
return SegmentProcess<DataType, MulFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
|
||||
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<float, int32_t>(float *inp_ptr, float *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<double, int32_t>(double *inp_ptr, double *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<half, int32_t>(half *inp_ptr, half *out_ptr, int32_t *seg_id_addr,
|
||||
size_t *seg_pos_ptr, size_t op,
|
||||
const size_t inner_size, const size_t outer_size,
|
||||
const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<float>, int32_t>(
|
||||
Complex<float> *inp_ptr, Complex<float> *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
|
||||
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<double>, int32_t>(
|
||||
Complex<double> *inp_ptr, Complex<double> *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
|
||||
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int8_t, int32_t>(int8_t *inp_ptr, int8_t *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int16_t, int32_t>(int16_t *inp_ptr, int16_t *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int32_t, int32_t>(int32_t *inp_ptr, int32_t *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int64_t, int32_t>(int64_t *inp_ptr, int64_t *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint8_t, int32_t>(uint8_t *inp_ptr, uint8_t *out_ptr,
|
||||
int32_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint16_t, int32_t>(
|
||||
uint16_t *inp_ptr, uint16_t *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint32_t, int32_t>(
|
||||
uint32_t *inp_ptr, uint32_t *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint64_t, int32_t>(
|
||||
uint64_t *inp_ptr, uint64_t *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<float, int64_t>(float *inp_ptr, float *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<double, int64_t>(double *inp_ptr, double *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<half, int64_t>(half *inp_ptr, half *out_ptr, int64_t *seg_id_addr,
|
||||
size_t *seg_pos_ptr, size_t op,
|
||||
const size_t inner_size, const size_t outer_size,
|
||||
const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<float>, int64_t>(
|
||||
Complex<float> *inp_ptr, Complex<float> *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
|
||||
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<double>, int64_t>(
|
||||
Complex<double> *inp_ptr, Complex<double> *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
|
||||
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int8_t, int64_t>(int8_t *inp_ptr, int8_t *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int16_t, int64_t>(int16_t *inp_ptr, int16_t *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int32_t, int64_t>(int32_t *inp_ptr, int32_t *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<int64_t, int64_t>(int64_t *inp_ptr, int64_t *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint8_t, int64_t>(uint8_t *inp_ptr, uint8_t *out_ptr,
|
||||
int64_t *seg_id_addr, size_t *seg_pos_ptr,
|
||||
size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint16_t, int64_t>(
|
||||
uint16_t *inp_ptr, uint16_t *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint32_t, int64_t>(
|
||||
uint32_t *inp_ptr, uint32_t *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSegmentCombination<uint64_t, int64_t>(
|
||||
uint64_t *inp_ptr, uint64_t *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SEGMENT_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SEGMENT_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename DataType, typename IndexType>
|
||||
CUDA_LIB_EXPORT void CalSegmentCombination(DataType *inp_ptr, DataType *out_ptr, IndexType *seg_id_addr,
|
||||
size_t *seg_pos_ptr, size_t op, const size_t inner_size,
|
||||
const size_t outer_size, const size_t outer_class,
|
||||
uint32_t device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SEGMENT_IMPL_CUH_
|
|
@ -0,0 +1,395 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/math/segment_ops_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr auto Segment_max = "SegmentMax";
|
||||
constexpr auto Segment_min = "SegmentMin";
|
||||
constexpr auto Segment_mean = "SegmentMean";
|
||||
constexpr auto Segment_sum = "SegmentSum";
|
||||
constexpr auto Segment_prod = "SegmentProd";
|
||||
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSegmentOpsKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::SegmentOpsHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
using SegmentOpsPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
const std::map<std::string, std::vector<std::pair<KernelAttr, SegmentOpsPtrCreatorFunc>>> kernel_attr_map = {
|
||||
{Segment_max,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
|
||||
{Segment_min,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
|
||||
{Segment_mean,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateSegmentOpsKernelPtr<Complex<float>, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateSegmentOpsKernelPtr<Complex<double>, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateSegmentOpsKernelPtr<Complex<float>, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateSegmentOpsKernelPtr<Complex<double>, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
|
||||
{Segment_sum,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateSegmentOpsKernelPtr<Complex<float>, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateSegmentOpsKernelPtr<Complex<double>, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateSegmentOpsKernelPtr<Complex<float>, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateSegmentOpsKernelPtr<Complex<double>, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
|
||||
{Segment_prod,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateSegmentOpsKernelPtr<Complex<float>, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateSegmentOpsKernelPtr<Complex<double>, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSegmentOpsKernelPtr<half, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateSegmentOpsKernelPtr<Complex<float>, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateSegmentOpsKernelPtr<Complex<double>, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSegmentOpsKernelPtr<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSegmentOpsKernelPtr<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}}}; // kernel_attr_map_
|
||||
} // namespace
|
||||
|
||||
bool SegmentOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
constexpr size_t inputs_num = 2;
|
||||
constexpr size_t outputs_num = 1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_);
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << tensor_attr << ".";
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr_map.at(kernel_type_)[index].second(kernel_name_, device_id_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SegmentOpsGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int SegmentOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
for (const auto &output : outputs) {
|
||||
auto output_shape = output->GetShapeVector();
|
||||
if (!IsValidShape(output_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> segment_ids_shape = inputs[1]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape);
|
||||
input_shapes.emplace_back(segment_ids_shape);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SegmentOpsGpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map.find(kernel_type_);
|
||||
if (iter == kernel_attr_map.end()) {
|
||||
MS_LOG(ERROR) << "For 'SegmentOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map)
|
||||
<< " currently, but got " << kernel_name_;
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SegmentOpsPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentMax,
|
||||
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_max); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentMin,
|
||||
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_min); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentMean,
|
||||
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_mean); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentSum,
|
||||
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_sum); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentProd,
|
||||
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_prod); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_SEGMENT_OPS_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_SEGMENT_OPS_GPU_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/segment_ops_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SegmentOpsGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
explicit SegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~SegmentOpsGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::string kernel_type_{"Unknown"};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -120,7 +120,7 @@ TypePtr SegmentMaxInferType(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentMax, PrimitiveC, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(SegmentMax, BaseOperator);
|
||||
AbstractBasePtr SegmentMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
|
|
|
@ -123,7 +123,7 @@ TypePtr SegmentMeanInferType(const PrimitivePtr &primitive, const std::vector<Ab
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentMean, PrimitiveC, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(SegmentMean, BaseOperator);
|
||||
AbstractBasePtr SegmentMeanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
|
|
|
@ -120,7 +120,7 @@ TypePtr SegmentMinInferType(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentMin, PrimitiveC, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(SegmentMin, BaseOperator);
|
||||
AbstractBasePtr SegmentMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
|
|
|
@ -123,7 +123,7 @@ TypePtr SegmentProdInferType(const PrimitivePtr &primitive, const std::vector<Ab
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentProd, PrimitiveC, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(SegmentProd, BaseOperator);
|
||||
AbstractBasePtr SegmentProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
|
|
|
@ -120,7 +120,7 @@ TypePtr SegmentSumInferType(const PrimitivePtr &primitive, const std::vector<Abs
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_BASE_IMPL(SegmentSum, PrimitiveC, BaseOperator);
|
||||
MIND_API_OPERATOR_IMPL(SegmentSum, BaseOperator);
|
||||
AbstractBasePtr SegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
|
|
|
@ -7568,7 +7568,7 @@ class SegmentMax(Primitive):
|
|||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
|
||||
|
@ -7618,7 +7618,7 @@ class SegmentMin(Primitive):
|
|||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
|
||||
|
@ -7672,7 +7672,7 @@ class SegmentSum(Primitive):
|
|||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
|
||||
|
@ -7873,7 +7873,7 @@ class SegmentMean(Primitive):
|
|||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [1, 2, 3], [7, 8, 9]], mstype.float64)
|
||||
|
@ -7927,7 +7927,7 @@ class SegmentProd(Primitive):
|
|||
ValueError: If the values of `segment_ids` are not sorted in ascending order.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations.array_ops import SegmentMax, SegmentMin, SegmentMean, SegmentSum, SegmentProd
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class SegmentMaxNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.segmentmax = SegmentMax()
|
||||
|
||||
def construct(self, x, segment_ids):
|
||||
return self.segmentmax(x, segment_ids)
|
||||
|
||||
|
||||
class SegmentMinNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.segmentmin = SegmentMin()
|
||||
|
||||
def construct(self, x, segment_ids):
|
||||
return self.segmentmin(x, segment_ids)
|
||||
|
||||
|
||||
class SegmentMeanNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.segmentmean = SegmentMean()
|
||||
|
||||
def construct(self, x, segment_ids):
|
||||
return self.segmentmean(x, segment_ids)
|
||||
|
||||
|
||||
class SegmentSumNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.segmentsum = SegmentSum()
|
||||
|
||||
def construct(self, x, segment_ids):
|
||||
return self.segmentsum(x, segment_ids)
|
||||
|
||||
|
||||
class SegmentProdNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.segmentprod = SegmentProd()
|
||||
|
||||
def construct(self, x, segment_ids):
|
||||
return self.segmentprod(x, segment_ids)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_segment_max_fp():
|
||||
"""
|
||||
Feature: SegmentMax operator.
|
||||
Description: test cases for SegmentMax operator.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor([1, 2, 3], mstype.int32)
|
||||
segment_ids = Tensor([0, 6, 6], mstype.int32)
|
||||
net = SegmentMaxNet()
|
||||
expect = np.array([1, 0, 0, 0, 0, 0, 3]).astype(np.int32)
|
||||
output_gr = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_array_almost_equal(output_gr, expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
output_py = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_almost_equal(output_py, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_segment_min_fp():
|
||||
"""
|
||||
Feature: SegmentMin operator.
|
||||
Description: test cases for SegmentMin operator.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor([1, 2, 3, 4], mstype.int32)
|
||||
segment_ids = Tensor([0, 0, 1, 5], mstype.int32)
|
||||
net = SegmentMinNet()
|
||||
expect = np.array([1, 3, 0, 0, 0, 4]).astype(np.int32)
|
||||
output_gr = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_array_almost_equal(output_gr, expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
output_py = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_almost_equal(output_py, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_segment_sum_fp():
|
||||
"""
|
||||
Feature: SegmentSum operator.
|
||||
Description: test cases for SegmentSum operator.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor([1 + 2j, 2 + 2j, 3 + 2j], mstype.float32)
|
||||
segment_ids = Tensor([0, 0, 2], mstype.int32)
|
||||
net = SegmentSumNet()
|
||||
expect = np.array([3 + 4j, 0, 3 + 2j]).astype(np.float32)
|
||||
output_gr = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_array_almost_equal(output_gr, expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_segment_mean_fp():
|
||||
"""
|
||||
Feature: SegmentMean operator.
|
||||
Description: test cases for SegmentMean operator.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor([2, 2, 3, 4], mstype.float32)
|
||||
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
|
||||
net = SegmentMeanNet()
|
||||
expect = np.array([2, 3, 4]).astype(np.float32)
|
||||
output_gr = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_array_almost_equal(output_gr, expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
output_py = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_almost_equal(output_py, expect)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_segment_prod_fp():
|
||||
"""
|
||||
Feature: SegmentProd operator.
|
||||
Description: test cases for SegmentProd operator.
|
||||
Expectation: the result match expectation.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor([1, 2, 3, 4], mstype.float32)
|
||||
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
|
||||
net = SegmentProdNet()
|
||||
expect = np.array([2, 3, 4]).astype(np.float32)
|
||||
output_gr = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_array_almost_equal(output_gr, expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
output_py = net(input_x, segment_ids).asnumpy()
|
||||
np.testing.assert_almost_equal(output_py, expect)
|
Loading…
Reference in New Issue