diff --git a/docs/api/api_python/mindspore.ops.function.rst b/docs/api/api_python/mindspore.ops.function.rst index b2a0f44a30d..71140742520 100644 --- a/docs/api/api_python/mindspore.ops.function.rst +++ b/docs/api/api_python/mindspore.ops.function.rst @@ -25,6 +25,7 @@ mindspore.ops.function mindspore.ops.ctc_greedy_decoder mindspore.ops.conv2d mindspore.ops.conv3d + mindspore.ops.crop_and_resize mindspore.ops.deformable_conv2d mindspore.ops.dropout2d mindspore.ops.dropout3d diff --git a/docs/api/api_python/ops/mindspore.ops.CropAndResize.rst b/docs/api/api_python/ops/mindspore.ops.CropAndResize.rst index 544b8934b15..1f31365d6bc 100644 --- a/docs/api/api_python/ops/mindspore.ops.CropAndResize.rst +++ b/docs/api/api_python/ops/mindspore.ops.CropAndResize.rst @@ -1,5 +1,5 @@ mindspore.ops.CropAndResize -============================ +=========================== .. py:class:: mindspore.ops.CropAndResize(method="bilinear", extrapolation_value=0.0) @@ -22,6 +22,16 @@ mindspore.ops.CropAndResize 四维Tensor,其shape为 :math:`(num\_boxes, crop\_height, crop\_width, depth)` ,数据类型类型为float32。 异常: - - **TypeError** - 如果 `method` 不是str。 - - **TypeError** - 如果 `extrapolation_value` 不是float,且取值不是"bilinear"、"nearest"或"bilinear_v2"。 - - **ValueError** - 如果 `method` 不是'bilinear'、 'nearest'或者'bilinear_v2'。 + - **TypeError** - `x`、 `boxes` 或 `box_index` 不是Tensor。 + - **TypeError** - `crop_size` 不是int32的2元组。 + - **TypeError** - `boxes` 的数据类型不是float, 或者,`box_index` 的数据类型不是int32。 + - **TypeError** - `method` 不是字符串。 + - **TypeError** - `extrapolation_value` 不是浮点值。 + - **ValueError** - `x` 的维度不是4维。 + - **ValueError** - `boxes` 的纬度不是2维。 + - **ValueError** - `boxes` 的第2维不是4。 + - **ValueError** - `box_index` 的维度不是1维。 + - **ValueError** - `box_index` 的第1维与 `boxes` 的第1维不相等。 + - **ValueError** - `box_index` 存在元素不在 `[0, batch)` 的范围内. + - **ValueError** - `crop_size` 的数据不是正整数. + - **ValueError** - `method` 不是 "bilinear"、"nearest"、"bilinear_v2"之一。 diff --git a/docs/api/api_python/ops/mindspore.ops.func_crop_and_resize.rst b/docs/api/api_python/ops/mindspore.ops.func_crop_and_resize.rst new file mode 100644 index 00000000000..d0e5a749e68 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_crop_and_resize.rst @@ -0,0 +1,36 @@ +mindspore.ops.crop_and_resize +============================= + +.. py:function:: mindspore.ops.crop_and_resize(x, boxes, box_indices, crop_size, method="bilinear", extrapolation_value=0.0) + + 对输入图像进行裁剪并调整其大小。 + + .. note:: + 输出的shape依赖`crop_size`, `crop_size` 必须为常量。 + 当前该算子的反向仅支持"bilinear"模式,其他模式将会返回0。 + + 参数: + - **x** (Tensor) - shape为 :math:`(batch, image_height, image_width, depth)` 的图像Tensor。数据类型:int8, int16, int32, int64, float16, float32, float64, uint8, uint16。 + - **boxes** (Tensor) - shape为 :math:`(num_boxes, 4)` 的2维Tensor。其中,第 :math:`i` 行指定对第 :math:`\text{box_indices[i]}` 张图像裁剪时的归一化坐标 :math:`[y1, x1, y2, x2]`,那么通过归一化的 :math:`y` 坐标值可映射到的图像坐标为 :math:`y * (image\_height - 1)`,因此,归一化的图像高度 :math:`[0, 1]` 间隔映射到的图像高度间隔为 :math:`[0, image\_height - 1]`。我们也允许 :math:`y1 > y2`,这种情况下,就是对图像进行的上下翻转,宽度方向与此类似。同时,我们也允许归一化的坐标值超出 :math:`[0, 1]` 的区间,这种情况下,采用 :math:`\text{extrapolation_value}` 进行填充。数据类型:float32。 + - **box_indices** (Tensor) - shape为 :math:`(num_boxes)` 的1维Tensor,其中,每一个元素必须是 :math:`[0, batch)` 区间内的值。:math:`\test{box_indices[i]}` 指定 :math:`\test{boxes[i, :]}` 所指向的图像索引。数据类型:int32。 + - **crop_size** (Tuple[int]) - 2元组 :math:`(crop_height, crop_width)`,该输入必须为常量并且均为正值。指定对裁剪出的图像进行调整时的输出大小,纵横比可与原图不一致。数据类型:int32。 + - **method** (str): 指定调整大小时的采样方法,取值为"bilinear"、 "nearest"或"bilinear_v2",其中,"bilinear"是标准的线性插值算法,而在某些情况下,"bilinear_v2"可能会得到更优的效果。默认值:"bilinear"。 + - **extrapolation_value** (float): 指定外插时的浮点值。默认值: 0.0。 + + 返回: + Tensor,shape为 :math:`(num_boxes, crop_height, crop_width, depth)`,数据类型:float32 。 + + 异常: + - **TypeError** - `x`、 `boxes` 或 `box_indices` 不是Tensor。 + - **TypeError** - `crop_size` 不是int32的2元组。 + - **TypeError** - `boxes` 的数据类型不是float, 或者,`box_indices` 的数据类型不是int32。 + - **TypeError** - `method` 不是字符串。 + - **TypeError** - `extrapolation_value` 不是浮点值。 + - **ValueError** - `x` 的维度不是4维。 + - **ValueError** - `boxes` 的纬度不是2维。 + - **ValueError** - `boxes` 的第2维不是4。 + - **ValueError** - `box_indices` 的维度不是1维。 + - **ValueError** - `box_indices` 的第1维与 `boxes` 的第1维不相等。 + - **ValueError** - `box_indices` 存在元素不在 `[0, batch)` 的范围内. + - **ValueError** - `crop_size` 的数据不是正整数. + - **ValueError** - `method` 不是 "bilinear"、"nearest"、"bilinear_v2"之一。 diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 59a165c1923..2879a9bd831 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -252,6 +252,7 @@ BuiltInTypeMap &GetMethodMap() { {"argmax", std::string("argmax")}, // P.Argmax() {"argmin", std::string("argmin")}, // P.Argmax() {"resize", std::string("resize")}, // P.Reshape() + {"crop_and_resize", std::string("crop_and_resize")}, // P.crop_and_resize {"select", std::string("select")}, // P.Select() {"choose", std::string("choose")}, // P.Select() {"diagonal", std::string("diagonal")}, // P.Eye() diff --git a/mindspore/core/ops/crop_and_resize.cc b/mindspore/core/ops/crop_and_resize.cc index 0ddedb3c42e..f008621429d 100644 --- a/mindspore/core/ops/crop_and_resize.cc +++ b/mindspore/core/ops/crop_and_resize.cc @@ -34,9 +34,8 @@ class CropAndResizeInfer : public abstract::OpInferBase { const std::vector &input_args) const override { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - MS_EXCEPTION_IF_CHECK_FAIL( - input_args.size() == kCropAndResizeInputSize, - "For primitive[" + prim_name + "], [input number] must be 4 but got " + std::to_string(input_args.size())); + (void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast(input_args.size()), kEqual, + kCropAndResizeInputSize, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -59,17 +58,11 @@ class CropAndResizeInfer : public abstract::OpInferBase { abstract::Shape::kShapeDimAny, abstract::Shape::kShapeDimAny}); } - size_t batch_rank = 0; - if (primitive->HasAttr(kBatchRank)) { - batch_rank = GetValue(primitive->GetAttr(kBatchRank)); - } - MS_EXCEPTION_IF_CHECK_FAIL(x_shape.size() == kShapeRank4 + batch_rank, - "For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " + - std::to_string(static_cast(x_shape.size()) - static_cast(batch_rank)) + "."); + auto x_dims = static_cast(x_shape.size()); + (void)CheckAndConvertUtils::CheckInteger("[x shape-length]", x_dims, kEqual, kShapeRank4, prim_name); int64_t out_channel = x_shape.back(); - std::vector batch_shape(x_shape.begin(), x_shape.begin() + static_cast(batch_rank)); - auto num_boxes = ParseNumBoxes(box_shape, box_index_shape, prim_name, batch_shape); + auto num_boxes = ParseNumBoxes(box_shape, box_index_shape, prim_name); auto crop_size_type = input_args[kInputIndex3]->BuildType(); MS_EXCEPTION_IF_CHECK_FAIL(crop_size_type != nullptr, "For primitive[" + prim_name + "], the [crop_size TypeId] is a nullptr."); @@ -79,8 +72,7 @@ class CropAndResizeInfer : public abstract::OpInferBase { crop_size = CheckAndConvertUtils::CheckTensorIntValue("crop_size", value_ptr, prim_name); } else if (IsIdentidityOrSubclass(crop_size_type, kTuple)) { auto value_tuple = value_ptr->cast(); - MS_EXCEPTION_IF_CHECK_FAIL(value_tuple != nullptr, - "For primitive[" + prim_name + "], the [crop_size] must a Tuple."); + MS_EXCEPTION_IF_NULL(value_tuple); auto &elements = value_tuple->value(); for (const auto &element : elements) { if (element->isa()) { @@ -88,18 +80,20 @@ class CropAndResizeInfer : public abstract::OpInferBase { } else { auto type = element->type(); std::string real_type_str = type == nullptr ? "Unknown." : type->ToString() + "."; - MS_LOG(EXCEPTION) << ("For primitive[" + prim_name + - "], the [crop_size] must be a tuple with two Int elements, but got " + real_type_str); + MS_EXCEPTION(TypeError) << "For primitive[" << prim_name + << "], the [crop_size] must be a tuple with two Int elements, but got " + << real_type_str; } } } else { - MS_LOG(EXCEPTION) << ("For primitive[" + prim_name + - "], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got " + - crop_size_type->ToString()); + MS_EXCEPTION(TypeError) << "For primitive[" + prim_name + << "], the [crop_size] is must be a Tensor or a Tuple with two Int elements, but got " + << crop_size_type->ToString(); } - CheckAndConvertUtils::Check("crop_size length", crop_size.size(), kEqual, kShapeRank2, prim_name); - CheckAndConvertUtils::Check("crop height", crop_size[0], kGreaterThan, 0, prim_name); - CheckAndConvertUtils::Check("crop weight", crop_size.back(), kGreaterThan, 0, prim_name); + (void)CheckAndConvertUtils::CheckInteger("[crop_size length]", static_cast(crop_size.size()), kEqual, + kShapeRank2, prim_name); + (void)CheckAndConvertUtils::CheckInteger("[crop height]", crop_size[0], kGreaterThan, 0, prim_name); + (void)CheckAndConvertUtils::CheckInteger("[crop width]", crop_size.back(), kGreaterThan, 0, prim_name); ShapeVector out_shape = {num_boxes, crop_size[0], crop_size.back(), out_channel}; return std::make_shared(out_shape); } @@ -107,9 +101,8 @@ class CropAndResizeInfer : public abstract::OpInferBase { TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kCropAndResizeInputSize, - "For primitive[" + prim_name + "], the [x shape-length] should be 4, bug got " + - std::to_string(input_args.size()) + "."); + (void)CheckAndConvertUtils::CheckInteger("[input number]", static_cast(input_args.size()), kEqual, + kCropAndResizeInputSize, prim_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } @@ -124,39 +117,27 @@ class CropAndResizeInfer : public abstract::OpInferBase { } protected: - int64_t ParseNumBoxes(const ShapeVector &box_shape, const ShapeVector &box_index_shape, const std::string &prim_name, - const std::vector &batch_shape) const { - size_t batch_rank = batch_shape.size(); - MS_EXCEPTION_IF_CHECK_FAIL(box_shape.size() == kShapeRank2 + batch_rank, - "For primitive[" + prim_name + "], the [boxes shape-length] should be 2, bug got " + - std::to_string(static_cast(box_shape.size()) - static_cast(batch_rank)) + - "."); - MS_EXCEPTION_IF_CHECK_FAIL( - batch_shape == std::vector(box_shape.begin(), box_shape.begin() + batch_rank), - "For primitive[" + prim_name + "], the [batch_shape] of boxes is not equal to that of input."); - MS_EXCEPTION_IF_CHECK_FAIL(box_shape.back() == kLimitValue4, "For primitive[" + prim_name + - "], the [boxes second-dim] must be 4, but got " + - std::to_string(box_shape.back()) + "."); + int64_t ParseNumBoxes(const ShapeVector &box_shape, const ShapeVector &box_index_shape, + const std::string &prim_name) const { + int64_t box_dims = static_cast(box_shape.size()); + (void)CheckAndConvertUtils::CheckInteger("[boxes shape-length]", box_dims, kEqual, kShapeRank2, prim_name); + (void)CheckAndConvertUtils::CheckInteger("[boxes second-dim]", box_shape.back(), kEqual, kLimitValue4, prim_name); - MS_EXCEPTION_IF_CHECK_FAIL( - box_index_shape.size() == 1 + batch_rank, - "For primitive[" + prim_name + "], the [box_index shape-length] should be 1, bug got " + - std::to_string(static_cast(box_index_shape.size()) - static_cast(batch_rank)) + "."); - MS_EXCEPTION_IF_CHECK_FAIL( - batch_shape == std::vector(box_index_shape.begin(), box_index_shape.begin() + batch_rank), - "For primitive[" + prim_name + "], the [batch_shape] of box_index is not equal to that of input."); - MS_EXCEPTION_IF_CHECK_FAIL( - box_shape[batch_rank] == box_index_shape[batch_rank], - "For primitive[" + prim_name + "], the [boxes first-dim] must be equal to [box_index first-dim], but got " + - std::to_string(box_shape[batch_rank]) + " vs " + std::to_string(box_index_shape[batch_rank]) + "."); - return box_shape[batch_rank]; + int64_t box_index_dims = static_cast(box_index_shape.size()); + (void)CheckAndConvertUtils::CheckInteger("[box_index shape-length]", box_index_dims, kEqual, 1, prim_name); + if (box_shape[0] != box_index_shape[0]) { + MS_EXCEPTION(ValueError) << "For primitive[" + prim_name + + "], the [boxes first-dim] must be equal to [box_index first-dim], but got " + + std::to_string(box_shape[0]) + " vs " + std::to_string(box_index_shape[0]) + "."; + } + return box_shape[0]; } private: const int64_t kLimitValue4 = 4; - const size_t kCropAndResizeInputSize = 4; - const size_t kShapeRank2 = 2; - const size_t kShapeRank4 = 4; + const int64_t kCropAndResizeInputSize = 4; + const int64_t kShapeRank2 = 2; + const int64_t kShapeRank4 = 4; }; void CropAndResize::Init(ResizeMethod method, float extrapolation_value) { diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_image_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_image_ops.py index 4b0fcb0d8dd..b97962e797c 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_image_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_image_ops.py @@ -16,6 +16,7 @@ """image_ops vmap impl.""" from __future__ import absolute_import +import mindspore.numpy as mnp from mindspore.ops import functional as F from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import image_ops as IMG @@ -83,3 +84,47 @@ def get_resize_grad_dynamic_rule(prim, axis_size): return out, 0 return vmap_rule + + +@vmap_rules_getters.register(IMG.CropAndResize) +def get_crop_and_resize_vmap_rule(prim, axis_size): + """VmapRule for `CropAndResize` operation.""" + + def vmap_rule(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim): + is_all_none, result = vmap_general_preprocess(x_bdim, boxes_bdim, box_indices_bdim, crop_size_bdim) + if is_all_none: + return result + + x, x_dim = x_bdim + boxes, boxes_dim = boxes_bdim + box_indices, box_indices_dim = box_indices_bdim + crop_size, crop_size_dim = crop_size_bdim + if crop_size_dim is not None: + _raise_value_error( + "The axis of `crop_size` in `{}` must be None, but got {}.".format(prim.name, crop_size_dim)) + + boxes = _bdim_at_front(boxes, boxes_dim, axis_size) + box_indices = _bdim_at_front(box_indices, box_indices_dim, axis_size) + boxes = F.reshape(boxes, (-1, 4)) + num_boxes = F.shape(box_indices)[-1] + + if x_dim is None: + box_indices = F.reshape(box_indices, (-1,)) + out = prim(x, boxes, box_indices, crop_size) + else: + x = _bdim_at_front(x, x_dim, axis_size) + x_shape = F.shape(x) + x = F.reshape(x, (-1,) + x_shape[2:]) + counts = mnp.arange(0, axis_size * x_shape[1], x_shape[1]) + counts = F.reshape(counts, (axis_size, 1)) + counts = F.broadcast_to(counts, (axis_size, num_boxes)) + box_indices = F.add(box_indices, counts) + box_indices = F.reshape(box_indices, (-1,)) + out = prim(x, boxes, box_indices, crop_size) + + out_shape = F.shape(out) + out = F.reshape(out, (-1, num_boxes) + out_shape[1:]) + return out, 0 + + + return vmap_rule diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 42f496ddf19..b3c5a4c7285 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -390,7 +390,8 @@ from .debug_func import ( from .image_func import ( bounding_box_decode, bounding_box_encode, - check_valid + check_valid, + crop_and_resize ) from .spectral_func import ( blackman_window diff --git a/mindspore/python/mindspore/ops/function/image_func.py b/mindspore/python/mindspore/ops/function/image_func.py index 81fbb4a17ee..60a5f55d952 100644 --- a/mindspore/python/mindspore/ops/function/image_func.py +++ b/mindspore/python/mindspore/ops/function/image_func.py @@ -134,10 +134,116 @@ def check_valid(bboxes, img_metas): return check_valid_op(bboxes, img_metas) +def crop_and_resize(x, boxes, box_indices, crop_size, method="bilinear", extrapolation_value=0.0): + """ + Extracts crops from the input image tensor and resizes them. + + Note: + In case that the output shape depends on crop_size, the crop_size must be constant. + For now, the backward of the operator only support bilinear method, for other methods, will return 0. + + Args: + x (Tensor): The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth]. + Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16. + boxes (Tensor): A 2-D tensor of shape [num_boxes, 4]. + The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image + and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to + the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is + mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled + crop is an up-down flipped version of the original image. The width dimension is treated similarly. + Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to + extrapolate the input image values. Types allowed: float32. + box_indices (Tensor): A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). + The value of box_ind[i] specifies the image that the i-th box refers to. Types allowed: int32. + crop_size (Tuple[int]): A tuple of two int32 elements: (crop_height, crop_width). + Only constant value is allowed. All cropped image patches are resized to this size. + The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive. + method (str): An optional string that specifies the sampling method for resizing. + It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear + interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear" + extrapolation_value (float): An optional float value used extrapolation, if applicable. Default: 0.0. + + Returns: + A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32. + + Raises: + TypeError: If `x` or `boxes` or `box_indices` is not a Tensor. + TypeError: If `crop_size` is not a Tuple with two int32 elements. + TypeError: If dtype of `boxes` is not float or that of `box_indices` is not int. + TypeError: If `method` is not a str. + TypeError: If `extrapolation_value` is not a float. + ValueError: If the shape rank of `x` is not 4. + ValueError: If the shape rank of `boxes` is not 2. + ValueError: If the second dim of `boxes` is not 4. + ValueError: If the shape rank of `box_indices` is not 1. + ValueError: If the first dim of `box_indices` is not equal to that of `boxes`. + ValueError: If existing element in `box_indices` is out of range `[0, batch)`. + ValueError: If the data of `crop_size` is not positive. + ValueError: If `method` is not one of 'bilinear', 'nearest', 'bilinear_v2'. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> BATCH_SIZE = 1 + >>> NUM_BOXES = 5 + >>> IMAGE_HEIGHT = 256 + >>> IMAGE_WIDTH = 256 + >>> CHANNELS = 3 + >>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32) + >>> boxes = np.random.uniform(size=[NUM_BOXES, 4]).astype(np.float32) + >>> box_indices = np.random.uniform(size=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32) + >>> crop_size = (24, 24) + >>> output = F.crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_indices), crop_size) + >>> print(output.shape) + (5, 24, 24, 3) + """ + if not isinstance(x, (Tensor, Tensor_)): + raise TypeError("For crop_and_resize, the input x must be a tensor") + if not isinstance(boxes, (Tensor, Tensor_)): + raise TypeError("For crop_and_resize, the input boxes must be a tensor") + if not isinstance(box_indices, (Tensor, Tensor_)): + raise TypeError("For crop_and_resize, the input box_indices must be a tensor") + if not isinstance(crop_size, tuple): + raise TypeError("For crop_and_resize, the input crop_size must be a tuple, but got {}".format(type(crop_size))) + if len(crop_size) != 2: + raise ValueError("For crop_and_resize, the crop_size's length must be 2, bot got {}".format(len(crop_size))) + if not isinstance(crop_size[0], int) or not isinstance(crop_size[1], int): + raise TypeError("For crop_and_resize, the crop_size's value must be int.") + if crop_size[0] <= 0 or crop_size[1] <= 0: + raise ValueError("For crop_and_resize, the crop_size's value must be positive.") + + x_shape = x.shape + if len(x_shape) != 4: + raise ValueError("For crop_and_resize, the input x must be 4D Tensor, but got is {}D".format(len(x_shape))) + boxes_dtype = _get_cache_prim(P.DType)()(boxes) + if boxes_dtype not in [mstype.float32]: + raise TypeError( + "For crop_and_resize, the input boxes must be {}, but got {}".format(mstype.float32, boxes_dtype)) + boxes_shape = boxes.shape + if len(boxes_shape) != 2 or boxes_shape[-1] != 4: + raise ValueError("For crop_and_resize, the input boxes must be 2D and the second-dim must be 4, " + "but got {}".format(boxes_shape)) + box_indices_dtype = _get_cache_prim(P.DType)()(box_indices) + if box_indices_dtype not in [mstype.int32]: + raise TypeError( + "For crop_and_resize, the input box_indices must be {}, but got {}".format(mstype.int32, box_indices_dtype)) + box_indices_shape = box_indices.shape + if len(box_indices_shape) != 1: + raise ValueError("For crop_and_resize, the input box_indices must be 1D, but got {}".format(box_indices_shape)) + if boxes_shape[0] != box_indices_shape[0]: + raise ValueError("For crop_and_resize, the first dim of input box_indices must be equal to that of input boxes" + ", but got {} vs {}".format(box_indices_shape[0], boxes_shape[0])) + + _crop_and_resize = _get_cache_prim(IMG.CropAndResize)(method, extrapolation_value) + return _crop_and_resize(x, boxes, box_indices, crop_size) + + __all__ = [ 'bounding_box_decode', 'bounding_box_encode', - 'check_valid' + 'check_valid', + 'crop_and_resize' ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/operations/image_ops.py b/mindspore/python/mindspore/ops/operations/image_ops.py index 25c1c6ffe84..762aa0c021e 100644 --- a/mindspore/python/mindspore/ops/operations/image_ops.py +++ b/mindspore/python/mindspore/ops/operations/image_ops.py @@ -270,39 +270,7 @@ class CropAndResize(Primitive): """ Extracts crops from the input image tensor and resizes them. - Note: - In case that the output shape depends on crop_size, the crop_size must be constant. - For now, the backward of the operator only support bilinear method, for other methods, will return 0. - - Args: - method (str, optional): An optional string that specifies the sampling method for resizing. - It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear - interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear" - extrapolation_value (float, optional): An optional float value used extrapolation, if applicable. Default: 0.0. - - Inputs: - - **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth]. - Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16. - - **boxes** (Tensor) - A 2-D tensor of shape [num_boxes, 4]. - The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image - and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to - the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is - mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled - crop is an up-down flipped version of the original image. The width dimension is treated similarly. - Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to - extrapolate the input image values. Types allowed: float32. - - **box_index** (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). - The value of box_ind[i] specifies the image that the i-th box refers to. Types allowed: int32. - - **crop_size** (Tuple[int]) - A tuple of two int32 elements: (crop_height, crop_width). - Only constant value is allowed. All cropped image patches are resized to this size. - The aspect ratio of the image content is not preserved. Both crop_height and crop_width need to be positive. - Outputs: - A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32. - - Raises: - TypeError: If `method` is not a str. - TypeError: If `extrapolation_value` is not a float. - ValueError: If `method` is not one of 'bilinear', 'nearest', 'bilinear_v2'. + Refer to :func:`mindspore.ops.crop_and_resize` for more detail. Supported Platforms: ``Ascend`` ``GPU`` ``CPU``