!48082 refactor interpolate

Merge pull request !48082 from DavidFFFan/api_interpolate
This commit is contained in:
i-robot 2023-02-08 11:20:44 +00:00 committed by Gitee
commit ff751199b4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 1030 additions and 401 deletions

View File

@ -29,6 +29,7 @@ mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_dee
mindspore/mindspore/ccsrc/pipeline/jit/resource.cc:mindspore::pipeline::GetMethodMap
mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:interpolate
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool2d
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool1d

View File

@ -1,36 +1,61 @@
mindspore.ops.interpolate
=========================
.. py:function:: mindspore.ops.interpolate(x, roi=None, scales=None, sizes=None, coordinate_transformation_mode="align_corners", mode="linear")
.. py:function:: mindspore.ops.interpolate(x, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None)
使用 `mode` 设置的插值方式调整输入 `x` 大小。
.. warning::
- 实验特性,接口可能发生变化。
- `roi` 是保留输入, `crop_and_resize` 坐标变换模式下生效,当前不支持。
- Ascend平台下当前不支持将 `mode` 设置为"linear"。
按照给定的 `size``scale_factor` 根据 `mode` 设置的插值方式,对输入 `x` 调整大小。
参数:
- **x** (Tensor) - 输入Tensor。当 `mode` 是"linear"时, `x` 为三维Tensor。当 `mode` 是"bilinear"时, `x` 为四维Tensor。
- **roi** (tuple[float],可选) - 在 `crop_and_resize` 坐标变换模式下生效,当前不支持。
- **scales** (tuple[float],可选) - 输入shape每个维度resize的系数。 `scales` 中的数全是正数。 `scales` 的长度跟 `x` 的shape长度相同。 `scales``sizes` 同时只能指定一个。
- **sizes** (tuple[int],可选) - 输入shape指定轴的新维度。 `sizes` 中的数全是正数。 `scales``sizes` 同时只能指定一个。当 `mode` 是"linear"时, `sizes` 为1个int元素 :math:`(new\_width,)` 的tuple。当 `mode` 是"bilinear"时, `sizes` 为2个int元素 :math:`(new\_height, new\_width)` 的tuple。
- **coordinate_transformation_mode** (str) - 指定进行坐标变换的方式,默认值是"align_corners",还可选"half_pixel"和"asymmetric"。
假如我们需要将输入Tensor的x轴进行resize。我们记 `new_i` 为resize之后的Tenosr沿x轴的第i个坐标`old_i` 为输入Tensor沿x轴的对应坐标
`new_length` 是resize之后的Tensor沿着x轴的长度`old_length` 是输入Tensor沿x轴的长度。我们可以通过下面的公式计算出来 `old_i`
- **x** (Tensor) - 被调整大小的Tensor。输入向量必须为3维4维或5维形状为 `(batch, channels, [optional depth], [optional height], width)` 数据类型为float。
- **size** (Union[int, tuple[int], list[int]], 可选) - 目标大小。如果 `size` 为tuple或list那么其长度应该和 `x` 维度相同。 `size``scale_factor` 同时只能指定一个。默认值None。
- **scale_factor** (Union[float, tuple[float], list[float]],可选) - 每个维度的缩放系数。 `scales` 中的数全是正数。 `size``scale_factor` 同时只能指定一个。默认值None。
- **mode** (str) - 采样算法。以下采样方式的一种,'nearest', 'linear' (仅三维)'bilinear' (仅四维)'bicubic' (仅四维)'trilinear' (仅五维)'area''nearest-exact'(三维和四维)。默认值:'nearest'。
- **align_corners** (bool) - 如果为True缩放比例系数使用 `(new\_height - 1) / (height - 1)` 计算此种方式调整的数据与原始数据边角对齐。如果为False缩放系数通过 `new\_height / height` 计算。
.. code-block::
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # if set to 'align_corners'
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # 'align_corners' 为 True
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # 'align_corners' 为 False
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # if set to 'half_pixel'
此选项只对'linear'、'bilinear'、'bicubic'和'trilinear'模式有效。默认值False。
- **recompute_scale_factor** (bool, 可选) - 重计算 `scale_factor` 。如果为True会使用参数 `scale_factor` 计算参数 `size`,最终使用 `size` 的值进行缩放。如果为False将使用 `size``scale_factor` 直接进行插值。默认值None。
old_i = new_length != 0 ? new_i * old_length / new_length : 0 # if set to 'asymmetric'
参数支持列表和支持平台:
- **mode** (str) - 所使用的插值方式。目前支持"linear"和"bilinear"插值方式。默认值:"linear"。
+----------------+------+----------------+---------------+------------------+
| mode | dim | align_corners | scale_factor | device |
+================+======+================+===============+==================+
| nearest | 3 | \- | × | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| | 4 | \- | × | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| | 5 | \- | √ | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| linear | 3 | √ | × | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| bilinear | 4 | √ | × | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| trilinear | 5 | √ | √ | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| bicubic | 4 | √ | × | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| area | 3 | \- | √ | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| | 4 | \- | √ | GPU |
+----------------+------+----------------+---------------+------------------+
| | 5 | \- | √ | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| nearest-exact | 3 | \- | × | Ascend,CPU |
+----------------+------+----------------+---------------+------------------+
| | 4 | \- | × | Ascend,CPU |
+----------------+------+----------------+---------------+------------------+
- `-` 表示无此参数。
- `×` 表示当前不支持此参数。
- `√` 表示当前支持此参数。
返回:
Tensor数据类型与 `x` 相同。
调整大小之后的Tensor维度和数据类型与 `x` 相同。
异常:
- **TypeError** - `x` 不是Tensor。
@ -42,4 +67,4 @@ mindspore.ops.interpolate
- **TypeError** - `coordinate_transformation_mode` 不是string。
- **ValueError** - `coordinate_transformation_mode` 不在支持的列表中。
- **TypeError** - `mode` 不是string类型。
- **ValueError** - `mode` 不在支持的列表中。
- **ValueError** - `mode` 不在支持的列表中。

View File

