switch Pooling3D from onednn to self-developed.

This commit is contained in:
zuochuanyong 2023-08-01 19:02:13 +08:00
parent 8c8f6ceda1
commit 0583b486f4
16 changed files with 926 additions and 64 deletions

View File

@ -39,6 +39,7 @@
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend_base.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/backend/graph_compiler/backend_base.cc" "variableScope"
"mindspore/mindspore/core/ops/max_pool.cc" "zerodivcond"
"mindspore/mindspore/core/ops/max_pool3d.cc" "zerodivcond"
"mindspore/core/utils/log_adapter.cc" "stlIfStrFind"
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.cc" "knownConditionTrueFalse"
"mindspore/mindspore/ccsrc/frontend/expander/bprop/grad_ops/grad_array_ops.cc" "internalAstError"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -193,13 +193,9 @@ void PoolingCpuKernelMod::ReComputeDivisor(T *dst) {
std::vector<KernelAttr> PoolingCpuKernelMod::GetOpSupport() {
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
{kMaxPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kMaxPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}},
{kAvgPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}},
{kAvgPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}}};
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}}};
auto iter = support_list_map.find(kernel_type_);
if (iter == support_list_map.end()) {
MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
@ -256,13 +252,9 @@ bool PoolingCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, MaxPool3D,
[]() { return std::make_shared<PoolingCpuKernelMod>(kMaxPool3DOpName); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, MaxPool,
[]() { return std::make_shared<PoolingCpuKernelMod>(kMaxPoolOpName); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AvgPool,
[]() { return std::make_shared<PoolingCpuKernelMod>(kAvgPoolOpName); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AvgPool3D,
[]() { return std::make_shared<PoolingCpuKernelMod>(kAvgPool3DOpName); });
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,346 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/mkldnn/pooling_cpu_kernel_nnacl.h"
#include <string>
#include <algorithm>
#include <functional>
#include "plugin/device/cpu/kernel/utils/cpu_utils.h"
#include "nnacl/fp32/pack_fp32.h"
#include "nnacl/fp32/pooling_fp32.h"
#include "ops/conv_pool_op_name.h"
namespace mindspore {
namespace kernel {
constexpr size_t kDepthOffset = 2;
constexpr size_t kPadListLen = 6;
enum kAxisIdx : int { kD = 2, kH, kW };
const size_t dimN = 0;
const size_t dimC = 1;
const size_t dimD = 2;
const size_t dimH = 3;
const size_t dimW = 4;
int64_t ComputeStride(const std::vector<int64_t> &shape, size_t index) {
int64_t result = 1;
for (size_t i = index + 1; i < shape.size(); ++i) {
result *= shape[i];
}
return result;
}
void PoolingCpuKernelNnaclMod::InitPooling3DParams() {
// Access structure members in declaration order
pooling_args_.pooling_compute_param_.input_w_ = in_size_[kW];
pooling_args_.pooling_compute_param_.input_h_ = in_size_[kH];
pooling_args_.pooling_compute_param_.input_batch_ = batches_;
pooling_args_.pooling_compute_param_.input_channel_ = channels_;
pooling_args_.pooling_compute_param_.output_w_ = out_size_[kW];
pooling_args_.pooling_compute_param_.output_h_ = out_size_[kH];
pooling_args_.input_d_ = in_size_[kD];
pooling_args_.output_d_ = out_size_[kD];
pooling_param_.pooling_parameter_.window_w_ = kernel_size_[kW];
pooling_param_.pooling_parameter_.window_h_ = kernel_size_[kH];
pooling_param_.pooling_parameter_.stride_w_ = stride_size_[kW];
pooling_param_.pooling_parameter_.stride_h_ = stride_size_[kH];
pooling_param_.pooling_parameter_.pad_u_ = padding_l_[kH - kDepthOffset];
pooling_param_.pooling_parameter_.pad_d_ = padding_r_[kH - kDepthOffset];
pooling_param_.pooling_parameter_.pad_l_ = padding_l_[kW - kDepthOffset];
pooling_param_.pooling_parameter_.pad_r_ = padding_r_[kW - kDepthOffset];
pooling_param_.window_d_ = kernel_size_[kD];
pooling_param_.stride_d_ = stride_size_[kD];
pooling_param_.pad_f_ = padding_l_[kD - kDepthOffset];
pooling_param_.pad_b_ = padding_r_[kD - kDepthOffset];
pooling_param_.count_include_pad_ = count_include_pad_;
pooling_param_.divisor_override_ = divisor_override_;
}
bool PoolingCpuKernelNnaclMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
kernel_name_ = base_operator->name();
if (kernel_name_ == kAvgPool3DOpName) {
pool_mode_ = MEAN_POOLING;
} else if (kernel_name_ == kMaxPool3DOpName) {
pool_mode_ = MAX_POOLING;
} else {
MS_LOG(ERROR) << "Pooling only supports Avg or Max, but got " << kernel_name_;
return false;
}
dtype_ = inputs[0]->GetDtype();
format_ = GetValue<std::string>(base_operator->GetAttr(FORMAT));
pad_mode_ = GetValue<std::string>(base_operator->GetAttr(PAD_MODE));
kernel_size_ = GetValue<std::vector<int64_t>>(base_operator->GetAttr(KERNEL_SIZE));
stride_size_ = GetValue<std::vector<int64_t>>(base_operator->GetAttr(STRIDES));
ceil_mode_ = pool_mode_ == MEAN_POOLING ? GetValue<bool>(base_operator->GetAttr(CEIL_MODE))
: (GetValue<int64_t>(base_operator->GetAttr(CEIL_MODE)) == 1);
count_include_pad_ = GetValue<bool>(base_operator->GetAttr(COUNT_INCLUDE_PAD));
divisor_override_ = GetValue<int64_t>(base_operator->GetAttr(DIVISOR_OVERRIDE));
return true;
}
int PoolingCpuKernelNnaclMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
in_size_ = inputs[0]->GetDeviceShapeAdaptively();
out_size_ = outputs[0]->GetDeviceShapeAdaptively();
auto src_dim = in_size_.size();
if (src_dim != SHAPE_5D || out_size_.size() != SHAPE_5D) {
MS_LOG(ERROR) << "Pooling only supports 5D input/output, but got input " << src_dim << "D, "
<< "output " << out_size_.size() << "D!";
return KRET_RESIZE_FAILED;
}
if (kernel_size_.size() != src_dim) {
MS_LOG(EXCEPTION) << kernel_name_ << " requires kernel_size must be " << src_dim << "D, but got "
<< kernel_size_.size() << "D!";
}
if (stride_size_.size() != src_dim) {
MS_LOG(EXCEPTION) << kernel_name_ << " requires strides must be " << src_dim << "D, but got " << stride_size_.size()
<< "D!";
}
pad_list_ = GetValue<std::vector<int64_t>>(base_operator->GetAttr(PAD_LIST));
if (pad_list_.size() != kPadListLen) {
MS_LOG(EXCEPTION) << kernel_name_ << " requires length of pad_list must be " << kPadListLen << ", but got "
<< pad_list_.size() << "!";
}
auto pad_size = pad_list_.size() / 2;
padding_l_.resize(pad_size);
padding_r_.resize(pad_size);
for (size_t i = 0; i < padding_l_.size(); ++i) {
padding_l_[i] = pad_list_[i << 1];
padding_r_[i] = pad_list_[(i << 1) + 1];
}
input_stride_n_ = ComputeStride(in_size_, dimN);
input_stride_c_ = ComputeStride(in_size_, dimC);
input_stride_d_ = ComputeStride(in_size_, dimD);
input_stride_h_ = ComputeStride(in_size_, dimH);
input_stride_w_ = ComputeStride(in_size_, dimW);
batches_ = in_size_[0];
channels_ = in_size_[1];
output_num_ = batches_ * channels_ * out_size_[kD] * out_size_[kH] * out_size_[kW];
auto in_dtype_size = GetTypeByte(TypeIdToType(inputs[0]->GetDtype()));
auto out_dtype_size = GetTypeByte(TypeIdToType(outputs[0]->GetDtype()));
use_channel_last_ = dtype_ == kNumberTypeFloat32 && channels_ >= 4;
if (use_channel_last_) {
InitPooling3DParams();
size_t ws_size = batches_ * channels_ * in_size_[kD] * in_size_[kH] * in_size_[kW] * in_dtype_size;
(void)workspace_size_list_.emplace_back(ws_size); // output buffer of transposing of input
(void)workspace_size_list_.emplace_back(output_num_ * out_dtype_size); // output buffer of pooling of ndhwc
}
return KRET_OK;
}
std::vector<KernelAttr> PoolingCpuKernelNnaclMod::GetOpSupport() {
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
{kMaxPool3DOpName,
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}},
{kAvgPool3DOpName,
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}},
};
auto iter = support_list_map.find(kernel_type_);
if (iter == support_list_map.end()) {
MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
}
return iter->second;
}
template <typename T>
CTask PoolingCpuKernelNnaclMod::KernelAvgPool(T *input_addr, T *output_addr) {
CTask task = [&, input_addr, output_addr](size_t start, size_t end) {
const int64_t d_max = in_size_[kD] + padding_r_[kD - kDepthOffset];
const int64_t h_max = in_size_[kH] + padding_r_[kH - kDepthOffset];
const int64_t w_max = in_size_[kW] + padding_r_[kW - kDepthOffset];
const bool count_include_pad = count_include_pad_;
int64_t divisor = divisor_override_;
int64_t n = 0, c = 0, d = 0, h = 0, w = 0;
offset_to_index_init((int64_t)start, &n, batches_, &c, channels_, &d, out_size_[kD], &h, out_size_[kH], &w,
out_size_[kW]);
for (size_t i = start; i < end; i++) {
int64_t d_start = d * stride_size_[kD] - padding_l_[kD - kDepthOffset];
int64_t d_end = std::min(d_start + kernel_size_[kD], d_max);
int64_t d_start2 = std::max(d_start, (int64_t)0);
int64_t d_end2 = std::min(d_end, in_size_[kD]);
int64_t h_start = h * stride_size_[kH] - padding_l_[kH - kDepthOffset];
int64_t h_end = std::min(h_start + kernel_size_[kH], h_max);
int64_t h_start2 = std::max(h_start, (int64_t)0);
int64_t h_end2 = std::min(h_end, in_size_[kH]);
int64_t w_start = w * stride_size_[kW] - padding_l_[kW - kDepthOffset];
int64_t w_end = std::min(w_start + kernel_size_[kW], w_max);
int64_t w_start2 = std::max(w_start, (int64_t)0);
int64_t w_end2 = std::min(w_end, in_size_[kW]);
if (divisor_override_ == 0) {
if (count_include_pad) {
divisor = (d_end - d_start) * (h_end - h_start) * (w_end - w_start);
} else {
divisor = (d_end2 - d_start2) * (h_end2 - h_start2) * (w_end2 - w_start2);
}
}
T *input = input_addr + n * input_stride_n_ + c * input_stride_c_;
T sum = static_cast<T>(0.0);
for (auto dd = d_start2; dd < d_end2; ++dd) {
int64_t stride_d = dd * input_stride_d_;
for (auto hh = h_start2; hh < h_end2; ++hh) {
int64_t stride_dh = stride_d + hh * input_stride_h_;
for (auto ww = w_start2; ww < w_end2; ++ww) {
int64_t index = stride_dh + ww;
sum += input[index];
}
}
}
output_addr[i] = sum / divisor;
offset_to_index_step(&n, batches_, &c, channels_, &d, out_size_[kD], &h, out_size_[kH], &w, out_size_[kW]);
}
};
return task;
}
template <typename T>
CTask PoolingCpuKernelNnaclMod::KernelMaxPool(T *input_addr, T *output_addr) {
CTask task = [&, input_addr, output_addr](size_t start, size_t end) {
int64_t n = 0, c = 0, d = 0, h = 0, w = 0;
offset_to_index_init((int64_t)start, &n, batches_, &c, channels_, &d, out_size_[kD], &h, out_size_[kH], &w,
out_size_[kW]);
for (size_t i = start; i < end; i++) {
int64_t d_start = d * stride_size_[kD] - padding_l_[kD - kDepthOffset];
int64_t d_end = std::min(d_start + kernel_size_[kD], in_size_[kD]);
d_start = std::max(d_start, (int64_t)0);
int64_t h_start = h * stride_size_[kH] - padding_l_[kH - kDepthOffset];
int64_t h_end = std::min(h_start + kernel_size_[kH], in_size_[kH]);
h_start = std::max(h_start, (int64_t)0);
int64_t w_start = w * stride_size_[kW] - padding_l_[kW - kDepthOffset];
int64_t w_end = std::min(w_start + kernel_size_[kW], in_size_[kW]);
w_start = std::max(w_start, (int64_t)0);
T *input = input_addr + n * input_stride_n_ + c * input_stride_c_;
T tmp_max = static_cast<T>(-FLT_MAX);
for (auto dd = d_start; dd < d_end; ++dd) {
int64_t stride_d = dd * input_stride_d_;
for (auto hh = h_start; hh < h_end; ++hh) {
int64_t stride_dh = stride_d + hh * input_stride_h_;
for (auto ww = w_start; ww < w_end; ++ww) {
int64_t index = stride_dh + ww;
tmp_max = std::max(input[index], tmp_max);
}
}
}
output_addr[i] = tmp_max;
offset_to_index_step(&n, batches_, &c, channels_, &d, out_size_[kD], &h, out_size_[kH], &w, out_size_[kW]);
}
};
return task;
}
void PoolingCpuKernelNnaclMod::LaunchTransposeFp32(float *input_addr, float *output_addr, int plane, int channel) {
int m = UP_DIV(plane, C8NUM);
int n = UP_DIV(channel, C8NUM);
size_t task_num = batches_ * m * n;
CTask task = [&, input_addr, output_addr](size_t start, size_t end) {
TransposeFp32(input_addr, output_addr, batches_, plane, channel, start, end);
};
ParallelLaunch(task, task_num, 1.0);
}
void PoolingCpuKernelNnaclMod::LaunchPoolingChannelLastFp32(float *input_addr, float *transpose_out, float *pooling_out,
float *output_addr) {
size_t task_num = batches_ * out_size_[kD] * out_size_[kH] * out_size_[kW];
int in_plane = in_size_[kD] * in_size_[kH] * in_size_[kW];
LaunchTransposeFp32(input_addr, transpose_out, channels_, in_plane);
CTask task;
if (pool_mode_ == MEAN_POOLING) {
task = [&, transpose_out, pooling_out](size_t start, size_t end) {
AvgPooling3D_NDHWC(transpose_out, pooling_out, &pooling_param_, &pooling_args_, start, end);
};
} else {
task = [&, transpose_out, pooling_out](size_t start, size_t end) {
MaxPooling3D_NDHWC(transpose_out, pooling_out, &pooling_param_, &pooling_args_, start, end);
};
}
ParallelLaunch(task, task_num, 1.0);
int out_plane = out_size_[kD] * out_size_[kH] * out_size_[kW];
LaunchTransposeFp32(pooling_out, output_addr, out_plane, channels_);
}
template <typename T>
bool PoolingCpuKernelNnaclMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
CTask task =
pool_mode_ == MEAN_POOLING ? KernelAvgPool<T>(input_addr, output_addr) : KernelMaxPool<T>(input_addr, output_addr);
ParallelLaunch(task, output_num_, 1.0);
return true;
}
#define POOL3D_KERNEL_CHANNEL_FIRST_CASE(TYPENUM, DTYPE) \
case (TYPENUM): { \
LaunchKernel<DTYPE>(inputs, workspaces, outputs); \
break; \
}
bool PoolingCpuKernelNnaclMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspaces,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
if (use_channel_last_) {
float *input_addr = reinterpret_cast<float *>(inputs[0]->addr);
float *output_addr = reinterpret_cast<float *>(outputs[0]->addr);
float *transpose_out = GetDeviceAddress<float>(workspaces, 0);
float *pooling_out = GetDeviceAddress<float>(workspaces, 1);
LaunchPoolingChannelLastFp32(input_addr, transpose_out, pooling_out, output_addr);
return true;
}
switch (dtype_) {
POOL3D_KERNEL_CHANNEL_FIRST_CASE(kNumberTypeFloat32, float)
POOL3D_KERNEL_CHANNEL_FIRST_CASE(kNumberTypeFloat16, float16)
POOL3D_KERNEL_CHANNEL_FIRST_CASE(kNumberTypeFloat64, double)
default:
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the type of input should be float16, float32 or float64, but got "
<< TypeIdToType(dtype_)->ToString();
return false;
}
return true;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AvgPool3D,
[]() { return std::make_shared<PoolingCpuKernelNnaclMod>(kAvgPool3DOpName); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, MaxPool3D,
[]() { return std::make_shared<PoolingCpuKernelNnaclMod>(kMaxPool3DOpName); });
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,102 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_CPU_KERNEL_NNACL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_CPU_KERNEL_NNACL_H_
#include <vector>
#include <memory>
#include <utility>
#include <unordered_map>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/nnacl/kernel/pooling.h"
#include "plugin/device/cpu/kernel/nnacl/pooling_parameter.h"
namespace mindspore {
namespace kernel {
constexpr auto kUnkown = "Unknown";
class PoolingCpuKernelNnaclMod : public NativeCpuKernelMod {
public:
PoolingCpuKernelNnaclMod() = default;
explicit PoolingCpuKernelNnaclMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~PoolingCpuKernelNnaclMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
std::vector<KernelAttr> GetOpSupport() override;
protected:
PoolMode pool_mode_;
std::string format_;
std::string pad_mode_;
int64_t batches_{0};
int64_t channels_{0};
int64_t input_stride_n_{1};
int64_t input_stride_c_{1};
int64_t input_stride_d_{1};
int64_t input_stride_h_{1};
int64_t input_stride_w_{1};
size_t output_num_{1};
std::vector<int64_t> in_size_;
std::vector<int64_t> out_size_;
std::vector<int64_t> kernel_size_;
std::vector<int64_t> stride_size_;
std::vector<int64_t> pad_list_;
std::vector<int64_t> padding_l_;
std::vector<int64_t> padding_r_;
bool ceil_mode_{false};
bool count_include_pad_{true};
int64_t divisor_override_{0};
Pooling3DParameter pooling_param_;
Pooling3DComputeParam pooling_args_;
private:
std::string kernel_type_{kUnkown};
void InitPooling3DParams();
template <typename T>
CTask KernelAvgPool(T *input_addr, T *output_addr);
template <typename T>
CTask KernelMaxPool(T *input_addr, T *output_addr);
void LaunchPoolingChannelLastFp32(float *input_addr, float *transpose_out, float *pooling_out, float *output_addr);
void LaunchTransposeFp32(float *input_addr, float *output_addr, int plane, int channel);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspaces,
const std::vector<kernel::AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
bool use_channel_last_{false};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_POOLING_CPU_KERNEL_NNACL_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -1680,6 +1680,86 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
}
}
/*
|<---------------- plane --------------->|
+---------------------------+------------+ ---
| | | | |
|8x8-blocks| ... |8x8-blocks| right | |
| | | | | |
+ - - - - -+ + - - - - -+ | |
| ... ... ... | top | channel
+ - - - - -+ + - - - - -| | |
| | | | tails | |
|8x8-blocks| ... |8x8-blocks| | |
+---------------------------+------------+ |
| |right bottom| |
| left bottom tails | tails |
+---------------------------+------------+ ---
*/
void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end) {
#ifdef ENABLE_ARM64
Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64;
#elif defined(ENABLE_ARM32)
Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32;
#elif defined(ENABLE_AVX)
Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx;
#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX)
Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse;
#endif
int m_pad = UP_DIV(channel, C8NUM);
int n_pad = UP_DIV(plane, C8NUM);
int m_blk = channel / C8NUM;
int n_blk = plane / C8NUM;
int b_stride = plane * channel;
// printf("channel, plane: %d, %d\n", channel, plane);
int b = 0, m = 0, n = 0;
// To make write dst consecutively, (m,n):(0,0)->(1,0)->...->(0,1)->(1,1)->...
offset_to_index_init(start, 6, &m, m_pad, &n, n_pad, &b, batches);
for (int task = start; task < end; task++) {
const float *src_batch = (const float *)src + b * b_stride;
float *dst_batch = (float *)dst + b * b_stride;
int m_start = m * C8NUM;
int n_start = n * C8NUM;
if (m < m_blk && n < n_blk) {
// process 8x8-blocks
const float *from = src_batch + m_start * plane + n_start;
float *to = dst_batch + n_start * channel + m_start;
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32)
Transpose8X8Fp32Func_(from, to, plane, channel);
#else
for (int tr = 0; tr < C8NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
to[tc * plane + tr] = from[tr * channel + tc];
}
}
#endif
} else {
// process right bottom tails
const float *from = src_batch;
float *to = dst_batch;
int i_start = m_start;
int i_end = channel;
int j_start = n_start;
int j_end = plane;
if (m >= m_blk && n < n_blk) {
// process left bottom tails
from = src_batch + n_start;
to = dst_batch + n_start * channel;
j_start = 0;
j_end = C8NUM;
} else if (m < m_blk && n >= n_blk) {
// process right top tails
from = src_batch + m_start * plane;
to = dst_batch + m_start;
i_start = 0;
i_end = C8NUM;
}
transpose_tail(from, to, j_start, j_end, i_start, i_end, channel, plane);
}
offset_to_index_step(6, &m, m_pad, &n, n_pad, &b, batches);
}
}
void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) {
PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count);
}
@ -1837,42 +1917,57 @@ inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_
#endif
#ifdef ENABLE_AVX
/*
Using _mm256_insertf128_ps at the beginning, instead of using _mm256_permute2f128_ps at the end.
On the whole, 4 vinsertf128 and 4 vperm2f128 are used less than before.
*/
inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) {
MS_LOAD256X8_F32(src, src_ptr, src_stride)
__m256 r1 = _mm256_unpacklo_ps(src1, src2);
__m256 r2 = _mm256_unpackhi_ps(src1, src2);
__m256 r3 = _mm256_unpacklo_ps(src3, src4);
__m256 r4 = _mm256_unpackhi_ps(src3, src4);
__m256 r5 = _mm256_unpacklo_ps(src5, src6);
__m256 r6 = _mm256_unpackhi_ps(src5, src6);
__m256 r7 = _mm256_unpacklo_ps(src7, src8);
__m256 r8 = _mm256_unpackhi_ps(src7, src8);
const float *src1 = src_ptr + 0 * src_stride;
const float *src2 = src_ptr + 1 * src_stride;
const float *src3 = src_ptr + 2 * src_stride;
const float *src4 = src_ptr + 3 * src_stride;
const float *src5 = src_ptr + 4 * src_stride;
const float *src6 = src_ptr + 5 * src_stride;
const float *src7 = src_ptr + 6 * src_stride;
const float *src8 = src_ptr + 7 * src_stride;
__m256 r1, r2, r3, r4, r5, r6, r7, r8;
__m256 t1, t2, t3, t4, t5, t6, t7, t8;
// _mm256_castps128_ps256 is only for compilation and generates no instructions, thus it has zero latency.
r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 0)), _mm_loadu_ps(src5 + 0), 1);
r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 0)), _mm_loadu_ps(src6 + 0), 1);
r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 0)), _mm_loadu_ps(src7 + 0), 1);
r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 0)), _mm_loadu_ps(src8 + 0), 1);
r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 4)), _mm_loadu_ps(src5 + 4), 1);
r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 4)), _mm_loadu_ps(src6 + 4), 1);
r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 4)), _mm_loadu_ps(src7 + 4), 1);
r8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 4)), _mm_loadu_ps(src8 + 4), 1);
t1 = _mm256_unpacklo_ps(r1, r2);
t2 = _mm256_unpackhi_ps(r1, r2);
t3 = _mm256_unpacklo_ps(r3, r4);
t4 = _mm256_unpackhi_ps(r3, r4);
t5 = _mm256_unpacklo_ps(r5, r6);
t6 = _mm256_unpackhi_ps(r5, r6);
t7 = _mm256_unpacklo_ps(r7, r8);
t8 = _mm256_unpackhi_ps(r7, r8);
__m256 v;
v = _mm256_shuffle_ps(r1, r3, 0x4E);
src1 = _mm256_blend_ps(r1, v, 0xCC);
src2 = _mm256_blend_ps(r3, v, 0x33);
v = _mm256_shuffle_ps(t1, t3, 0x4E);
r1 = _mm256_blend_ps(t1, v, 0xCC);
r2 = _mm256_blend_ps(t3, v, 0x33);
v = _mm256_shuffle_ps(r2, r4, 0x4E);
src3 = _mm256_blend_ps(r2, v, 0xCC);
src4 = _mm256_blend_ps(r4, v, 0x33);
v = _mm256_shuffle_ps(t2, t4, 0x4E);
r3 = _mm256_blend_ps(t2, v, 0xCC);
r4 = _mm256_blend_ps(t4, v, 0x33);
v = _mm256_shuffle_ps(r5, r7, 0x4E);
src5 = _mm256_blend_ps(r5, v, 0xCC);
src6 = _mm256_blend_ps(r7, v, 0x33);
v = _mm256_shuffle_ps(t5, t7, 0x4E);
r5 = _mm256_blend_ps(t5, v, 0xCC);
r6 = _mm256_blend_ps(t7, v, 0x33);
v = _mm256_shuffle_ps(r6, r8, 0x4E);
src7 = _mm256_blend_ps(r6, v, 0xCC);
src8 = _mm256_blend_ps(r8, v, 0x33);
r1 = _mm256_permute2f128_ps(src1, src5, 0x20);
r2 = _mm256_permute2f128_ps(src2, src6, 0x20);
r3 = _mm256_permute2f128_ps(src3, src7, 0x20);
r4 = _mm256_permute2f128_ps(src4, src8, 0x20);
r5 = _mm256_permute2f128_ps(src1, src5, 0x31);
r6 = _mm256_permute2f128_ps(src2, src6, 0x31);
r7 = _mm256_permute2f128_ps(src3, src7, 0x31);
r8 = _mm256_permute2f128_ps(src4, src8, 0x31);
v = _mm256_shuffle_ps(t6, t8, 0x4E);
r7 = _mm256_blend_ps(t6, v, 0xCC);
r8 = _mm256_blend_ps(t8, v, 0x33);
STORE256X8_F32(dst_ptr, dst_stride, r);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,7 +25,16 @@
#ifdef __cplusplus
extern "C" {
#endif
static inline void transpose_tail(const float *from, float *to, int j_start, int j_end, int i_start, int i_end,
int j_stride, int i_stride) {
// write consecutively
for (int j = j_start; j < j_end; j++) {
for (int i = i_start; i < i_end; i++) {
to[j * j_stride + i] = from[i * i_stride + j];
}
}
}
void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end);
void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel);
void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -633,3 +633,154 @@ int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const
}
return NNACL_OK;
}
void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param,
const Pooling3DComputeParam *pooling_args, int start, int end) {
// Access structure members in declaration order
int in_size_w = pooling_args->pooling_compute_param_.input_w_;
int in_size_h = pooling_args->pooling_compute_param_.input_h_;
int batch = pooling_args->pooling_compute_param_.input_batch_;
int channel = pooling_args->pooling_compute_param_.input_channel_;
int out_size_w = pooling_args->pooling_compute_param_.output_w_;
int out_size_h = pooling_args->pooling_compute_param_.output_h_;
int in_size_d = pooling_args->input_d_;
int out_size_d = pooling_args->output_d_;
int kernel_w = pooling_param->pooling_parameter_.window_w_;
int kernel_h = pooling_param->pooling_parameter_.window_h_;
int stride_w = pooling_param->pooling_parameter_.stride_w_;
int stride_h = pooling_param->pooling_parameter_.stride_h_;
int pad_l_h = pooling_param->pooling_parameter_.pad_u_;
int pad_l_w = pooling_param->pooling_parameter_.pad_l_;
int kernel_d = pooling_param->window_d_;
int stride_d = pooling_param->stride_d_;
int pad_l_d = pooling_param->pad_f_;
int n_stride = in_size_d * in_size_h * in_size_w * channel;
int d_stride = in_size_h * in_size_w * channel;
int h_stride = in_size_w * channel;
int n = 0, d = 0, h = 0, w = 0;
const int parallel_dims = 4; // parallel on N/D/H/W four dims
offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n,
batch);
for (int i = start; i < end; i++) {
int d_start = d * stride_d - pad_l_d;
int d_end = MSMIN(d_start + kernel_d, in_size_d);
d_start = MSMAX(d_start, 0);
int h_start = h * stride_h - pad_l_h;
int h_end = MSMIN(h_start + kernel_h, in_size_h);
h_start = MSMAX(h_start, 0);
int w_start = w * stride_w - pad_l_w;
int w_end = MSMIN(w_start + kernel_w, in_size_w);
w_start = MSMAX(w_start, 0);
const float *src_batch_ptr = input_ptr + n * n_stride;
float *out = output_ptr + i * channel;
int c_idx = 0;
SIMD_RUN_NO_SCALAR(MaxPooling3D, c_idx, src_batch_ptr, channel, out, d_start, d_end, h_start, h_end, w_start, w_end,
d_stride, h_stride);
for (; c_idx < channel; ++c_idx) {
const float *src_c_ptr = src_batch_ptr + c_idx;
float *dst_c_ptr = out + c_idx;
float tmp_max = -FLT_MAX;
for (int dd = d_start; dd < d_end; ++dd) {
for (int hh = h_start; hh < h_end; ++hh) {
for (int ww = w_start; ww < w_end; ++ww) {
const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel;
tmp_max = MSMAX(input[0], tmp_max);
}
}
}
dst_c_ptr[0] = tmp_max;
}
offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch);
}
}
void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param,
const Pooling3DComputeParam *pooling_args, int start, int end) {
// Access structure members in declaration order
int in_size_w = pooling_args->pooling_compute_param_.input_w_;
int in_size_h = pooling_args->pooling_compute_param_.input_h_;
int batch = pooling_args->pooling_compute_param_.input_batch_;
int channel = pooling_args->pooling_compute_param_.input_channel_;
int out_size_w = pooling_args->pooling_compute_param_.output_w_;
int out_size_h = pooling_args->pooling_compute_param_.output_h_;
int in_size_d = pooling_args->input_d_;
int out_size_d = pooling_args->output_d_;
int kernel_w = pooling_param->pooling_parameter_.window_w_;
int kernel_h = pooling_param->pooling_parameter_.window_h_;
int stride_w = pooling_param->pooling_parameter_.stride_w_;
int stride_h = pooling_param->pooling_parameter_.stride_h_;
int pad_l_h = pooling_param->pooling_parameter_.pad_u_;
int pad_r_h = pooling_param->pooling_parameter_.pad_d_;
int pad_l_w = pooling_param->pooling_parameter_.pad_l_;
int pad_r_w = pooling_param->pooling_parameter_.pad_r_;
int kernel_d = pooling_param->window_d_;
int stride_d = pooling_param->stride_d_;
int pad_l_d = pooling_param->pad_f_;
int pad_r_d = pooling_param->pad_b_;
bool count_include_pad = pooling_param->count_include_pad_;
int divisor = pooling_param->divisor_override_;
int n_stride = in_size_d * in_size_h * in_size_w * channel;
int d_stride = in_size_h * in_size_w * channel;
int h_stride = in_size_w * channel;
const int d_max = in_size_d + pad_r_d;
const int h_max = in_size_h + pad_r_h;
const int w_max = in_size_w + pad_r_w;
int n = 0, d = 0, h = 0, w = 0;
const int parallel_dims = 4; // parallel on N/D/H/W four dims
offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n,
batch);
for (int i = start; i < end; i++) {
int d_start = d * stride_d - pad_l_d;
int d_end = MSMIN(d_start + kernel_d, d_max);
int d_start2 = MSMAX(d_start, 0);
int d_end2 = MSMIN(d_end, in_size_d);
int h_start = h * stride_h - pad_l_h;
int h_end = MSMIN(h_start + kernel_h, h_max);
int h_start2 = MSMAX(h_start, 0);
int h_end2 = MSMIN(h_end, in_size_h);
int w_start = w * stride_w - pad_l_w;
int w_end = MSMIN(w_start + kernel_w, w_max);
int w_start2 = MSMAX(w_start, 0);
int w_end2 = MSMIN(w_end, in_size_w);
const float *src_batch_ptr = input_ptr + n * n_stride;
float *out = output_ptr + i * channel;
if (pooling_param->divisor_override_ == 0) {
if (count_include_pad) {
divisor = (d_end - d_start) * (h_end - h_start) * (w_end - w_start);
} else {
divisor = (d_end2 - d_start2) * (h_end2 - h_start2) * (w_end2 - w_start2);
}
}
int c_idx = 0;
SIMD_RUN_NO_SCALAR(AvgPooling3D, c_idx, src_batch_ptr, channel, out, d_start2, d_end2, h_start2, h_end2, w_start2,
w_end2, d_stride, h_stride, divisor);
for (; c_idx < channel; ++c_idx) {
const float *src_c_ptr = src_batch_ptr + c_idx;
float *dst_c_ptr = out + c_idx;
float tmp_avg = 0;
for (int dd = d_start2; dd < d_end2; ++dd) {
for (int hh = h_start2; hh < h_end2; ++hh) {
for (int ww = w_start2; ww < w_end2; ++ww) {
const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel;
tmp_avg = tmp_avg + input[0];
}
}
}
dst_c_ptr[0] = tmp_avg / (float)divisor;
}
offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch);
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,6 +37,10 @@ int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const
const PoolingComputeParam *pooling_args, int task_id, int thread_num);
int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
const PoolingComputeParam *pooling_args, int task_id, int thread_num);
void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param,
const Pooling3DComputeParam *pooling_args, int start, int end);
void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param,
const Pooling3DComputeParam *pooling_args, int start, int end);
#ifdef __cplusplus
}
#endif

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -70,6 +70,45 @@ static inline int MaxPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_pla
return ci;
}
static inline int MaxPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out,
int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride) {
for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) {
const float *src_c_ptr = src_batch_ptr + c_idx;
float *dst_c_ptr = out + c_idx;
SIMD_F32 tmp_max = SIMD_MOV_F32(-FLT_MAX);
for (int dd = d_start; dd < d_end; ++dd) {
for (int hh = h_start; hh < h_end; ++hh) {
for (int ww = w_start; ww < w_end; ++ww) {
const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel;
tmp_max = SIMD_MAX_F32(SIMD_LD_F32(input), tmp_max);
}
}
}
SIMD_ST_F32(dst_c_ptr, tmp_max);
}
return c_idx;
}
static inline int AvgPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out,
int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride, int divisor) {
for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) {
const float *src_c_ptr = src_batch_ptr + c_idx;
float *dst_c_ptr = out + c_idx;
SIMD_F32 tmp_avg = SIMD_SET0_F32;
for (int dd = d_start; dd < d_end; ++dd) {
for (int hh = h_start; hh < h_end; ++hh) {
for (int ww = w_start; ww < w_end; ++ww) {
const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel;
tmp_avg = SIMD_ADD_F32(SIMD_LD_F32(input), tmp_avg);
}
}
}
tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(divisor));
SIMD_ST_F32(dst_c_ptr, tmp_avg);
}
return c_idx;
}
@SIMD_INSTRUCTION_END@
#ifdef __cplusplus
}

