diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc index d71c5136319..4f81d7be7e7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.cc @@ -62,15 +62,9 @@ int ResizeBaseCPUKernel::CheckParameters() { MS_LOG(INFO) << "Out shape is not assigned"; const_shape_ = false; } else { - new_height_ = out_tensors_.at(0)->shape()[1]; - if (new_height_ < 1) { - MS_LOG(ERROR) << "Resize new_height should >= 1, but got " << new_height_; - return RET_INVALID_OP_ATTR; - } - new_width_ = out_tensors_.at(0)->shape()[2]; - if (new_width_ < 1) { - MS_LOG(ERROR) << "Resize new_width should >= 1, but got " << new_width_; - return RET_INVALID_OP_ATTR; + auto ret = CalculateNewHeightWidth(); + if (ret != RET_OK) { + return ret; } const_shape_ = true; } @@ -84,6 +78,49 @@ int ResizeBaseCPUKernel::CheckParameters() { return RET_OK; } +int ResizeBaseCPUKernel::CalculateNewHeightWidth() { + if (in_tensors_.size() != 2) { + return RET_ERROR; + } + auto input_tensor = in_tensors_.at(0); + auto shape_scale_tensor = in_tensors_.at(1); + if (shape_scale_tensor->data_type() == kNumberTypeFloat32) { + // float type means scale + float *shape_scale = reinterpret_cast(shape_scale_tensor->data_c()); + if (shape_scale == nullptr) { + return RET_ERROR; + } + if (shape_scale_tensor->format() == schema::Format_NHWC) { + new_height_ = input_tensor->Height() * shape_scale[1]; + new_width_ = input_tensor->Width() * shape_scale[2]; + } else if (shape_scale_tensor->format() == schema::Format_NCHW) { + new_height_ = input_tensor->Height() * shape_scale[2]; + new_width_ = input_tensor->Width() * shape_scale[3]; + } else { + MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format(); + return RET_ERROR; + } + } else if (shape_scale_tensor->data_type() == kNumberTypeInt32) { + // int32 type means real shape + int32_t *shape_data = reinterpret_cast(shape_scale_tensor->data_c()); + if (shape_data == nullptr) { + return RET_ERROR; + } + if (shape_scale_tensor->format() == schema::Format_NHWC) { + new_height_ = shape_data[1]; + new_width_ = shape_data[2]; + } else if (shape_scale_tensor->format() == schema::Format_NCHW) { + new_height_ = shape_data[2]; + new_width_ = shape_data[3]; + } else { + MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format(); + return RET_ERROR; + } + } + + return RET_OK; +} + int ResizeBaseCPUKernel::CheckInputsOuputs() { if (in_tensors_.size() <= lite::kDoubleNum) { for (size_t i = 0; i < in_tensors_.size(); i++) { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h index 0357b18e958..e65906f818b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/resize_base.h @@ -47,6 +47,7 @@ class ResizeBaseCPUKernel : public LiteKernel { private: int CheckParameters(); int CheckInputsOuputs(); + int CalculateNewHeightWidth(); }; } // namespace mindspore::kernel