@ -15,14 +15,14 @@
"""Defines nn operators with functional form."""
from __future__ import absolute_import
from math import pi, log
from math import pi, log, floor
import numpy as np
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
from mindspore.ops.operations import nn_ops as NN_OPS
from mindspore.ops.operations import image_ops as IMG
import mindspore.common.dtype as mstype
from mindspore.ops.function.math_func import logsumexp
from mindspore.common.tensor import Tensor
@ -1927,160 +1927,249 @@ def hardswish(x):
@constexpr
def _check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
prim_name):
"""Check input"""
msg_prefix = f"For '{prim_name}', the"
validator.check_value_type("coordinate_transformation_mode", coordinate_transformation_mode, [str], prim_name)
support_coordinate_mode_list = ["align_corners", "half_pixel", "asymmetric"]
if coordinate_transformation_mode not in support_coordinate_mode_list:
raise TypeError(f"{msg_prefix} coordinate_transformation_mode must be in {support_coordinate_mode_list},"
f" but got {coordinate_transformation_mode}")
validator.check_value_type("mode", mode, [str], prim_name)
if mode == "linear":
validator.check_int(input_dims, 3, Rel.EQ, "input dims", prim_name)
elif mode == "bilinear":
validator.check_int(input_dims, 4, Rel.EQ, "input dims", prim_name)
else:
raise ValueError(f"{msg_prefix} mode must be 'linear' or 'bilinear', but got {mode}")
if sizes is None and scales is None:
raise ValueError(f"{msg_prefix} 'sizes' and 'scale' both none.")
if sizes is not None and scales is not None:
raise ValueError(f"{msg_prefix} 'sizes' and 'scale' both not none.")
if sizes is not None:
if not isinstance(sizes, tuple):
raise TypeError(
f"{msg_prefix} 'sizes' must be tuple or None, but got {type(sizes).__name__}.")
for item in sizes:
validator.check_positive_int(item, 'sizes item', prim_name)
validator.check_value_type("sizes item", item, int, prim_name)
validator.check_int(len(sizes), input_dims - 2, Rel.EQ, "sizes", prim_name)
return
if not isinstance(scales, tuple):
raise TypeError(
f"{msg_prefix} 'scales' must be tuple or None, but got {type(scales).__name__}.")
for item in scales:
validator.check_positive_float(item, 'scales item', prim_name)
validator.check_value_type("scales item", item, float, prim_name)
scales_dims = len(scales)
validator.check_int(scales_dims, input_dims, Rel.EQ, "scales dims", prim_name)
validator.check_float(scales[0], 1.0, Rel.EQ, "scales[0]", prim_name)
validator.check_float(scales[1], 1.0, Rel.EQ, "scales[1]", prim_name)
def _scale_factor_convert_size(shape, scale_factor, dim):
return [int(floor(float(shape[i + 2]) * scale_factor[i])) for i in range(dim)]
def _interpolate_output_shape(shape, scales, sizes, mode):
"""calculate output shape"""
if sizes is not None:
if mode == "bilinear":
return sizes
return Tensor(sizes)
ret = ()
for i in range(2, len(shape)):
ret = ret + (int(scales[i] * shape[i]),)
if mode == "bilinear":
return ret
return Tensor(ret)
def interpolate(x, roi=None, scales=None, sizes=None, coordinate_transformation_mode="align_corners", mode="linear"):
def interpolate(x, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None):
r"""
Using the interpolate method specified by `mode` resize the input tensor `x`.
.. warning::
- This is an experimental prototype that is subject to change.
- The `roi` is reserved interface for 'crop_and_resize' coordinate transformation mode,
which is not support now.
- The Ascend platforms is currently not supported when `mode` is "linear".
Samples the input Tensor to the given size or scale_factor by using one of the interpolate algorithms.
Args:
x (Tensor): a tensor which to resize. `x` is a 3-D tensor when `mode` is "linear". `x` is a 4-D tensor when
`mode` is "bilinear".
roi (tuple[float], optional): a tuple of float. Only takes effect when attr coordinate_transformation_mode is
'crop_and_resize'.
scales (tuple[float], optional): a tuple of float. Describe the scale along each dimension.
Its length is the same as that of shape of `x`. The numbers in `scales` must all be positive. Only one of
`scales` and `sizes` can be specified.
sizes (tuple[int], optional): a tuple of int, describes the shape of the output tensor. The numbers in `sizes`
must all be positive. Only one of `scales` and `sizes` can be specified. If `sizes` is specified, then set
`scales` to 'None' in this operator's input list. It is 1 int elements :math:`(new\_width,)` when `mode`
is "linear". It is 2 int elements :math:`(new\_height, new\_width)` when `mode` is "bilinear".
coordinate_transformation_mode (str): Default is 'align_corners'. Describes how to transform the coordinate
in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.
For example, we want to resize the original tensor along axis x. Let's denote `new_i` as the i-th coordinate
of the resized tensor along axis x, `old_i` as the coordinate of the original tensor along axis x,
`new_length` as the length of the resized tensor along axis x, `old_length` as the length of the original
tensor along axis x. We compute the `old_i` via the following formula:
x (Tensor): Tensor to be resized.
Input tensor must be a 3-D, 4-D, or 5-D tensor with shape
`(batch, channels, [optional depth], [optional height], width)`, with data type of float.
size (Union[int, tuple[int], list[int]], optional)): The target size.
If size is a tuple or list, size must have the same dimensions as x.
One and only one of size and scale_factor can be set to None. Default: None.
scale_factor (Union[float, tuple[float], list[float]], optional): The scale factor of new size of the tensor.
If size is a tuple or list, size must have the same dimensions as x.
One and only one of size and scale_factor can be set to None. Default: None.
mode (str): The sampling algorithm.
One of 'nearest', 'linear' (3D only), 'bilinear' (4D only), 'bicubic' (4D only), 'trilinear' (5D only),
'area', 'nearest-exact'(3D and 4D). Default: 'nearest'.
align_corners (bool): If True, rescale input by `(new\_height - 1) / (height - 1)`, which exactly
aligns the corners of data and resized data. If False, rescale by `new\_height / height`.
.. code-block::
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # if set to 'align_corners'
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # 'align_corners' = True
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # 'align_corners' = False
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # if set to 'half_pixel'
This is only valid for 'linear', 'bilinear', 'bicubic', or 'trilinear' modes. Default: False.
recompute_scale_factor (bool, optional): Recalculate `scale_factor`.
If True, the parameter `size` will be calculated using the value of the `scale_factor`,
and finally scaled using the value of `size`.
If False, the value of `size` or `scale_factor` will be used for direct interpolation. Default: None.
old_i = new_length != 0 ? new_i * old_length / new_length : 0 # if set to 'asymmetric'
Args Support List and Supported Platforms:
mode (str): The method used to interpolate: 'linear' | 'bilinear'. Default is 'linear'.
+----------------+------+----------------+---------------+------------------+
| mode | dim | align_corners | scale_factor | device |
+================+======+================+===============+==================+
| nearest | 3 | \- | × | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| | 4 | \- | × | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| | 5 | \- | | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| linear | 3 | | × | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| bilinear | 4 | | × | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| trilinear | 5 | | | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| bicubic | 4 | | × | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| area | 3 | \- | | Ascend,GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| | 4 | \- | | GPU |
+----------------+------+----------------+---------------+------------------+
| | 5 | \- | | GPU,CPU |
+----------------+------+----------------+---------------+------------------+
| nearest-exact | 3 | \- | × | Ascend,CPU |
+----------------+------+----------------+---------------+------------------+
| | 4 | \- | × | Ascend,CPU |
+----------------+------+----------------+---------------+------------------+
Returns:
Resized tensor, with the same data type as input `x`.
- `-` indicates that there is no such parameter.
- `×` indicates that this parameter is not currently supported.
- `` indicates that this parameter is supported.
Outputs:
Tensor, resized, whose dimensions and dtype are the same as `x`.
Raises:
TypeError: If `x` is not a Tensor.
TypeError: If the data type of `x` is not supported.
TypeError: If `scales` is not a float tuple.
ValueError: If not all numbers in `scales` are positive.
TypeError: If `sizes` is not an int tuple.
ValueError: If not all numbers in `sizes` are positive.
TypeError: If `coordinate_transformation_mode` is not a string.
ValueError: If `coordinate_transformation_mode` is not in the support list.
TypeError: If `mode` is not a string.
ValueError: If `mode` is not in the support list.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
TypeError: If `x` is not Tensor.
TypeError: If `mode` is not one of 'nearest', 'linear', 'bilinear', 'bicubic',
'trilinear', 'area', nor 'nearest-exact'.
TypeError: If `size` is not one of tuple, list, None.
TypeError: If `scale_factor` is neither int nor None.
TypeError: If `align_corners` is not a bool.
TypeError: If dtype of `x` is neither float16 nor float32.
ValueError: If `size` and `scale_factor` are both None or not None.
ValueError: If shape of `x` or `size` and `mode` are not match.
ValueError: If `scale_factor` is an int which is less than 0.
ValueError: If `size` is a list or tuple whose length is not match `mode`.
Examples:
>>> # case 1: linear mode
>>> import mindspore
>>> from mindspore import Tensor, ops
>>> x = Tensor([[[1, 2, 3], [4, 5, 6]]], mindspore.float32)
>>> output = ops.interpolate(x, None, None, (6,), "align_corners")
>>> output = ops.interpolate(x, size=(6,), mode='nearest')
>>> print(output)
[[[1. 1.4 1.8 2.2 2.6 3.]
[4. 4.4 4.8 5.2 5.6 6.]]]
>>> # case 2: bilinear mode
>>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
>>> output = ops.interpolate(x, None, None, (5, 5), "asymmetric", "bilinear")
>>> print(output)
[[[[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]
[1. 2. 3. 4. 5.]]]]
[[[1. 1. 2. 2. 3. 3.]
[4. 4. 5. 5. 6. 6.]]]
"""
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError("For interpolate, the input x must be tensor")
input_shape = x.shape
input_dims = len(input_shape)
_check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
"interpolate")
output_size = _interpolate_output_shape(input_shape, scales, sizes, mode)
if mode == "linear":
resize_linear_inner = _get_cache_prim(IMG.ResizeLinear1D)(
coordinate_transformation_mode=coordinate_transformation_mode)
return resize_linear_inner(x, output_size)
if mode == "bilinear":
def run_nearest(x, size, align_corners=None, scale_factor=None):
# 3D 4D use ResizeNearestNeighborV2, 5D use UpsampleNearest3D
if x.ndim == 3:
size = Tensor((size[0], 1), dtype=mstype.int32)
x = x.unsqueeze(-1)
x = _get_cache_prim(P.ResizeNearestNeighborV2)(data_format="NCHW")(x, size)
x = x.squeeze(-1)
elif x.ndim == 4:
size = Tensor(size, dtype=mstype.int32)
x = _get_cache_prim(P.ResizeNearestNeighborV2)(data_format="NCHW")(x, size)
else:
x = _get_cache_prim(P.UpsampleNearest3D)(size, scales=scale_factor)(x)
return x
def run_linear(x, size, align_corners=None, scale_factor=None):
coordinate_transformation_mode = "align_corners" if align_corners else "half_pixel"
resize = _get_cache_prim(P.image_ops.ResizeLinear1D)(
coordinate_transformation_mode
)
return resize(x, Tensor(size, dtype=mstype.int32))
def run_bilinear(x, size, align_corners=None, scale_factor=None):
resize = _get_cache_prim(P.ResizeBilinearV2)(align_corners, not align_corners)
return resize(x, size)
def run_trilinear(x, size, align_corners=None, scale_factor=None):
resize = _get_cache_prim(P.nn_ops.UpsampleTrilinear3D)(
output_size=size, scales=scale_factor, align_corners=align_corners
)
return resize(x)
def run_bicubic(x, size, align_corners=None, scale_factor=None):
resize = _get_cache_prim(P.image_ops.ResizeBicubic)(
align_corners=align_corners, half_pixel_centers=not align_corners
)
x = resize(x, Tensor(size, dtype=mstype.int32))
return x
def run_area(x, size, align_corners=None, scale_factor=None):
if x.ndim == 3:
resize = nn.AdaptiveAvgPool1d(output_size=size[0])
x = resize(x)
elif x.ndim == 4:
x = ops.adaptive_avg_pool2d(x, tuple(size))
else:
x = ops.adaptive_avg_pool3d(x, tuple(size))
return x
def run_nearest_exact(x, size, align_corners=None, scale_factor=None):
if x.ndim == 3:
size = Tensor((size[0], 1), dtype=mstype.int32)
# For impl of nearest 3D use 4D.
x = x.unsqueeze(-1)
resize = _get_cache_prim(P.ResizeNearestNeighborV2)(
data_format="NCHW", align_corners=False, half_pixel_centers=True
)
x = resize(x, size)
x = x.squeeze(-1)
if x.ndim == 4:
size = Tensor(size, dtype=mstype.int32)
resize = _get_cache_prim(P.ResizeNearestNeighborV2)(
data_format="NCHW", align_corners=False, half_pixel_centers=True
)
x = resize(x, size)
return x
# support_dict "mode":{dim:{"scale_factor", "align_corners"}}
supported_dict = {
"nearest": {3: (), 4: (), 5: ("scale_factor",)},
"linear": {3: ("align_corners",)},
"bilinear": {4: ("align_corners",)},
"bicubic": {4: ("align_corners",)},
"trilinear": {5: ("scale_factor", "align_corners")},
"area": {3: ("scale_factor",), 4: ("scale_factor",), 5: ("scale_factor",)},
"nearest-exact": {3: (), 4: ()},
}
resize_func = {
"nearest": run_nearest,
"linear": run_linear,
"bilinear": run_bilinear,
"bicubic": run_bicubic,
"trilinear": run_trilinear,
"area": run_area,
"nearest-exact": run_nearest_exact,
}
if not isinstance(x, Tensor):
raise TypeError(f"For 'interpolate', 'x' must be a tensor, but got {type(x)}")
if size is not None and scale_factor is not None:
raise ValueError(
"For 'interpolate', only one of size or scale_factor should be defined"
)
if size is not None:
if isinstance(size, (list, tuple)):
if len(size) != x.ndim - 2:
raise ValueError(
f"For 'interpolate', 'x' and 'size' must have same number of spatial dimensions, "
f"but got 'x' is {x.ndim - 2}D, 'size' is {len(size)}D"
)
else:
size = [size for _ in range(x.ndim - 2)]
elif scale_factor is not None:
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != x.ndim - 2:
raise ValueError(
f"For 'interpolate', 'x' and 'scale_factor' must have same number of spatial dimensions, "
f"but got 'x' is {x.ndim - 2}D, 'scale_factor' is {len(size)}D"
)
else:
scale_factor = [scale_factor for _ in range(x.ndim - 2)]
else:
raise ValueError(
"For 'interpolate', either 'size' or 'scale_factor' should be defined"
)
if mode not in supported_dict:
raise ValueError(
f"For 'interpolate', 'mode' must be in '{list(supported_dict)}', but got {mode}"
)
if x.ndim not in supported_dict.get(mode):
raise ValueError(
f"For 'interpolate', {mode} only support '{list(supported_dict.get(mode, {}))}'D, but got {x.ndim}D"
)
# "area" mode always requires an explicit size rather than scale factor.
if mode == "area" and size is None:
recompute_scale_factor = True
if recompute_scale_factor is not None and recompute_scale_factor:
# todo: check bool type
if size is not None:
raise ValueError(
"For 'interpolate', 'recompute_scale_factor' is not meaningful with an explicit size"
)
size = _scale_factor_convert_size(x.shape, scale_factor, x.ndim - 2)
scale_factor = None
else:
if scale_factor is not None and "scale_factor" not in supported_dict.get(mode, {}).get(x.ndim):
raise ValueError(
f"For 'interpolate', 'scale_factor' option cannot currently be set with the "
f"mode = {mode} and dim = {x.ndim}D."
)
if align_corners is not None:
if "align_corners" not in supported_dict.get(mode, {}).get(x.ndim):
raise ValueError(
f"For 'interpolate', 'align_corners' option cannot currently be set with the "
f"mode = {mode}, and dim = {x.ndim}D"
)
else:
align_corners = False
half_pixel_centers = False
if coordinate_transformation_mode == "align_corners":
align_corners = True
elif coordinate_transformation_mode == "half_pixel":
half_pixel_centers = True
resize_bilinear_inner = _get_cache_prim(IMG.ResizeBilinearV2)(align_corners, half_pixel_centers)
return resize_bilinear_inner(x, output_size)
raise TypeError(
"Input Error: For interpolate, {} mode is not support now".format(mode))
return resize_func.get(mode)(x, size, align_corners, scale_factor)
def softsign(x):

