forked from mindspore-Ecosystem/mindspore
!11087 fix resize get correct new height
From: @zhaozhenlong Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhang_xue_tong,@HilbertDavid
This commit is contained in:
commit
e001199972
|
@ -62,15 +62,9 @@ int ResizeBaseCPUKernel::CheckParameters() {
|
|||
MS_LOG(INFO) << "Out shape is not assigned";
|
||||
const_shape_ = false;
|
||||
} else {
|
||||
new_height_ = out_tensors_.at(0)->shape()[1];
|
||||
if (new_height_ < 1) {
|
||||
MS_LOG(ERROR) << "Resize new_height should >= 1, but got " << new_height_;
|
||||
return RET_INVALID_OP_ATTR;
|
||||
}
|
||||
new_width_ = out_tensors_.at(0)->shape()[2];
|
||||
if (new_width_ < 1) {
|
||||
MS_LOG(ERROR) << "Resize new_width should >= 1, but got " << new_width_;
|
||||
return RET_INVALID_OP_ATTR;
|
||||
auto ret = CalculateNewHeightWidth();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
const_shape_ = true;
|
||||
}
|
||||
|
@ -84,6 +78,49 @@ int ResizeBaseCPUKernel::CheckParameters() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int ResizeBaseCPUKernel::CalculateNewHeightWidth() {
|
||||
if (in_tensors_.size() != 2) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
auto shape_scale_tensor = in_tensors_.at(1);
|
||||
if (shape_scale_tensor->data_type() == kNumberTypeFloat32) {
|
||||
// float type means scale
|
||||
float *shape_scale = reinterpret_cast<float *>(shape_scale_tensor->data_c());
|
||||
if (shape_scale == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (shape_scale_tensor->format() == schema::Format_NHWC) {
|
||||
new_height_ = input_tensor->Height() * shape_scale[1];
|
||||
new_width_ = input_tensor->Width() * shape_scale[2];
|
||||
} else if (shape_scale_tensor->format() == schema::Format_NCHW) {
|
||||
new_height_ = input_tensor->Height() * shape_scale[2];
|
||||
new_width_ = input_tensor->Width() * shape_scale[3];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format();
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (shape_scale_tensor->data_type() == kNumberTypeInt32) {
|
||||
// int32 type means real shape
|
||||
int32_t *shape_data = reinterpret_cast<int32_t *>(shape_scale_tensor->data_c());
|
||||
if (shape_data == nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (shape_scale_tensor->format() == schema::Format_NHWC) {
|
||||
new_height_ = shape_data[1];
|
||||
new_width_ = shape_data[2];
|
||||
} else if (shape_scale_tensor->format() == schema::Format_NCHW) {
|
||||
new_height_ = shape_data[2];
|
||||
new_width_ = shape_data[3];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ResizeBaseCPUKernel::CheckInputsOuputs() {
|
||||
if (in_tensors_.size() <= lite::kDoubleNum) {
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
|
|
|
@ -47,6 +47,7 @@ class ResizeBaseCPUKernel : public LiteKernel {
|
|||
private:
|
||||
int CheckParameters();
|
||||
int CheckInputsOuputs();
|
||||
int CalculateNewHeightWidth();
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
Loading…
Reference in New Issue