View File

@ -36,6 +36,13 @@ typedef struct PoolingComputeParam {
float maxf;
} PoolingComputeParam;
typedef struct Pooling3DComputeParam {
PoolingComputeParam pooling_compute_param_;
int input_d_;
int output_d_;
int window_d_;
} Pooling3DComputeParam;
typedef struct PoolingStruct {
KernelBase base_;
PoolingComputeParam compute_;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +17,7 @@
#ifndef NNACL_OP_BASE_H_
#define NNACL_OP_BASE_H_
#include <stdarg.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdbool.h>
@ -761,4 +762,33 @@ typedef enum CalFixedMultiplierMode {
Method_DoublePrecision
} CalFixedMultiplierMode;
#define VA_ARG_TUPLE_LEN 2
static inline void offset_to_index_init(int offset, int cnt, ...) {
va_list valist;
va_start(valist, cnt);
int start = offset;
for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) {
int *x = va_arg(valist, int *);
int X = va_arg(valist, int);
*x = start % X;
start = start / X;
}
va_end(valist);
}
static inline void offset_to_index_step(int cnt, ...) {
va_list valist;
int flag = 1;
va_start(valist, cnt);
for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) {
int *x = va_arg(valist, int *);
int X = va_arg(valist, int);
if (flag) {
*x = (++*x != X) ? (flag = 0, *x) : (flag = 1, 0);
}
}
va_end(valist);
}
#endif // NNACL_OP_BASE_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,6 +48,8 @@ typedef struct Pooling3DParameter {
int output_d_;
int pad_f_; // front
int pad_b_; // back
bool count_include_pad_;
int divisor_override_;
} Pooling3DParameter;
#endif // NNACL_POOLING_PARAMETER_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,11 +18,35 @@
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_UTILS_CPU_UTILS_H_
#include <cmath>
#include <utility>
#include "mindspore/core/base/float16.h"
namespace mindspore {
namespace kernel {
template <typename T>
inline T offset_to_index_init(T offset) {
return offset;
}
template <typename T, typename... Args>
inline T offset_to_index_init(T offset, T *x, const T &X, Args &&... args) {
offset = offset_to_index_init(offset, std::forward<Args>(args)...);
*x = offset % X;
return offset / X;
}
inline bool offset_to_index_step() { return true; }
template <typename T, typename... Args>
inline bool offset_to_index_step(T *x, const T &X, Args &&... args) {
if (offset_to_index_step(std::forward<Args>(args)...)) {
*x = ((*x + 1) == X) ? 0 : (*x + 1);
return *x == 0;
}
return false;
}
// compatible with MSVC
template <typename T>
inline bool IsNan(T x) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2022-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -46,6 +46,8 @@
namespace mindspore {
namespace ops {
constexpr size_t kMaxPool3DPadDims = 6;
MIND_API_OPERATOR_IMPL(MaxPool3D, BaseOperator);
void MaxPool3D::Init(const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride,
@ -212,6 +214,36 @@ std::vector<int64_t> GetOutputShape(const PrimitivePtr &primitive, const std::ve
return output_shape;
}
void GetPadsByPadding(const PrimitivePtr &primitive, int64_t in_d, int64_t in_h, int64_t in_w, int64_t kernel_d,
int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w,
const int64_t &pad_mode, const std::vector<int64_t> &padding, std::vector<int64_t> *pad_list) {
MS_EXCEPTION_IF_NULL(pad_list);
if (pad_mode == PadMode::VALID) {
(void)pad_list->insert(pad_list->begin(), kMaxPool3DPadDims, 0);
} else if (pad_mode == PadMode::SAME) {
if (stride_d == 0 || stride_h == 0 || stride_w == 0) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', stride_d or stride_h or stride_w must be non-zero, but got stride_d: " << stride_d
<< ", stride_h: " << stride_h << ", stride_w: " << stride_w << ".";
}
int64_t tail_d = in_d % stride_d;
int64_t tail_h = in_h % stride_h;
int64_t tail_w = in_w % stride_w;
int64_t pad_d = std::max((tail_d > 0 ? kernel_d - tail_d : kernel_d - stride_d), (int64_t)0);
int64_t pad_h = std::max((tail_h > 0 ? kernel_h - tail_h : kernel_h - stride_h), (int64_t)0);
int64_t pad_w = std::max((tail_w > 0 ? kernel_w - tail_w : kernel_w - stride_w), (int64_t)0);
constexpr int twice = 2;
pad_list->push_back(static_cast<int64_t>(std::floor(pad_d / twice)));
pad_list->push_back(pad_d - pad_list->at(0));
pad_list->push_back(static_cast<int64_t>(std::floor(pad_h / twice)));
pad_list->push_back(pad_h - pad_list->at(kInputIndex2));
pad_list->push_back(static_cast<int64_t>(std::floor(pad_w / twice)));
pad_list->push_back(pad_w - pad_list->at(kInputIndex4));
} else if (pad_mode == PadMode::PAD) {
pad_list->assign(padding.begin(), padding.end());
}
}
abstract::ShapePtr MaxPool3DInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
constexpr int64_t k5DInputDims = 5;
MS_EXCEPTION_IF_NULL(primitive);
@ -237,8 +269,12 @@ abstract::ShapePtr MaxPool3DInferShape(const PrimitivePtr &primitive, const std:
auto stride_h = strides[kInputIndex3];
auto stride_w = strides[kInputIndex4];
std::vector<int64_t> new_pad_list;
GetPadsByPadding(primitive, in_shape[kInputIndex2], in_shape[kInputIndex3], in_shape[kInputIndex4], kernel_d,
kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_mode, pad_list, &new_pad_list);
primitive->set_attr(kPadList, MakeValue(new_pad_list));
std::vector<int64_t> out_shape = GetOutputShape(primitive, in_shape, kernel_d, kernel_h, kernel_w, stride_d, stride_h,
stride_w, pad_list, ceil_mode, pad_mode);
stride_w, new_pad_list, ceil_mode, pad_mode);
if (std::any_of(out_shape.begin(), out_shape.end(), [](int64_t shp_v) { return shp_v <= 0; })) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name()
<< "', output shape's all elements must be positive, but got shape: " << out_shape << ".";

View File

@ -1,4 +1,4 @@
# Copyright 2019-2022 Huawei Technologies Co., Ltd
# Copyright 2019-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -389,16 +389,17 @@ def avg_pool3d_forward_functional(nptype):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_avg_pool3d_forward_float32_functional():
@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.float64])
def test_avg_pool3d_forward_float32_functional(dtype):
"""
Feature: test avg_pool3d forward.
Description: test float32 inputs.
Expectation: the result match with expected result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
avg_pool3d_forward_functional(np.float32)
avg_pool3d_forward_functional(dtype)
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
avg_pool3d_forward_functional(np.float32)
avg_pool3d_forward_functional(dtype)
@pytest.mark.level0
@ -418,3 +419,22 @@ def test_avgpool_cpu_dynamic_shape():
output = net(Tensor(x, msdtype.float32))
expect_out_shape = (2, 32, 4, 4)
assert output.asnumpy().shape == expect_out_shape
@pytest.mark.parametrize("dtype", [msdtype.float32, msdtype.float16, msdtype.float64])
def test_avgpool3d_cpu_dynamic_shape(dtype):
"""
Feature: test dynamic shape of avgpool.
Description: test the dynamic shape output of avgpool.
Expectation: correct output shape.
"""
x_dyn = Tensor(shape=[None, 32, None, None, None], dtype=dtype)
net = AvgPool(dim=3, kernel_size=2, strides=2, pad_mode="VALID")
net.set_inputs(x_dyn)
x = np.random.randn(2, 32, 9, 9, 9)
print("x: ", x.shape)
output = net(Tensor(x, dtype))
print("output: ", output.asnumpy().shape)
expect_out_shape = (2, 32, 4, 4, 4)
assert output.asnumpy().shape == expect_out_shape

View File

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -121,14 +121,15 @@ def test_maxpool2d_same():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maxpool3d_1():
@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.float64])
def test_maxpool3d_1(dtype):
"""
Feature: test maxpool3d op.
Description: including forward and backward.
Expectation: expect correct forward and backward result.
"""
x_shape = (1, 3, 2, 3, 4)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(np.float32)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(dtype)
maxpool = MaxPool(dim=3, kernel_size=(2, 2, 3), strides=1, pad_mode='VALID')
actual_output = maxpool(x)
expect_output = np.array([[[[[18, 19],
@ -166,14 +167,15 @@ def test_maxpool3d_1():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maxpool3d_2():
@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.float64])
def test_maxpool3d_2(dtype):
"""
Feature: test maxpool3d op.
Description: including forward and backward.
Expectation: expect correct forward and backward result.
"""
x_shape = (1, 3, 2, 3, 4)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(np.float32)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(dtype)
maxpool = MaxPool(dim=3, kernel_size=2, strides=1, pad_mode='VALID')
actual_output = maxpool(x)
expect_output = np.array([[[[[17, 18, 19],
@ -211,14 +213,15 @@ def test_maxpool3d_2():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maxpool3d_3():
@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.float64])
def test_maxpool3d_3(dtype):
"""
Feature: test maxpool3d op.
Description: including forward and backward.
Expectation: expect correct forward and backward result.
"""
x_shape = (1, 3, 2, 3, 4)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(np.float32)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(dtype)
maxpool = MaxPool(dim=3, kernel_size=2, strides=3, pad_mode='VALID')
actual_output = maxpool(x)
expect_output = np.array([[[[[17]]],
@ -253,14 +256,15 @@ def test_maxpool3d_3():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_maxpool3d_4():
@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.float64])
def test_maxpool3d_4(dtype):
"""
Feature: test maxpool3d op.
Description: including forward and backward.
Expectation: expect correct forward and backward result.
"""
x_shape = (1, 3, 2, 3, 4)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(np.float32)
x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(dtype)
maxpool = MaxPool(dim=3, kernel_size=(2, 2, 3), strides=1, pad_mode='SAME')
actual_output = maxpool(x)
expect_output = np.array([[[[[17, 18, 19, 19],