View File

@ -23,7 +23,7 @@ from mindspore import ParameterTuple
class NetResizeBilinear(nn.Cell):
def construct(self, inputs, size):
return ops.interpolate(inputs, None, None, size, "asymmetric", "bilinear")
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=False)(inputs, size)
def case():

View File

@ -44,8 +44,8 @@ class ResizeBilinearGradAlignCornerF(nn.Cell):
class NetResizeBilinearFunc(nn.Cell):
def construct(self, inputs, size, align_corner=False, half_pixel_centers=False):
if align_corner and not half_pixel_centers:
return ops.interpolate(inputs, None, None, size, "align_corners", "bilinear")
return ops.interpolate(inputs, None, None, size, "half_pixel", "bilinear")
return ops.ResizeBilinearV2(align_corners=True, half_pixel_centers=False)(inputs, size)
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=True)(inputs, size)
def test_resize_bilinear_grad_align_corner():

View File

@ -13,12 +13,10 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import context, Tensor
from mindspore.ops import operations as P
from mindspore import nn
import mindspore.ops as ops
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -547,167 +545,3 @@ def test_resize_nn_grayscale_align_corners_float(datatype=np.float32):
diff = output.asnumpy() - expected_output.asnumpy()
assert np.all(abs(diff) < error)
assert np.all(abs(diff_align) < error)
class NetResizeBilinearFunc(nn.Cell):
def construct(self, x, scale, size):
return ops.interpolate(x, None, scale, size, "asymmetric", "bilinear")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resize_bilinar_func_cpu():
"""
Feature: Test bilinear on cpu.
Description: bilinear executes by calling ResizeBilinearV2
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float32))
resize_nn = NetResizeBilinearFunc()
output = resize_nn(input_tensor, None, (3, 7))
expect = np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144,
0.3857143, 0.4],
[0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381,
0.65238094, 0.6666667],
[0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714,
0.78571427, 0.8]]]]).astype(np.float32)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resize_bilinar_func_scale_cpu():
"""
Feature: Test bilinear with scale on cpu.
Description: bilinear executes by calling ResizeBilinearV2
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float32))
resize_nn = NetResizeBilinearFunc()
output = resize_nn(input_tensor, (1.0, 1.0, 4.0, 2.0), None)
expect = resize_nn(input_tensor, None, (8, 8))
assert np.allclose(output.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resize_bilinar_func_cpu_fp16():
"""
Feature: Test bilinear on cpu with fp16.
Description: bilinear executes by calling ResizeBilinearV2
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float16))
resize_nn = NetResizeBilinearFunc()
output = resize_nn(input_tensor, None, (3, 7))
expect = np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144,
0.3857143, 0.4],
[0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381,
0.65238094, 0.6666667],
[0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714,
0.78571427, 0.8]]]]).astype(np.float16)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resize_bilinar_func_scale_cpu_fp16():
"""
Feature: Test bilinear with scale on cpu with fp16.
Description: bilinear executes by calling ResizeBilinearV2
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float16))
resize_nn = NetResizeBilinearFunc()
output = resize_nn(input_tensor, (1.0, 1.0, 4.0, 2.0), None)
expect = resize_nn(input_tensor, None, (8, 8))
assert np.allclose(output.asnumpy(), expect.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resize_bilinar_func_cpu_fp64():
"""
Feature: Test bilinear on cpu with fp64.
Description: bilinear executes by calling ResizeBilinearV2
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float64))
resize_nn = NetResizeBilinearFunc()
output = resize_nn(input_tensor, None, (3, 7))
expect = np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144,
0.3857143, 0.4],
[0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381,
0.65238094, 0.6666667],
[0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714,
0.78571427, 0.8]]]]).astype(np.float64)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resize_bilinar_func_scale_cpu_fp64():
"""
Feature: Test bilinear with scale on cpu with fp64.
Description: bilinear executes by calling ResizeBilinearV2
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float64))
resize_nn = NetResizeBilinearFunc()
output = resize_nn(input_tensor, (1.0, 1.0, 4.0, 2.0), None)
expect = resize_nn(input_tensor, None, (8, 8))
assert np.allclose(output.asnumpy(), expect.asnumpy())
@pytest.mark.level1
@pytest.mark.platform_x86_cpu_training
def test_resize_nn_func_half_pixel_centers(datatype=np.float32):
"""
Feature: Test resize_bilinear on CPU.
Description: The half_pixel_centers is True.
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
resize_nn_func = NetResizeBilinearFunc()
output = resize_nn_func(input_tensor, (3, 7), align_corner=False, half_pixel_centers=True)
expected_output = np.array([[[[0.1, 0.13571429, 0.19285715, 0.25, 0.30714288,
0.36428574, 0.4],
[0.3, 0.3357143, 0.39285716, 0.45, 0.5071429,
0.56428576, 0.6],
[0.5, 0.5357143, 0.5928572, 0.65, 0.7071429,
0.76428574, 0.8]]]], dtype=datatype)
assert np.allclose(output.asnumpy(), expected_output)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu_training
def test_resize_nn_func_half_pixel_centers_fp16(datatype=np.float16):
"""
Feature: Test resize_bilinear fp16 on CPU.
Description: The half_pixel_centers is True.
Expectation: Assert that results are consistent with expect.
"""
input_tensor = Tensor(
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
resize_nn_func = NetResizeBilinearFunc()
output = resize_nn_func(input_tensor, (3, 7), align_corner=False, half_pixel_centers=True)
expected_output = np.array([[[[0.1, 0.13571429, 0.19285715, 0.25, 0.30714288,
0.36428574, 0.4],
[0.3, 0.3357143, 0.39285716, 0.45, 0.5071429,
0.56428576, 0.6],
[0.5, 0.5357143, 0.5928572, 0.65, 0.7071429,
0.76428574, 0.8]]]], dtype=datatype)
assert np.allclose(output.asnumpy(), expected_output)

View File

@ -19,7 +19,6 @@ import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.operations.image_ops import ResizeLinear1D
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@ -80,39 +79,3 @@ def test_resize_linear_1d_size_not_change(dtype):
expect = np.array([[[1., 2., 3.],
[4., 5., 6.]]]).astype(dtype)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float16])
def test_resize_linear_1d_half_pixel_functional_interface(dtype):
"""
Feature: ResizeLinear1D cpu kernel half_pixel mode functional interface.
Description: test the rightness of ResizeLinear1D cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
output = F.interpolate(x, None, None, (6,), 'half_pixel')
expect = np.array([[[1., 1.25, 1.75, 2.25, 2.75, 3.],
[4., 4.25, 4.75, 5.25, 5.75, 6.]]]).astype(dtype)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float16, np.float64])
def test_resize_linear_1d_align_corners_functional_interface(dtype):
"""
Feature: ResizeLinear1D cpu kernel align_corners mode functional interface.
Description: test the rightness of ResizeLinear1D cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
output = F.interpolate(x, None, (1., 1., 2.), None, 'align_corners')
expect = np.array([[[1., 1.4, 1.8, 2.2, 2.6, 3.],
[4., 4.4, 4.8, 5.2, 5.6, 6.]]]).astype(dtype)
assert np.allclose(output.asnumpy(), expect)

View File

@ -41,7 +41,7 @@ class NetResizeBilinear(nn.Cell):
def construct(self, inputs, size, indices_input, axis):
unique_input_index, _ = ops.unique(indices_input)
inputs_dyn = ops.gather(inputs, unique_input_index, axis)
return ops.interpolate(inputs_dyn, None, None, size, "asymmetric", "bilinear")
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=False)(inputs_dyn, size)
def case_input_dyn(mode, device_target, dtype="float32"):

View File

@ -576,8 +576,8 @@ def test_resize_nn_grayscale_align_corners_float(datatype=np.float32):
class NetResizeBilinearFunc(nn.Cell):
def construct(self, inputs, size, align_corner=False, half_pixel_centers=False):
if align_corner and not half_pixel_centers:
return ops.interpolate(inputs, None, None, size, "align_corners", "bilinear")
return ops.interpolate(inputs, None, None, size, "half_pixel", "bilinear")
return ops.ResizeBilinearV2(align_corners=True, half_pixel_centers=False)(inputs, size)
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=True)(inputs, size)
@pytest.mark.level1

View File

@ -19,7 +19,6 @@ import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops.operations.image_ops import ResizeLinear1D
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
@ -80,39 +79,3 @@ def test_resize_linear_1d_size_not_change(dtype):
expect = np.array([[[1., 2., 3.],
[4., 5., 6.]]]).astype(dtype)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float32, np.float16])
def test_resize_linear_1d_half_pixel_functional_interface(dtype):
"""
Feature: ResizeLinear1D gpu kernel half_pixel mode functional interface.
Description: test the rightness of ResizeLinear1D cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
output = F.interpolate(x, None, None, (6,), 'half_pixel')
expect = np.array([[[1., 1.25, 1.75, 2.25, 2.75, 3.],
[4., 4.25, 4.75, 5.25, 5.75, 6.]]]).astype(dtype)
assert np.allclose(output.asnumpy(), expect)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [np.float16, np.float64])
def test_resize_linear_1d_align_corners_functional_interface(dtype):
"""
Feature: ResizeLinear1D gpu kernel align_corners mode functional interface.
Description: test the rightness of ResizeLinear1D cpu kernel.
Expectation: the output is same as expect.
"""
x = Tensor(np.array([[[1, 2, 3],
[4, 5, 6]]], dtype=dtype))
output = F.interpolate(x, None, (1., 1., 2.), None, 'align_corners')
expect = np.array([[[1., 1.4, 1.8, 2.2, 2.6, 3.],
[4., 4.4, 4.8, 5.2, 5.6, 6.]]]).astype(dtype)
assert np.allclose(output.asnumpy(), expect)

