!41037 [lite]CropAndResize support vmap and functional

Merge pull request !41037 from 徐安越/master4
This commit is contained in:
i-robot 2022-09-30 03:43:29 +00:00 committed by Gitee
commit 7c99522569
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 241 additions and 92 deletions

View File

@ -25,6 +25,7 @@ mindspore.ops.function
mindspore.ops.ctc_greedy_decoder
mindspore.ops.conv2d
mindspore.ops.conv3d
mindspore.ops.crop_and_resize
mindspore.ops.deformable_conv2d
mindspore.ops.dropout2d
mindspore.ops.dropout3d

View File

@ -1,5 +1,5 @@
mindspore.ops.CropAndResize
============================
===========================
.. py:class:: mindspore.ops.CropAndResize(method="bilinear", extrapolation_value=0.0)
@ -22,6 +22,16 @@ mindspore.ops.CropAndResize
四维Tensor其shape为 :math:`(num\_boxes, crop\_height, crop\_width, depth)` 数据类型类型为float32。
异常:
- **TypeError** - 如果 `method` 不是str。
- **TypeError** - 如果 `extrapolation_value` 不是float且取值不是"bilinear"、"nearest"或"bilinear_v2"。
- **ValueError** - 如果 `method` 不是'bilinear'、 'nearest'或者'bilinear_v2'。
- **TypeError** - `x``boxes``box_index` 不是Tensor。
- **TypeError** - `crop_size` 不是int32的2元组。
- **TypeError** - `boxes` 的数据类型不是float 或者,`box_index` 的数据类型不是int32。
- **TypeError** - `method` 不是字符串。
- **TypeError** - `extrapolation_value` 不是浮点值。
- **ValueError** - `x` 的维度不是4维。
- **ValueError** - `boxes` 的纬度不是2维。
- **ValueError** - `boxes` 的第2维不是4。
- **ValueError** - `box_index` 的维度不是1维。
- **ValueError** - `box_index` 的第1维与 `boxes` 的第1维不相等。
- **ValueError** - `box_index` 存在元素不在 `[0, batch)` 的范围内.
- **ValueError** - `crop_size` 的数据不是正整数.
- **ValueError** - `method` 不是 "bilinear"、"nearest"、"bilinear_v2"之一。

View File

@ -0,0 +1,36 @@
mindspore.ops.crop_and_resize
=============================
.. py:function:: mindspore.ops.crop_and_resize(x, boxes, box_indices, crop_size, method="bilinear", extrapolation_value=0.0)
对输入图像进行裁剪并调整其大小。
.. note::
输出的shape依赖`crop_size`, `crop_size` 必须为常量。
当前该算子的反向仅支持"bilinear"模式其他模式将会返回0。
参数:
- **x** (Tensor) - shape为 :math:`(batch, image_height, image_width, depth)` 的图像Tensor。数据类型int8, int16, int32, int64, float16, float32, float64, uint8, uint16。
- **boxes** (Tensor) - shape为 :math:`(num_boxes, 4)` 的2维Tensor。其中:math:`i` 行指定对第 :math:`\text{box_indices[i]}` 张图像裁剪时的归一化坐标 :math:`[y1, x1, y2, x2]`,那么通过归一化的 :math:`y` 坐标值可映射到的图像坐标为 :math:`y * (image\_height - 1)`,因此,归一化的图像高度 :math:`[0, 1]` 间隔映射到的图像高度间隔为 :math:`[0, image\_height - 1]`。我们也允许 :math:`y1 > y2`,这种情况下,就是对图像进行的上下翻转,宽度方向与此类似。同时,我们也允许归一化的坐标值超出 :math:`[0, 1]` 的区间,这种情况下,采用 :math:`\text{extrapolation_value}` 进行填充。数据类型float32。
- **box_indices** (Tensor) - shape为 :math:`(num_boxes)` 的1维Tensor其中每一个元素必须是 :math:`[0, batch)` 区间内的值。:math:`\test{box_indices[i]}` 指定 :math:`\test{boxes[i, :]}` 所指向的图像索引。数据类型int32。
- **crop_size** (Tuple[int]) - 2元组 :math:`(crop_height, crop_width)`该输入必须为常量并且均为正值。指定对裁剪出的图像进行调整时的输出大小纵横比可与原图不一致。数据类型int32。
- **method** (str): 指定调整大小时的采样方法,取值为"bilinear"、 "nearest"或"bilinear_v2",其中,"bilinear"是标准的线性插值算法,而在某些情况下,"bilinear_v2"可能会得到更优的效果。默认值:"bilinear"。
- **extrapolation_value** (float): 指定外插时的浮点值。默认值: 0.0。
返回:
Tensorshape为 :math:`(num_boxes, crop_height, crop_width, depth)`数据类型float32 。
异常:
- **TypeError** - `x``boxes``box_indices` 不是Tensor。
- **TypeError** - `crop_size` 不是int32的2元组。
- **TypeError** - `boxes` 的数据类型不是float 或者,`box_indices` 的数据类型不是int32。
- **TypeError** - `method` 不是字符串。
- **TypeError** - `extrapolation_value` 不是浮点值。
- **ValueError** - `x` 的维度不是4维。
- **ValueError** - `boxes` 的纬度不是2维。
- **ValueError** - `boxes` 的第2维不是4。
- **ValueError** - `box_indices` 的维度不是1维。
- **ValueError** - `box_indices` 的第1维与 `boxes` 的第1维不相等。
- **ValueError** - `box_indices` 存在元素不在 `[0, batch)` 的范围内.
- **ValueError** - `crop_size` 的数据不是正整数.
- **ValueError** - `method` 不是 "bilinear"、"nearest"、"bilinear_v2"之一。

