refactor interpolate
This commit is contained in:
parent
757406cb5f
commit
46c12d3057
|
@ -29,6 +29,7 @@ mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_dee
|
|||
mindspore/mindspore/ccsrc/pipeline/jit/resource.cc:mindspore::pipeline::GetMethodMap
|
||||
mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape
|
||||
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:interpolate
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool3d
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool2d
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:max_unpool1d
|
||||
|
|
|
@ -1,36 +1,61 @@
|
|||
mindspore.ops.interpolate
|
||||
=========================
|
||||
|
||||
.. py:function:: mindspore.ops.interpolate(x, roi=None, scales=None, sizes=None, coordinate_transformation_mode="align_corners", mode="linear")
|
||||
.. py:function:: mindspore.ops.interpolate(x, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None)
|
||||
|
||||
使用 `mode` 设置的插值方式调整输入 `x` 大小。
|
||||
|
||||
.. warning::
|
||||
- 实验特性,接口可能发生变化。
|
||||
- `roi` 是保留输入, `crop_and_resize` 坐标变换模式下生效,当前不支持。
|
||||
- Ascend平台下,当前不支持将 `mode` 设置为"linear"。
|
||||
按照给定的 `size` 或 `scale_factor` 根据 `mode` 设置的插值方式,对输入 `x` 调整大小。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入Tensor。当 `mode` 是"linear"时, `x` 为三维Tensor。当 `mode` 是"bilinear"时, `x` 为四维Tensor。
|
||||
- **roi** (tuple[float],可选) - 在 `crop_and_resize` 坐标变换模式下生效,当前不支持。
|
||||
- **scales** (tuple[float],可选) - 输入shape每个维度resize的系数。 `scales` 中的数全是正数。 `scales` 的长度跟 `x` 的shape长度相同。 `scales` 和 `sizes` 同时只能指定一个。
|
||||
- **sizes** (tuple[int],可选) - 输入shape指定轴的新维度。 `sizes` 中的数全是正数。 `scales` 和 `sizes` 同时只能指定一个。当 `mode` 是"linear"时, `sizes` 为1个int元素 :math:`(new\_width,)` 的tuple。当 `mode` 是"bilinear"时, `sizes` 为2个int元素 :math:`(new\_height, new\_width)` 的tuple。
|
||||
- **coordinate_transformation_mode** (str) - 指定进行坐标变换的方式,默认值是"align_corners",还可选"half_pixel"和"asymmetric"。
|
||||
假如我们需要将输入Tensor的x轴进行resize。我们记 `new_i` 为resize之后的Tenosr沿x轴的第i个坐标;记 `old_i` 为输入Tensor沿x轴的对应坐标;
|
||||
记 `new_length` 是resize之后的Tensor沿着x轴的长度,记 `old_length` 是输入Tensor沿x轴的长度。我们可以通过下面的公式计算出来 `old_i` :
|
||||
- **x** (Tensor) - 被调整大小的Tensor。输入向量必须为3维,4维或5维,形状为 `(batch, channels, [optional depth], [optional height], width)` ,数据类型为float。
|
||||
- **size** (Union[int, tuple[int], list[int]], 可选) - 目标大小。如果 `size` 为tuple或list,那么其长度应该和 `x` 维度相同。 `size` 和 `scale_factor` 同时只能指定一个。默认值:None。
|
||||
- **scale_factor** (Union[float, tuple[float], list[float]],可选) - 每个维度的缩放系数。 `scales` 中的数全是正数。 `size` 和 `scale_factor` 同时只能指定一个。默认值:None。
|
||||
- **mode** (str) - 采样算法。以下采样方式的一种,'nearest', 'linear' (仅三维),'bilinear' (仅四维),'bicubic' (仅四维),'trilinear' (仅五维),'area','nearest-exact'(三维和四维)。默认值:'nearest'。
|
||||
- **align_corners** (bool) - 如果为True,缩放比例系数使用 `(new\_height - 1) / (height - 1)` 计算,此种方式调整的数据与原始数据边角对齐。如果为False,缩放系数通过 `new\_height / height` 计算。
|
||||
|
||||
.. code-block::
|
||||
|
||||
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # if set to 'align_corners'
|
||||
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # 'align_corners' 为 True
|
||||
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # 'align_corners' 为 False
|
||||
|
||||
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # if set to 'half_pixel'
|
||||
此选项只对'linear'、'bilinear'、'bicubic'和'trilinear'模式有效。默认值:False。
|
||||
- **recompute_scale_factor** (bool, 可选) - 重计算 `scale_factor` 。如果为True,会使用参数 `scale_factor` 计算参数 `size`,最终使用 `size` 的值进行缩放。如果为False,将使用 `size` 或 `scale_factor` 直接进行插值。默认值:None。
|
||||
|
||||
old_i = new_length != 0 ? new_i * old_length / new_length : 0 # if set to 'asymmetric'
|
||||
参数支持列表和支持平台:
|
||||
|
||||
- **mode** (str) - 所使用的插值方式。目前支持"linear"和"bilinear"插值方式。默认值:"linear"。
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| mode | dim | align_corners | scale_factor | device |
|
||||
+================+======+================+===============+==================+
|
||||
| nearest | 3 | \- | × | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 4 | \- | × | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 5 | \- | √ | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| linear | 3 | √ | × | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| bilinear | 4 | √ | × | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| trilinear | 5 | √ | √ | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| bicubic | 4 | √ | × | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| area | 3 | \- | √ | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 4 | \- | √ | GPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 5 | \- | √ | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| nearest-exact | 3 | \- | × | Ascend,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 4 | \- | × | Ascend,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
|
||||
- `-` 表示无此参数。
|
||||
- `×` 表示当前不支持此参数。
|
||||
- `√` 表示当前支持此参数。
|
||||
|
||||
返回:
|
||||
Tensor,数据类型与 `x` 相同。
|
||||
调整大小之后的Tensor,维度和数据类型与 `x` 相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
|
@ -42,4 +67,4 @@ mindspore.ops.interpolate
|
|||
- **TypeError** - `coordinate_transformation_mode` 不是string。
|
||||
- **ValueError** - `coordinate_transformation_mode` 不在支持的列表中。
|
||||
- **TypeError** - `mode` 不是string类型。
|
||||
- **ValueError** - `mode` 不在支持的列表中。
|
||||
- **ValueError** - `mode` 不在支持的列表中。
|
||||
|
|
|
@ -15,14 +15,14 @@
|
|||
|
||||
"""Defines nn operators with functional form."""
|
||||
from __future__ import absolute_import
|
||||
from math import pi, log
|
||||
from math import pi, log, floor
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import nn_ops as NN_OPS
|
||||
from mindspore.ops.operations import image_ops as IMG
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.function.math_func import logsumexp
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -1927,160 +1927,249 @@ def hardswish(x):
|
|||
|
||||
|
||||
@constexpr
|
||||
def _check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
|
||||
prim_name):
|
||||
"""Check input"""
|
||||
msg_prefix = f"For '{prim_name}', the"
|
||||
validator.check_value_type("coordinate_transformation_mode", coordinate_transformation_mode, [str], prim_name)
|
||||
support_coordinate_mode_list = ["align_corners", "half_pixel", "asymmetric"]
|
||||
if coordinate_transformation_mode not in support_coordinate_mode_list:
|
||||
raise TypeError(f"{msg_prefix} coordinate_transformation_mode must be in {support_coordinate_mode_list},"
|
||||
f" but got {coordinate_transformation_mode}")
|
||||
validator.check_value_type("mode", mode, [str], prim_name)
|
||||
if mode == "linear":
|
||||
validator.check_int(input_dims, 3, Rel.EQ, "input dims", prim_name)
|
||||
elif mode == "bilinear":
|
||||
validator.check_int(input_dims, 4, Rel.EQ, "input dims", prim_name)
|
||||
else:
|
||||
raise ValueError(f"{msg_prefix} mode must be 'linear' or 'bilinear', but got {mode}")
|
||||
|
||||
if sizes is None and scales is None:
|
||||
raise ValueError(f"{msg_prefix} 'sizes' and 'scale' both none.")
|
||||
if sizes is not None and scales is not None:
|
||||
raise ValueError(f"{msg_prefix} 'sizes' and 'scale' both not none.")
|
||||
if sizes is not None:
|
||||
if not isinstance(sizes, tuple):
|
||||
raise TypeError(
|
||||
f"{msg_prefix} 'sizes' must be tuple or None, but got {type(sizes).__name__}.")
|
||||
for item in sizes:
|
||||
validator.check_positive_int(item, 'sizes item', prim_name)
|
||||
validator.check_value_type("sizes item", item, int, prim_name)
|
||||
validator.check_int(len(sizes), input_dims - 2, Rel.EQ, "sizes", prim_name)
|
||||
return
|
||||
if not isinstance(scales, tuple):
|
||||
raise TypeError(
|
||||
f"{msg_prefix} 'scales' must be tuple or None, but got {type(scales).__name__}.")
|
||||
for item in scales:
|
||||
validator.check_positive_float(item, 'scales item', prim_name)
|
||||
validator.check_value_type("scales item", item, float, prim_name)
|
||||
scales_dims = len(scales)
|
||||
validator.check_int(scales_dims, input_dims, Rel.EQ, "scales dims", prim_name)
|
||||
validator.check_float(scales[0], 1.0, Rel.EQ, "scales[0]", prim_name)
|
||||
validator.check_float(scales[1], 1.0, Rel.EQ, "scales[1]", prim_name)
|
||||
def _scale_factor_convert_size(shape, scale_factor, dim):
|
||||
return [int(floor(float(shape[i + 2]) * scale_factor[i])) for i in range(dim)]
|
||||
|
||||
|
||||
def _interpolate_output_shape(shape, scales, sizes, mode):
|
||||
"""calculate output shape"""
|
||||
if sizes is not None:
|
||||
if mode == "bilinear":
|
||||
return sizes
|
||||
return Tensor(sizes)
|
||||
ret = ()
|
||||
for i in range(2, len(shape)):
|
||||
ret = ret + (int(scales[i] * shape[i]),)
|
||||
if mode == "bilinear":
|
||||
return ret
|
||||
return Tensor(ret)
|
||||
|
||||
|
||||
def interpolate(x, roi=None, scales=None, sizes=None, coordinate_transformation_mode="align_corners", mode="linear"):
|
||||
def interpolate(x, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None):
|
||||
r"""
|
||||
Using the interpolate method specified by `mode` resize the input tensor `x`.
|
||||
|
||||
.. warning::
|
||||
- This is an experimental prototype that is subject to change.
|
||||
- The `roi` is reserved interface for 'crop_and_resize' coordinate transformation mode,
|
||||
which is not support now.
|
||||
- The Ascend platforms is currently not supported when `mode` is "linear".
|
||||
Samples the input Tensor to the given size or scale_factor by using one of the interpolate algorithms.
|
||||
|
||||
Args:
|
||||
x (Tensor): a tensor which to resize. `x` is a 3-D tensor when `mode` is "linear". `x` is a 4-D tensor when
|
||||
`mode` is "bilinear".
|
||||
roi (tuple[float], optional): a tuple of float. Only takes effect when attr coordinate_transformation_mode is
|
||||
'crop_and_resize'.
|
||||
scales (tuple[float], optional): a tuple of float. Describe the scale along each dimension.
|
||||
Its length is the same as that of shape of `x`. The numbers in `scales` must all be positive. Only one of
|
||||
`scales` and `sizes` can be specified.
|
||||
sizes (tuple[int], optional): a tuple of int, describes the shape of the output tensor. The numbers in `sizes`
|
||||
must all be positive. Only one of `scales` and `sizes` can be specified. If `sizes` is specified, then set
|
||||
`scales` to 'None' in this operator's input list. It is 1 int elements :math:`(new\_width,)` when `mode`
|
||||
is "linear". It is 2 int elements :math:`(new\_height, new\_width)` when `mode` is "bilinear".
|
||||
coordinate_transformation_mode (str): Default is 'align_corners'. Describes how to transform the coordinate
|
||||
in the resized tensor to the coordinate in the original tensor. Other optional: 'half_pixel', 'asymmetric'.
|
||||
For example, we want to resize the original tensor along axis x. Let's denote `new_i` as the i-th coordinate
|
||||
of the resized tensor along axis x, `old_i` as the coordinate of the original tensor along axis x,
|
||||
`new_length` as the length of the resized tensor along axis x, `old_length` as the length of the original
|
||||
tensor along axis x. We compute the `old_i` via the following formula:
|
||||
x (Tensor): Tensor to be resized.
|
||||
Input tensor must be a 3-D, 4-D, or 5-D tensor with shape
|
||||
`(batch, channels, [optional depth], [optional height], width)`, with data type of float.
|
||||
size (Union[int, tuple[int], list[int]], optional)): The target size.
|
||||
If size is a tuple or list, size must have the same dimensions as x.
|
||||
One and only one of size and scale_factor can be set to None. Default: None.
|
||||
scale_factor (Union[float, tuple[float], list[float]], optional): The scale factor of new size of the tensor.
|
||||
If size is a tuple or list, size must have the same dimensions as x.
|
||||
One and only one of size and scale_factor can be set to None. Default: None.
|
||||
mode (str): The sampling algorithm.
|
||||
One of 'nearest', 'linear' (3D only), 'bilinear' (4D only), 'bicubic' (4D only), 'trilinear' (5D only),
|
||||
'area', 'nearest-exact'(3D and 4D). Default: 'nearest'.
|
||||
align_corners (bool): If True, rescale input by `(new\_height - 1) / (height - 1)`, which exactly
|
||||
aligns the corners of data and resized data. If False, rescale by `new\_height / height`.
|
||||
|
||||
.. code-block::
|
||||
|
||||
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # if set to 'align_corners'
|
||||
old_i = new_length != 1 ? new_i * (old_length - 1) / (new_length - 1) : 0 # 'align_corners' = True
|
||||
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # 'align_corners' = False
|
||||
|
||||
old_i = new_length > 1 ? (new_x + 0.5) * old_length / new_length - 0.5 : 0 # if set to 'half_pixel'
|
||||
This is only valid for 'linear', 'bilinear', 'bicubic', or 'trilinear' modes. Default: False.
|
||||
recompute_scale_factor (bool, optional): Recalculate `scale_factor`.
|
||||
If True, the parameter `size` will be calculated using the value of the `scale_factor`,
|
||||
and finally scaled using the value of `size`.
|
||||
If False, the value of `size` or `scale_factor` will be used for direct interpolation. Default: None.
|
||||
|
||||
old_i = new_length != 0 ? new_i * old_length / new_length : 0 # if set to 'asymmetric'
|
||||
Args Support List and Supported Platforms:
|
||||
|
||||
mode (str): The method used to interpolate: 'linear' | 'bilinear'. Default is 'linear'.
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| mode | dim | align_corners | scale_factor | device |
|
||||
+================+======+================+===============+==================+
|
||||
| nearest | 3 | \- | × | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 4 | \- | × | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 5 | \- | √ | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| linear | 3 | √ | × | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| bilinear | 4 | √ | × | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| trilinear | 5 | √ | √ | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| bicubic | 4 | √ | × | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| area | 3 | \- | √ | Ascend,GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 4 | \- | √ | GPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 5 | \- | √ | GPU,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| nearest-exact | 3 | \- | × | Ascend,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
| | 4 | \- | × | Ascend,CPU |
|
||||
+----------------+------+----------------+---------------+------------------+
|
||||
|
||||
Returns:
|
||||
Resized tensor, with the same data type as input `x`.
|
||||
- `-` indicates that there is no such parameter.
|
||||
- `×` indicates that this parameter is not currently supported.
|
||||
- `√` indicates that this parameter is supported.
|
||||
|
||||
Outputs:
|
||||
Tensor, resized, whose dimensions and dtype are the same as `x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If the data type of `x` is not supported.
|
||||
TypeError: If `scales` is not a float tuple.
|
||||
ValueError: If not all numbers in `scales` are positive.
|
||||
TypeError: If `sizes` is not an int tuple.
|
||||
ValueError: If not all numbers in `sizes` are positive.
|
||||
TypeError: If `coordinate_transformation_mode` is not a string.
|
||||
ValueError: If `coordinate_transformation_mode` is not in the support list.
|
||||
TypeError: If `mode` is not a string.
|
||||
ValueError: If `mode` is not in the support list.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
TypeError: If `x` is not Tensor.
|
||||
TypeError: If `mode` is not one of 'nearest', 'linear', 'bilinear', 'bicubic',
|
||||
'trilinear', 'area', nor 'nearest-exact'.
|
||||
TypeError: If `size` is not one of tuple, list, None.
|
||||
TypeError: If `scale_factor` is neither int nor None.
|
||||
TypeError: If `align_corners` is not a bool.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
ValueError: If `size` and `scale_factor` are both None or not None.
|
||||
ValueError: If shape of `x` or `size` and `mode` are not match.
|
||||
ValueError: If `scale_factor` is an int which is less than 0.
|
||||
ValueError: If `size` is a list or tuple whose length is not match `mode`.
|
||||
|
||||
Examples:
|
||||
>>> # case 1: linear mode
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> x = Tensor([[[1, 2, 3], [4, 5, 6]]], mindspore.float32)
|
||||
>>> output = ops.interpolate(x, None, None, (6,), "align_corners")
|
||||
>>> output = ops.interpolate(x, size=(6,), mode='nearest')
|
||||
>>> print(output)
|
||||
[[[1. 1.4 1.8 2.2 2.6 3.]
|
||||
[4. 4.4 4.8 5.2 5.6 6.]]]
|
||||
>>> # case 2: bilinear mode
|
||||
>>> x = Tensor([[[[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]]], mindspore.float32)
|
||||
>>> output = ops.interpolate(x, None, None, (5, 5), "asymmetric", "bilinear")
|
||||
>>> print(output)
|
||||
[[[[1. 2. 3. 4. 5.]
|
||||
[1. 2. 3. 4. 5.]
|
||||
[1. 2. 3. 4. 5.]
|
||||
[1. 2. 3. 4. 5.]
|
||||
[1. 2. 3. 4. 5.]]]]
|
||||
[[[1. 1. 2. 2. 3. 3.]
|
||||
[4. 4. 5. 5. 6. 6.]]]
|
||||
"""
|
||||
if not isinstance(x, (Tensor, Tensor_)):
|
||||
raise TypeError("For interpolate, the input x must be tensor")
|
||||
input_shape = x.shape
|
||||
input_dims = len(input_shape)
|
||||
_check_interpolate_inputs(input_dims, roi, scales, sizes, coordinate_transformation_mode, mode,
|
||||
"interpolate")
|
||||
output_size = _interpolate_output_shape(input_shape, scales, sizes, mode)
|
||||
|
||||
if mode == "linear":
|
||||
resize_linear_inner = _get_cache_prim(IMG.ResizeLinear1D)(
|
||||
coordinate_transformation_mode=coordinate_transformation_mode)
|
||||
return resize_linear_inner(x, output_size)
|
||||
if mode == "bilinear":
|
||||
def run_nearest(x, size, align_corners=None, scale_factor=None):
|
||||
# 3D 4D use ResizeNearestNeighborV2, 5D use UpsampleNearest3D
|
||||
if x.ndim == 3:
|
||||
size = Tensor((size[0], 1), dtype=mstype.int32)
|
||||
x = x.unsqueeze(-1)
|
||||
x = _get_cache_prim(P.ResizeNearestNeighborV2)(data_format="NCHW")(x, size)
|
||||
x = x.squeeze(-1)
|
||||
elif x.ndim == 4:
|
||||
size = Tensor(size, dtype=mstype.int32)
|
||||
x = _get_cache_prim(P.ResizeNearestNeighborV2)(data_format="NCHW")(x, size)
|
||||
else:
|
||||
x = _get_cache_prim(P.UpsampleNearest3D)(size, scales=scale_factor)(x)
|
||||
return x
|
||||
|
||||
def run_linear(x, size, align_corners=None, scale_factor=None):
|
||||
coordinate_transformation_mode = "align_corners" if align_corners else "half_pixel"
|
||||
resize = _get_cache_prim(P.image_ops.ResizeLinear1D)(
|
||||
coordinate_transformation_mode
|
||||
)
|
||||
return resize(x, Tensor(size, dtype=mstype.int32))
|
||||
|
||||
def run_bilinear(x, size, align_corners=None, scale_factor=None):
|
||||
resize = _get_cache_prim(P.ResizeBilinearV2)(align_corners, not align_corners)
|
||||
return resize(x, size)
|
||||
|
||||
def run_trilinear(x, size, align_corners=None, scale_factor=None):
|
||||
resize = _get_cache_prim(P.nn_ops.UpsampleTrilinear3D)(
|
||||
output_size=size, scales=scale_factor, align_corners=align_corners
|
||||
)
|
||||
return resize(x)
|
||||
|
||||
def run_bicubic(x, size, align_corners=None, scale_factor=None):
|
||||
resize = _get_cache_prim(P.image_ops.ResizeBicubic)(
|
||||
align_corners=align_corners, half_pixel_centers=not align_corners
|
||||
)
|
||||
x = resize(x, Tensor(size, dtype=mstype.int32))
|
||||
return x
|
||||
|
||||
def run_area(x, size, align_corners=None, scale_factor=None):
|
||||
if x.ndim == 3:
|
||||
resize = nn.AdaptiveAvgPool1d(output_size=size[0])
|
||||
x = resize(x)
|
||||
elif x.ndim == 4:
|
||||
x = ops.adaptive_avg_pool2d(x, tuple(size))
|
||||
else:
|
||||
x = ops.adaptive_avg_pool3d(x, tuple(size))
|
||||
return x
|
||||
|
||||
def run_nearest_exact(x, size, align_corners=None, scale_factor=None):
|
||||
if x.ndim == 3:
|
||||
size = Tensor((size[0], 1), dtype=mstype.int32)
|
||||
# For impl of nearest 3D use 4D.
|
||||
x = x.unsqueeze(-1)
|
||||
resize = _get_cache_prim(P.ResizeNearestNeighborV2)(
|
||||
data_format="NCHW", align_corners=False, half_pixel_centers=True
|
||||
)
|
||||
x = resize(x, size)
|
||||
x = x.squeeze(-1)
|
||||
if x.ndim == 4:
|
||||
size = Tensor(size, dtype=mstype.int32)
|
||||
resize = _get_cache_prim(P.ResizeNearestNeighborV2)(
|
||||
data_format="NCHW", align_corners=False, half_pixel_centers=True
|
||||
)
|
||||
x = resize(x, size)
|
||||
return x
|
||||
|
||||
# support_dict "mode":{dim:{"scale_factor", "align_corners"}}
|
||||
supported_dict = {
|
||||
"nearest": {3: (), 4: (), 5: ("scale_factor",)},
|
||||
"linear": {3: ("align_corners",)},
|
||||
"bilinear": {4: ("align_corners",)},
|
||||
"bicubic": {4: ("align_corners",)},
|
||||
"trilinear": {5: ("scale_factor", "align_corners")},
|
||||
"area": {3: ("scale_factor",), 4: ("scale_factor",), 5: ("scale_factor",)},
|
||||
"nearest-exact": {3: (), 4: ()},
|
||||
}
|
||||
resize_func = {
|
||||
"nearest": run_nearest,
|
||||
"linear": run_linear,
|
||||
"bilinear": run_bilinear,
|
||||
"bicubic": run_bicubic,
|
||||
"trilinear": run_trilinear,
|
||||
"area": run_area,
|
||||
"nearest-exact": run_nearest_exact,
|
||||
}
|
||||
if not isinstance(x, Tensor):
|
||||
raise TypeError(f"For 'interpolate', 'x' must be a tensor, but got {type(x)}")
|
||||
if size is not None and scale_factor is not None:
|
||||
raise ValueError(
|
||||
"For 'interpolate', only one of size or scale_factor should be defined"
|
||||
)
|
||||
if size is not None:
|
||||
if isinstance(size, (list, tuple)):
|
||||
if len(size) != x.ndim - 2:
|
||||
raise ValueError(
|
||||
f"For 'interpolate', 'x' and 'size' must have same number of spatial dimensions, "
|
||||
f"but got 'x' is {x.ndim - 2}D, 'size' is {len(size)}D"
|
||||
)
|
||||
else:
|
||||
size = [size for _ in range(x.ndim - 2)]
|
||||
elif scale_factor is not None:
|
||||
if isinstance(scale_factor, (list, tuple)):
|
||||
if len(scale_factor) != x.ndim - 2:
|
||||
raise ValueError(
|
||||
f"For 'interpolate', 'x' and 'scale_factor' must have same number of spatial dimensions, "
|
||||
f"but got 'x' is {x.ndim - 2}D, 'scale_factor' is {len(size)}D"
|
||||
)
|
||||
else:
|
||||
scale_factor = [scale_factor for _ in range(x.ndim - 2)]
|
||||
else:
|
||||
raise ValueError(
|
||||
"For 'interpolate', either 'size' or 'scale_factor' should be defined"
|
||||
)
|
||||
|
||||
if mode not in supported_dict:
|
||||
raise ValueError(
|
||||
f"For 'interpolate', 'mode' must be in '{list(supported_dict)}', but got {mode}"
|
||||
)
|
||||
if x.ndim not in supported_dict.get(mode):
|
||||
raise ValueError(
|
||||
f"For 'interpolate', {mode} only support '{list(supported_dict.get(mode, {}))}'D, but got {x.ndim}D"
|
||||
)
|
||||
# "area" mode always requires an explicit size rather than scale factor.
|
||||
if mode == "area" and size is None:
|
||||
recompute_scale_factor = True
|
||||
if recompute_scale_factor is not None and recompute_scale_factor:
|
||||
# todo: check bool type
|
||||
if size is not None:
|
||||
raise ValueError(
|
||||
"For 'interpolate', 'recompute_scale_factor' is not meaningful with an explicit size"
|
||||
)
|
||||
size = _scale_factor_convert_size(x.shape, scale_factor, x.ndim - 2)
|
||||
scale_factor = None
|
||||
else:
|
||||
if scale_factor is not None and "scale_factor" not in supported_dict.get(mode, {}).get(x.ndim):
|
||||
raise ValueError(
|
||||
f"For 'interpolate', 'scale_factor' option cannot currently be set with the "
|
||||
f"mode = {mode} and dim = {x.ndim}D."
|
||||
)
|
||||
if align_corners is not None:
|
||||
if "align_corners" not in supported_dict.get(mode, {}).get(x.ndim):
|
||||
raise ValueError(
|
||||
f"For 'interpolate', 'align_corners' option cannot currently be set with the "
|
||||
f"mode = {mode}, and dim = {x.ndim}D"
|
||||
)
|
||||
else:
|
||||
align_corners = False
|
||||
half_pixel_centers = False
|
||||
if coordinate_transformation_mode == "align_corners":
|
||||
align_corners = True
|
||||
elif coordinate_transformation_mode == "half_pixel":
|
||||
half_pixel_centers = True
|
||||
resize_bilinear_inner = _get_cache_prim(IMG.ResizeBilinearV2)(align_corners, half_pixel_centers)
|
||||
return resize_bilinear_inner(x, output_size)
|
||||
|
||||
raise TypeError(
|
||||
"Input Error: For interpolate, {} mode is not support now".format(mode))
|
||||
return resize_func.get(mode)(x, size, align_corners, scale_factor)
|
||||
|
||||
|
||||
def softsign(x):
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore import ParameterTuple
|
|||
|
||||
class NetResizeBilinear(nn.Cell):
|
||||
def construct(self, inputs, size):
|
||||
return ops.interpolate(inputs, None, None, size, "asymmetric", "bilinear")
|
||||
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=False)(inputs, size)
|
||||
|
||||
|
||||
def case():
|
||||
|
|
|
@ -44,8 +44,8 @@ class ResizeBilinearGradAlignCornerF(nn.Cell):
|
|||
class NetResizeBilinearFunc(nn.Cell):
|
||||
def construct(self, inputs, size, align_corner=False, half_pixel_centers=False):
|
||||
if align_corner and not half_pixel_centers:
|
||||
return ops.interpolate(inputs, None, None, size, "align_corners", "bilinear")
|
||||
return ops.interpolate(inputs, None, None, size, "half_pixel", "bilinear")
|
||||
return ops.ResizeBilinearV2(align_corners=True, half_pixel_centers=False)(inputs, size)
|
||||
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=True)(inputs, size)
|
||||
|
||||
|
||||
def test_resize_bilinear_grad_align_corner():
|
||||
|
|
|
@ -13,12 +13,10 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
@ -547,167 +545,3 @@ def test_resize_nn_grayscale_align_corners_float(datatype=np.float32):
|
|||
diff = output.asnumpy() - expected_output.asnumpy()
|
||||
assert np.all(abs(diff) < error)
|
||||
assert np.all(abs(diff_align) < error)
|
||||
|
||||
|
||||
class NetResizeBilinearFunc(nn.Cell):
|
||||
def construct(self, x, scale, size):
|
||||
return ops.interpolate(x, None, scale, size, "asymmetric", "bilinear")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_bilinar_func_cpu():
|
||||
"""
|
||||
Feature: Test bilinear on cpu.
|
||||
Description: bilinear executes by calling ResizeBilinearV2
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float32))
|
||||
resize_nn = NetResizeBilinearFunc()
|
||||
output = resize_nn(input_tensor, None, (3, 7))
|
||||
expect = np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144,
|
||||
0.3857143, 0.4],
|
||||
[0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381,
|
||||
0.65238094, 0.6666667],
|
||||
[0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714,
|
||||
0.78571427, 0.8]]]]).astype(np.float32)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_bilinar_func_scale_cpu():
|
||||
"""
|
||||
Feature: Test bilinear with scale on cpu.
|
||||
Description: bilinear executes by calling ResizeBilinearV2
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float32))
|
||||
resize_nn = NetResizeBilinearFunc()
|
||||
output = resize_nn(input_tensor, (1.0, 1.0, 4.0, 2.0), None)
|
||||
expect = resize_nn(input_tensor, None, (8, 8))
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_bilinar_func_cpu_fp16():
|
||||
"""
|
||||
Feature: Test bilinear on cpu with fp16.
|
||||
Description: bilinear executes by calling ResizeBilinearV2
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float16))
|
||||
resize_nn = NetResizeBilinearFunc()
|
||||
output = resize_nn(input_tensor, None, (3, 7))
|
||||
expect = np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144,
|
||||
0.3857143, 0.4],
|
||||
[0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381,
|
||||
0.65238094, 0.6666667],
|
||||
[0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714,
|
||||
0.78571427, 0.8]]]]).astype(np.float16)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_bilinar_func_scale_cpu_fp16():
|
||||
"""
|
||||
Feature: Test bilinear with scale on cpu with fp16.
|
||||
Description: bilinear executes by calling ResizeBilinearV2
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float16))
|
||||
resize_nn = NetResizeBilinearFunc()
|
||||
output = resize_nn(input_tensor, (1.0, 1.0, 4.0, 2.0), None)
|
||||
expect = resize_nn(input_tensor, None, (8, 8))
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_bilinar_func_cpu_fp64():
|
||||
"""
|
||||
Feature: Test bilinear on cpu with fp64.
|
||||
Description: bilinear executes by calling ResizeBilinearV2
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float64))
|
||||
resize_nn = NetResizeBilinearFunc()
|
||||
output = resize_nn(input_tensor, None, (3, 7))
|
||||
expect = np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144,
|
||||
0.3857143, 0.4],
|
||||
[0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381,
|
||||
0.65238094, 0.6666667],
|
||||
[0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714,
|
||||
0.78571427, 0.8]]]]).astype(np.float64)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_resize_bilinar_func_scale_cpu_fp64():
|
||||
"""
|
||||
Feature: Test bilinear with scale on cpu with fp64.
|
||||
Description: bilinear executes by calling ResizeBilinearV2
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(np.float64))
|
||||
resize_nn = NetResizeBilinearFunc()
|
||||
output = resize_nn(input_tensor, (1.0, 1.0, 4.0, 2.0), None)
|
||||
expect = resize_nn(input_tensor, None, (8, 8))
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
def test_resize_nn_func_half_pixel_centers(datatype=np.float32):
|
||||
"""
|
||||
Feature: Test resize_bilinear on CPU.
|
||||
Description: The half_pixel_centers is True.
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
|
||||
resize_nn_func = NetResizeBilinearFunc()
|
||||
output = resize_nn_func(input_tensor, (3, 7), align_corner=False, half_pixel_centers=True)
|
||||
expected_output = np.array([[[[0.1, 0.13571429, 0.19285715, 0.25, 0.30714288,
|
||||
0.36428574, 0.4],
|
||||
[0.3, 0.3357143, 0.39285716, 0.45, 0.5071429,
|
||||
0.56428576, 0.6],
|
||||
[0.5, 0.5357143, 0.5928572, 0.65, 0.7071429,
|
||||
0.76428574, 0.8]]]], dtype=datatype)
|
||||
assert np.allclose(output.asnumpy(), expected_output)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu_training
|
||||
def test_resize_nn_func_half_pixel_centers_fp16(datatype=np.float16):
|
||||
"""
|
||||
Feature: Test resize_bilinear fp16 on CPU.
|
||||
Description: The half_pixel_centers is True.
|
||||
Expectation: Assert that results are consistent with expect.
|
||||
"""
|
||||
input_tensor = Tensor(
|
||||
np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype))
|
||||
resize_nn_func = NetResizeBilinearFunc()
|
||||
output = resize_nn_func(input_tensor, (3, 7), align_corner=False, half_pixel_centers=True)
|
||||
expected_output = np.array([[[[0.1, 0.13571429, 0.19285715, 0.25, 0.30714288,
|
||||
0.36428574, 0.4],
|
||||
[0.3, 0.3357143, 0.39285716, 0.45, 0.5071429,
|
||||
0.56428576, 0.6],
|
||||
[0.5, 0.5357143, 0.5928572, 0.65, 0.7071429,
|
||||
0.76428574, 0.8]]]], dtype=datatype)
|
||||
assert np.allclose(output.asnumpy(), expected_output)
|
||||
|
|
|
@ -19,7 +19,6 @@ import pytest
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations.image_ops import ResizeLinear1D
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -80,39 +79,3 @@ def test_resize_linear_1d_size_not_change(dtype):
|
|||
expect = np.array([[[1., 2., 3.],
|
||||
[4., 5., 6.]]]).astype(dtype)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float16])
|
||||
def test_resize_linear_1d_half_pixel_functional_interface(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1D cpu kernel half_pixel mode functional interface.
|
||||
Description: test the rightness of ResizeLinear1D cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
output = F.interpolate(x, None, None, (6,), 'half_pixel')
|
||||
expect = np.array([[[1., 1.25, 1.75, 2.25, 2.75, 3.],
|
||||
[4., 4.25, 4.75, 5.25, 5.75, 6.]]]).astype(dtype)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float64])
|
||||
def test_resize_linear_1d_align_corners_functional_interface(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1D cpu kernel align_corners mode functional interface.
|
||||
Description: test the rightness of ResizeLinear1D cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
output = F.interpolate(x, None, (1., 1., 2.), None, 'align_corners')
|
||||
expect = np.array([[[1., 1.4, 1.8, 2.2, 2.6, 3.],
|
||||
[4., 4.4, 4.8, 5.2, 5.6, 6.]]]).astype(dtype)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
|
|
@ -41,7 +41,7 @@ class NetResizeBilinear(nn.Cell):
|
|||
def construct(self, inputs, size, indices_input, axis):
|
||||
unique_input_index, _ = ops.unique(indices_input)
|
||||
inputs_dyn = ops.gather(inputs, unique_input_index, axis)
|
||||
return ops.interpolate(inputs_dyn, None, None, size, "asymmetric", "bilinear")
|
||||
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=False)(inputs_dyn, size)
|
||||
|
||||
|
||||
def case_input_dyn(mode, device_target, dtype="float32"):
|
||||
|
|
|
@ -576,8 +576,8 @@ def test_resize_nn_grayscale_align_corners_float(datatype=np.float32):
|
|||
class NetResizeBilinearFunc(nn.Cell):
|
||||
def construct(self, inputs, size, align_corner=False, half_pixel_centers=False):
|
||||
if align_corner and not half_pixel_centers:
|
||||
return ops.interpolate(inputs, None, None, size, "align_corners", "bilinear")
|
||||
return ops.interpolate(inputs, None, None, size, "half_pixel", "bilinear")
|
||||
return ops.ResizeBilinearV2(align_corners=True, half_pixel_centers=False)(inputs, size)
|
||||
return ops.ResizeBilinearV2(align_corners=False, half_pixel_centers=True)(inputs, size)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
|
|
@ -19,7 +19,6 @@ import pytest
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations.image_ops import ResizeLinear1D
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
|
@ -80,39 +79,3 @@ def test_resize_linear_1d_size_not_change(dtype):
|
|||
expect = np.array([[[1., 2., 3.],
|
||||
[4., 5., 6.]]]).astype(dtype)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float16])
|
||||
def test_resize_linear_1d_half_pixel_functional_interface(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1D gpu kernel half_pixel mode functional interface.
|
||||
Description: test the rightness of ResizeLinear1D cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
output = F.interpolate(x, None, None, (6,), 'half_pixel')
|
||||
expect = np.array([[[1., 1.25, 1.75, 2.25, 2.75, 3.],
|
||||
[4., 4.25, 4.75, 5.25, 5.75, 6.]]]).astype(dtype)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('dtype', [np.float16, np.float64])
|
||||
def test_resize_linear_1d_align_corners_functional_interface(dtype):
|
||||
"""
|
||||
Feature: ResizeLinear1D gpu kernel align_corners mode functional interface.
|
||||
Description: test the rightness of ResizeLinear1D cpu kernel.
|
||||
Expectation: the output is same as expect.
|
||||
"""
|
||||
x = Tensor(np.array([[[1, 2, 3],
|
||||
[4, 5, 6]]], dtype=dtype))
|
||||
output = F.interpolate(x, None, (1., 1., 2.), None, 'align_corners')
|
||||
expect = np.array([[[1., 1.4, 1.8, 2.2, 2.6, 3.],
|
||||
[4., 4.4, 4.8, 5.2, 5.6, 6.]]]).astype(dtype)
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
|
|
@ -0,0 +1,754 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, mode):
|
||||
super(Net, self).__init__()
|
||||
self.mode = mode
|
||||
|
||||
def construct(self, x, size=None, scale_factor=None, align_corners=None, recompute_scale_factor=None):
|
||||
return ops.interpolate(x, size, scale_factor, self.mode, align_corners, recompute_scale_factor)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_area_3d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate area mode
|
||||
1. 3D size
|
||||
2. 3D scale_factor
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("area")
|
||||
|
||||
# 1. 3D(1, 3, 5)
|
||||
input_3d = np.array([[[0.816117, 0.004037, 0.746452, 0.137449, 0.337593],
|
||||
[0.970709, 0.558792, 0.053919, 0.734102, 0.432973],
|
||||
[0.830186, 0.77753, 0.384094, 0.905231, 0.76362]]], dtype=np.float32)
|
||||
except_3d_1 = np.array([[[0.410077, 0.375244, 0.441951, 0.237521],
|
||||
[0.76475, 0.306356, 0.394011, 0.583538],
|
||||
[0.803858, 0.580812, 0.644662, 0.834426]]], dtype=np.float32)
|
||||
size = 4
|
||||
output_3d_1 = net(Tensor(input_3d), size=size)
|
||||
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# 2. 3D(1, 3, 5) scale_factor=0.2
|
||||
except_3d_2 = np.array([[[0.40833],
|
||||
[0.550099],
|
||||
[0.732132]]], dtype=np.float32)
|
||||
scale_factor = 0.2
|
||||
output_3d_2 = net(Tensor(input_3d), scale_factor=scale_factor)
|
||||
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_area_4d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate area mode
|
||||
1. 4D size
|
||||
2. 4D scale_factor
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("area")
|
||||
|
||||
# 1. 4D(1, 3, 3, 5) size=(5, 8)
|
||||
input_4d = np.array([[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
|
||||
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
|
||||
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
|
||||
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]], dtype=np.float32)
|
||||
except_4d_1 = np.array([[[[0.4992, 0.62114, 0.743079, 0.656731, 0.756619, 0.942855,
|
||||
0.888125, 0.833395],
|
||||
[0.572477, 0.59329, 0.614102, 0.689021, 0.85097, 0.937999,
|
||||
0.816418, 0.694836],
|
||||
[0.47219, 0.506109, 0.540027, 0.547129, 0.622315, 0.690399,
|
||||
0.551147, 0.411894],
|
||||
[0.384028, 0.445635, 0.507243, 0.457736, 0.419754, 0.43128,
|
||||
0.459625, 0.48797],
|
||||
[0.46943, 0.444494, 0.419558, 0.542525, 0.540199, 0.414906,
|
||||
0.561666, 0.708427]]]], dtype=np.float32)
|
||||
size = (5, 8)
|
||||
output_4d_1 = net(Tensor(input_4d), size=size)
|
||||
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 4D(1, 3, 3, 5) scale_factor=(1.5, 0.4)
|
||||
except_4d_2 = np.array([[[[0.604221, 0.782211],
|
||||
[0.650173, 0.798925],
|
||||
[0.696126, 0.815639],
|
||||
[0.348173, 0.28871],
|
||||
[0.433166, 0.442492],
|
||||
[0.51816, 0.596275]]]], dtype=np.float32)
|
||||
scale_factor = (1.5, 0.4)
|
||||
output_4d_2 = net(Tensor(input_4d), scale_factor=scale_factor)
|
||||
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_area_5d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate area mode
|
||||
1. 5D size
|
||||
2. 5D scale_factor
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("area")
|
||||
|
||||
# 1. 5D(1, 1, 1, 4, 5) size=(2, 4, 3)
|
||||
input_5d = np.array([[[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
|
||||
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
|
||||
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
|
||||
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]]], dtype=np.float32)
|
||||
except_5d_1 = np.array([[[[[0.62114, 0.752106, 0.888125],
|
||||
[0.56544, 0.791922, 0.74471],
|
||||
[0.446777, 0.397849, 0.357583],
|
||||
[0.444494, 0.499985, 0.561666]],
|
||||
|
||||
[[0.62114, 0.752106, 0.888125],
|
||||
[0.56544, 0.791922, 0.74471],
|
||||
[0.446777, 0.397849, 0.357583],
|
||||
[0.444494, 0.499985, 0.561666]]]]], dtype=np.float32)
|
||||
size = (2, 4, 3)
|
||||
output_5d_1 = net(Tensor(input_5d), size=size)
|
||||
assert np.allclose(output_5d_1.asnumpy(),
|
||||
except_5d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 5D(1, 1, 1, 4, 5) scale_factor=(3, 0.4, 0.7)
|
||||
except_5d_2 = np.array([[[[[0.519463, 0.610466, 0.638021]],
|
||||
[[0.519463, 0.610466, 0.638021]],
|
||||
[[0.519463, 0.610466, 0.638021]]]]], dtype=np.float32)
|
||||
scale_factor = (3., 0.4, 0.7)
|
||||
output_5d_2 = net(Tensor(input_5d), scale_factor=scale_factor)
|
||||
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_bicubic(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate bicubic mode
|
||||
1. 4D size tf
|
||||
2. 4D size align_corners=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("bicubic")
|
||||
input_4d = np.array([[[[0.003088, 0.313131, 0.481231, 0.326219, 0.190293],
|
||||
[0.711616, 0.583990, 0.718121, 0.258823, 0.121847],
|
||||
[0.781316, 0.591508, 0.858185, 0.091935, 0.444639]],
|
||||
[[0.389884, 0.894497, 0.471427, 0.188708, 0.557449],
|
||||
[0.998047, 0.380719, 0.570574, 0.722258, 0.997173],
|
||||
[0.195751, 0.050744, 0.002008, 0.482685, 0.708559]]]],
|
||||
dtype=np.float32)
|
||||
size = (5, 3)
|
||||
# 1. 4D size=(5, 3)
|
||||
except_4d_1 = np.array([[[[0.038226, 0.46334, 0.228343],
|
||||
[0.287464, 0.558147, 0.186838],
|
||||
[0.671836, 0.718121, 0.14377],
|
||||
[0.729394, 0.819616, 0.255285],
|
||||
[0.723466, 0.868763, 0.334474]],
|
||||
|
||||
[[0.522484, 0.463939, 0.409831],
|
||||
[0.670862, 0.531739, 0.626711],
|
||||
[0.821431, 0.570574, 0.926654],
|
||||
[0.40312, 0.206133, 0.777062],
|
||||
[0.107333, -0.040933, 0.642958]]]], dtype=np.float32)
|
||||
output_4d_1 = net(Tensor(input_4d), size=size)
|
||||
assert np.allclose(output_4d_1.asnumpy(),
|
||||
except_4d_1, atol=1e-5, rtol=1e-5)
|
||||
# 2. 4D size=(5, 3), align_corners=True
|
||||
except_4d_2 = np.array([[[[0.003088, 0.481231, 0.190293],
|
||||
[0.350818, 0.586545, 0.125808],
|
||||
[0.711616, 0.718121, 0.121847],
|
||||
[0.81289, 0.810361, 0.276826],
|
||||
[0.781316, 0.858185, 0.444639]],
|
||||
|
||||
[[0.389884, 0.471427, 0.557449],
|
||||
[0.769181, 0.574304, 0.804369],
|
||||
[0.998047, 0.570574, 0.997173],
|
||||
[0.653914, 0.295586, 0.89409],
|
||||
[0.195751, 0.002008, 0.708559]]]], dtype=np.float32)
|
||||
output_4d_2 = net(Tensor(input_4d), size=size, align_corners=True)
|
||||
|
||||
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_linear(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate linear mode
|
||||
1. 3D size
|
||||
2. 3D size align_corners=True
|
||||
3. 3D scale_factor recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("linear")
|
||||
|
||||
# 1. 3D(1, 3, 5) size=8
|
||||
input_3d = np.array([[[0.784811, 0.008794, 0.52707, 0.821989, 0.88349],
|
||||
[0.00115, 0.556748, 0.8774, 0.761826, 0.23384],
|
||||
[0.842441, 0.853507, 0.294813, 0.375813, 0.399932]]], dtype=np.float32)
|
||||
except_3d_1 = np.array([[[0.784811, 0.445304, 0.041186, 0.365109, 0.619232, 0.803557, 0.856583, 0.88349],
|
||||
[0.00115, 0.244224, 0.576789, 0.777196, 0.841283, 0.769049, 0.464834, 0.23384],
|
||||
[0.842441, 0.847282, 0.818589, 0.469405, 0.320126, 0.370751, 0.38938, 0.399932]]],
|
||||
dtype=np.float32)
|
||||
size = 8
|
||||
output_3d_1 = net(Tensor(input_3d), size=size)
|
||||
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 3D(1, 3, 5) size=8 align_corners=True
|
||||
except_3d_2 = np.array([[[0.784811, 0.341373, 0.082833, 0.378991, 0.611333, 0.779858,
|
||||
0.848347, 0.88349],
|
||||
[0.00115, 0.318635, 0.602555, 0.785785, 0.844379, 0.778337,
|
||||
0.535546, 0.23384],
|
||||
[0.842441, 0.848764, 0.773694, 0.45444, 0.317956, 0.364242,
|
||||
0.38615, 0.399932]]], dtype=np.float32)
|
||||
output_3d_2 = net(Tensor(input_3d), size=size, align_corners=True)
|
||||
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 3. scale_factor=0.7, recompute_scale_factor=True
|
||||
except_3d_3 = np.array([[[0.526139, 0.52707, 0.86299],
|
||||
[0.186349, 0.8774, 0.409835],
|
||||
[0.84613, 0.294813, 0.391892]]], dtype=np.float32)
|
||||
scale_factor = 0.7
|
||||
output_3d_3 = net(Tensor(input_3d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_3d_3.asnumpy(), except_3d_3, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_bilinear(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate bilinear mode
|
||||
1. 4D size
|
||||
2. 4D size align_corners=True
|
||||
3. 4D scale_factor recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("bilinear")
|
||||
input_4d = np.array([[[[0.003088, 0.313131, 0.481231, 0.326219, 0.190293],
|
||||
[0.711616, 0.58399, 0.718121, 0.258823, 0.121847],
|
||||
[0.781316, 0.591508, 0.858185, 0.091935, 0.444639]],
|
||||
[[0.389884, 0.894497, 0.471427, 0.188708, 0.557449],
|
||||
[0.998047, 0.380719, 0.570574, 0.722258, 0.997173],
|
||||
[0.195751, 0.050744, 0.002008, 0.482685, 0.708559]]]],
|
||||
dtype=np.float32)
|
||||
size = (5, 3)
|
||||
|
||||
# 1. 4D size=(5, 3)
|
||||
except_4d_1 = np.array([[[[0.106436, 0.481231, 0.235602],
|
||||
[0.331491, 0.575987, 0.208363],
|
||||
[0.669074, 0.718121, 0.167506],
|
||||
[0.698458, 0.802159, 0.263245],
|
||||
[0.718047, 0.858185, 0.327071]],
|
||||
[[0.558088, 0.471427, 0.434535],
|
||||
[0.651761, 0.511086, 0.622935],
|
||||
[0.792271, 0.570574, 0.905535],
|
||||
[0.405358, 0.229434, 0.742174],
|
||||
[0.147415, 0.002008, 0.633268]]]], dtype=np.float32)
|
||||
output_4d_1 = net(Tensor(input_4d), size=size)
|
||||
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 4D size=(5, 3), align_corners=True
|
||||
except_4d_2 = np.array([[[[0.003088, 0.481231, 0.190293],
|
||||
[0.357352, 0.599676, 0.15607],
|
||||
[0.711616, 0.718121, 0.121847],
|
||||
[0.746466, 0.788153, 0.283243],
|
||||
[0.781316, 0.858185, 0.444639]],
|
||||
|
||||
[[0.389884, 0.471427, 0.557449],
|
||||
[0.693965, 0.521001, 0.777311],
|
||||
[0.998047, 0.570574, 0.997173],
|
||||
[0.596899, 0.286291, 0.852866],
|
||||
[0.195751, 0.002008, 0.708559]]]], dtype=np.float32)
|
||||
output_4d_2 = net(Tensor(input_4d), size=size, align_corners=True)
|
||||
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 3. scale_factor=(0.5, 1.5), recompute_scale_factor=True
|
||||
scale_factor = (0.5, 1.5)
|
||||
except_4d_3 = np.array([[[[0.711616, 0.638687, 0.622313, 0.718121, 0.390051, 0.200119,
|
||||
0.121847]],
|
||||
|
||||
[[0.998047, 0.645288, 0.434963, 0.570574, 0.67892, 0.840079,
|
||||
0.997173]]]], dtype=np.float32)
|
||||
output_4d_3 = net(Tensor(input_4d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_4d_3.asnumpy(), except_4d_3, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_trilinear(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate trilinear mode
|
||||
1. 5D size
|
||||
2. 5D size align_corners=True
|
||||
3. 5D scale_factor
|
||||
4. 5D scale_factor align_corners=True
|
||||
5. 5D scale_factor align_corners=True recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("trilinear")
|
||||
# 1. 5D size (1, 1, 2, 3, 4)
|
||||
input_5d = np.array([[[[[0.415192, 0.34623, 0.339439, 0.509735],
|
||||
[0.004669, 0.383352, 0.951564, 0.405198],
|
||||
[0.038453, 0.526071, 0.010246, 0.164402]],
|
||||
|
||||
[[0.530777, 0.452261, 0.012696, 0.99017],
|
||||
[0.899132, 0.921811, 0.812974, 0.254313],
|
||||
[0.713248, 0.420471, 0.757313, 0.815576]]]]], dtype=np.float32)
|
||||
size = (3, 4, 5)
|
||||
except_5d_1 = np.array([[[[[0.415192, 0.366919, 0.342835, 0.390528, 0.509735],
|
||||
[0.158615, 0.306186, 0.545724, 0.638732, 0.444399],
|
||||
[0.017338, 0.311012, 0.517721, 0.513469, 0.3149],
|
||||
[0.038453, 0.379786, 0.268158, 0.056493, 0.164402]],
|
||||
|
||||
[[0.472984, 0.421367, 0.287657, 0.348233, 0.749952],
|
||||
[0.459807, 0.528248, 0.587512, 0.578409, 0.487329],
|
||||
[0.423382, 0.536753, 0.640338, 0.603688, 0.389843],
|
||||
[0.37585, 0.444045, 0.428525, 0.415642, 0.489989]],
|
||||
|
||||
[[0.530777, 0.475816, 0.232478, 0.305938, 0.99017],
|
||||
[0.760999, 0.75031, 0.6293, 0.518087, 0.530259],
|
||||
[0.829425, 0.762494, 0.762955, 0.693907, 0.464787],
|
||||
[0.713248, 0.508304, 0.588892, 0.774792, 0.815576]]]]], dtype=np.float32)
|
||||
output_5d_1 = net(Tensor(input_5d), size=size)
|
||||
assert np.allclose(output_5d_1.asnumpy(), except_5d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 5D size align_corners=True
|
||||
except_5d_2 = np.array([[[[[0.415192, 0.36347, 0.342835, 0.382013, 0.509735],
|
||||
[0.14151, 0.313611, 0.55925, 0.670653, 0.440044],
|
||||
[0.01593, 0.327176, 0.534358, 0.559577, 0.324933],
|
||||
[0.038453, 0.404166, 0.268158, 0.048785, 0.164402]],
|
||||
|
||||
[[0.472984, 0.41768, 0.287657, 0.319539, 0.749952],
|
||||
[0.458928, 0.540834, 0.607502, 0.602607, 0.469821],
|
||||
[0.426551, 0.551246, 0.654459, 0.632871, 0.383167],
|
||||
[0.37585, 0.448916, 0.428525, 0.410332, 0.489989]],
|
||||
|
||||
[[0.530777, 0.47189, 0.232478, 0.257064, 0.99017],
|
||||
[0.776347, 0.768057, 0.655755, 0.534561, 0.499599],
|
||||
[0.837171, 0.775316, 0.774559, 0.706165, 0.441401],
|
||||
[0.713248, 0.493665, 0.588892, 0.771879, 0.815576]]]]],
|
||||
dtype=np.float32)
|
||||
output_5d_2 = net(Tensor(input_5d), size=size, align_corners=True)
|
||||
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 3. 5D size scale_factor
|
||||
scale_factor = (1.5, 1.2, 0.5)
|
||||
except_5d_3 = np.array([[[[[0.380711, 0.424587],
|
||||
[0.194011, 0.678381],
|
||||
[0.282262, 0.087324]],
|
||||
|
||||
[[0.436115, 0.46301],
|
||||
[0.552241, 0.606012],
|
||||
[0.424561, 0.436884]],
|
||||
|
||||
[[0.491519, 0.501433],
|
||||
[0.910471, 0.533643],
|
||||
[0.566859, 0.786444]]]]], dtype=np.float32)
|
||||
output_5d_3 = net(Tensor(input_5d), scale_factor=scale_factor)
|
||||
assert np.allclose(output_5d_3.asnumpy(), except_5d_3, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 4. 5D size scale_factor align_corners=True
|
||||
except_5d_4 = np.array([[[[[0.415192, 0.509735],
|
||||
[0.004669, 0.405198],
|
||||
[0.038453, 0.164402]],
|
||||
|
||||
[[0.472984, 0.749952],
|
||||
[0.451901, 0.329755],
|
||||
[0.37585, 0.489989]],
|
||||
|
||||
[[0.530777, 0.99017],
|
||||
[0.899132, 0.254313],
|
||||
[0.713248, 0.815576]]]]], dtype=np.float32)
|
||||
output_5d_4 = net(Tensor(input_5d), scale_factor=scale_factor, align_corners=True)
|
||||
assert np.allclose(output_5d_4.asnumpy(), except_5d_4, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 5. 5D size scale_factor align_corners=True recompute_scale_factor=True
|
||||
except_5d_5 = np.array([[[[[0.415192, 0.509735],
|
||||
[0.004669, 0.405198],
|
||||
[0.038453, 0.164402]],
|
||||
|
||||
[[0.472984, 0.749952],
|
||||
[0.451901, 0.329755],
|
||||
[0.37585, 0.489989]],
|
||||
|
||||
[[0.530777, 0.99017],
|
||||
[0.899132, 0.254313],
|
||||
[0.713248, 0.815576]]]]], dtype=np.float32)
|
||||
output_5d_5 = net(Tensor(input_5d), scale_factor=scale_factor, align_corners=True, recompute_scale_factor=True)
|
||||
assert np.allclose(output_5d_5.asnumpy(), except_5d_5, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_trilinear_gpu(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate trilinear mode
|
||||
1. 5D size
|
||||
2. 5D size align_corners=True
|
||||
3. 5D scale_factor failed I6D8EN
|
||||
4. 5D scale_factor align_corners=True failed I6D8EN
|
||||
5. 5D scale_factor align_corners=True recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("trilinear")
|
||||
# 1. 5D size (1, 1, 2, 3, 4)
|
||||
input_5d = np.array([[[[[0.415192, 0.34623, 0.339439, 0.509735],
|
||||
[0.004669, 0.383352, 0.951564, 0.405198],
|
||||
[0.038453, 0.526071, 0.010246, 0.164402]],
|
||||
|
||||
[[0.530777, 0.452261, 0.012696, 0.99017],
|
||||
[0.899132, 0.921811, 0.812974, 0.254313],
|
||||
[0.713248, 0.420471, 0.757313, 0.815576]]]]], dtype=np.float32)
|
||||
size = (3, 4, 5)
|
||||
except_5d_1 = np.array([[[[[0.415192, 0.366919, 0.342835, 0.390528, 0.509735],
|
||||
[0.158615, 0.306186, 0.545724, 0.638732, 0.444399],
|
||||
[0.017338, 0.311012, 0.517721, 0.513469, 0.3149],
|
||||
[0.038453, 0.379786, 0.268158, 0.056493, 0.164402]],
|
||||
|
||||
[[0.472984, 0.421367, 0.287657, 0.348233, 0.749952],
|
||||
[0.459807, 0.528248, 0.587512, 0.578409, 0.487329],
|
||||
[0.423382, 0.536753, 0.640338, 0.603688, 0.389843],
|
||||
[0.37585, 0.444045, 0.428525, 0.415642, 0.489989]],
|
||||
|
||||
[[0.530777, 0.475816, 0.232478, 0.305938, 0.99017],
|
||||
[0.760999, 0.75031, 0.6293, 0.518087, 0.530259],
|
||||
[0.829425, 0.762494, 0.762955, 0.693907, 0.464787],
|
||||
[0.713248, 0.508304, 0.588892, 0.774792, 0.815576]]]]], dtype=np.float32)
|
||||
output_5d_1 = net(Tensor(input_5d), size=size)
|
||||
assert np.allclose(output_5d_1.asnumpy(), except_5d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 5D size align_corners=True
|
||||
except_5d_2 = np.array([[[[[0.415192, 0.36347, 0.342835, 0.382013, 0.509735],
|
||||
[0.14151, 0.313611, 0.55925, 0.670653, 0.440044],
|
||||
[0.01593, 0.327176, 0.534358, 0.559577, 0.324933],
|
||||
[0.038453, 0.404166, 0.268158, 0.048785, 0.164402]],
|
||||
|
||||
[[0.472984, 0.41768, 0.287657, 0.319539, 0.749952],
|
||||
[0.458928, 0.540834, 0.607502, 0.602607, 0.469821],
|
||||
[0.426551, 0.551246, 0.654459, 0.632871, 0.383167],
|
||||
[0.37585, 0.448916, 0.428525, 0.410332, 0.489989]],
|
||||
|
||||
[[0.530777, 0.47189, 0.232478, 0.257064, 0.99017],
|
||||
[0.776347, 0.768057, 0.655755, 0.534561, 0.499599],
|
||||
[0.837171, 0.775316, 0.774559, 0.706165, 0.441401],
|
||||
[0.713248, 0.493665, 0.588892, 0.771879, 0.815576]]]]],
|
||||
dtype=np.float32)
|
||||
output_5d_2 = net(Tensor(input_5d), size=size, align_corners=True)
|
||||
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
scale_factor = (1.5, 1.2, 0.5)
|
||||
# 5. 5D size scale_factor align_corners=True recompute_scale_factor=True
|
||||
except_5d_5 = np.array([[[[[0.415192, 0.509735],
|
||||
[0.004669, 0.405198],
|
||||
[0.038453, 0.164402]],
|
||||
|
||||
[[0.472984, 0.749952],
|
||||
[0.451901, 0.329755],
|
||||
[0.37585, 0.489989]],
|
||||
|
||||
[[0.530777, 0.99017],
|
||||
[0.899132, 0.254313],
|
||||
[0.713248, 0.815576]]]]], dtype=np.float32)
|
||||
output_5d_5 = net(Tensor(input_5d), scale_factor=scale_factor, align_corners=True, recompute_scale_factor=True)
|
||||
assert np.allclose(output_5d_5.asnumpy(), except_5d_5, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_nearest_exact_3d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate nearest-exact mode
|
||||
1. 3D size
|
||||
2. 3D scale_factor recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("nearest-exact")
|
||||
|
||||
# 1. 3D(1, 3, 5) size=8
|
||||
input_3d = np.array([[[0.816117, 0.004037, 0.746452, 0.137449, 0.337593],
|
||||
[0.970709, 0.558792, 0.053919, 0.734102, 0.432973],
|
||||
[0.830186, 0.77753, 0.384094, 0.905231, 0.76362]]], dtype=np.float32)
|
||||
except_3d_1 = np.array([[[0.816117, 0.816117, 0.004037, 0.746452, 0.746452, 0.137449,
|
||||
0.337593, 0.337593],
|
||||
[0.970709, 0.970709, 0.558792, 0.053919, 0.053919, 0.734102,
|
||||
0.432973, 0.432973],
|
||||
[0.830186, 0.830186, 0.77753, 0.384094, 0.384094, 0.905231,
|
||||
0.76362, 0.76362]]], dtype=np.float32)
|
||||
size = 8
|
||||
output_3d_1 = net(Tensor(input_3d), size=size)
|
||||
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 3D(1, 3, 5) scale_factor=0.7 recompute_scale_factor=True
|
||||
except_3d_2 = np.array([[[0.816117, 0.746452, 0.337593],
|
||||
[0.970709, 0.053919, 0.432973],
|
||||
[0.830186, 0.384094, 0.76362]]], dtype=np.float32)
|
||||
scale_factor = 0.7
|
||||
output_3d_2 = net(Tensor(input_3d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_nearest_exact_4d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate nearest-exact mode
|
||||
Expectation: success
|
||||
1. 4D size
|
||||
2. 4D scale_factor recompute_scale_factor=True
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("nearest-exact")
|
||||
|
||||
# 1. 4D(1, 3, 3, 5) size=(5, 8)
|
||||
size = (5, 8)
|
||||
input_4d = np.array([[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
|
||||
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
|
||||
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
|
||||
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]], dtype=np.float32)
|
||||
except_4d_1 = np.array([[[[0.4992, 0.4992, 0.743079, 0.570383, 0.570383, 0.942855,
|
||||
0.833395, 0.833395],
|
||||
[0.645754, 0.645754, 0.485126, 0.957497, 0.957497, 0.933144,
|
||||
0.556276, 0.556276],
|
||||
[0.298626, 0.298626, 0.594928, 0.150964, 0.150964, 0.447654,
|
||||
0.267512, 0.267512],
|
||||
[0.298626, 0.298626, 0.594928, 0.150964, 0.150964, 0.447654,
|
||||
0.267512, 0.267512],
|
||||
[0.46943, 0.46943, 0.419558, 0.665492, 0.665492, 0.414906,
|
||||
0.708427, 0.708427]]]], dtype=np.float32)
|
||||
output_4d_1 = net(Tensor(input_4d), size=size)
|
||||
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 4D(1, 3, 3, 5) scale_factor=(1.5, 0.4) recompute_scale_factor=True
|
||||
scale_factor = (1.5, 0.4)
|
||||
except_4d_2 = np.array([[[[0.743079, 0.942855],
|
||||
[0.485126, 0.933144],
|
||||
[0.485126, 0.933144],
|
||||
[0.594928, 0.447654],
|
||||
[0.419558, 0.414906],
|
||||
[0.419558, 0.414906]]]], dtype=np.float32)
|
||||
output_4d_2 = net(Tensor(input_4d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_nearest_3d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate nearest mode
|
||||
1. 3D size
|
||||
2. 3D scale_factor recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("nearest")
|
||||
|
||||
# 1. 3D(1, 3, 5) size=8
|
||||
size = 8
|
||||
input_3d = np.array([[[0.816117, 0.004037, 0.746452, 0.137449, 0.337593],
|
||||
[0.970709, 0.558792, 0.053919, 0.734102, 0.432973],
|
||||
[0.830186, 0.77753, 0.384094, 0.905231, 0.76362]]], dtype=np.float32)
|
||||
except_3d_1 = np.array([[[0.816117, 0.816117, 0.004037, 0.004037, 0.746452, 0.137449,
|
||||
0.137449, 0.337593],
|
||||
[0.970709, 0.970709, 0.558792, 0.558792, 0.053919, 0.734102,
|
||||
0.734102, 0.432973],
|
||||
[0.830186, 0.830186, 0.77753, 0.77753, 0.384094, 0.905231,
|
||||
0.905231, 0.76362]]], dtype=np.float32)
|
||||
output_3d_1 = net(Tensor(input_3d), size=size)
|
||||
assert np.allclose(output_3d_1.asnumpy(), except_3d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 3D(1, 3, 5) scale_factor=0.7 recompute_scale_factor=True
|
||||
except_3d_2 = np.array([[[0.816117, 0.004037, 0.137449],
|
||||
[0.970709, 0.558792, 0.734102],
|
||||
[0.830186, 0.77753, 0.905231]]], dtype=np.float32)
|
||||
scale_factor = 0.7
|
||||
output_3d_2 = net(Tensor(input_3d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_3d_2.asnumpy(), except_3d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_nearest_4d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate nearest mode
|
||||
1. 4D size
|
||||
2. 4D scale_factor recompute_scale_factor=True
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("nearest")
|
||||
|
||||
# 1. 4D(1, 3, 3, 5) size=(5, 8)
|
||||
size = (5, 8)
|
||||
input_4d = np.array([[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
|
||||
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
|
||||
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
|
||||
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]], dtype=np.float32)
|
||||
except_4d_1 = np.array([[[[0.4992, 0.4992, 0.743079, 0.743079, 0.570383, 0.942855,
|
||||
0.942855, 0.833395],
|
||||
[0.4992, 0.4992, 0.743079, 0.743079, 0.570383, 0.942855,
|
||||
0.942855, 0.833395],
|
||||
[0.645754, 0.645754, 0.485126, 0.485126, 0.957497, 0.933144,
|
||||
0.933144, 0.556276],
|
||||
[0.298626, 0.298626, 0.594928, 0.594928, 0.150964, 0.447654,
|
||||
0.447654, 0.267512],
|
||||
[0.46943, 0.46943, 0.419558, 0.419558, 0.665492, 0.414906,
|
||||
0.414906, 0.708427]]]], dtype=np.float32)
|
||||
output_4d_1 = net(Tensor(input_4d), size=size)
|
||||
assert np.allclose(output_4d_1.asnumpy(), except_4d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 2. 4D(1, 3, 3, 5) scale_factor=(1.5, 0.4) recompute_scale_factor=True
|
||||
except_4d_2 = np.array([[[[0.4992, 0.570383],
|
||||
[0.4992, 0.570383],
|
||||
[0.645754, 0.957497],
|
||||
[0.298626, 0.150964],
|
||||
[0.298626, 0.150964],
|
||||
[0.46943, 0.665492]]]], dtype=np.float32)
|
||||
scale_factor = (1.5, 0.4)
|
||||
output_4d_2 = net(Tensor(input_4d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_4d_2.asnumpy(), except_4d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_interpolate_nearest_5d(mode):
|
||||
"""
|
||||
Feature: interpolate
|
||||
Description: Verify the result of interpolate nearest mode
|
||||
1. 5D size
|
||||
2. 5D scale_factor recompute_scale_factor=True
|
||||
3. 5D scale_factor
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net("nearest")
|
||||
|
||||
# 1. 5D(1, 1, 1, 4, 5) size=(2, 4, 3)
|
||||
size = (2, 4, 3)
|
||||
input_5d = np.array([[[[[0.4992, 0.743079, 0.570383, 0.942855, 0.833395],
|
||||
[0.645754, 0.485126, 0.957497, 0.933144, 0.556276],
|
||||
[0.298626, 0.594928, 0.150964, 0.447654, 0.267512],
|
||||
[0.46943, 0.419558, 0.665492, 0.414906, 0.708427]]]]], dtype=np.float32)
|
||||
except_5d_1 = np.array([[[[[0.4992, 0.743079, 0.942855],
|
||||
[0.645754, 0.485126, 0.933144],
|
||||
[0.298626, 0.594928, 0.447654],
|
||||
[0.46943, 0.419558, 0.414906]],
|
||||
[[0.4992, 0.743079, 0.942855],
|
||||
[0.645754, 0.485126, 0.933144],
|
||||
[0.298626, 0.594928, 0.447654],
|
||||
[0.46943, 0.419558, 0.414906]]]]], dtype=np.float32)
|
||||
output_5d_1 = net(Tensor(input_5d), size=size)
|
||||
assert np.allclose(output_5d_1.asnumpy(), except_5d_1, atol=1e-5, rtol=1e-5)
|
||||
|
||||
scale_factor = (3., 0.4, 0.7)
|
||||
# 2. 5D(1, 1, 1, 4, 5) scale_factor=(3, 0.4, 0.7) recompute_scale_factor=True
|
||||
except_5d_2 = np.array([[[[[0.4992, 0.743079, 0.942855]],
|
||||
[[0.4992, 0.743079, 0.942855]],
|
||||
[[0.4992, 0.743079, 0.942855]]]]], dtype=np.float32)
|
||||
output_5d_2 = net(Tensor(input_5d), scale_factor=scale_factor, recompute_scale_factor=True)
|
||||
assert np.allclose(output_5d_2.asnumpy(), except_5d_2, atol=1e-5, rtol=1e-5)
|
||||
|
||||
# 3. 5D(1, 1, 1, 4, 5) scale_factor=(3, 0.4, 0.7)
|
||||
except_5d_3 = np.array([[[[[0.4992, 0.743079, 0.570383]],
|
||||
[[0.4992, 0.743079, 0.570383]],
|
||||
[[0.4992, 0.743079, 0.570383]]]]], dtype=np.float32)
|
||||
output_5d_3 = net(Tensor(input_5d), scale_factor=scale_factor)
|
||||
assert np.allclose(output_5d_3.asnumpy(), except_5d_3, atol=1e-5, rtol=1e-5)
|
Loading…
Reference in New Issue