View File

@ -0,0 +1,754 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import ops
class Net(nn.Cell):
def __init__(self, mode):
super(Net, self).__init__()
self.mode = mode
def construct(self, x, size=None, scale_factor=None, align_corners=None, recompute_scale_factor=None):
return ops.interpolate(x, size, scale_factor, self.mode, align_corners, recompute_scale_factor)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_area_3d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate area mode
1. 3D size
2. 3D scale_factor
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("area")
# 1. 3D(1, 3, 5)
input_3d = np.array([[[0.816117, 0.004037, 0.746452, 0.137449, 0.337593],
[0.970709, 0.558792, 0.053919, 0.734102, 0.432973],
[0.830186, 0.77753, 0.384094, 0.905231, 0.76362]]], dtype=np.float32)
except_3d_1 = np.array([[[0.410077, 0.375244, 0.441951, 0.237521],
[0.76475, 0.306356, 0.394011, 0.583538],
[0.803858, 0.580812, 0.644662, 0.834426]]], dtype=np.float32)
size = 4
output_3d_1 = net(Tensor(input_3d), size=size)
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-3, rtol=1e-3)
# 2. 3D(1, 3, 5) scale_factor=0.2
except_3d_2 = np.array([[[0.40833],
[0.550099],
[0.732132]]], dtype=np.float32)
scale_factor = 0.2
output_3d_2 = net(Tensor(input_3d), scale_factor=scale_factor)
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-3, rtol=1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_area_4d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate area mode
1. 4D size
2. 4D scale_factor
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("area")
# 1. 4D(1, 3, 3, 5) size=(5, 8)
input_4d = np.array([[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]], dtype=np.float32)
except_4d_1 = np.array([[[[0.4992, 0.62114, 0.743079, 0.656731, 0.756619, 0.942855,
0.888125, 0.833395],
[0.572477, 0.59329, 0.614102, 0.689021, 0.85097, 0.937999,
0.816418, 0.694836],
[0.47219, 0.506109, 0.540027, 0.547129, 0.622315, 0.690399,
0.551147, 0.411894],
[0.384028, 0.445635, 0.507243, 0.457736, 0.419754, 0.43128,
0.459625, 0.48797],
[0.46943, 0.444494, 0.419558, 0.542525, 0.540199, 0.414906,
0.561666, 0.708427]]]], dtype=np.float32)
size = (5, 8)
output_4d_1 = net(Tensor(input_4d), size=size)
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
# 2. 4D(1, 3, 3, 5) scale_factor=(1.5, 0.4)
except_4d_2 = np.array([[[[0.604221, 0.782211],
[0.650173, 0.798925],
[0.696126, 0.815639],
[0.348173, 0.28871],
[0.433166, 0.442492],
[0.51816, 0.596275]]]], dtype=np.float32)
scale_factor = (1.5, 0.4)
output_4d_2 = net(Tensor(input_4d), scale_factor=scale_factor)
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_area_5d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate area mode
1. 5D size
2. 5D scale_factor
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("area")
# 1. 5D(1, 1, 1, 4, 5) size=(2, 4, 3)
input_5d = np.array([[[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]]], dtype=np.float32)
except_5d_1 = np.array([[[[[0.62114, 0.752106, 0.888125],
[0.56544, 0.791922, 0.74471],
[0.446777, 0.397849, 0.357583],
[0.444494, 0.499985, 0.561666]],
[[0.62114, 0.752106, 0.888125],
[0.56544, 0.791922, 0.74471],
[0.446777, 0.397849, 0.357583],
[0.444494, 0.499985, 0.561666]]]]], dtype=np.float32)
size = (2, 4, 3)
output_5d_1 = net(Tensor(input_5d), size=size)
assert np.allclose(output_5d_1.asnumpy(),
except_5d_1, atol=1e-5, rtol=1e-5)
# 2. 5D(1, 1, 1, 4, 5) scale_factor=(3, 0.4, 0.7)
except_5d_2 = np.array([[[[[0.519463, 0.610466, 0.638021]],
[[0.519463, 0.610466, 0.638021]],
[[0.519463, 0.610466, 0.638021]]]]], dtype=np.float32)
scale_factor = (3., 0.4, 0.7)
output_5d_2 = net(Tensor(input_5d), scale_factor=scale_factor)
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_bicubic(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate bicubic mode
1. 4D size tf
2. 4D size align_corners=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("bicubic")
input_4d = np.array([[[[0.003088, 0.313131, 0.481231, 0.326219, 0.190293],
[0.711616, 0.583990, 0.718121, 0.258823, 0.121847],
[0.781316, 0.591508, 0.858185, 0.091935, 0.444639]],
[[0.389884, 0.894497, 0.471427, 0.188708, 0.557449],
[0.998047, 0.380719, 0.570574, 0.722258, 0.997173],
[0.195751, 0.050744, 0.002008, 0.482685, 0.708559]]]],
dtype=np.float32)
size = (5, 3)
# 1. 4D size=(5, 3)
except_4d_1 = np.array([[[[0.038226, 0.46334, 0.228343],
[0.287464, 0.558147, 0.186838],
[0.671836, 0.718121, 0.14377],
[0.729394, 0.819616, 0.255285],
[0.723466, 0.868763, 0.334474]],
[[0.522484, 0.463939, 0.409831],
[0.670862, 0.531739, 0.626711],
[0.821431, 0.570574, 0.926654],
[0.40312, 0.206133, 0.777062],
[0.107333, -0.040933, 0.642958]]]], dtype=np.float32)
output_4d_1 = net(Tensor(input_4d), size=size)
assert np.allclose(output_4d_1.asnumpy(),
except_4d_1, atol=1e-5, rtol=1e-5)
# 2. 4D size=(5, 3), align_corners=True
except_4d_2 = np.array([[[[0.003088, 0.481231, 0.190293],
[0.350818, 0.586545, 0.125808],
[0.711616, 0.718121, 0.121847],
[0.81289, 0.810361, 0.276826],
[0.781316, 0.858185, 0.444639]],
[[0.389884, 0.471427, 0.557449],
[0.769181, 0.574304, 0.804369],
[0.998047, 0.570574, 0.997173],
[0.653914, 0.295586, 0.89409],
[0.195751, 0.002008, 0.708559]]]], dtype=np.float32)
output_4d_2 = net(Tensor(input_4d), size=size, align_corners=True)
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_linear(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate linear mode
1. 3D size
2. 3D size align_corners=True
3. 3D scale_factor recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("linear")
# 1. 3D(1, 3, 5) size=8
input_3d = np.array([[[0.784811, 0.008794, 0.52707, 0.821989, 0.88349],
[0.00115, 0.556748, 0.8774, 0.761826, 0.23384],
[0.842441, 0.853507, 0.294813, 0.375813, 0.399932]]], dtype=np.float32)
except_3d_1 = np.array([[[0.784811, 0.445304, 0.041186, 0.365109, 0.619232, 0.803557, 0.856583, 0.88349],
[0.00115, 0.244224, 0.576789, 0.777196, 0.841283, 0.769049, 0.464834, 0.23384],
[0.842441, 0.847282, 0.818589, 0.469405, 0.320126, 0.370751, 0.38938, 0.399932]]],
dtype=np.float32)
size = 8
output_3d_1 = net(Tensor(input_3d), size=size)
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-5, rtol=1e-5)
# 2. 3D(1, 3, 5) size=8 align_corners=True
except_3d_2 = np.array([[[0.784811, 0.341373, 0.082833, 0.378991, 0.611333, 0.779858,
0.848347, 0.88349],
[0.00115, 0.318635, 0.602555, 0.785785, 0.844379, 0.778337,
0.535546, 0.23384],
[0.842441, 0.848764, 0.773694, 0.45444, 0.317956, 0.364242,
0.38615, 0.399932]]], dtype=np.float32)
output_3d_2 = net(Tensor(input_3d), size=size, align_corners=True)
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-5, rtol=1e-5)
# 3. scale_factor=0.7, recompute_scale_factor=True
except_3d_3 = np.array([[[0.526139, 0.52707, 0.86299],
[0.186349, 0.8774, 0.409835],
[0.84613, 0.294813, 0.391892]]], dtype=np.float32)
scale_factor = 0.7
output_3d_3 = net(Tensor(input_3d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_3d_3.asnumpy(), except_3d_3, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_bilinear(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate bilinear mode
1. 4D size
2. 4D size align_corners=True
3. 4D scale_factor recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("bilinear")
input_4d = np.array([[[[0.003088, 0.313131, 0.481231, 0.326219, 0.190293],
[0.711616, 0.58399, 0.718121, 0.258823, 0.121847],
[0.781316, 0.591508, 0.858185, 0.091935, 0.444639]],
[[0.389884, 0.894497, 0.471427, 0.188708, 0.557449],
[0.998047, 0.380719, 0.570574, 0.722258, 0.997173],
[0.195751, 0.050744, 0.002008, 0.482685, 0.708559]]]],
dtype=np.float32)
size = (5, 3)
# 1. 4D size=(5, 3)
except_4d_1 = np.array([[[[0.106436, 0.481231, 0.235602],
[0.331491, 0.575987, 0.208363],
[0.669074, 0.718121, 0.167506],
[0.698458, 0.802159, 0.263245],
[0.718047, 0.858185, 0.327071]],
[[0.558088, 0.471427, 0.434535],
[0.651761, 0.511086, 0.622935],
[0.792271, 0.570574, 0.905535],
[0.405358, 0.229434, 0.742174],
[0.147415, 0.002008, 0.633268]]]], dtype=np.float32)
output_4d_1 = net(Tensor(input_4d), size=size)
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
# 2. 4D size=(5, 3), align_corners=True
except_4d_2 = np.array([[[[0.003088, 0.481231, 0.190293],
[0.357352, 0.599676, 0.15607],
[0.711616, 0.718121, 0.121847],
[0.746466, 0.788153, 0.283243],
[0.781316, 0.858185, 0.444639]],
[[0.389884, 0.471427, 0.557449],
[0.693965, 0.521001, 0.777311],
[0.998047, 0.570574, 0.997173],
[0.596899, 0.286291, 0.852866],
[0.195751, 0.002008, 0.708559]]]], dtype=np.float32)
output_4d_2 = net(Tensor(input_4d), size=size, align_corners=True)
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
# 3. scale_factor=(0.5, 1.5), recompute_scale_factor=True
scale_factor = (0.5, 1.5)
except_4d_3 = np.array([[[[0.711616, 0.638687, 0.622313, 0.718121, 0.390051, 0.200119,
0.121847]],
[[0.998047, 0.645288, 0.434963, 0.570574, 0.67892, 0.840079,
0.997173]]]], dtype=np.float32)
output_4d_3 = net(Tensor(input_4d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_4d_3.asnumpy(), except_4d_3, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_trilinear(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate trilinear mode
1. 5D size
2. 5D size align_corners=True
3. 5D scale_factor
4. 5D scale_factor align_corners=True
5. 5D scale_factor align_corners=True recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("trilinear")
# 1. 5D size (1, 1, 2, 3, 4)
input_5d = np.array([[[[[0.415192, 0.34623, 0.339439, 0.509735],
[0.004669, 0.383352, 0.951564, 0.405198],
[0.038453, 0.526071, 0.010246, 0.164402]],
[[0.530777, 0.452261, 0.012696, 0.99017],
[0.899132, 0.921811, 0.812974, 0.254313],
[0.713248, 0.420471, 0.757313, 0.815576]]]]], dtype=np.float32)
size = (3, 4, 5)
except_5d_1 = np.array([[[[[0.415192, 0.366919, 0.342835, 0.390528, 0.509735],
[0.158615, 0.306186, 0.545724, 0.638732, 0.444399],
[0.017338, 0.311012, 0.517721, 0.513469, 0.3149],
[0.038453, 0.379786, 0.268158, 0.056493, 0.164402]],
[[0.472984, 0.421367, 0.287657, 0.348233, 0.749952],
[0.459807, 0.528248, 0.587512, 0.578409, 0.487329],
[0.423382, 0.536753, 0.640338, 0.603688, 0.389843],
[0.37585, 0.444045, 0.428525, 0.415642, 0.489989]],
[[0.530777, 0.475816, 0.232478, 0.305938, 0.99017],
[0.760999, 0.75031, 0.6293, 0.518087, 0.530259],
[0.829425, 0.762494, 0.762955, 0.693907, 0.464787],
[0.713248, 0.508304, 0.588892, 0.774792, 0.815576]]]]], dtype=np.float32)
output_5d_1 = net(Tensor(input_5d), size=size)
assert np.allclose(output_5d_1.asnumpy(), except_5d_1, atol=1e-5, rtol=1e-5)
# 2. 5D size align_corners=True
except_5d_2 = np.array([[[[[0.415192, 0.36347, 0.342835, 0.382013, 0.509735],
[0.14151, 0.313611, 0.55925, 0.670653, 0.440044],
[0.01593, 0.327176, 0.534358, 0.559577, 0.324933],
[0.038453, 0.404166, 0.268158, 0.048785, 0.164402]],
[[0.472984, 0.41768, 0.287657, 0.319539, 0.749952],
[0.458928, 0.540834, 0.607502, 0.602607, 0.469821],
[0.426551, 0.551246, 0.654459, 0.632871, 0.383167],
[0.37585, 0.448916, 0.428525, 0.410332, 0.489989]],
[[0.530777, 0.47189, 0.232478, 0.257064, 0.99017],
[0.776347, 0.768057, 0.655755, 0.534561, 0.499599],
[0.837171, 0.775316, 0.774559, 0.706165, 0.441401],
[0.713248, 0.493665, 0.588892, 0.771879, 0.815576]]]]],
dtype=np.float32)
output_5d_2 = net(Tensor(input_5d), size=size, align_corners=True)
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
# 3. 5D size scale_factor
scale_factor = (1.5, 1.2, 0.5)
except_5d_3 = np.array([[[[[0.380711, 0.424587],
[0.194011, 0.678381],
[0.282262, 0.087324]],
[[0.436115, 0.46301],
[0.552241, 0.606012],
[0.424561, 0.436884]],
[[0.491519, 0.501433],
[0.910471, 0.533643],
[0.566859, 0.786444]]]]], dtype=np.float32)
output_5d_3 = net(Tensor(input_5d), scale_factor=scale_factor)
assert np.allclose(output_5d_3.asnumpy(), except_5d_3, atol=1e-5, rtol=1e-5)
# 4. 5D size scale_factor align_corners=True
except_5d_4 = np.array([[[[[0.415192, 0.509735],
[0.004669, 0.405198],
[0.038453, 0.164402]],
[[0.472984, 0.749952],
[0.451901, 0.329755],
[0.37585, 0.489989]],
[[0.530777, 0.99017],
[0.899132, 0.254313],
[0.713248, 0.815576]]]]], dtype=np.float32)
output_5d_4 = net(Tensor(input_5d), scale_factor=scale_factor, align_corners=True)
assert np.allclose(output_5d_4.asnumpy(), except_5d_4, atol=1e-5, rtol=1e-5)
# 5. 5D size scale_factor align_corners=True recompute_scale_factor=True
except_5d_5 = np.array([[[[[0.415192, 0.509735],
[0.004669, 0.405198],
[0.038453, 0.164402]],
[[0.472984, 0.749952],
[0.451901, 0.329755],
[0.37585, 0.489989]],
[[0.530777, 0.99017],
[0.899132, 0.254313],
[0.713248, 0.815576]]]]], dtype=np.float32)
output_5d_5 = net(Tensor(input_5d), scale_factor=scale_factor, align_corners=True, recompute_scale_factor=True)
assert np.allclose(output_5d_5.asnumpy(), except_5d_5, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_trilinear_gpu(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate trilinear mode
1. 5D size
2. 5D size align_corners=True
3. 5D scale_factor failed I6D8EN
4. 5D scale_factor align_corners=True failed I6D8EN
5. 5D scale_factor align_corners=True recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("trilinear")
# 1. 5D size (1, 1, 2, 3, 4)
input_5d = np.array([[[[[0.415192, 0.34623, 0.339439, 0.509735],
[0.004669, 0.383352, 0.951564, 0.405198],
[0.038453, 0.526071, 0.010246, 0.164402]],
[[0.530777, 0.452261, 0.012696, 0.99017],
[0.899132, 0.921811, 0.812974, 0.254313],
[0.713248, 0.420471, 0.757313, 0.815576]]]]], dtype=np.float32)
size = (3, 4, 5)
except_5d_1 = np.array([[[[[0.415192, 0.366919, 0.342835, 0.390528, 0.509735],
[0.158615, 0.306186, 0.545724, 0.638732, 0.444399],
[0.017338, 0.311012, 0.517721, 0.513469, 0.3149],
[0.038453, 0.379786, 0.268158, 0.056493, 0.164402]],
[[0.472984, 0.421367, 0.287657, 0.348233, 0.749952],
[0.459807, 0.528248, 0.587512, 0.578409, 0.487329],
[0.423382, 0.536753, 0.640338, 0.603688, 0.389843],
[0.37585, 0.444045, 0.428525, 0.415642, 0.489989]],
[[0.530777, 0.475816, 0.232478, 0.305938, 0.99017],
[0.760999, 0.75031, 0.6293, 0.518087, 0.530259],
[0.829425, 0.762494, 0.762955, 0.693907, 0.464787],
[0.713248, 0.508304, 0.588892, 0.774792, 0.815576]]]]], dtype=np.float32)
output_5d_1 = net(Tensor(input_5d), size=size)
assert np.allclose(output_5d_1.asnumpy(), except_5d_1, atol=1e-5, rtol=1e-5)
# 2. 5D size align_corners=True
except_5d_2 = np.array([[[[[0.415192, 0.36347, 0.342835, 0.382013, 0.509735],
[0.14151, 0.313611, 0.55925, 0.670653, 0.440044],
[0.01593, 0.327176, 0.534358, 0.559577, 0.324933],
[0.038453, 0.404166, 0.268158, 0.048785, 0.164402]],
[[0.472984, 0.41768, 0.287657, 0.319539, 0.749952],
[0.458928, 0.540834, 0.607502, 0.602607, 0.469821],
[0.426551, 0.551246, 0.654459, 0.632871, 0.383167],
[0.37585, 0.448916, 0.428525, 0.410332, 0.489989]],
[[0.530777, 0.47189, 0.232478, 0.257064, 0.99017],
[0.776347, 0.768057, 0.655755, 0.534561, 0.499599],
[0.837171, 0.775316, 0.774559, 0.706165, 0.441401],
[0.713248, 0.493665, 0.588892, 0.771879, 0.815576]]]]],
dtype=np.float32)
output_5d_2 = net(Tensor(input_5d), size=size, align_corners=True)
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
scale_factor = (1.5, 1.2, 0.5)
# 5. 5D size scale_factor align_corners=True recompute_scale_factor=True
except_5d_5 = np.array([[[[[0.415192, 0.509735],
[0.004669, 0.405198],
[0.038453, 0.164402]],
[[0.472984, 0.749952],
[0.451901, 0.329755],
[0.37585, 0.489989]],
[[0.530777, 0.99017],
[0.899132, 0.254313],
[0.713248, 0.815576]]]]], dtype=np.float32)
output_5d_5 = net(Tensor(input_5d), scale_factor=scale_factor, align_corners=True, recompute_scale_factor=True)
assert np.allclose(output_5d_5.asnumpy(), except_5d_5, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_nearest_exact_3d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate nearest-exact mode
1. 3D size
2. 3D scale_factor recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("nearest-exact")
# 1. 3D(1, 3, 5) size=8
input_3d = np.array([[[0.816117, 0.004037, 0.746452, 0.137449, 0.337593],
[0.970709, 0.558792, 0.053919, 0.734102, 0.432973],
[0.830186, 0.77753, 0.384094, 0.905231, 0.76362]]], dtype=np.float32)
except_3d_1 = np.array([[[0.816117, 0.816117, 0.004037, 0.746452, 0.746452, 0.137449,
0.337593, 0.337593],
[0.970709, 0.970709, 0.558792, 0.053919, 0.053919, 0.734102,
0.432973, 0.432973],
[0.830186, 0.830186, 0.77753, 0.384094, 0.384094, 0.905231,
0.76362, 0.76362]]], dtype=np.float32)
size = 8
output_3d_1 = net(Tensor(input_3d), size=size)
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-5, rtol=1e-5)
# 2. 3D(1, 3, 5) scale_factor=0.7 recompute_scale_factor=True
except_3d_2 = np.array([[[0.816117, 0.746452, 0.337593],
[0.970709, 0.053919, 0.432973],
[0.830186, 0.384094, 0.76362]]], dtype=np.float32)
scale_factor = 0.7
output_3d_2 = net(Tensor(input_3d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_nearest_exact_4d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate nearest-exact mode
Expectation: success
1. 4D size
2. 4D scale_factor recompute_scale_factor=True
"""
ms.set_context(mode=mode)
net = Net("nearest-exact")
# 1. 4D(1, 3, 3, 5) size=(5, 8)
size = (5, 8)
input_4d = np.array([[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]], dtype=np.float32)
except_4d_1 = np.array([[[[0.4992, 0.4992, 0.743079, 0.570383, 0.570383, 0.942855,
0.833395, 0.833395],
[0.645754, 0.645754, 0.485126, 0.957497, 0.957497, 0.933144,
0.556276, 0.556276],
[0.298626, 0.298626, 0.594928, 0.150964, 0.150964, 0.447654,
0.267512, 0.267512],
[0.298626, 0.298626, 0.594928, 0.150964, 0.150964, 0.447654,
0.267512, 0.267512],
[0.46943, 0.46943, 0.419558, 0.665492, 0.665492, 0.414906,
0.708427, 0.708427]]]], dtype=np.float32)
output_4d_1 = net(Tensor(input_4d), size=size)
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
# 2. 4D(1, 3, 3, 5) scale_factor=(1.5, 0.4) recompute_scale_factor=True
scale_factor = (1.5, 0.4)
except_4d_2 = np.array([[[[0.743079, 0.942855],
[0.485126, 0.933144],
[0.485126, 0.933144],
[0.594928, 0.447654],
[0.419558, 0.414906],
[0.419558, 0.414906]]]], dtype=np.float32)
output_4d_2 = net(Tensor(input_4d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_nearest_3d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate nearest mode
1. 3D size
2. 3D scale_factor recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("nearest")
# 1. 3D(1, 3, 5) size=8
size = 8
input_3d = np.array([[[0.816117, 0.004037, 0.746452, 0.137449, 0.337593],
[0.970709, 0.558792, 0.053919, 0.734102, 0.432973],
[0.830186, 0.77753, 0.384094, 0.905231, 0.76362]]], dtype=np.float32)
except_3d_1 = np.array([[[0.816117, 0.816117, 0.004037, 0.004037, 0.746452, 0.137449,
0.137449, 0.337593],
[0.970709, 0.970709, 0.558792, 0.558792, 0.053919, 0.734102,
0.734102, 0.432973],
[0.830186, 0.830186, 0.77753, 0.77753, 0.384094, 0.905231,
0.905231, 0.76362]]], dtype=np.float32)
output_3d_1 = net(Tensor(input_3d), size=size)
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-5, rtol=1e-5)
# 2. 3D(1, 3, 5) scale_factor=0.7 recompute_scale_factor=True
except_3d_2 = np.array([[[0.816117, 0.004037, 0.137449],
[0.970709, 0.558792, 0.734102],
[0.830186, 0.77753, 0.905231]]], dtype=np.float32)
scale_factor = 0.7
output_3d_2 = net(Tensor(input_3d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_nearest_4d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate nearest mode
1. 4D size
2. 4D scale_factor recompute_scale_factor=True
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("nearest")
# 1. 4D(1, 3, 3, 5) size=(5, 8)
size = (5, 8)
input_4d = np.array([[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]], dtype=np.float32)
except_4d_1 = np.array([[[[0.4992, 0.4992, 0.743079, 0.743079, 0.570383, 0.942855,
0.942855, 0.833395],
[0.4992, 0.4992, 0.743079, 0.743079, 0.570383, 0.942855,
0.942855, 0.833395],
[0.645754, 0.645754, 0.485126, 0.485126, 0.957497, 0.933144,
0.933144, 0.556276],
[0.298626, 0.298626, 0.594928, 0.594928, 0.150964, 0.447654,
0.447654, 0.267512],
[0.46943, 0.46943, 0.419558, 0.419558, 0.665492, 0.414906,
0.414906, 0.708427]]]], dtype=np.float32)
output_4d_1 = net(Tensor(input_4d), size=size)
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
# 2. 4D(1, 3, 3, 5) scale_factor=(1.5, 0.4) recompute_scale_factor=True
except_4d_2 = np.array([[[[0.4992, 0.570383],
[0.4992, 0.570383],
[0.645754, 0.957497],
[0.298626, 0.150964],
[0.298626, 0.150964],
[0.46943, 0.665492]]]], dtype=np.float32)
scale_factor = (1.5, 0.4)
output_4d_2 = net(Tensor(input_4d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_interpolate_nearest_5d(mode):
"""
Feature: interpolate
Description: Verify the result of interpolate nearest mode
1. 5D size
2. 5D scale_factor recompute_scale_factor=True
3. 5D scale_factor
Expectation: success
"""
ms.set_context(mode=mode)
net = Net("nearest")
# 1. 5D(1, 1, 1, 4, 5) size=(2, 4, 3)
size = (2, 4, 3)
input_5d = np.array([[[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]]], dtype=np.float32)
except_5d_1 = np.array([[[[[0.4992, 0.743079, 0.942855],
[0.645754, 0.485126, 0.933144],
[0.298626, 0.594928, 0.447654],
[0.46943, 0.419558, 0.414906]],
[[0.4992, 0.743079, 0.942855],
[0.645754, 0.485126, 0.933144],
[0.298626, 0.594928, 0.447654],
[0.46943, 0.419558, 0.414906]]]]], dtype=np.float32)
output_5d_1 = net(Tensor(input_5d), size=size)
assert np.allclose(output_5d_1.asnumpy(), except_5d_1, atol=1e-5, rtol=1e-5)
scale_factor = (3., 0.4, 0.7)
# 2. 5D(1, 1, 1, 4, 5) scale_factor=(3, 0.4, 0.7) recompute_scale_factor=True
except_5d_2 = np.array([[[[[0.4992, 0.743079, 0.942855]],
[[0.4992, 0.743079, 0.942855]],
[[0.4992, 0.743079, 0.942855]]]]], dtype=np.float32)
output_5d_2 = net(Tensor(input_5d), scale_factor=scale_factor, recompute_scale_factor=True)
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
# 3. 5D(1, 1, 1, 4, 5) scale_factor=(3, 0.4, 0.7)
except_5d_3 = np.array([[[[[0.4992, 0.743079, 0.570383]],
[[0.4992, 0.743079, 0.570383]],
[[0.4992, 0.743079, 0.570383]]]]], dtype=np.float32)
output_5d_3 = net(Tensor(input_5d), scale_factor=scale_factor)
assert np.allclose(output_5d_3.asnumpy(), except_5d_3, atol=1e-5, rtol=1e-5)