!33422 Add GPU implementation of Cummin operator.

Merge pull request !33422 from hezhenhao1/add_cumop
This commit is contained in:
i-robot 2022-04-25 01:22:55 +00:00 committed by Gitee
commit dfc0d8c43a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 550 additions and 1 deletions

View File

@ -0,0 +1,190 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cum_op_impl.cuh"
#include <thrust/transform.h>
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <algorithm>
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
template <typename T>
__device__ bool IsNan(const T &x) {
return isnan(x);
}
__device__ bool IsNan(const half &x) { return __hisnan(x); }
template <typename T, typename OP>
struct binary_op {
const T *input_ptr_;
size_t axis_inner_size_;
size_t axis_size_;
size_t inner_size_;
OP op;
__thrust_exec_check_disable__ __device__ size_t operator()(const size_t &lhs, const size_t &rhs) const {
if (rhs % axis_size_) {
size_t batch_idx = rhs / axis_size_;
size_t axis_idx = rhs - batch_idx * axis_size_;
size_t outer_idx = batch_idx / inner_size_;
size_t inner_idx = batch_idx - outer_idx * inner_size_;
size_t fix_part = outer_idx * axis_inner_size_ + inner_idx;
size_t lhs_idx = fix_part + lhs * inner_size_;
size_t rhs_idx = fix_part + axis_idx * inner_size_;
return IsNan(input_ptr_[lhs_idx]) || op(input_ptr_[lhs_idx], input_ptr_[rhs_idx]) ? lhs : axis_idx;
} else {
return 0;
}
}
};
template <typename T, typename S>
__global__ void DecodeKernel(const T *input_ptr, const size_t *workspace_ptr, T *value_ptr, S *index_ptr,
size_t element_size, size_t axis_inner_size, size_t axis_size, size_t inner_size) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < element_size; tid += blockDim.x * gridDim.x) {
size_t batch_idx = tid / axis_size;
size_t axis_idx = tid - batch_idx * axis_size;
size_t outer_idx = batch_idx / inner_size;
size_t inner_idx = batch_idx - outer_idx * inner_size;
size_t fix_part = outer_idx * axis_inner_size + inner_idx;
size_t real_idx = fix_part + axis_idx * inner_size;
size_t cum_idx = fix_part + workspace_ptr[tid] * inner_size;
value_ptr[real_idx] = input_ptr[cum_idx];
index_ptr[real_idx] = workspace_ptr[tid];
}
}
template <typename T, typename S>
void CumOp(enum CumOpType op_type, const T *input_ptr, size_t *workspace_ptr, T *value_ptr, S *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size, cudaStream_t cuda_stream) {
// Cummin/Cummax cuda algorithm:
// 1. Generate a sequence from 0 to element_size-1;
// 2. Using thrust:inclusive_scan to get the cumulative maximum/minimum result of transposed array.
// Note that 1. Segmentation of array is done within binary_op of inclusive_scan;
// 2. it's not necessary to directly transpose the original array, but using the mapping rule;
// 3. Restore the transposed array using DecodeKernel, and also with the help of mapping rule.
auto device = thrust::cuda::par.on(cuda_stream);
auto thrust_ptr = thrust::device_pointer_cast(workspace_ptr);
thrust::sequence(device, thrust_ptr, thrust_ptr + element_size);
auto axis_inner_size = axis_size * inner_size;
switch (op_type) {
case CUMMIN: {
binary_op<T, thrust::less<T>> op{input_ptr, axis_inner_size, axis_size, inner_size};
thrust::inclusive_scan(device, thrust_ptr, thrust_ptr + element_size, thrust_ptr, op);
break;
}
case CUMMAX: {
binary_op<T, thrust::greater<T>> op{input_ptr, axis_inner_size, axis_size, inner_size};
thrust::inclusive_scan(device, thrust_ptr, thrust_ptr + element_size, thrust_ptr, op);
break;
}
default:
break;
}
DecodeKernel<<<GET_BLOCKS(element_size), GET_THREADS, 0, cuda_stream>>>(
input_ptr, workspace_ptr, value_ptr, index_ptr, element_size, axis_inner_size, axis_size, inner_size);
}
template CUDA_LIB_EXPORT void CumOp<int8_t, int32_t>(enum CumOpType op_type, const int8_t *input_ptr,
size_t *workspace_ptr, int8_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int16_t, int32_t>(enum CumOpType op_type, const int16_t *input_ptr,
size_t *workspace_ptr, int16_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int32_t, int32_t>(enum CumOpType op_type, const int32_t *input_ptr,
size_t *workspace_ptr, int32_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int64_t, int32_t>(enum CumOpType op_type, const int64_t *input_ptr,
size_t *workspace_ptr, int64_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint8_t, int32_t>(enum CumOpType op_type, const uint8_t *input_ptr,
size_t *workspace_ptr, uint8_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint16_t, int32_t>(enum CumOpType op_type, const uint16_t *input_ptr,
size_t *workspace_ptr, uint16_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint32_t, int32_t>(enum CumOpType op_type, const uint32_t *input_ptr,
size_t *workspace_ptr, uint32_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint64_t, int32_t>(enum CumOpType op_type, const uint64_t *input_ptr,
size_t *workspace_ptr, uint64_t *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<half, int32_t>(enum CumOpType op_type, const half *input_ptr, size_t *workspace_ptr,
half *value_ptr, int32_t *index_ptr, size_t element_size,
size_t axis_size, size_t inner_size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<float, int32_t>(enum CumOpType op_type, const float *input_ptr,
size_t *workspace_ptr, float *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<double, int32_t>(enum CumOpType op_type, const double *input_ptr,
size_t *workspace_ptr, double *value_ptr, int32_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int8_t, int64_t>(enum CumOpType op_type, const int8_t *input_ptr,
size_t *workspace_ptr, int8_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int16_t, int64_t>(enum CumOpType op_type, const int16_t *input_ptr,
size_t *workspace_ptr, int16_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int32_t, int64_t>(enum CumOpType op_type, const int32_t *input_ptr,
size_t *workspace_ptr, int32_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<int64_t, int64_t>(enum CumOpType op_type, const int64_t *input_ptr,
size_t *workspace_ptr, int64_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint8_t, int64_t>(enum CumOpType op_type, const uint8_t *input_ptr,
size_t *workspace_ptr, uint8_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint16_t, int64_t>(enum CumOpType op_type, const uint16_t *input_ptr,
size_t *workspace_ptr, uint16_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint32_t, int64_t>(enum CumOpType op_type, const uint32_t *input_ptr,
size_t *workspace_ptr, uint32_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<uint64_t, int64_t>(enum CumOpType op_type, const uint64_t *input_ptr,
size_t *workspace_ptr, uint64_t *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<half, int64_t>(enum CumOpType op_type, const half *input_ptr, size_t *workspace_ptr,
half *value_ptr, int64_t *index_ptr, size_t element_size,
size_t axis_size, size_t inner_size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<float, int64_t>(enum CumOpType op_type, const float *input_ptr,
size_t *workspace_ptr, float *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CumOp<double, int64_t>(enum CumOpType op_type, const double *input_ptr,
size_t *workspace_ptr, double *value_ptr, int64_t *index_ptr,
size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUM_OP_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUM_OP_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
enum CumOpType { CUMMIN = 0, CUMMAX, CUM_OP_INVALID_TYPE = 255 };
template <typename T, typename S>
CUDA_LIB_EXPORT void CumOp(enum CumOpType op_type, const T *input_ptr, size_t *workspace_ptr, T *value_ptr,
S *index_ptr, size_t element_size, size_t axis_size, size_t inner_size,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CUM_OP_IMPL_CUH_

View File

@ -0,0 +1,175 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/cum_op_gpu_kernel.h"
#include <functional>
#include <algorithm>
#include "mindspore/core/abstract/utils.h"
#include "kernel/common_utils.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr int kCumInputsNum = 1;
constexpr int kCumOutputsNum = 2;
constexpr char AXIS[] = "axis";
static const std::map<std::string, CumOpType> kCumOpTypeMap = {
{"Cummin", CUMMIN},
};
} // namespace
void CumOpGpuKernelMod::ResetResource() noexcept {
inner_size_ = 1;
outer_size_ = 1;
axis_size_ = 1;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
bool CumOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->GetPrim()->name();
if (kernel_name_ != kernel_type_) {
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << ", but got kernel name as " << kernel_name_;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCumInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCumOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
}
auto iter = kCumOpTypeMap.find(kernel_name_);
if (iter == kCumOpTypeMap.end()) {
MS_LOG(EXCEPTION) << "Only support these cum operators: " << Map2Str(kCumOpTypeMap) << " currently, but got "
<< kernel_name_;
}
cum_op_type_ = iter->second;
t_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).first);
s_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex1).first);
kernel_func_ = func_list_[index].second;
return true;
}
bool CumOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
ResetResource();
std::vector<int64_t> input_shape = inputs[kIndex0]->GetShapeVector();
auto rank = SizeToLong(input_shape.size());
auto axis_input = GetValue<int64_t>(base_operator->GetAttr(AXIS));
auto axis = axis_input < 0 ? LongToSize(axis_input + rank) : LongToSize(axis_input);
for (size_t i = 0; i < input_shape.size(); i++) {
if (i < axis) {
outer_size_ *= input_shape.at(i);
} else if (i > axis) {
inner_size_ *= input_shape.at(i);
} else {
axis_size_ = input_shape.at(i);
}
}
element_size_ = outer_size_ * inner_size_ * axis_size_;
if (!element_size_) {
return true;
}
input_size_list_.push_back(element_size_ * t_size_);
output_size_list_.push_back(element_size_ * t_size_);
output_size_list_.push_back(element_size_ * s_size_);
workspace_size_list_.push_back(element_size_ * sizeof(size_t));
return true;
}
template <typename T, typename S>
bool CumOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (!element_size_) {
return true;
}
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
auto input_ptr = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
auto value_ptr = reinterpret_cast<T *>(outputs.at(kIndex0)->addr);
auto index_ptr = reinterpret_cast<S *>(outputs.at(kIndex1)->addr);
auto workspace_ptr = reinterpret_cast<size_t *>(workspace.at(kIndex0)->addr);
CumOp(cum_op_type_, input_ptr, workspace_ptr, value_ptr, index_ptr, element_size_, axis_size_, inner_size_,
cuda_stream);
return true;
}
std::vector<std::pair<KernelAttr, CumOpGpuKernelMod::CumOpLaunchFunc>> CumOpGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<half, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
&CumOpGpuKernelMod::LaunchKernel<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<uint64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<half, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
&CumOpGpuKernelMod::LaunchKernel<double, int64_t>}};
std::vector<KernelAttr> CumOpGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CumOpGpuKernelMod::CumOpLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_WITH_NAME_PARAM(NativeGpuKernelMod, Cummin, CumOpGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_OP_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_OP_GPU_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <map>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cum_op_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
constexpr auto kUnKnown = "UnKnown";
class CumOpGpuKernelMod : public NativeGpuKernelMod {
public:
CumOpGpuKernelMod() = default;
explicit CumOpGpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~CumOpGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
std::vector<KernelAttr> GetOpSupport() override;
private:
void ResetResource() noexcept;
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using CumOpLaunchFunc = std::function<bool(CumOpGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, CumOpLaunchFunc>> func_list_;
CumOpType cum_op_type_;
CumOpLaunchFunc kernel_func_;
size_t t_size_{0}; // Equal to sizeof(T).
size_t s_size_{0}; // Equal to sizeof(S).
size_t inner_size_{1};
size_t outer_size_{1};
size_t axis_size_{1};
size_t element_size_{1};
std::string kernel_type_{kUnKnown};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_CUM_OP_GPU_KERNEL_H_

View File

@ -885,7 +885,7 @@ def cummin(x, axis):
ValueError:If 'axis' is out the range of [-len(`input_x`.shape) to len(`input_x`.shape) - 1]
Supported Platforms:
``Ascend``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> from mindspore import Tensor, ops

View File

@ -0,0 +1,81 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.context as context
import mindspore.ops as ops
from mindspore import Tensor
def cummin_compare(x, expected, axis, data_type):
x = np.array(x).astype(data_type)
expected = (np.array(expected[0]).astype(data_type), np.array(expected[1]).astype(data_type))
# Pynative
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output = ops.cummin(Tensor(x), axis=axis)
assert np.allclose(output[0].asnumpy(), expected[0], equal_nan=True)
assert np.allclose(output[1].asnumpy(), expected[1])
# Graph
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
output = ops.cummin(Tensor(x), axis=axis)
assert np.allclose(output[0].asnumpy(), expected[0], equal_nan=True)
assert np.allclose(output[1].asnumpy(), expected[1])
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize("data_type", [np.int32, np.float16, np.float32])
def test_cummin_multi_dims(data_type):
"""
Feature: Op Cummin
Description: test Cummin operator with multiple dimension.
Expectation: the result match expectation.
"""
axis = 1
x = [[[14, 19, 18, 11, 6], [1, 4, 18, 6, 1], [15, 13, 12, 9, 19]],
[[16, 16, 17, 10, 15], [9, 7, 10, 9, 4], [6, 14, 16, 3, 2]],
[[1, 13, 15, 1, 6], [20, 6, 8, 19, 19], [3, 14, 20, 18, 19]],
[[20, 1, 14, 9, 3], [13, 11, 2, 17, 14], [0, 15, 13, 7, 10]]]
cummin_output = (
[[[14, 19, 18, 11, 6], [1, 4, 18, 6, 1], [1, 4, 12, 6, 1]],
[[16, 16, 17, 10, 15], [9, 7, 10, 9, 4], [6, 7, 10, 3, 2]],
[[1, 13, 15, 1, 6], [1, 6, 8, 1, 6], [1, 6, 8, 1, 6]], [[20, 1, 14, 9, 3], [13, 1, 2, 9, 3], [0, 1, 2, 7, 3]]],
[[[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [1, 1, 2, 1, 1]], [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 1, 1, 2, 2]],
[[0, 0, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 1, 0, 0]], [[0, 0, 0, 0, 0], [1, 0, 1, 0, 0], [2, 0, 1, 2, 0]]])
cummin_compare(x, cummin_output, axis, data_type)
@pytest.mark.level0
@pytest.mark.env_onecard
@pytest.mark.platform_x86_gpu_training
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
def test_cummin_nan(data_type):
"""
Feature: Op Cummin
Description: test Cummin operator with nan input.
Expectation: the result match expectation.
"""
inf = float('inf')
nan = float('nan')
axis = 0
x = [4, inf, 1.5, -inf, 0, nan, 1]
cummin_output = ([4, 4, 1.5, -inf, -inf, nan, nan], [0, 0, 2, 3, 3, 5, 5])
cummin_compare(x, cummin_output, axis, data_type)