fix adaptive_max_pool3d cpu

This commit is contained in:
fanjibin 2022-11-28 21:09:07 +08:00 committed by fan-jibin
parent d969d9e1d3
commit 1f747cdebc
1 changed files with 26 additions and 4 deletions

View File

@ -27,6 +27,9 @@ namespace {
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr size_t kInputNumDims5 = 5;
constexpr size_t kInputShapeDims4 = 4;
bool AdaptiveMaxPool3DCpuKernelMod::Init(const BaseOperatorPtr &base_operator, bool AdaptiveMaxPool3DCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) { const std::vector<KernelTensorPtr> &outputs) {
@ -56,6 +59,24 @@ int AdaptiveMaxPool3DCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
input_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively(); input_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
input_num_dims_ = input_shape_.size(); input_num_dims_ = input_shape_.size();
outputs_ = outputs; outputs_ = outputs;
if (!(input_num_dims_ == kInputNumDims5 || input_num_dims_ == kInputShapeDims4)) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input data dimensions should be equal to 4 or 5, but got "
<< input_num_dims_ << ".";
return KRET_RESIZE_FAILED;
}
auto output_size_shape = inputs[kIndex1]->GetShapeVector();
if (output_size_shape.size() != 1) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size dimensions should be equal to 1, but got "
<< output_size_shape.size() << ".";
return KRET_RESIZE_FAILED;
}
const size_t kOutputSizeElemNum = 3;
if (output_size_shape[0] != kOutputSizeElemNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', output size elem number should be equal to 3, but got "
<< output_size_shape[0] << ".";
return KRET_RESIZE_FAILED;
}
return KRET_OK; return KRET_OK;
} }
@ -127,7 +148,6 @@ bool AdaptiveMaxPool3DCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
// Set Shape // Set Shape
const size_t kInputNumDims5 = 5;
output_shape_ = {input_shape_[0]}; output_shape_ = {input_shape_[0]};
if (input_num_dims_ == kInputNumDims5) { if (input_num_dims_ == kInputNumDims5) {
output_shape_.push_back(input_shape_[1]); output_shape_.push_back(input_shape_[1]);
@ -242,10 +262,8 @@ void AdaptiveMaxPool3DCpuKernelMod::AdaptiveMaxPool3DCompute(const std::vector<A
auto input_data = reinterpret_cast<T *>(inputs[0]->addr); auto input_data = reinterpret_cast<T *>(inputs[0]->addr);
auto output_data = reinterpret_cast<T *>(outputs[0]->addr); auto output_data = reinterpret_cast<T *>(outputs[0]->addr);
auto indices_data = reinterpret_cast<int32_t *>(outputs[1]->addr); auto indices_data = reinterpret_cast<int32_t *>(outputs[1]->addr);
const size_t kInputShapeDims4 = 4;
if (input_shape_.size() == kInputShapeDims4) { if (input_shape_.size() == kInputShapeDims4) {
input_shape_.insert(input_shape_.begin(), 1); input_shape_.insert(input_shape_.begin(), 1);
output_shape_.insert(output_shape_.begin(), 1);
} }
size_B_ = input_shape_[dimB]; size_B_ = input_shape_[dimB];
size_D_ = input_shape_[dimD]; size_D_ = input_shape_[dimD];
@ -266,7 +284,11 @@ void AdaptiveMaxPool3DCpuKernelMod::AdaptiveMaxPool3DCompute(const std::vector<A
auto shard_adaptive_max_pool_3d = [&](int64_t start, int64_t end) { auto shard_adaptive_max_pool_3d = [&](int64_t start, int64_t end) {
ComputeKernel(input_data, output_data, indices_data, start, end); ComputeKernel(input_data, output_data, indices_data, start, end);
}; };
CPUKernelUtils::ParallelFor(shard_adaptive_max_pool_3d, output_size_T_);
// The AdaptiveMaxPool3D will be reinit in graph mode, so the ParallelLaunchAutoSearch dose not work, use
// ParallelLaunch instead.
const float block_size = 1.0;
ParallelLaunch(shard_adaptive_max_pool_3d, output_size_T_, block_size);
} }
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdaptiveMaxPool3D, AdaptiveMaxPool3DCpuKernelMod); MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdaptiveMaxPool3D, AdaptiveMaxPool3DCpuKernelMod);