forked from mindspore-Ecosystem/mindspore
!4932 Add CropAndResize for old backend.
Merge pull request !4932 from liuxiao93/Add-CropAndResize-for-old-backend
This commit is contained in:
commit
03991a969e
|
@ -189,6 +189,7 @@ constexpr const char kNameRange[] = "Range";
|
|||
constexpr const char kNameSquareSumAll[] = "SquareSumAll";
|
||||
constexpr const char kNameAscendQuant[] = "Quant";
|
||||
constexpr const char kNameAscendDequant[] = "Dequant";
|
||||
constexpr const char kNameCropAndResize[] = "CropAndResize";
|
||||
constexpr const char kNameReverseSequence[] = "ReverseSequence";
|
||||
constexpr const char kNameEditDistance[] = "EditDistance";
|
||||
constexpr const char kNameCase[] = "Case";
|
||||
|
|
|
@ -45,4 +45,12 @@ ATTR_MAP(ResizeBilinearV2D) = {
|
|||
{"align_corners", ATTR_DESC(align_corners, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ResizeBilinearV2D, kNameResizeBilinear, ADPT_DESC(ResizeBilinearV2D))
|
||||
|
||||
// CropAndResize
|
||||
INPUT_MAP(CropAndResize) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(boxes)}, {3, INPUT_DESC(box_index)}, {4, INPUT_DESC(crop_size)}};
|
||||
ATTR_MAP(CropAndResize) = {{"extrapolation_value", ATTR_DESC(extrapolation_value, AnyTraits<float>())},
|
||||
{"method", ATTR_DESC(method, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(CropAndResize) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(CropAndResize, kNameCropAndResize, ADPT_DESC(CropAndResize))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -34,5 +34,8 @@ DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D)
|
|||
|
||||
DECLARE_OP_ADAPTER(ResizeBilinearV2Grad)
|
||||
DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad)
|
||||
|
||||
DECLARE_OP_ADAPTER(CropAndResize)
|
||||
DECLARE_OP_USE_OUTPUT(CropAndResize)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_IMAGE_OPS_DECLARE_H_
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""image_ops"""
|
||||
from ... import context
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
|
@ -84,6 +85,7 @@ class CropAndResize(PrimitiveWithInfer):
|
|||
self.method = method
|
||||
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
|
||||
self.extrapolation_value = extrapolation_value
|
||||
self.is_ge = context.get_context("enable_ge")
|
||||
|
||||
def __infer__(self, x, boxes, box_index, crop_size):
|
||||
# get shape
|
||||
|
@ -124,6 +126,9 @@ class CropAndResize(PrimitiveWithInfer):
|
|||
crop_height = crop_size_value[0]
|
||||
crop_width = crop_size_value[1]
|
||||
depth = x_shape[3]
|
||||
return {'shape': (num_boxes, crop_height, crop_width, depth),
|
||||
out_shape = (num_boxes, crop_height, crop_width, depth)
|
||||
if self.is_ge:
|
||||
out_shape = (num_boxes, x_shape[1], crop_height, crop_width)
|
||||
return {'shape': out_shape,
|
||||
'dtype': mstype.float32,
|
||||
'value': None}
|
||||
|
|
Loading…
Reference in New Issue