This commit is contained in:
江南明洲 2022-12-16 15:48:45 +08:00
parent fe7acc0aa3
commit 25bb7400bf
7 changed files with 95 additions and 36 deletions

View File

@ -34,6 +34,7 @@ void PoolingCpuKernelMod::InitPoolingFields(const BaseOperatorPtr &base_operator
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
dtype_ = inputs[0]->GetDtype();
if (base_operator->HasAttr(CEIL_MODE)) {
ValuePtr ceil_mode = base_operator->GetPrim()->GetAttr(CEIL_MODE);
ceil_mode_ = (ceil_mode->isa<BoolImm>() && GetValue<bool>(ceil_mode)) ||
@ -117,7 +118,8 @@ int PoolingCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std:
return KRET_OK;
}
void PoolingCpuKernelMod::EliminateInvalidPadding(float *dst) {
template <typename T>
void PoolingCpuKernelMod::EliminateInvalidPadding(T *dst) {
if (dst_shape_.size() < SHAPE_5D || kernel_.size() + NC_LEN < SHAPE_5D ||
padding_invalid_.size() + NC_LEN < SHAPE_5D) {
MS_LOG(ERROR) << "The dst_shape must be 5D, the kernel and the padding_invalid must be 3D!";
@ -156,7 +158,8 @@ void PoolingCpuKernelMod::EliminateInvalidPadding(float *dst) {
const size_t index =
static_cast<size_t>(i * dst_shape_[D_INDEX] * dst_shape_[H_INDEX] * dst_shape_[W_INDEX] +
d * dst_shape_[H_INDEX] * dst_shape_[W_INDEX] + h * dst_shape_[W_INDEX] + w);
dst[index] = dst[index] * LongToFloat(kernel_size) / LongToFloat(valid_kernel_size);
dst[index] =
dst[index] * static_cast<T>(LongToFloat(kernel_size)) / static_cast<T>(LongToFloat(valid_kernel_size));
}
}
}
@ -167,12 +170,13 @@ void PoolingCpuKernelMod::EliminateInvalidPadding(float *dst) {
&parallel_search_info_);
}
void PoolingCpuKernelMod::ReComputeDivisor(float *dst) {
template <typename T>
void PoolingCpuKernelMod::ReComputeDivisor(T *dst) {
const int64_t kernel_size = std::accumulate(kernel_.begin(), kernel_.end(), int64_t(1), std::multiplies<int64_t>());
const size_t size = std::accumulate(dst_shape_.begin(), dst_shape_.end(), size_t(1), std::multiplies<size_t>());
CTask task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
dst[i] = dst[i] * LongToFloat(kernel_size) / LongToFloat(divisor_override_);
dst[i] = dst[i] * static_cast<T>(LongToFloat(kernel_size)) / static_cast<T>(LongToFloat(divisor_override_));
}
};
ParallelLaunchAutoSearch(task, size, this, &parallel_search_info_);
@ -183,6 +187,8 @@ std::vector<KernelAttr> PoolingCpuKernelMod::GetOpSupport() {
{kMaxPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kMaxPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}},
{kAvgPoolOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}},
{kAvgPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kAvgPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}},
{kAvgPool3DOpName, {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}}};
@ -193,6 +199,27 @@ std::vector<KernelAttr> PoolingCpuKernelMod::GetOpSupport() {
return iter->second;
}
template <typename T>
bool PoolingCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
T *dst = reinterpret_cast<T *>(outputs[0]->addr);
if (divisor_override_ != 0) {
ReComputeDivisor(dst);
return true;
}
bool has_invalid_padding =
std::any_of(padding_invalid_.begin(), padding_invalid_.end(), [](const int64_t &padding) { return padding != 0; });
if (algorithm_ == dnnl::algorithm::pooling_avg_include_padding && has_invalid_padding) {
EliminateInvalidPadding(dst);
}
return false;
}
bool PoolingCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kPoolingInputsNum, kernel_name_);
@ -208,21 +235,15 @@ bool PoolingCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
MS_LOG(ERROR) << "Resize PoolingCpuKernelMod while launching failed: " << resize_ret;
return false;
}
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
float *dst = reinterpret_cast<float *>(outputs[0]->addr);
if (divisor_override_ != 0) {
ReComputeDivisor(dst);
return true;
}
bool has_invalid_padding =
std::any_of(padding_invalid_.begin(), padding_invalid_.end(), [](const int64_t &padding) { return padding != 0; });
if (algorithm_ == dnnl::algorithm::pooling_avg_include_padding && has_invalid_padding) {
EliminateInvalidPadding(dst);
if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dtype of input should be float16, float32 or float64, but got "
<< TypeIdToType(dtype_)->ToString();
}
return true;
}

View File

@ -49,8 +49,10 @@ class PoolingCpuKernelMod : public MKLCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
protected:
void EliminateInvalidPadding(float *output);
void ReComputeDivisor(float *output);
template <typename T>
void EliminateInvalidPadding(T *output);
template <typename T>
void ReComputeDivisor(T *output);
dnnl::algorithm algorithm_{dnnl::algorithm::pooling_max};
bool ceil_mode_{false};
@ -72,6 +74,11 @@ class PoolingCpuKernelMod : public MKLCpuKernelMod {
const std::vector<KernelTensorPtr> &outputs);
std::string kernel_type_{kUnkown};
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore

