fix kernelTensor adaptation pooling op gpu bugs

pooling op cpu bugs
This commit is contained in:
lilinjie 2023-10-30 14:18:47 +08:00 committed by chenfei
parent fad3cc3031
commit aee4ae3afc
24 changed files with 153 additions and 78 deletions

View File

@ -294,6 +294,8 @@ mindspore/mindspore/core/ops/sparse_to_dense_v2.cc:mindspore::ops::SparseToDense
mindspore/mindspore/core/ops/ops_func_impl/div.cc:mindspore::ops::DivFrontendFuncImpl::InferValue
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/cumprod_gpu_kernel.cc:mindspore::kernel::CumProdGpuKernelMod::GetFuncList
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/cumprod_cpu_kernel.cc:mindspore::kernel::CumProdCpuKernelMod::GetFuncList
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/pooling_cpu_kernel.cc:mindspore::kernel::PoolingCpuKernelMod::InitPoolingFields
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/pooling_cpu_kernel_nnacl.cc:mindspore::kernel::PoolingCpuKernelNnaclMod::Resize
# AICPU migration
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/bias_add_grad.cc:aicpu::BiasAddGradCpuKernel::BiasAddGradCompute

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/max_pool_grad_with_argmax_cpu_kernel.h"
#include <algorithm>
#include <string>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/grad/max_pool_grad_with_argmax.h"
@ -39,7 +40,8 @@ bool MaxPoolGradWithArgmaxCpuKernelMod::Init(const std::vector<KernelTensor *> &
"but got the window height: "
<< stride_height_ << ", and the window width: " << stride_height_;
}
pad_mode_ = PadMode(GetValue<int64_t>(primitive_->GetAttr(ops::kPadMode)));
pad_mode_ =
static_cast<mindspore::PadMode>(ops::PadModeStringToInt(GetValue<std::string>(primitive_->GetAttr(ops::kPadMode))));
// pair = [is_match, index]
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto pair = MatchKernelAttr(kernel_attr, GetOpSupport());

View File

