This commit is contained in:
gp121 2022-07-08 16:11:21 +08:00 committed by zhangyizhuo1124
parent f02867e3b8
commit 09432c95f8
12 changed files with 1228 additions and 10 deletions

View File

@ -0,0 +1,171 @@
/* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SEGMENT_OPS_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SEGMENT_OPS_HELPER_H_
#include <map>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/segment_impl.cuh"
namespace mindspore {
namespace cukernel {
enum SegmentOpsOptype {
SEGMENT_OP_MAX = 0,
SEGMENT_OP_MIN = 1,
SEGMENT_OP_MEAN = 2,
SEGMENT_OP_SUM = 3,
SEGMENT_OP_PROD = 4,
SEGMENT_OP_INVALID_TYPE = 5
};
static const std::map<std::string, SegmentOpsOptype> kSegmentOpsOpTypeMap = {{"SegmentMax", SEGMENT_OP_MAX},
{"SegmentMin", SEGMENT_OP_MIN},
{"SegmentMean", SEGMENT_OP_MEAN},
{"SegmentSum", SEGMENT_OP_SUM},
{"SegmentProd", SEGMENT_OP_PROD}};
template <typename T, typename S>
class SegmentOpsHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit SegmentOpsHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
is_null_input_ = false;
Segment_ops_op_type_ = SEGMENT_OP_INVALID_TYPE;
}
virtual ~SegmentOpsHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
ResetResource();
auto iter = kSegmentOpsOpTypeMap.find(kernel_name_);
if (iter == kSegmentOpsOpTypeMap.end()) {
MS_LOG(ERROR) << "For 'SegmentOps', only support these types: " << kernel::Map2Str(kSegmentOpsOpTypeMap)
<< " currently, but got " << kernel_name_;
return -1;
}
Segment_ops_op_type_ = iter->second;
constexpr size_t INPUT_NUM = 2;
constexpr size_t OUTPUT_NUM = 1;
int inp_flag = CalShapesNum(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
int out_flag =
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (inp_flag == 1 || out_flag == 1) {
is_null_input_ = true;
return 0;
}
if (input_shapes[0].size() < 1) {
MS_LOG(ERROR) << "For 'SegmentOps', data shape must be more than 1D, but got " << input_shapes[0].size();
return -1;
}
if (input_shapes[1].size() != 1) {
MS_LOG(ERROR) << "For 'SegmentOps', segment_ids' shape only support 1D, but got " << input_shapes[1].size();
return -1;
}
outer_class_ = output_shapes[0][0];
outer_size_ = input_shapes[0][0];
inner_size_ = input_size_list_[0] / outer_size_;
size_t segment_id_num = static_cast<size_t>(input_shapes[1][0]);
if (segment_id_num != outer_size_) {
MS_LOG(ERROR) << "For 'SegmentOps', the length of segment_id must be equal to input_shape[0],"
" but got the length of segment_id : "
<< segment_id_num << ", and input_shape[0] " << outer_size_;
return -1;
}
input_size_list_[0] *= sizeof(T);
input_size_list_[1] *= sizeof(S);
work_size_list_.emplace_back((outer_size_ + 1) * sizeof(size_t));
return 0;
}
static int CalShapesNum(const std::vector<std::vector<int64_t>> &shapes, const size_t shape_num,
const std::string kernel_name, const std::string param_name,
std::vector<size_t> *shapes_size) {
if (shape_num != shapes.size()) {
MS_LOG(ERROR) << "For '" << kernel_name << "', the number of " << param_name << "should be equal to " << shape_num
<< ", but got " << shapes.size();
return -1;
}
int return_flag = 0;
for (size_t idx = 0; idx < shape_num; ++idx) {
size_t cur_size = 1;
if (shapes[idx].size() == 0) {
MS_LOG(WARNING) << "For '" << kernel_name << "', the shapes[" << idx << "] is ( )";
shapes_size->emplace_back(cur_size);
continue;
}
for (const auto &val : shapes[idx]) {
cur_size *= val;
}
if (cur_size == 0) {
MS_LOG(WARNING) << "For '" << kernel_name << "', got shapes[" << idx << "] is "
<< ConvertVectorToString(shapes[idx]);
return_flag = 1;
}
shapes_size->emplace_back(cur_size);
}
return return_flag;
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
T *input_addr;
S *seg_id_addr;
size_t *seg_pos_addr;
T *output_addr;
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_addr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(input_ptrs, 1, kernel_name_, &seg_id_addr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<size_t>(work_ptrs, 0, kernel_name_, &seg_pos_addr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_addr);
if (flag != 0) {
return flag;
}
CalSegmentCombination(input_addr, output_addr, seg_id_addr, seg_pos_addr, Segment_ops_op_type_, inner_size_,
outer_size_, outer_class_, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void ResetResource() noexcept override {
inner_size_ = 1;
outer_size_ = 1;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
work_size_list_.clear();
}
private:
SegmentOpsOptype Segment_ops_op_type_;
size_t inner_size_;
size_t outer_size_;
size_t outer_class_;
bool is_null_input_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SEGMENT_OPS_HELPER_H_

View File

@ -0,0 +1,396 @@
/*copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <limits>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/segment_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
// Basic function
template <typename DataType>
struct MinFunc {
__device__ __host__ __forceinline__ DataType operator()(const DataType &lhs, const DataType &rhs) {
return lhs < rhs ? lhs : rhs;
}
};
template <>
struct MinFunc<Complex<float>> {
__device__ __host__ __forceinline__ Complex<float> operator()(const Complex<float> &lhs, const Complex<float> &rhs) {
return 0;
}
};
template <>
struct MinFunc<Complex<double>> {
__device__ __host__ __forceinline__
Complex<double> operator()(const Complex<double> &lhs, const Complex<double> &rhs) {
return 0;
}
};
template <typename DataType>
struct MaxFunc {
__device__ __host__ __forceinline__ DataType operator()(const DataType &lhs, const DataType &rhs) {
return lhs > rhs ? lhs : rhs;
}
};
template <>
struct MaxFunc<Complex<float>> {
__device__ __host__ __forceinline__ Complex<float> operator()(const Complex<float> &lhs, const Complex<float> &rhs) {
return 0;
}
};
template <>
struct MaxFunc<Complex<double>> {
__device__ __host__ __forceinline__
Complex<double> operator()(const Complex<double> &lhs, const Complex<double> &rhs) {
return 0;
}
};
template <typename T>
struct AddFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs + rhs);
}
};
template <typename T>
struct MulFunc {
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
__device__ __host__ __forceinline__ Complex<T> operator()(const Complex<T> &lhs, const Complex<T> &rhs) {
return (lhs * rhs);
}
};
template <typename DataType>
DataType max_val_init() {
return std::numeric_limits<DataType>::max();
}
template <>
half max_val_init() {
return 65504; // Max value for Half
}
template <typename DataType>
DataType min_val_init() {
return std::numeric_limits<DataType>::lowest();
}
template <>
half min_val_init() {
return -65504; // Max value for Half
}
template <typename DataType>
DataType get_default_value(size_t op) {
return static_cast<DataType>(0);
}
template <>
half get_default_value(size_t op) {
return op == 0 ? -65504 : 65504;
}
template <>
float get_default_value(size_t op) {
return op == 0 ? std::numeric_limits<double>::lowest() : -std::numeric_limits<float>::lowest();
}
template <>
double get_default_value(size_t op) {
return op == 0 ? std::numeric_limits<double>::lowest() : -std::numeric_limits<double>::lowest();
}
template <typename IndexType>
__global__ void CalSegmentPos(const IndexType *segment_ids_ptr, size_t *segment_pos_ptr, const size_t segment_size) {
for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos <= segment_size; pos += blockDim.x * gridDim.x) {
IndexType temp =
(segment_size > (segment_ids_ptr[segment_size - 1]) + 1) ? segment_size : (segment_ids_ptr[segment_size - 1] + 1);
IndexType begin_pos = (pos == 0) ? 0 : (segment_ids_ptr[pos - 1] + 1);
IndexType end_pos = (pos != segment_size) ? segment_ids_ptr[pos] : temp;
const IndexType max_size = static_cast<IndexType>(temp);
const IndexType min_size = IndexType(0);
begin_pos = max(min_size, min(max_size, begin_pos));
end_pos = max(min_size, min(max_size, end_pos));
for (IndexType j = begin_pos; j <= end_pos; ++j) {
segment_pos_ptr[j] = pos;
}
}
}
template <typename DataType, typename Func>
__device__ DataType ReduceWithinBlock(const DataType &value) {
// Refer to reduce3 from Mark Harris, et al. 'Optimizing Parallel Reduction in CUDA'.
extern __shared__ __align__(16) char share_data[];
DataType *share_data_ptr = reinterpret_cast<DataType *>(share_data);
const unsigned int x = threadIdx.x;
const unsigned int y = threadIdx.y;
const unsigned int tid = y * blockDim.x + x;
share_data_ptr[tid] = value;
__syncthreads();
// Reduce over the y dimension of the block.
for (unsigned k = blockDim.y / 2; k > 0; k /= 2) {
if (y < k) {
share_data_ptr[tid] = Func()(share_data_ptr[tid], share_data_ptr[(y + k) * blockDim.x + x]);
}
__syncthreads();
}
return share_data_ptr[tid];
}
template <typename DataType, typename Func, typename IndexType>
__global__ void SegmentProcess(DataType *inp_ptr, DataType *out_ptr, size_t *seg_pos_ptr, const size_t inner_size,
const size_t outer_size, const size_t outer_class, size_t op, DataType init_K,
DataType default_value, IndexType *seg_id_ptr) {
for (size_t thread_x = threadIdx.x + blockIdx.x * blockDim.x; thread_x < inner_size;
thread_x += blockDim.x * gridDim.x) {
for (size_t block_idx_y = blockIdx.y; block_idx_y < outer_class; block_idx_y += gridDim.y) {
size_t begin_pos = seg_pos_ptr[block_idx_y];
size_t end_pos = seg_pos_ptr[block_idx_y + 1];
DataType res = init_K;
DataType cur_data = init_K;
for (size_t pos = begin_pos; pos < end_pos; pos += blockDim.y) {
size_t thread_y = pos + threadIdx.y;
cur_data = (thread_y < end_pos) ? inp_ptr[thread_y * inner_size + thread_x] : static_cast<DataType>(0);
cur_data = ReduceWithinBlock<DataType, Func>(cur_data);
if (threadIdx.y == 0) {
res = Func()(res, cur_data);
}
}
if (threadIdx.y == 0) {
if (op == 2) {
DataType segment_len = DataType(static_cast<double>(end_pos - begin_pos));
out_ptr[block_idx_y * inner_size + thread_x] =
(begin_pos >= end_pos) ? static_cast<DataType>(0) : res / segment_len;
} else if (op == 3) {
out_ptr[block_idx_y * inner_size + thread_x] = (begin_pos >= end_pos) ? static_cast<DataType>(0) : res;
} else if (op == 4) {
out_ptr[block_idx_y * inner_size + thread_x] = (begin_pos >= end_pos) ? static_cast<DataType>(1) : res;
} else {
out_ptr[block_idx_y * inner_size + thread_x] = (begin_pos >= end_pos) ? default_value : res;
}
}
}
}
}
inline int Log2Floor(uint32_t n) {
if (n == 0) return -1;
int log = 0;
for (int i = 4; i >= 0; --i) {
int shift = (1 << i);
uint32_t x = n >> shift;
if (x) {
n = x;
log += shift;
}
}
return log;
}
inline int Log2Floor64(uint64_t n) {
// Scan n first high 32 then low 32 bits.
const uint32_t high_32_bit = static_cast<uint32_t>(n >> 32);
if (high_32_bit == 0) {
return Log2Floor(static_cast<uint32_t>(n));
} else {
return 32 + Log2Floor(high_32_bit);
}
}
inline int Log2Ceil64(uint64_t n) {
int floor = Log2Floor64(n);
if (n == (n & ~(n - 1)))
return floor;
else
return floor + 1;
}
template <typename DataType, typename IndexType>
CUDA_LIB_EXPORT void CalSegmentCombination(DataType *inp_ptr, DataType *out_ptr, IndexType *seg_id_ptr,
size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream) {
// Get start position of each segment and set to segment_pos_ptr.
// The last element of segment_pos_ptr must equal to indices_size.
const unsigned int segment_size = outer_size + 1;
// size_t segment_pos_length[1] = {0};
CalSegmentPos<<<CUDA_BLOCKS(device_id, segment_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
seg_id_ptr, seg_pos_ptr, outer_size);
const unsigned int max_grid_x = (1u << 31) - 1;
const unsigned int max_grid_y = (1u << 16) - 1;
const unsigned int max_block_x = 1024;
const unsigned int max_block_y = 64;
unsigned int inner_power2 = 1u << Log2Ceil64(inner_size);
unsigned int avg_reduce_size = UP_DIV(outer_size, outer_size);
unsigned int avg_reduce_size_power2 = 1u << Log2Ceil64(avg_reduce_size);
unsigned int block_x = std::min(inner_power2, max_block_x);
unsigned int block_y = std::min(avg_reduce_size_power2, UP_DIV(max_block_y, block_x));
unsigned int grid_x = std::min(static_cast<unsigned int>(UP_DIV(inner_size, block_x)), max_grid_x);
unsigned int grid_y = std::min(segment_size, max_grid_y);
dim3 block(block_x, block_y);
dim3 grid(grid_x, grid_y);
unsigned int shared_memory_size = block_x * block_y * sizeof(DataType);
DataType init_K = std::numeric_limits<DataType>::lowest();
DataType default_value = get_default_value<DataType>(op);
switch (op) {
case 0:
init_K = min_val_init<DataType>();
return SegmentProcess<DataType, MaxFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
case 1:
init_K = max_val_init<DataType>();
return SegmentProcess<DataType, MinFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
case 2:
init_K = 0.0;
return SegmentProcess<DataType, AddFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
case 3:
init_K = 0.0;
return SegmentProcess<DataType, AddFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
case 4:
init_K = 1.0;
return SegmentProcess<DataType, MulFunc<DataType>><<<grid, block, shared_memory_size, cuda_stream>>>(
inp_ptr, out_ptr, seg_pos_ptr, inner_size, outer_size, outer_class, op, init_K, default_value, seg_id_ptr);
default:
break;
}
}
template CUDA_LIB_EXPORT void CalSegmentCombination<float, int32_t>(float *inp_ptr, float *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<double, int32_t>(double *inp_ptr, double *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<half, int32_t>(half *inp_ptr, half *out_ptr, int32_t *seg_id_addr,
size_t *seg_pos_ptr, size_t op,
const size_t inner_size, const size_t outer_size,
const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<float>, int32_t>(
Complex<float> *inp_ptr, Complex<float> *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<double>, int32_t>(
Complex<double> *inp_ptr, Complex<double> *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int8_t, int32_t>(int8_t *inp_ptr, int8_t *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int16_t, int32_t>(int16_t *inp_ptr, int16_t *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int32_t, int32_t>(int32_t *inp_ptr, int32_t *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int64_t, int32_t>(int64_t *inp_ptr, int64_t *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint8_t, int32_t>(uint8_t *inp_ptr, uint8_t *out_ptr,
int32_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint16_t, int32_t>(
uint16_t *inp_ptr, uint16_t *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint32_t, int32_t>(
uint32_t *inp_ptr, uint32_t *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint64_t, int32_t>(
uint64_t *inp_ptr, uint64_t *out_ptr, int32_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<float, int64_t>(float *inp_ptr, float *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<double, int64_t>(double *inp_ptr, double *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<half, int64_t>(half *inp_ptr, half *out_ptr, int64_t *seg_id_addr,
size_t *seg_pos_ptr, size_t op,
const size_t inner_size, const size_t outer_size,
const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<float>, int64_t>(
Complex<float> *inp_ptr, Complex<float> *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<Complex<double>, int64_t>(
Complex<double> *inp_ptr, Complex<double> *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op,
const size_t inner_size, const size_t outer_size, const size_t outer_class, uint32_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int8_t, int64_t>(int8_t *inp_ptr, int8_t *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int16_t, int64_t>(int16_t *inp_ptr, int16_t *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int32_t, int64_t>(int32_t *inp_ptr, int32_t *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<int64_t, int64_t>(int64_t *inp_ptr, int64_t *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint8_t, int64_t>(uint8_t *inp_ptr, uint8_t *out_ptr,
int64_t *seg_id_addr, size_t *seg_pos_ptr,
size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint16_t, int64_t>(
uint16_t *inp_ptr, uint16_t *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint32_t, int64_t>(
uint32_t *inp_ptr, uint32_t *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalSegmentCombination<uint64_t, int64_t>(
uint64_t *inp_ptr, uint64_t *out_ptr, int64_t *seg_id_addr, size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class, uint32_t device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SEGMENT_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SEGMENT_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename DataType, typename IndexType>
CUDA_LIB_EXPORT void CalSegmentCombination(DataType *inp_ptr, DataType *out_ptr, IndexType *seg_id_addr,
size_t *seg_pos_ptr, size_t op, const size_t inner_size,
const size_t outer_size, const size_t outer_class,
uint32_t device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SEGMENT_IMPL_CUH_

View File

@ -0,0 +1,395 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <utility>
#include "plugin/device/gpu/kernel/math/segment_ops_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr auto Segment_max = "SegmentMax";
constexpr auto Segment_min = "SegmentMin";
constexpr auto Segment_mean = "SegmentMean";
constexpr auto Segment_sum = "SegmentSum";
constexpr auto Segment_prod = "SegmentProd";
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSegmentOpsKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::SegmentOpsHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using SegmentOpsPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
template <typename T>
using Complex = mindspore::utils::Complex<T>;
const std::map<std::string, std::vector<std::pair<KernelAttr, SegmentOpsPtrCreatorFunc>>> kernel_attr_map = {
{Segment_max,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
{Segment_min,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
{Segment_mean,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CreateSegmentOpsKernelPtr<Complex<float>, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
CreateSegmentOpsKernelPtr<Complex<double>, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CreateSegmentOpsKernelPtr<Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
CreateSegmentOpsKernelPtr<Complex<double>, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
{Segment_sum,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CreateSegmentOpsKernelPtr<Complex<float>, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
CreateSegmentOpsKernelPtr<Complex<double>, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CreateSegmentOpsKernelPtr<Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
CreateSegmentOpsKernelPtr<Complex<double>, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}},
{Segment_prod,
{{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64),
CreateSegmentOpsKernelPtr<Complex<float>, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
CreateSegmentOpsKernelPtr<Complex<double>, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CreateSegmentOpsKernelPtr<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
CreateSegmentOpsKernelPtr<Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
CreateSegmentOpsKernelPtr<Complex<double>, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CreateSegmentOpsKernelPtr<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CreateSegmentOpsKernelPtr<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
CreateSegmentOpsKernelPtr<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
CreateSegmentOpsKernelPtr<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
CreateSegmentOpsKernelPtr<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CreateSegmentOpsKernelPtr<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
CreateSegmentOpsKernelPtr<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
CreateSegmentOpsKernelPtr<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
CreateSegmentOpsKernelPtr<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
CreateSegmentOpsKernelPtr<uint64_t, int64_t>}}}}; // kernel_attr_map_
} // namespace
bool SegmentOpsGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
constexpr size_t inputs_num = 2;
constexpr size_t outputs_num = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), inputs_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), outputs_num, kernel_name_);
kernel_name_ = base_operator->GetPrim()->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << kernel_name_ << " does not support this kernel data type: " << tensor_attr << ".";
return false;
}
helper_ptr_ = std::move(kernel_attr_map.at(kernel_type_)[index].second(kernel_name_, device_id_));
return true;
}
bool SegmentOpsGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
int SegmentOpsGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
for (const auto &output : outputs) {
auto output_shape = output->GetShapeVector();
if (!IsValidShape(output_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> inp_shape = inputs[0]->GetShapeVector();
std::vector<int64_t> segment_ids_shape = inputs[1]->GetShapeVector();
std::vector<int64_t> out_shape = outputs[0]->GetShapeVector();
input_shapes.emplace_back(inp_shape);
input_shapes.emplace_back(segment_ids_shape);
output_shapes.emplace_back(out_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
return KRET_OK;
}
std::vector<KernelAttr> SegmentOpsGpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map.find(kernel_type_);
if (iter == kernel_attr_map.end()) {
MS_LOG(ERROR) << "For 'SegmentOpsOp', only support these types: " << kernel::Map2Str(kernel_attr_map)
<< " currently, but got " << kernel_name_;
}
std::vector<KernelAttr> support_list;
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SegmentOpsPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentMax,
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_max); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentMin,
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_min); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentMean,
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_mean); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentSum,
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_sum); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SegmentProd,
[]() { return std::make_shared<SegmentOpsGpuKernelMod>(Segment_prod); });
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_SEGMENT_OPS_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_SEGMENT_OPS_GPU_KERNEL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/segment_ops_helper.h"
namespace mindspore {
namespace kernel {
class SegmentOpsGpuKernelMod : public NativeGpuKernelMod {
public:
explicit SegmentOpsGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~SegmentOpsGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::string kernel_type_{"Unknown"};
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -120,7 +120,7 @@ TypePtr SegmentMaxInferType(const PrimitivePtr &primitive, const std::vector<Abs
}
} // namespace
MIND_API_BASE_IMPL(SegmentMax, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(SegmentMax, BaseOperator);
AbstractBasePtr SegmentMaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {

View File

@ -123,7 +123,7 @@ TypePtr SegmentMeanInferType(const PrimitivePtr &primitive, const std::vector<Ab
}
} // namespace
MIND_API_BASE_IMPL(SegmentMean, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(SegmentMean, BaseOperator);
AbstractBasePtr SegmentMeanInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {

View File

@ -120,7 +120,7 @@ TypePtr SegmentMinInferType(const PrimitivePtr &primitive, const std::vector<Abs
}
} // namespace
MIND_API_BASE_IMPL(SegmentMin, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(SegmentMin, BaseOperator);
AbstractBasePtr SegmentMinInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {

View File

@ -123,7 +123,7 @@ TypePtr SegmentProdInferType(const PrimitivePtr &primitive, const std::vector<Ab
}
} // namespace
MIND_API_BASE_IMPL(SegmentProd, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(SegmentProd, BaseOperator);
AbstractBasePtr SegmentProdInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {

View File

@ -120,7 +120,7 @@ TypePtr SegmentSumInferType(const PrimitivePtr &primitive, const std::vector<Abs
}
} // namespace
MIND_API_BASE_IMPL(SegmentSum, PrimitiveC, BaseOperator);
MIND_API_OPERATOR_IMPL(SegmentSum, BaseOperator);
AbstractBasePtr SegmentSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {

View File

@ -7568,7 +7568,7 @@ class SegmentMax(Primitive):
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
@ -7618,7 +7618,7 @@ class SegmentMin(Primitive):
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
@ -7672,7 +7672,7 @@ class SegmentSum(Primitive):
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)
@ -7873,7 +7873,7 @@ class SegmentMean(Primitive):
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [1, 2, 3], [7, 8, 9]], mstype.float64)
@ -7927,7 +7927,7 @@ class SegmentProd(Primitive):
ValueError: If the values of `segment_ids` are not sorted in ascending order.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.float64)

View File

@ -0,0 +1,176 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.operations.array_ops import SegmentMax, SegmentMin, SegmentMean, SegmentSum, SegmentProd
from mindspore.nn import Cell
import mindspore.common.dtype as mstype
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
class SegmentMaxNet(Cell):
def __init__(self):
super().__init__()
self.segmentmax = SegmentMax()
def construct(self, x, segment_ids):
return self.segmentmax(x, segment_ids)
class SegmentMinNet(Cell):
def __init__(self):
super().__init__()
self.segmentmin = SegmentMin()
def construct(self, x, segment_ids):
return self.segmentmin(x, segment_ids)
class SegmentMeanNet(Cell):
def __init__(self):
super().__init__()
self.segmentmean = SegmentMean()
def construct(self, x, segment_ids):
return self.segmentmean(x, segment_ids)
class SegmentSumNet(Cell):
def __init__(self):
super().__init__()
self.segmentsum = SegmentSum()
def construct(self, x, segment_ids):
return self.segmentsum(x, segment_ids)
class SegmentProdNet(Cell):
def __init__(self):
super().__init__()
self.segmentprod = SegmentProd()
def construct(self, x, segment_ids):
return self.segmentprod(x, segment_ids)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_segment_max_fp():
"""
Feature: SegmentMax operator.
Description: test cases for SegmentMax operator.
Expectation: the result match expectation.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor([1, 2, 3], mstype.int32)
segment_ids = Tensor([0, 6, 6], mstype.int32)
net = SegmentMaxNet()
expect = np.array([1, 0, 0, 0, 0, 0, 3]).astype(np.int32)
output_gr = net(input_x, segment_ids).asnumpy()
np.testing.assert_array_almost_equal(output_gr, expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_py = net(input_x, segment_ids).asnumpy()
np.testing.assert_almost_equal(output_py, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_segment_min_fp():
"""
Feature: SegmentMin operator.
Description: test cases for SegmentMin operator.
Expectation: the result match expectation.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor([1, 2, 3, 4], mstype.int32)
segment_ids = Tensor([0, 0, 1, 5], mstype.int32)
net = SegmentMinNet()
expect = np.array([1, 3, 0, 0, 0, 4]).astype(np.int32)
output_gr = net(input_x, segment_ids).asnumpy()
np.testing.assert_array_almost_equal(output_gr, expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_py = net(input_x, segment_ids).asnumpy()
np.testing.assert_almost_equal(output_py, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_segment_sum_fp():
"""
Feature: SegmentSum operator.
Description: test cases for SegmentSum operator.
Expectation: the result match expectation.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor([1 + 2j, 2 + 2j, 3 + 2j], mstype.float32)
segment_ids = Tensor([0, 0, 2], mstype.int32)
net = SegmentSumNet()
expect = np.array([3 + 4j, 0, 3 + 2j]).astype(np.float32)
output_gr = net(input_x, segment_ids).asnumpy()
np.testing.assert_array_almost_equal(output_gr, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_segment_mean_fp():
"""
Feature: SegmentMean operator.
Description: test cases for SegmentMean operator.
Expectation: the result match expectation.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor([2, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
net = SegmentMeanNet()
expect = np.array([2, 3, 4]).astype(np.float32)
output_gr = net(input_x, segment_ids).asnumpy()
np.testing.assert_array_almost_equal(output_gr, expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_py = net(input_x, segment_ids).asnumpy()
np.testing.assert_almost_equal(output_py, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_segment_prod_fp():
"""
Feature: SegmentProd operator.
Description: test cases for SegmentProd operator.
Expectation: the result match expectation.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x = Tensor([1, 2, 3, 4], mstype.float32)
segment_ids = Tensor([0, 0, 1, 2], mstype.int32)
net = SegmentProdNet()
expect = np.array([2, 3, 4]).astype(np.float32)
output_gr = net(input_x, segment_ids).asnumpy()
np.testing.assert_array_almost_equal(output_gr, expect)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_py = net(input_x, segment_ids).asnumpy()
np.testing.assert_almost_equal(output_py, expect)