View File

@ -69,6 +69,7 @@ bool PoolingGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
pad_mode_ = GetValue<std::string>(prim->GetAttr(PAD_MODE));
kernel_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(KERNEL_SIZE));
strides_include_nc_ = GetValue<std::vector<int64_t>>(prim->GetAttr(STRIDES));
dtype_ = inputs[grad_index_]->GetDtype();
return true;
}
@ -143,18 +144,20 @@ int PoolingGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
return KRET_OK;
}
void PoolingGradCpuKernelMod::ReComputeDivisor(float *dst) {
template <typename T>
void PoolingGradCpuKernelMod::ReComputeDivisor(T *dst) {
const int64_t kernel_size = std::accumulate(kernel_.begin(), kernel_.end(), int64_t(1), std::multiplies<int64_t>());
const size_t size = std::accumulate(dst_shape_.begin(), dst_shape_.end(), size_t(1), std::multiplies<size_t>());
CTask task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
dst[i] = dst[i] * LongToFloat(kernel_size) / LongToFloat(divisor_override_);
dst[i] = dst[i] * static_cast<T>(LongToFloat(kernel_size)) / static_cast<T>(LongToFloat(divisor_override_));
}
};
ParallelLaunchAutoSearch(task, size, this, &parallel_search_info_);
}
void PoolingGradCpuKernelMod::EliminateInvalidPadding(float *dst) {
template <typename T>
void PoolingGradCpuKernelMod::EliminateInvalidPadding(T *dst) {
if (dst_shape_.size() < SHAPE_5D || kernel_.size() + NC_LEN < SHAPE_5D ||
padding_invalid_.size() + NC_LEN < SHAPE_5D) {
MS_LOG(ERROR) << "The dst_shape must be 5D, the kernel and the padding_invalid must be 3D!";
@ -193,7 +196,8 @@ void PoolingGradCpuKernelMod::EliminateInvalidPadding(float *dst) {
const size_t index =
static_cast<size_t>(i * dst_shape_[D_INDEX] * dst_shape_[H_INDEX] * dst_shape_[W_INDEX] +
d * dst_shape_[H_INDEX] * dst_shape_[W_INDEX] + h * dst_shape_[W_INDEX] + w);
dst[index] = dst[index] * LongToFloat(kernel_size) / LongToFloat(valid_kernel_size);
dst[index] =
dst[index] * static_cast<T>(LongToFloat(kernel_size)) / static_cast<T>(LongToFloat(valid_kernel_size));
}
}
}
@ -293,7 +297,22 @@ bool PoolingGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inpu
return true;
}
float *dst = reinterpret_cast<float *>(inputs[grad_index_]->addr);
if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchKernel<double>(inputs, outputs);
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << " error get " << TypeIdToType(dtype_)->ToString();
}
return true;
}
template <typename T>
bool PoolingGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *dst = reinterpret_cast<T *>(inputs[grad_index_]->addr);
if (divisor_override_ != 0) {
ReComputeDivisor(dst);
} else {

View File

@ -58,7 +58,10 @@ class PoolingGradCpuKernelMod : public MKLCpuKernelMod {
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)}}},
.AddOutputAttr(kNumberTypeFloat32)},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)}}},
{kAvgPool3DGrad,
{{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
@ -89,8 +92,10 @@ class PoolingGradCpuKernelMod : public MKLCpuKernelMod {
}
private:
void EliminateInvalidPadding(float *output);
void ReComputeDivisor(float *output);
template <typename T>
void EliminateInvalidPadding(T *output);
template <typename T>
void ReComputeDivisor(T *output);
dnnl::algorithm algorithm_{dnnl::algorithm::pooling_max};
bool ceil_mode_{false};
@ -119,6 +124,11 @@ class PoolingGradCpuKernelMod : public MKLCpuKernelMod {
std::shared_ptr<dnnl::pooling_forward> primitive_forward_{nullptr};
ParallelSearchInfo forward_parallel_info_{};
std::string kernel_type_{kUnknown};
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
};
} // namespace kernel
} // namespace mindspore

View File

@ -22,6 +22,8 @@ MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).Add
PoolingFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PoolingFwdGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
PoolingFwdGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

View File

@ -137,12 +137,6 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
auto input_shape = inputs[0]->GetDeviceShapeAdaptively();
auto output_shape = outputs[0]->GetDeviceShapeAdaptively();

View File

@ -532,7 +532,13 @@ std::map<std::string, std::vector<std::pair<KernelAttr, PoolingGradGpuKernelMod:
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&PoolingGradGpuKernelMod::LaunchKernel<half>}}},
&PoolingGradGpuKernelMod::LaunchKernel<half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&PoolingGradGpuKernelMod::LaunchKernel<double>}}},
{kAvgPool3DGrad,
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&PoolingGradGpuKernelMod::LaunchKernel<double>},