!11087 fix resize get correct new height

From: @zhaozhenlong
Reviewed-by: @zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong,@HilbertDavid
This commit is contained in:
mindspore-ci-bot 2021-01-09 17:30:36 +08:00 committed by Gitee
commit e001199972
2 changed files with 47 additions and 9 deletions

View File

@ -62,15 +62,9 @@ int ResizeBaseCPUKernel::CheckParameters() {
MS_LOG(INFO) << "Out shape is not assigned";
const_shape_ = false;
} else {
new_height_ = out_tensors_.at(0)->shape()[1];
if (new_height_ < 1) {
MS_LOG(ERROR) << "Resize new_height should >= 1, but got " << new_height_;
return RET_INVALID_OP_ATTR;
}
new_width_ = out_tensors_.at(0)->shape()[2];
if (new_width_ < 1) {
MS_LOG(ERROR) << "Resize new_width should >= 1, but got " << new_width_;
return RET_INVALID_OP_ATTR;
auto ret = CalculateNewHeightWidth();
if (ret != RET_OK) {
return ret;
}
const_shape_ = true;
}
@ -84,6 +78,49 @@ int ResizeBaseCPUKernel::CheckParameters() {
return RET_OK;
}
int ResizeBaseCPUKernel::CalculateNewHeightWidth() {
if (in_tensors_.size() != 2) {
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(0);
auto shape_scale_tensor = in_tensors_.at(1);
if (shape_scale_tensor->data_type() == kNumberTypeFloat32) {
// float type means scale
float *shape_scale = reinterpret_cast<float *>(shape_scale_tensor->data_c());
if (shape_scale == nullptr) {
return RET_ERROR;
}
if (shape_scale_tensor->format() == schema::Format_NHWC) {
new_height_ = input_tensor->Height() * shape_scale[1];
new_width_ = input_tensor->Width() * shape_scale[2];
} else if (shape_scale_tensor->format() == schema::Format_NCHW) {
new_height_ = input_tensor->Height() * shape_scale[2];
new_width_ = input_tensor->Width() * shape_scale[3];
} else {
MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format();
return RET_ERROR;
}
} else if (shape_scale_tensor->data_type() == kNumberTypeInt32) {
// int32 type means real shape
int32_t *shape_data = reinterpret_cast<int32_t *>(shape_scale_tensor->data_c());
if (shape_data == nullptr) {
return RET_ERROR;
}
if (shape_scale_tensor->format() == schema::Format_NHWC) {
new_height_ = shape_data[1];
new_width_ = shape_data[2];
} else if (shape_scale_tensor->format() == schema::Format_NCHW) {
new_height_ = shape_data[2];
new_width_ = shape_data[3];
} else {
MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format();
return RET_ERROR;
}
}
return RET_OK;
}
int ResizeBaseCPUKernel::CheckInputsOuputs() {
if (in_tensors_.size() <= lite::kDoubleNum) {
for (size_t i = 0; i < in_tensors_.size(); i++) {

View File

@ -47,6 +47,7 @@ class ResizeBaseCPUKernel : public LiteKernel {
private:
int CheckParameters();
int CheckInputsOuputs();
int CalculateNewHeightWidth();
};
} // namespace mindspore::kernel