forked from mindspore-Ecosystem/mindspore
回退 'Pull Request !45880 : add bicubic mode in interpolate'
This commit is contained in:
parent
afbae413e8
commit
fccc39247f
|
@ -28,7 +28,7 @@ mindspore.ops.interpolate
|
|||
|
||||
old_i = new_length != 0 ? new_i * old_length / new_length : 0 # if set to 'asymmetric'
|
||||
|
||||
- **mode** (str) - 所使用的插值方式。目前支持"linear","bilinear"和"bicubic"插值方式。默认值:"linear"。
|
||||
- **mode** (str) - 所使用的插值方式。目前支持"linear"和"bilinear"插值方式。默认值:"linear"。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型与 `x` 相同。
|
||||
|
|
|
@ -1902,7 +1902,7 @@ def _check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transfo
|
|||
validator.check_value_type("mode", mode, [str], prim_name)
|
||||
if mode == "linear":
|
||||
validator.check_int(input_dims, 3, Rel.EQ, "input dims", prim_name)
|
||||
elif mode in ["bilinear", "bicubic"]:
|
||||
elif mode == "bilinear":
|
||||
validator.check_int(input_dims, 4, Rel.EQ, "input dims", prim_name)
|
||||
else:
|
||||
raise ValueError(f"{msg_prefix} mode must be 'linear' or 'bilinear', but got {mode}")
|
||||
|
@ -1937,16 +1937,12 @@ def _interpolate_output_shape(shape, scales, sizes, mode):
|
|||
if sizes is not None:
|
||||
if mode == "bilinear":
|
||||
return sizes
|
||||
if mode == "bicubic":
|
||||
return Tensor(sizes, dtype=mstype.int32)
|
||||
return Tensor(sizes)
|
||||
ret = ()
|
||||
for i in range(2, len(shape)):
|
||||
ret = ret + (int(scales[i] * shape[i]),)
|
||||
if mode == "bilinear":
|
||||
return ret
|
||||
if mode == "bicubic":
|
||||
return Tensor(ret, dtype=mstype.int32)
|
||||
return Tensor(ret)
|
||||
|
||||
|
||||
|
@ -1989,7 +1985,7 @@ def interpolate(x, roi=None, scales=None, sizes=None, coordinate_transformation_
|
|||
|
||||
old_i = new_length != 0 ? new_i * old_length / new_length : 0 # if set to 'asymmetric'
|
||||
|
||||
mode (str): The method used to interpolate: 'linear' | 'bilinear' | 'bicubic'. Default is 'linear'.
|
||||
mode (str): The method used to interpolate: 'linear' | 'bilinear'. Default is 'linear'.
|
||||
|
||||
Returns:
|
||||
Resized tensor, with the same data type as input `x`.
|
||||
|
@ -2030,26 +2026,26 @@ def interpolate(x, roi=None, scales=None, sizes=None, coordinate_transformation_
|
|||
raise TypeError("For interpolate, the input x must be tensor")
|
||||
input_shape = x.shape
|
||||
input_dims = len(input_shape)
|
||||
_check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode, "interpolate")
|
||||
_check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
|
||||
"interpolate")
|
||||
output_size = _interpolate_output_shape(input_shape, scales, sizes, mode)
|
||||
|
||||
if mode == "linear":
|
||||
resize_linear_inner = _get_cache_prim(IMG.ResizeLinear1D)(
|
||||
coordinate_transformation_mode=coordinate_transformation_mode)
|
||||
return resize_linear_inner(x, output_size)
|
||||
if mode in ["bilinear", "bicubic"]:
|
||||
if mode == "bilinear":
|
||||
align_corners = False
|
||||
half_pixel_centers = False
|
||||
if coordinate_transformation_mode == "align_corners":
|
||||
align_corners = True
|
||||
elif coordinate_transformation_mode == "half_pixel":
|
||||
half_pixel_centers = True
|
||||
if mode == "bilinear":
|
||||
resize_bilinear_inner = _get_cache_prim(IMG.ResizeBilinearV2)(align_corners, half_pixel_centers)
|
||||
return resize_bilinear_inner(x, output_size)
|
||||
if mode == "bicubic":
|
||||
resize_bicubic_inner = _get_cache_prim(IMG.ResizeBicubic)(align_corners, half_pixel_centers)
|
||||
return resize_bicubic_inner(x.transpose((0, 2, 3, 1)), output_size).transpose((0, 3, 1, 2))
|
||||
raise TypeError("Input Error: For interpolate, {} mode is not support now".format(mode))
|
||||
resize_bilinear_inner = _get_cache_prim(IMG.ResizeBilinearV2)(align_corners, half_pixel_centers)
|
||||
return resize_bilinear_inner(x, output_size)
|
||||
|
||||
raise TypeError(
|
||||
"Input Error: For interpolate, {} mode is not support now".format(mode))
|
||||
|
||||
|
||||
def softsign(x):
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, coordinate_transformation_mode, mode):
|
||||
super(Net, self).__init__()
|
||||
self.coordinate_transformation_mode = coordinate_transformation_mode
|
||||
self.mode = mode
|
||||
|
||||
def construct(self, x, scales=None, sizes=None):
|
||||
return ops.interpolate(x, None, scales, sizes, self.coordinate_transformation_mode, self.mode)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_normal(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("align_corners", "bicubic")
|
||||
x = ops.arange(4, dtype=ms.float32).reshape((1, 1, 2, 2))
|
||||
out1 = net(x, sizes=(4, 6))
|
||||
expect_out1 = np.array([[[[0.0000, 0.1762, 0.3884, 0.6116, 0.8238, 1.0000],
|
||||
[0.6289, 0.8051, 1.0174, 1.2405, 1.4527, 1.6289],
|
||||
[1.3711, 1.5473, 1.7595, 1.9826, 2.1949, 2.3711],
|
||||
[2.0000, 2.1762, 2.3884, 2.6116, 2.8238, 3.0000]]]])
|
||||
assert np.allclose(out1.asnumpy(), expect_out1, rtol=1e-4)
|
||||
out2 = net(x, scales=(1.0, 1.0, 3.0, 3.0))
|
||||
expect_out2 = np.array([[[[0.0000, 0.1762, 0.3884, 0.6116, 0.8238, 1.0000],
|
||||
[0.3524, 0.5286, 0.7408, 0.9640, 1.1762, 1.3524],
|
||||
[0.7769, 0.9531, 1.1653, 1.3884, 1.6007, 1.7769],
|
||||
[1.2231, 1.3993, 1.6116, 1.8347, 2.0469, 2.2231],
|
||||
[1.6476, 1.8238, 2.0360, 2.2592, 2.4714, 2.6476],
|
||||
[2.0000, 2.1762, 2.3884, 2.6116, 2.8238, 3.0000]]]])
|
||||
assert np.allclose(out2.asnumpy(), expect_out2, rtol=1e-4)
|
Loading…
Reference in New Issue