View File

@ -252,6 +252,7 @@ BuiltInTypeMap &GetMethodMap() {
{"argmax", std::string("argmax")}, // P.Argmax()
{"argmin", std::string("argmin")}, // P.Argmax()
{"resize", std::string("resize")}, // P.Reshape()
{"crop_and_resize", std::string("crop_and_resize")}, // P.crop_and_resize
{"select", std::string("select")}, // P.Select()
{"choose", std::string("choose")}, // P.Select()
{"diagonal", std::string("diagonal")}, // P.Eye()

View File

@ -34,9 +34,8 @@ class CropAndResizeInfer : public abstract::OpInferBase {
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(
input_args.size() == kCropAndResizeInputSize,
"For primitive[" + prim_name + "], [input number] must be 4 but got " + std::to_string(input_args.size()));
(void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast<int64_t>(input_args.size()), kEqual,
kCropAndResizeInputSize, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -59,17 +58,11 @@ class CropAndResizeInfer : public abstract::OpInferBase {
abstract::Shape::kShapeDimAny, abstract::Shape::kShapeDimAny});
}
size_t batch_rank = 0;
if (primitive->HasAttr(kBatchRank)) {
batch_rank = GetValue<int64_t>(primitive->GetAttr(kBatchRank));
}
MS_EXCEPTION_IF_CHECK_FAIL(x_shape.size() == kShapeRank4 + batch_rank,
"For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " +
std::to_string(static_cast<int>(x_shape.size()) - static_cast<int>(batch_rank)) + ".");
auto x_dims = static_cast<int64_t>(x_shape.size());
(void)CheckAndConvertUtils::CheckInteger("[x shape-length]", x_dims, kEqual, kShapeRank4, prim_name);
int64_t out_channel = x_shape.back();
std::vector<int64_t> batch_shape(x_shape.begin(), x_shape.begin() + static_cast<int>(batch_rank));
auto num_boxes = ParseNumBoxes(box_shape, box_index_shape, prim_name, batch_shape);
auto num_boxes = ParseNumBoxes(box_shape, box_index_shape, prim_name);
auto crop_size_type = input_args[kInputIndex3]->BuildType();
MS_EXCEPTION_IF_CHECK_FAIL(crop_size_type != nullptr,
"For primitive[" + prim_name + "], the [crop_size TypeId] is a nullptr.");
@ -79,8 +72,7 @@ class CropAndResizeInfer : public abstract::OpInferBase {
crop_size = CheckAndConvertUtils::CheckTensorIntValue("crop_size", value_ptr, prim_name);
} else if (IsIdentidityOrSubclass(crop_size_type, kTuple)) {
auto value_tuple = value_ptr->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_CHECK_FAIL(value_tuple != nullptr,
"For primitive[" + prim_name + "], the [crop_size] must a Tuple.");
MS_EXCEPTION_IF_NULL(value_tuple);
auto &elements = value_tuple->value();
for (const auto &element : elements) {
if (element->isa<Int64Imm>()) {
@ -88,18 +80,20 @@ class CropAndResizeInfer : public abstract::OpInferBase {
} else {
auto type = element->type();
std::string real_type_str = type == nullptr ? "Unknown." : type->ToString() + ".";
MS_LOG(EXCEPTION) << ("For primitive[" + prim_name +
"], the [crop_size] must be a tuple with two Int elements, but got " + real_type_str);
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
<< "], the [crop_size] must be a tuple with two Int elements, but got "
<< real_type_str;
}
}
} else {
MS_LOG(EXCEPTION) << ("For primitive[" + prim_name +
"], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got " +
crop_size_type->ToString());
MS_EXCEPTION(TypeError) << "For primitive[" + prim_name
<< "], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got "
<< crop_size_type->ToString();
}
CheckAndConvertUtils::Check("crop_size length", crop_size.size(), kEqual, kShapeRank2, prim_name);
CheckAndConvertUtils::Check("crop height", crop_size[0], kGreaterThan, 0, prim_name);
CheckAndConvertUtils::Check("crop weight", crop_size.back(), kGreaterThan, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[crop_size length]", static_cast<int64_t>(crop_size.size()), kEqual,
kShapeRank2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[crop height]", crop_size[0], kGreaterThan, 0, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[crop width]", crop_size.back(), kGreaterThan, 0, prim_name);
ShapeVector out_shape = {num_boxes, crop_size[0], crop_size.back(), out_channel};
return std::make_shared<abstract::Shape>(out_shape);
}
@ -107,9 +101,8 @@ class CropAndResizeInfer : public abstract::OpInferBase {
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kCropAndResizeInputSize,
"For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " +
std::to_string(input_args.size()) + ".");
(void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast<int64_t>(input_args.size()), kEqual,
kCropAndResizeInputSize, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
@ -124,39 +117,27 @@ class CropAndResizeInfer : public abstract::OpInferBase {
}
protected:
int64_t ParseNumBoxes(const ShapeVector &box_shape, const ShapeVector &box_index_shape, const std::string &prim_name,
const std::vector<int64_t> &batch_shape) const {
size_t batch_rank = batch_shape.size();
MS_EXCEPTION_IF_CHECK_FAIL(box_shape.size() == kShapeRank2 + batch_rank,
"For primitive[" + prim_name + "], the [boxes shape-length] should be 2, bug got " +
std::to_string(static_cast<int>(box_shape.size()) - static_cast<int>(batch_rank)) +
".");
MS_EXCEPTION_IF_CHECK_FAIL(
batch_shape == std::vector<int64_t>(box_shape.begin(), box_shape.begin() + batch_rank),
"For primitive[" + prim_name + "], the [batch_shape] of boxes is not equal to that of input.");
MS_EXCEPTION_IF_CHECK_FAIL(box_shape.back() == kLimitValue4, "For primitive[" + prim_name +
"], the [boxes second-dim] must be 4, but got " +
std::to_string(box_shape.back()) + ".");
int64_t ParseNumBoxes(const ShapeVector &box_shape, const ShapeVector &box_index_shape,
const std::string &prim_name) const {
int64_t box_dims = static_cast<int64_t>(box_shape.size());
(void)CheckAndConvertUtils::CheckInteger("[boxes shape-length]", box_dims, kEqual, kShapeRank2, prim_name);
(void)CheckAndConvertUtils::CheckInteger("[boxes second-dim]", box_shape.back(), kEqual, kLimitValue4, prim_name);
MS_EXCEPTION_IF_CHECK_FAIL(
box_index_shape.size() == 1 + batch_rank,
"For primitive[" + prim_name + "], the [box_index shape-length] should be 1, bug got " +
std::to_string(static_cast<int>(box_index_shape.size()) - static_cast<int>(batch_rank)) + ".");
MS_EXCEPTION_IF_CHECK_FAIL(
batch_shape == std::vector<int64_t>(box_index_shape.begin(), box_index_shape.begin() + batch_rank),
"For primitive[" + prim_name + "], the [batch_shape] of box_index is not equal to that of input.");
MS_EXCEPTION_IF_CHECK_FAIL(
box_shape[batch_rank] == box_index_shape[batch_rank],
"For primitive[" + prim_name + "], the [boxes first-dim] must be equal to [box_index first-dim], but got " +
std::to_string(box_shape[batch_rank]) + " vs " + std::to_string(box_index_shape[batch_rank]) + ".");
return box_shape[batch_rank];
int64_t box_index_dims = static_cast<int64_t>(box_index_shape.size());
(void)CheckAndConvertUtils::CheckInteger("[box_index shape-length]", box_index_dims, kEqual, 1, prim_name);
if (box_shape[0] != box_index_shape[0]) {
MS_EXCEPTION(ValueError) << "For primitive[" + prim_name +
"], the [boxes first-dim] must be equal to [box_index first-dim], but got " +
std::to_string(box_shape[0]) + " vs " + std::to_string(box_index_shape[0]) + ".";
}
return box_shape[0];
}
private:
const int64_t kLimitValue4 = 4;
const size_t kCropAndResizeInputSize = 4;
const size_t kShapeRank2 = 2;
const size_t kShapeRank4 = 4;
const int64_t kCropAndResizeInputSize = 4;
const int64_t kShapeRank2 = 2;
const int64_t kShapeRank4 = 4;
};
void CropAndResize::Init(ResizeMethod method, float extrapolation_value) {

View File

@ -16,6 +16,7 @@
"""image_ops vmap impl."""
from __future__ import absolute_import
import mindspore.numpy as mnp
from mindspore.ops import functional as F
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import image_ops as IMG
@ -83,3 +84,47 @@ def get_resize_grad_dynamic_rule(prim, axis_size):
return out, 0
return vmap_rule
@vmap_rules_getters.register(IMG.CropAndResize)
def get_crop_and_resize_vmap_rule(prim, axis_size):
"""VmapRule for `CropAndResize` operation."""
def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim):
is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim)
if is_all_none:
return result
x, x_dim = x_bdim
boxes, boxes_dim = boxes_bdim
box_indices, box_indices_dim = box_indices_bdim
crop_size, crop_size_dim = crop_size_bdim
if crop_size_dim is not None:
_raise_value_error(
"The axis of `crop_size` in `{}` must be None, but got {}.".format(prim.name, crop_size_dim))
boxes = _bdim_at_front(boxes, boxes_dim, axis_size)
box_indices = _bdim_at_front(box_indices, box_indices_dim, axis_size)
boxes = F.reshape(boxes, (-1, 4))
num_boxes = F.shape(box_indices)[-1]
if x_dim is None:
box_indices = F.reshape(box_indices, (-1,))
out = prim(x, boxes, box_indices, crop_size)
else:
x = _bdim_at_front(x, x_dim, axis_size)
x_shape = F.shape(x)
x = F.reshape(x, (-1,) + x_shape[2:])
counts = mnp.arange(0, axis_size * x_shape[1], x_shape[1])
counts = F.reshape(counts, (axis_size, 1))
counts = F.broadcast_to(counts, (axis_size, num_boxes))
box_indices = F.add(box_indices, counts)
box_indices = F.reshape(box_indices, (-1,))
out = prim(x, boxes, box_indices, crop_size)
out_shape = F.shape(out)
out = F.reshape(out, (-1, num_boxes) + out_shape[1:])
return out, 0
return vmap_rule

View File

@ -390,7 +390,8 @@ from .debug_func import (
from .image_func import (
bounding_box_decode,
bounding_box_encode,
check_valid
check_valid,
crop_and_resize
)
from .spectral_func import (
blackman_window

View File

@ -134,10 +134,116 @@ def check_valid(bboxes, img_metas):
return check_valid_op(bboxes, img_metas)
def crop_and_resize(x, boxes, box_indices, crop_size, method="bilinear", extrapolation_value=0.0):
"""
Extracts crops from the input image tensor and resizes them.
Note:
In case that the output shape depends on crop_size, the crop_size must be constant.
For now, the backward of the operator only support bilinear method, for other methods, will return 0.
Args:
x (Tensor): The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth].
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
boxes (Tensor): A 2-D tensor of shape [num_boxes, 4].
The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image
and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to
the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is
mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled
crop is an up-down flipped version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to
extrapolate the input image values. Types allowed: float32.
box_indices (Tensor): A 1-D tensor of shape [num_boxes] with int32 values in [0, batch).
The value of box_ind[i] specifies the image that the i-th box refers to. Types allowed: int32.
crop_size (Tuple[int]): A tuple of two int32 elements: (crop_height, crop_width).
Only constant value is allowed. All cropped image patches are resized to this size.
The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive.
method (str): An optional string that specifies the sampling method for resizing.
It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear
interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear"
extrapolation_value (float): An optional float value used extrapolation, if applicable. Default: 0.0.
Returns:
A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32.
Raises:
TypeError: If `x` or `boxes` or `box_indices` is not a Tensor.
TypeError: If `crop_size` is not a Tuple with two int32 elements.
TypeError: If dtype of `boxes` is not float or that of `box_indices` is not int.
TypeError: If `method` is not a str.
TypeError: If `extrapolation_value` is not a float.
ValueError: If the shape rank of `x` is not 4.
ValueError: If the shape rank of `boxes` is not 2.
ValueError: If the second dim of `boxes` is not 4.
ValueError: If the shape rank of `box_indices` is not 1.
ValueError: If the first dim of `box_indices` is not equal to that of `boxes`.
ValueError: If existing element in `box_indices` is out of range `[0, batch)`.
ValueError: If the data of `crop_size` is not positive.
ValueError: If `method` is not one of 'bilinear', 'nearest', 'bilinear_v2'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> BATCH_SIZE = 1
>>> NUM_BOXES = 5
>>> IMAGE_HEIGHT = 256
>>> IMAGE_WIDTH = 256
>>> CHANNELS = 3
>>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32)
>>> boxes = np.random.uniform(size=[NUM_BOXES, 4]).astype(np.float32)
>>> box_indices = np.random.uniform(size=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32)
>>> crop_size = (24, 24)
>>> output = F.crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_indices), crop_size)
>>> print(output.shape)
(5, 24, 24, 3)
"""
if not isinstance(x, (Tensor, Tensor_)):
raise TypeError("For crop_and_resize, the input x must be a tensor")
if not isinstance(boxes, (Tensor, Tensor_)):
raise TypeError("For crop_and_resize, the input boxes must be a tensor")
if not isinstance(box_indices, (Tensor, Tensor_)):
raise TypeError("For crop_and_resize, the input box_indices must be a tensor")
if not isinstance(crop_size, tuple):
raise TypeError("For crop_and_resize, the input crop_size must be a tuple, but got {}".format(type(crop_size)))
if len(crop_size) != 2:
raise ValueError("For crop_and_resize, the crop_size's length must be 2, bot got {}".format(len(crop_size)))
if not isinstance(crop_size[0], int) or not isinstance(crop_size[1], int):
raise TypeError("For crop_and_resize, the crop_size's value must be int.")
if crop_size[0] <= 0 or crop_size[1] <= 0:
raise ValueError("For crop_and_resize, the crop_size's value must be positive.")
x_shape = x.shape
if len(x_shape) != 4:
raise ValueError("For crop_and_resize, the input x must be 4D Tensor, but got is {}D".format(len(x_shape)))
boxes_dtype = _get_cache_prim(P.DType)()(boxes)
if boxes_dtype not in [mstype.float32]:
raise TypeError(
"For crop_and_resize, the input boxes must be {}, but got {}".format(mstype.float32, boxes_dtype))
boxes_shape = boxes.shape
if len(boxes_shape) != 2 or boxes_shape[-1] != 4:
raise ValueError("For crop_and_resize, the input boxes must be 2D and the second-dim must be 4, "
"but got {}".format(boxes_shape))
box_indices_dtype = _get_cache_prim(P.DType)()(box_indices)
if box_indices_dtype not in [mstype.int32]:
raise TypeError(
"For crop_and_resize, the input box_indices must be {}, but got {}".format(mstype.int32, box_indices_dtype))
box_indices_shape = box_indices.shape
if len(box_indices_shape) != 1:
raise ValueError("For crop_and_resize, the input box_indices must be 1D, but got {}".format(box_indices_shape))
if boxes_shape[0] != box_indices_shape[0]:
raise ValueError("For crop_and_resize, the first dim of input box_indices must be equal to that of input boxes"
", but got {} vs {}".format(box_indices_shape[0], boxes_shape[0]))
_crop_and_resize = _get_cache_prim(IMG.CropAndResize)(method, extrapolation_value)
return _crop_and_resize(x, boxes, box_indices, crop_size)
__all__ = [
'bounding_box_decode',
'bounding_box_encode',
'check_valid'
'check_valid',
'crop_and_resize'
]
__all__.sort()

View File

@ -270,39 +270,7 @@ class CropAndResize(Primitive):
"""
Extracts crops from the input image tensor and resizes them.
Note:
In case that the output shape depends on crop_size, the crop_size must be constant.
For now, the backward of the operator only support bilinear method, for other methods, will return 0.
Args:
method (str, optional): An optional string that specifies the sampling method for resizing.
It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear
interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear"
extrapolation_value (float, optional): An optional float value used extrapolation, if applicable. Default: 0.0.
Inputs:
- **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth].
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
- **boxes** (Tensor) - A 2-D tensor of shape [num_boxes, 4].
The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image
and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to
the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is
mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled
crop is an up-down flipped version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to
extrapolate the input image values. Types allowed: float32.
- **box_index** (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch).
The value of box_ind[i] specifies the image that the i-th box refers to. Types allowed: int32.
- **crop_size** (Tuple[int]) - A tuple of two int32 elements: (crop_height, crop_width).
Only constant value is allowed. All cropped image patches are resized to this size.
The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive.
Outputs:
A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32.
Raises:
TypeError: If `method` is not a str.
TypeError: If `extrapolation_value` is not a float.
ValueError: If `method` is not one of 'bilinear', 'nearest', 'bilinear_v2'.
Refer to :func:`mindspore.ops.crop_and_resize` for more detail.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``