@ -23,6 +23,7 @@
#include "mindspore/core/ops/grad/max_pool_grad_with_argmax.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "mindspore/core/ops/op_utils.h"
namespace mindspore {
namespace kernel {

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/max_pool_with_argmax_cpu_kernel.h"
#include <algorithm>
#include <string>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "mindspore/core/ops/max_pool_with_argmax.h"
@ -31,7 +32,8 @@ constexpr int kPadHalf = 2;
bool MaxPoolWithArgmaxCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
const std::vector<KernelTensor *> &outputs) {
data_format_ = Format(GetValue<int64_t>(primitive_->GetAttr(ops::kFormat)));
data_format_ =
static_cast<mindspore::Format>(ops::FormatStringToInt(GetValue<std::string>(primitive_->GetAttr(ops::kFormat))));
auto kernel_size = GetValue<std::vector<int64_t>>(primitive_->GetAttr(ops::kKernelSize));
auto strides = GetValue<std::vector<int64_t>>(primitive_->GetAttr(ops::kStrides));
if (kernel_size.size() < kIndex3 || strides.size() < kIndex3) {
@ -58,7 +60,8 @@ bool MaxPoolWithArgmaxCpuKernelMod::Init(const std::vector<KernelTensor *> &inpu
"but got the window height: "
<< window_height_ << ", and the window width: " << window_width_;
}
pad_mode_ = PadMode(GetValue<int64_t>(primitive_->GetAttr(ops::kPadMode)));
pad_mode_ =
static_cast<mindspore::PadMode>(ops::PadModeStringToInt(GetValue<std::string>(primitive_->GetAttr(ops::kPadMode))));
if (pad_mode_ == PadMode::SAME) {
int tmp_height = (input_height_ / stride_height_) * stride_height_ == input_height_
? (input_height_ / stride_height_)

View File

@ -23,6 +23,7 @@
#include "mindspore/core/ops/max_pool_with_argmax.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
#include "mindspore/core/ops/op_utils.h"
namespace mindspore {
namespace kernel {

View File

@ -50,7 +50,8 @@ bool MaxPoolGradGradCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs
kernels_ = GetValue<std::vector<int64_t>>(primitive_->GetAttr(ops::kKernelSize));
strides_ = GetValue<std::vector<int64_t>>(primitive_->GetAttr(ops::kStrides));
pad_mode_ = PadMode(GetValue<int64_t>(primitive_->GetAttr(ops::kPadMode)));
pad_mode_ =
static_cast<mindspore::PadMode>(ops::PadModeStringToInt(GetValue<std::string>(primitive_->GetAttr(ops::kPadMode))));
if (pad_mode_ != PadMode::SAME && pad_mode_ != PadMode::VALID) {
MS_LOG(ERROR) << kernel_name_ << " only support pad mode same or valid, but get " << pad_mode_;
return false;

View File

@ -28,6 +28,7 @@
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/nnacl/pooling_parameter.h"
#include "plugin/device/cpu/kernel/nnacl/kernel/pooling.h"
#include "mindspore/core/ops/op_utils.h"
namespace mindspore {
namespace kernel {

View File

@ -31,13 +31,26 @@ constexpr size_t kPoolingOutputsNum = 1;
void PoolingCpuKernelMod::InitPoolingFields(const std::vector<KernelTensor *> &inputs,
const std::vector<KernelTensor *> &outputs) {
dtype_ = inputs[0]->dtype_id();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kPoolingInputsNum, kernel_name_);
if (kernel_name_ == kAvgPoolOpName) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 5, kernel_name_);
} else {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kPoolingOutputsNum, kernel_name_);
format_ = GetValue<std::string>(KernelMod::primitive_->GetAttr(FORMAT));
pad_mode = static_cast<mindspore::PadMode>(GetValue<int64_t>(KernelMod::primitive_->GetAttr(PAD_MODE)));
kernel_include_nc = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(KERNEL_SIZE));
strides_include_nc = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(STRIDES));
if (kernel_name_ == kAvgPoolOpName) {
kernel_include_nc = inputs[1]->GetValue<std::vector<int64_t>>().value();
strides_include_nc = inputs[2]->GetValue<std::vector<int64_t>>().value();
pad_mode_ = inputs[3]->GetValue<PadMode>().value();
format_ = inputs[4]->GetValue<Format>().value();
} else {
kernel_include_nc = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(KERNEL_SIZE));
strides_include_nc = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(STRIDES));
pad_mode_ = static_cast<mindspore::PadMode>(
ops::PadModeStringToInt(GetValue<std::string>(KernelMod::primitive_->GetAttr(PAD_MODE))));
format_ = static_cast<mindspore::Format>(
ops::PadModeStringToInt(GetValue<std::string>(KernelMod::primitive_->GetAttr(FORMAT))));
}
if (KernelMod::primitive_->HasAttr(CEIL_MODE)) {
ValuePtr ceil_mode = KernelMod::primitive_->GetAttr(CEIL_MODE);
@ -45,14 +58,14 @@ void PoolingCpuKernelMod::InitPoolingFields(const std::vector<KernelTensor *> &i
(ceil_mode->isa<Int64Imm>() && GetValue<int64_t>(ceil_mode) == 1);
}
if (kernel_name_ == kAvgPool3DOpName && (pad_mode == mindspore::PadMode::PAD) &&
if (kernel_name_ == kAvgPool3DOpName && (pad_mode_ == mindspore::PadMode::PAD) &&
KernelMod::primitive_->HasAttr(DIVISOR_OVERRIDE) &&
GetValue<int64_t>(KernelMod::primitive_->GetAttr(DIVISOR_OVERRIDE)) != 0 &&
KernelMod::primitive_->HasAttr(COUNT_INCLUDE_PAD) &&
!GetValue<bool>(KernelMod::primitive_->GetAttr(COUNT_INCLUDE_PAD))) {
auto pad = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(PAD_LIST));
if (std::any_of(pad.begin(), pad.end(), [](int64_t pad) { return pad > 0; })) {
MS_LOG(EXCEPTION) << kernel_name_ << "does not support the scenes while padmode == " << pad_mode
MS_LOG(EXCEPTION) << kernel_name_ << "does not support the scenes while padmode == " << pad_mode_
<< " && padding > 0 && count_include_pad == False && divisor_override != None";
}
}
@ -110,7 +123,8 @@ int PoolingCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const
dnnl::memory::dims padding_l;
dnnl::memory::dims padding_r;
kernel_ = kernel;
PaddingInfo padding_info{pad_mode, kernel_, strides, dilation, &padding_l, &padding_r, &padding_invalid_, ceil_mode_};
PaddingInfo padding_info{pad_mode_, kernel_, strides, dilation,
&padding_l, &padding_r, &padding_invalid_, ceil_mode_};
GetPadding(src_shape, padding_info);
const auto desc = CreateDesc<dnnl::pooling_forward::desc>(dnnl::prop_kind::forward_inference, algorithm_, src_desc,
dst_desc, strides, kernel, padding_l, padding_r);
@ -196,7 +210,14 @@ void PoolingCpuKernelMod::ReComputeDivisor(T *dst) {
std::vector<KernelAttr> PoolingCpuKernelMod::GetOpSupport() {
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
{kMaxPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName,
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeTuple, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}}};
auto iter = support_list_map.find(kernel_type_);

View File

@ -23,7 +23,7 @@
#include <unordered_map>
#include <map>
#include <string>
#include "mindspore/core/ops/op_utils.h"
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
@ -59,7 +59,7 @@ class PoolingCpuKernelMod : public MKLCpuKernelMod {
std::vector<int64_t> kernel_;
std::vector<int64_t> padding_invalid_;
std::string format_;
mindspore::PadMode pad_mode;
mindspore::PadMode pad_mode_;
std::vector<int64_t> kernel_include_nc{};
std::vector<int64_t> strides_include_nc{};
std::map<uint32_t, tensor::TensorPtr> inputs_on_host_{};

View File

@ -88,7 +88,11 @@ void PoolingCpuKernelNnaclMod::InitPooling3DParams() {
bool PoolingCpuKernelNnaclMod::Init(const std::vector<KernelTensor *> &inputs,
const std::vector<KernelTensor *> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 5, kernel_name_);
if (kernel_name_ == kAvgPoolOpName) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 5, kernel_name_);
} else {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
kernel_name_ = primitive()->name();
if (kernel_name_ == kAvgPool3DOpName || kernel_name_ == kAvgPoolOpName) {
@ -96,14 +100,24 @@ bool PoolingCpuKernelNnaclMod::Init(const std::vector<KernelTensor *> &inputs,
} else if (kernel_name_ == kMaxPool3DOpName || kernel_name_ == kMaxPoolOpName) {
pool_mode_ = MAX_POOLING;
} else {
MS_LOG(ERROR) << "Pooling only supports Avg or Max, but got " << kernel_name_;
MS_LOG(ERROR) << "Pooling only supports Avg or Max, but got: " << kernel_name_ << ".";
return false;
}
dtype_ = inputs[0]->dtype_id();
kernel_size_ = inputs[1]->GetValue<std::vector<int64_t>>().value();
stride_size_ = inputs[2]->GetValue<std::vector<int64_t>>().value();
pad_mode_ = inputs[3]->GetValue<PadMode>().value();
format_ = inputs[4]->GetValue<Format>().value();
if (kernel_name_ == kAvgPoolOpName) {
kernel_size_ = inputs[1]->GetValue<std::vector<int64_t>>().value();
stride_size_ = inputs[2]->GetValue<std::vector<int64_t>>().value();
pad_mode_ = inputs[3]->GetValue<PadMode>().value();
format_ = inputs[4]->GetValue<Format>().value();
} else {
kernel_size_ = GetValue<std::vector<int64_t>>(primitive()->GetAttr(KERNEL_SIZE));
stride_size_ = GetValue<std::vector<int64_t>>(primitive()->GetAttr(STRIDES));
pad_mode_ =
static_cast<mindspore::PadMode>(ops::PadModeStringToInt(GetValue<std::string>(primitive()->GetAttr(PAD_MODE))));
format_ =
static_cast<mindspore::Format>(ops::FormatStringToInt(GetValue<std::string>(primitive()->GetAttr(FORMAT))));
}
if (primitive()->HasAttr(COUNT_INCLUDE_PAD)) {
count_include_pad_ = GetValue<bool>(primitive()->GetAttr(COUNT_INCLUDE_PAD));
}
@ -132,23 +146,39 @@ int PoolingCpuKernelNnaclMod::Resize(const std::vector<KernelTensor *> &inputs,
// kernel_size_/stride_size_/pad_list_ in 4D will be extended to 5D later, so here we need to reset
// kernel_size_/stride_size_/pad_list_ before the extending.
if (src_dim == SHAPE_4D) {
auto kernel_size = inputs[1]->GetValue<std::vector<int64_t>>().value();
auto stride_size = inputs[2]->GetValue<std::vector<int64_t>>().value();
std::vector<int64_t> kernel_size;
std::vector<int64_t> stride_size;
if (kernel_name_ == kAvgPoolOpName) {
kernel_size = inputs[1]->GetValue<std::vector<int64_t>>().value();
stride_size = inputs[2]->GetValue<std::vector<int64_t>>().value();
} else {
kernel_size = GetValue<std::vector<int64_t>>(primitive()->GetAttr(KERNEL_SIZE));
stride_size = GetValue<std::vector<int64_t>>(primitive()->GetAttr(STRIDES));
}
size_t NC_LEN = 0;
if (kernel_name_ != kAvgPoolOpName) {
NC_LEN += 2;
}
constexpr auto kernel_size_len = 2;
if (kernel_size.size() != kernel_size_len) {
if (kernel_size.size() != kernel_size_len + NC_LEN) {
MS_LOG(INTERNAL_EXCEPTION) << "Unexpected kernel size length:" << kernel_size.size();
}
constexpr auto stride_size_len = 2;
if (stride_size.size() != stride_size_len) {
if (stride_size.size() != stride_size_len + NC_LEN) {
MS_LOG(INTERNAL_EXCEPTION) << "Unexpected stride size length:" << stride_size.size();
}
// change kernel size and strides from (H, W) to (1, 1, H, W)
kernel_size_ = {1, 1};
kernel_size_.emplace_back(kernel_size[0]);
kernel_size_.emplace_back(kernel_size[1]);
stride_size_ = {1, 1};
stride_size_.emplace_back(stride_size[0]);
stride_size_.emplace_back(stride_size[1]);
if (kernel_name_ == kAvgPoolOpName) {
kernel_size_ = {1, 1};
kernel_size_.emplace_back(kernel_size[0]);
kernel_size_.emplace_back(kernel_size[1]);
stride_size_ = {1, 1};
stride_size_.emplace_back(stride_size[0]);
stride_size_.emplace_back(stride_size[1]);
} else {
kernel_size_ = kernel_size;
stride_size_ = stride_size;
}
pad_list_.clear();
}
@ -393,7 +423,11 @@ bool PoolingCpuKernelNnaclMod::LaunchKernel(const std::vector<kernel::KernelTens
bool PoolingCpuKernelNnaclMod::Launch(const std::vector<kernel::KernelTensor *> &inputs,
const std::vector<kernel::KernelTensor *> &workspaces,
const std::vector<kernel::KernelTensor *> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 5, kernel_name_);
if (kernel_name_ == kAvgPoolOpName) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 5, kernel_name_);
} else {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
if (use_channel_last_) {

View File

@ -23,7 +23,7 @@
#include <unordered_map>
#include <map>
#include <string>
#include "mindspore/core/ops/op_utils.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/nnacl/kernel/pooling.h"
#include "plugin/device/cpu/kernel/nnacl/pooling_parameter.h"

View File

@ -27,6 +27,10 @@ namespace kernel {
namespace {
constexpr size_t kMaxPoolingGradWorkSpaceNum = 2;
constexpr size_t kAvgPoolingGradWorkSpaceNum = 1;
constexpr size_t kPoolingGradInputsNum = 3;
constexpr size_t kAvgPooling3DGradInputsNum = 1;
constexpr size_t kAvgPooling3DGradDynamicInputsNum = 2;
constexpr size_t kPoolingGradOutputsNum = 1;
// avgpoolgrad and maxpoolgrad input indexes
constexpr size_t kGradIndex = 2;
constexpr size_t kKernelSizeIdx = 3;
@ -60,7 +64,7 @@ constexpr size_t kMax3DDataFormatIdx = 8;
bool PoolingGradCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
const std::vector<KernelTensor *> &outputs) {
if (kernel_name_ != kAvgPoolGradOpName) {
if (kernel_name_ == kAvgPool3DGradOpName || kernel_name_ == kMaxPool3DGradOpName) {
if (KernelMod::primitive_->HasAttr(CEIL_MODE)) {
ValuePtr ceil_mode = KernelMod::primitive_->GetAttr(CEIL_MODE);
ceil_mode_ = (ceil_mode->isa<BoolImm>() && GetValue<bool>(ceil_mode)) ||
(ceil_mode->isa<Int64Imm>() && GetValue<int64_t>(ceil_mode) == 1);
@ -76,7 +80,6 @@ bool PoolingGradCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
divisor_override_ = GetValue<int64_t>(KernelMod::primitive_->GetAttr(DIVISOR_OVERRIDE));
}
}
algorithm_ = dnnl::algorithm::pooling_avg;
grad_index_ = kernel_name_ == kAvgPool3DGradOpName ? 1 : kGradIndex;
format_ = static_cast<mindspore::Format>(
ops::FormatStringToInt(GetValue<std::string>(KernelMod::primitive_->GetAttr(FORMAT))));
@ -84,12 +87,12 @@ bool PoolingGradCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
ops::PadModeStringToInt(GetValue<std::string>(KernelMod::primitive_->GetAttr(PAD_MODE))));
kernel_include_nc_ = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(KERNEL_SIZE));
strides_include_nc_ = GetValue<std::vector<int64_t>>(KernelMod::primitive_->GetAttr(STRIDES));
dtype_ = inputs[grad_index_]->GetDtype();
dtype_ = inputs[grad_index_]->dtype_id();
return true;
}
grad_index_ = kGradIndex;
dtype_ = inputs[grad_index_]->dtype_id();
// avgpoolgrad input
// AvgPoolGrad input
algorithm_ = dnnl::algorithm::pooling_avg;
pad_mode_ = static_cast<mindspore::PadMode>(inputs[kPadModeIdx]->GetValueWithCheck<int64_t>());
kernel_include_nc_ = inputs[kKernelSizeIdx]->GetValueWithCheck<std::vector<int64_t>>();
@ -318,7 +321,6 @@ bool PoolingGradCpuKernelMod::Launch(const std::vector<kernel::KernelTensor *> &
ExecutePrimitive();
return true;
}
if (dtype_ == kNumberTypeFloat32) {
return LaunchKernel<float>(inputs, workspace, outputs);
} else if (dtype_ == kNumberTypeFloat16) {

View File

@ -34,9 +34,6 @@ namespace mindspore {
namespace kernel {
constexpr auto kNumberThree = 3;
constexpr auto kNumberTwo = 2;
constexpr auto kAvgPool = "AvgPool";
constexpr auto kAvgPool3D = "AvgPool3D";
constexpr auto kMaxPool3D = "MaxPool3D";
constexpr auto kKernelSizeIdx = 1;
constexpr auto kStridesIdx = 2;
constexpr auto kPadModeIdx = 3;
@ -130,19 +127,19 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
InitResource();
size_t format_index = kDataFormatIdx;
if (kernel_name_ == kAvgPool3D) {
if (kernel_name_ == kAvgPool3DOpName) {
divisor_override_ = GetValue<int64_t>(primitive_->GetAttr("divisor_override"));
ceil_mode_ = GetValue<bool>(primitive_->GetAttr("ceil_mode"));
AvgPool3DPadListCheck(inputs);
format_index = kFormatAvg3DIdx;
}
if (kernel_name_ == kMaxPool3D) {
if (kernel_name_ == kMaxPool3DOpName) {
format_index = kFormatMax3DIdx;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs[0]->dtype_id()));
data_format_ = inputs[0]->format();
mindspore::Format format_attr;
if (kernel_name_ == kAvgPool) {
if (kernel_name_ == kAvgPoolOpName) {
format_attr = static_cast<mindspore::Format>(inputs[format_index]->GetValueWithCheck<int64_t>());
} else {
format_attr =
@ -244,12 +241,12 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
void SetPoolingMode(const std::vector<KernelTensor *> &inputs) {
mode_ = kernel_name_;
bool include = false;
if (kernel_name_ == kAvgPool3D) {
if (kernel_name_ == kAvgPool3DOpName) {
if (primitive_->HasAttr("count_include_pad")) {
include = GetValue<bool>(primitive_->GetAttr("count_include_pad"));
}
}
if (mode_ == kAvgPool || mode_ == kAvgPool3D) {
if (mode_ == kAvgPoolOpName || mode_ == kAvgPool3DOpName) {
pooling_mode_ =
include ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0;
@ -261,7 +258,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
void SetPad(const std::vector<KernelTensor *> &inputs) {
mindspore::PadMode pad_mode_;
if (kernel_name_ == kAvgPool) {
if (kernel_name_ == kAvgPoolOpName) {
pad_mode_ = static_cast<mindspore::PadMode>(inputs[kPadModeIdx]->GetValueWithCheck<int64_t>());
} else {
pad_mode_ = static_cast<mindspore::PadMode>(
@ -270,7 +267,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
}
std::vector<int> window;
std::vector<int64_t> window_me;
if (kernel_name_ == kAvgPool) {
if (kernel_name_ == kAvgPoolOpName) {
window_me = inputs[kKernelSizeIdx]->GetValueWithCheck<std::vector<int64_t>>();
} else {
window_me = GetValue<std::vector<int64_t>>(primitive_->GetAttr("kernel_size"));
@ -286,7 +283,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
int window_height = window[0 + kNC_SIZE];
int window_width = window[1 + kNC_SIZE];
std::vector<int64_t> stride_me;
if (kernel_name_ == kAvgPool) {
if (kernel_name_ == kAvgPoolOpName) {
stride_me = inputs[kStridesIdx]->GetValueWithCheck<std::vector<int64_t>>();
} else {
stride_me = GetValue<std::vector<int64_t>>(primitive_->GetAttr("strides"));
@ -313,7 +310,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
pad_height_ = 0;
pad_width_ = 0;
}
if (kernel_name_ != kAvgPool) {
if (kernel_name_ != kAvgPoolOpName) {
kNC_SIZE -= 2;
}
const size_t k2dDim = 2;
@ -399,7 +396,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
std::vector<int64_t> kernel_size;
std::vector<int64_t> strides;
std::vector<int64_t> pad;
if (kernel_name_ == kAvgPool) {
if (kernel_name_ == kAvgPoolOpName) {
kernel_size = inputs[kKernelSizeIdx]->GetValueWithCheck<std::vector<int64_t>>();
strides = inputs[kStridesIdx]->GetValueWithCheck<std::vector<int64_t>>();
pad = inputs[kPadListIdx]->GetValueWithCheck<std::vector<int64_t>>();
@ -452,7 +449,7 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
return;
}
std::vector<int64_t> pad_list = GetValue<std::vector<int64_t>>(primitive_->GetAttr("pad_list"));
if (kernel_name_ == kAvgPool3DDOpName && !GetValue<bool>(primitive_->GetAttr("count_include_pad")) &&
if (kernel_name_ == kAvgPool3DOpName && !GetValue<bool>(primitive_->GetAttr("count_include_pad")) &&
primitive_->HasAttr("divisor_override") &&
std::any_of(pad_list.begin(), pad_list.end(), [](int64_t pad) { return pad > 0; })) {
MS_LOG(EXCEPTION) << kernel_name_ << "does not support the scenes while padmode == " << pad_mode

View File

@ -16,3 +16,5 @@ cumprod:
dtype: tensor
class:
name: CumProd
function:
name: cumprod_

View File

@ -16,3 +16,5 @@ cumsum:
dtype: tensor
class:
name: CumSum
function:
name: cumsum_

View File

@ -101,6 +101,34 @@ def get_nll_loss_grad_vmap_rule(prim, axis_size):
@vmap_rules_getters.register(G.MaxPoolGrad)
def get_max_pool_grad_vmap_rule(prim, axis_size):
"""VmapRule for `MaxPoolGrad`."""
chw_reverse_index = -3
def vmap_rule(x_bdim, y_bdim, dy_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, y_bdim, dy_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
y, y_dim = y_bdim
dy, dy_dim = dy_bdim
x = _bdim_at_front(x, x_dim, axis_size)
y = _bdim_at_front(y, y_dim, axis_size)
dy = _bdim_at_front(dy, dy_dim, axis_size)
x_shape = F.shape(x)
y_shape = F.shape(y)
dy_shape = F.shape(dy)
x = F.reshape(x, (-1,) + x_shape[chw_reverse_index:])
y = F.reshape(y, (-1,) + y_shape[chw_reverse_index:])
dy = F.reshape(dy, (-1,) + dy_shape[chw_reverse_index:])
out = prim(x, y, dy)
out = F.reshape(out, x_shape)
return out, 0
return vmap_rule
@vmap_rules_getters.register(G.AvgPoolGrad)
def get_avg_pool_grad_vmap_rule(prim, axis_size):
"""VmapRule for `AvgPoolGrad`."""

View File

@ -52,8 +52,6 @@ def test_cummax_forward(context_mode, dtype):
x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(dtype))
axis = -2
values, indices = cummax_forward_func(x, axis)
print("values:\n", values)
print("indices:\n", indices)
expect_values = np.asarray([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(dtype)
expect_indices = np.asarray([[0, 0, 0, 0], [1, 1, 1, 1]]).astype(np.int64)
assert np.allclose(values.asnumpy(), expect_values)
@ -80,8 +78,6 @@ def test_cummax_vmap(context_mode, dtype):
axis = 0
nest_vmap = ops.vmap(ops.vmap(cummax_forward_func, in_axes=(0, None)), in_axes=(0, None))
values, indices = nest_vmap(x, axis)
print("values:\n", values)
print("indices:\n", indices)
expect_values = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]]).astype(dtype)
expect_indices = np.array([[[0, 1, 2, 3], [0, 1, 2, 3]]]).astype(np.int64)
assert (values.asnumpy() == expect_values).all()

View File

@ -52,8 +52,6 @@ def test_cummin_forward(context_mode, dtype):
x = Tensor(np.array([[3, 1, 4, 1], [1, 5, 9, 2]]).astype(dtype))
axis = -2
values, indices = cummin_forward_func(x, axis)
print("values:\n", values)
print("indices:\n", indices)
expect_values = np.asarray([[3, 1, 4, 1], [1, 1, 4, 1]]).astype(dtype)
expect_indices = np.asarray([[0, 0, 0, 0], [1, 0, 0, 0]]).astype(np.int64)
assert np.allclose(values.asnumpy(), expect_values)

View File

@ -23,7 +23,7 @@ import mindspore as ms
@test_utils.run_with_cell
def cumprod_forward_func(x, axis, exclusive, reverse):
return ops.auto_generate.cumprod(x, axis, exclusive, reverse)
return ops.auto_generate.cumprod_(x, axis, exclusive, reverse)
@test_utils.run_with_cell
@ -33,7 +33,7 @@ def cumprod_backward_func(x, axis, exclusive, reverse):
@test_utils.run_with_cell
def cumprod_dyn_shape_func(x, axis, exclusive, reverse):
return ops.auto_generate.cumprod(x, axis, exclusive, reverse)
return ops.auto_generate.cumprod_(x, axis, exclusive, reverse)
@pytest.mark.level0
@pytest.mark.env_onecard

View File

@ -23,7 +23,7 @@ import mindspore as ms
@test_utils.run_with_cell
def cumsum_forward_func(x, axis, exclusive, reverse):
return ops.auto_generate.cumsum(x, axis, exclusive, reverse)
return ops.auto_generate.cumsum_(x, axis, exclusive, reverse)
@test_utils.run_with_cell
@ -33,7 +33,7 @@ def cumsum_backward_func(x, axis, exclusive, reverse):
@test_utils.run_with_cell
def cumsum_dyn_shape_func(x, axis, exclusive, reverse):
return ops.auto_generate.cumsum(x, axis, exclusive, reverse)
return ops.auto_generate.cumsum_(x, axis, exclusive, reverse)
@pytest.mark.level0

View File

@ -54,12 +54,8 @@ TEST_P(TestCummax, dyn_shape) {
auto prim = std::make_shared<Primitive>("Cummax");
auto out_dtype = cummax_func_impl.InferType(prim, {x, axis});
std::cout << "out_dtype: " << out_dtype->ToString() << "\n";
std::cout << "expect_type: " << expect_type->ToString() << "\n";
ASSERT_TRUE(*out_dtype == *expect_type);
auto out_shape = cummax_func_impl.InferShape(prim, {x, axis});
std::cout << "out_shape: " << out_shape->ToString() << "\n";
std::cout << "expect_shape: " << expect_shape->ToString() << "\n";
ASSERT_TRUE(*out_shape == *expect_shape);
}

View File

@ -54,12 +54,8 @@ TEST_P(TestCummin, dyn_shape) {
auto prim = std::make_shared<Primitive>("Cummin");
auto out_dtype = cummin_func_impl.InferType(prim, {x, axis});
std::cout << "out_dtype: " << out_dtype->ToString() << "\n";
std::cout << "expect_type: " << expect_type->ToString() << "\n";
ASSERT_TRUE(*out_dtype == *expect_type);
auto out_shape = cummin_func_impl.InferShape(prim, {x, axis});
std::cout << "out_shape: " << out_shape->ToString() << "\n";
std::cout << "expect_shape: " << expect_shape->ToString() << "\n";
ASSERT_TRUE(*out_shape == *expect_shape);
}

View File

@ -54,12 +54,8 @@ TEST_P(TestCumProd, dyn_shape) {
auto prim = std::make_shared<Primitive>("CumProd");
auto out_dtype = cumprod_func_impl.InferType(prim, {x, axis, exclusive, reverse});
std::cout << "out_dtype: " << out_dtype->ToString() << "\n";
std::cout << "expect_type: " << expect_dtype->ToString() << "\n";
ASSERT_TRUE(*out_dtype == *expect_dtype);
auto out_shape = cumprod_func_impl.InferShape(prim, {x, axis, exclusive, reverse});
std::cout << "out_shape: " << out_shape->ToString() << "\n";
std::cout << "expect_shape: " << expect_shape->ToString() << "\n";
ASSERT_TRUE(*out_shape == *expect_shape);
}

View File

@ -54,12 +54,8 @@ TEST_P(TestCumSum, dyn_shape) {
auto prim = std::make_shared<Primitive>("CumSum");
auto out_dtype = cumsum_func_impl.InferType(prim, {x, axis, exclusive, reverse});
std::cout << "out_dtype: " << out_dtype->ToString() << "\n";
std::cout << "expect_type: " << expect_dtype->ToString() << "\n";
ASSERT_TRUE(*out_dtype == *expect_dtype);
auto out_shape = cumsum_func_impl.InferShape(prim, {x, axis, exclusive, reverse});
std::cout << "out_shape: " << out_shape->ToString() << "\n";
std::cout << "expect_shape: " << expect_shape->ToString() << "\n";
ASSERT_TRUE(*out_shape == *expect_shape);
}