diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 923a8783b31..41495be9137 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -682,3 +682,14 @@ def get_bprop_broadcast_to(self): dx = reshape(reduced_grad, x_shape) return (dx,) return bprop + + +@bprop_getters.register(P.ReverseSequence) +def get_bprop_reverse_sequence(self): + """Generate bprop for ReverseSequence""" + reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_) + + def bprop(x, seq_lengths, out, dout): + dx = reverse_sequence_grad(dout, seq_lengths) + return dx, zeros_like(seq_lengths) + return bprop diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 24f4ad750cd..9349e10cfff 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -26,3 +26,6 @@ from .expand_dims import _expand_dims_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu from .pack import _pack_aicpu from .normal import _normal_aicpu +from .ctcloss import _ctcloss_aicpu +from .reverse_sequence import _reverse_sequence_aicpu +from .crop_and_resize import _crop_and_resize_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/crop_and_resize.py b/mindspore/ops/_op_impl/aicpu/crop_and_resize.py new file mode 100644 index 00000000000..f52e6b00ee2 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/crop_and_resize.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""CropAndResize op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +crop_and_resize_op_info = AiCPURegOp("CropAndResize") \ + .fusion_type("OPAQUE") \ + .input(0, "image", "required") \ + .input(1, "boxes", "required") \ + .input(2, "box_index", "required") \ + .input(3, "crop_size", "required") \ + .output(0, "y", "required") \ + .attr("method", "str") \ + .attr("extrapolation_value", "float") \ + .dtype_format(DataType.I8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.I16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.U16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default) \ + .dtype_format(DataType.I8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.I16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.I32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.I64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.F16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.F64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.U8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .dtype_format(DataType.U16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC, + DataType.F32_NHWC) \ + .get_op_info() + + +@op_info_register(crop_and_resize_op_info) +def _crop_and_resize_aicpu(): + """CropAndResize AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/ctcloss.py b/mindspore/ops/_op_impl/aicpu/ctcloss.py new file mode 100644 index 00000000000..c393cb04b61 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/ctcloss.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""CTCLoss op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +ctcloss_op_info = AiCPURegOp("CTCLoss") \ + .fusion_type("OPAQUE") \ + .input(0, "inputs", "required") \ + .input(1, "labels_indices", "required") \ + .input(2, "labels_values", "required") \ + .input(3, "sequence_length", "required") \ + .output(0, "loss", "required") \ + .output(1, "gradient", "required") \ + .attr("preprocess_collapse_repeated", "bool") \ + .attr("ctc_merge_repeated", "bool") \ + .attr("ignore_longer_outputs_than_inputs", "bool") \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, + DataType.F64_NCHW, DataType.F64_NCHW) \ + .get_op_info() + +@op_info_register(ctcloss_op_info) +def _ctcloss_aicpu(): + """CTCLoss AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/reverse_sequence.py b/mindspore/ops/_op_impl/aicpu/reverse_sequence.py new file mode 100644 index 00000000000..678a4a61f31 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/reverse_sequence.py @@ -0,0 +1,78 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReverseSequence op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +reverse_sequence_op_info = AiCPURegOp("ReverseSequence") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .input(1, "seq_lengths", "required") \ + .output(0, "y", "required") \ + .attr("seq_dim", "int") \ + .attr("batch_dim", "int") \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \ + .dtype_format(DataType.I8_NCHW, DataType.I32_NCHW, DataType.I8_NCHW) \ + .dtype_format(DataType.I16_NCHW, DataType.I32_NCHW, DataType.I16_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.I32_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.U8_NCHW, DataType.I32_NCHW, DataType.U8_NCHW) \ + .dtype_format(DataType.U16_NCHW, DataType.I32_NCHW, DataType.U16_NCHW) \ + .dtype_format(DataType.U32_NCHW, DataType.I32_NCHW, DataType.U32_NCHW) \ + .dtype_format(DataType.U64_NCHW, DataType.I32_NCHW, DataType.U64_NCHW) \ + .dtype_format(DataType.F16_NCHW, DataType.I32_NCHW, DataType.F16_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.I32_NCHW, DataType.F64_NCHW) \ + .dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \ + .dtype_format(DataType.BOOL_NCHW, DataType.I64_NCHW, DataType.BOOL_NCHW) \ + .dtype_format(DataType.I8_NCHW, DataType.I64_NCHW, DataType.I8_NCHW) \ + .dtype_format(DataType.I16_NCHW, DataType.I64_NCHW, DataType.I16_NCHW) \ + .dtype_format(DataType.I32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW) \ + .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.U8_NCHW, DataType.I64_NCHW, DataType.U8_NCHW) \ + .dtype_format(DataType.U16_NCHW, DataType.I64_NCHW, DataType.U16_NCHW) \ + .dtype_format(DataType.U32_NCHW, DataType.I64_NCHW, DataType.U32_NCHW) \ + .dtype_format(DataType.U64_NCHW, DataType.I64_NCHW, DataType.U64_NCHW) \ + .dtype_format(DataType.F16_NCHW, DataType.I64_NCHW, DataType.F16_NCHW) \ + .dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.F64_NCHW) \ + .get_op_info() + +@op_info_register(reverse_sequence_op_info) +def _reverse_sequence_aicpu(): + """ReverseSequence AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 901db32c46b..e84cf44945f 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -19,6 +19,7 @@ Primitive operator classes. A collection of operators to build nerual networks or computing functions. """ +from .image_ops import (CropAndResize) from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, @@ -30,7 +31,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, - SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate) + SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, @@ -79,6 +80,8 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, from .thor_ops import * __all__ = [ + 'ReverseSequence', + 'CropAndResize', 'TensorAdd', 'Argmax', 'Argmin', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 395d3c509c2..558d54b95a1 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2841,3 +2841,52 @@ class InplaceUpdate(PrimitiveWithInfer): Rel.EQ, self.name) return x_shape + + +class ReverseSequence(PrimitiveWithInfer): + """ + Reverses variable length slices. + + Args: + seq_dim (int): The dimension along which reversal is performed. Required. + batch_dim (int): The input is sliced along this dimmension. Default: 0. + + Inputs: + - **x** (Tensor) - The input to reverse, support all number types including bool. + - **seq_lengths** (Tensor) - Must be 1-D vector with types: int32, int64. + + Outputs: + Reversed tensor with the same shape and data type as input. + + Examples: + >>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32) + >>> seq_lengths = Tensor(np.array([1, 2, 3])) + >>> reverse_sequence = P.ReverseSequence(seq_dim=1) + >>> output = reverse_sequence(x, seq_lengths) + [[1 2 3] + [5 4 6] + [9 8 7]] + """ + + @prim_attr_register + def __init__(self, seq_dim, batch_dim=0): + """init ReverseSequence""" + self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y']) + validator.check_value_type("seq_dim", seq_dim, [int], self.name) + self.seq_dim_ = seq_dim + validator.check_value_type("batch_dim", batch_dim, [int], self.name) + self.batch_dim_ = batch_dim + + def infer_shape(self, x, seq_lengths): + validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name) + validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name) + validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name) + validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name) + validator.check("seq_lengths vector size", seq_lengths[0], + "input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name) + return x + + def infer_dtype(self, x, seq_lengths): + validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) + validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) + return x diff --git a/mindspore/ops/operations/image_ops.py b/mindspore/ops/operations/image_ops.py new file mode 100644 index 00000000000..68dae34530b --- /dev/null +++ b/mindspore/ops/operations/image_ops.py @@ -0,0 +1,126 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""image_ops""" +from ..._checkparam import Validator as validator +from ..._checkparam import Rel +from ...common import dtype as mstype +from ..primitive import PrimitiveWithInfer, prim_attr_register + + +class CropAndResize(PrimitiveWithInfer): + """ + Extracts crops from the input image tensor and resizes them. + + Note: + In case that the output shape depends on crop_size, the crop_size should be constant. + + Args: + method (str): An optional string specifying the sampling method for resizing. + It can be either "bilinear" or "nearest" and default to "bilinear" + extrapolation_value (float): An optional float defaults to 0. Value used for extrapolation, when applicable. + + Inputs: + - **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth]. + Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16. + - **boxes** (Tensor) - A 2-D tensor of shape [num_boxes, 4]. + The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image + and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to + the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is + mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled + crop is an up-down flipped version of the original image. The width dimension is treated similarly. + Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to + extrapolate the input image values. Types allowd: float32. + - **box_index** (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch). + The value of box_ind[i] specifies the image that the i-th box refers to. Types allowd: int32. + - **crop_size** (Tensor) - Only constant value is allowd. Types allowed: int32. + A 1-D tensor of 2 elements, size = [crop_height, crop_width]. + All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved. + Both crop_height and crop_width need to be positive. + Outputs: + A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32. + + Examples: + >>> class CropAndResizeNet(nn.Cell): + >>> def __init__(self, crop_size): + >>> super(CropAndResizeNet, self).__init__() + >>> self.crop_and_resize = P.CropAndResize() + >>> self.crop_size = crop_size + >>> @ms_function + >>> def construct(self, x, boxes, box_index): + >>> return self.crop_and_resize(x, boxes, box_index, self.crop_size) + >>> + >>> BATCH_SIZE = 1 + >>> NUM_BOXES = 5 + >>> IMAGE_HEIGHT = 256 + >>> IMAGE_WIDTH = 256 + >>> CHANNELS = 3 + >>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32) + >>> boxes = np.random.uniform(shape=[NUM_BOXES, 4]).astype(np.float32) + >>> box_index = np.random.uniform(shape=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32) + >>> crop_size = np.array([24, 24]).astype(np.int32) + >>> crop_and_resize = CropAndResizeNet(crop_size=Tensor(crop_size)) + >>> output = crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_index)) + >>> print(output.asnumpy()) + """ + + @prim_attr_register + def __init__(self, method="bilinear", extrapolation_value=0.0): + """init CropAndResize""" + self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y']) + validator.check_value_type("method", method, [str], self.name) + validator.check_string("method", method, ["bilinear", "nearest"], self.name) + self.method = method + validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) + self.extrapolation_value = extrapolation_value + + def __infer__(self, x, boxes, box_index, crop_size): + # get shape + x_shape = list(x['shape']) + boxes_shape = list(boxes['shape']) + box_index_shape = list(box_index['shape']) + crop_size_shape = list(crop_size['shape']) + # get value + if crop_size['value'] is None: + raise ValueError(f"For {self.name}, crop_size must be const.") + crop_size_value = crop_size['value'].asnumpy() + # get dtype + x_dtype = x['dtype'] + boxes_dtype = boxes['dtype'] + box_index_dtype = box_index['dtype'] + crop_size_dtype = crop_size['dtype'] + # check dytpe + validator.check_tensor_type_same({"x": x_dtype}, + [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16, + mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name) + validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name) + validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name) + validator.check_tensor_type_same({"crop_size": crop_size_dtype}, [mstype.int32], self.name) + # check input shape rank + validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name) + validator.check("boxes rank", len(boxes_shape), "expected", 2, Rel.EQ, self.name) + validator.check("box_index rank", len(box_index_shape), "expected", 1, Rel.EQ, self.name) + validator.check("crop_size rank", len(crop_size_shape), "expected", 1, Rel.EQ, self.name) + + validator.check("boxes dim_0", boxes_shape[0], "box_index dim_0", box_index_shape[0], Rel.EQ, self.name) + validator.check("boxes dim_1", boxes_shape[1], "expected", 4, Rel.EQ, self.name) + + num_boxes = boxes_shape[0] + crop_height = crop_size_value[0] + crop_width = crop_size_value[1] + depth = x_shape[3] + return {'shape': (num_boxes, crop_height, crop_width, depth), + 'dtype': mstype.float32, + 'value': None} diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_crop_and_reszie.py b/tests/st/ops/ascend/test_aicpu_ops/test_crop_and_reszie.py new file mode 100644 index 00000000000..f85751975d5 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_crop_and_reszie.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, crop_size): + super(Net, self).__init__() + self.crop_and_resize = P.CropAndResize() + self.crop_size = crop_size + + @ms_function + def construct(self, x, boxes, box_index): + return self.crop_and_resize(x, boxes, box_index, self.crop_size) + + +def test_net_float32(): + batch_size = 1 + num_boxes = 5 + image_height = 256 + image_width = 256 + channels = 3 + image = np.random.normal(size=[batch_size, image_height, image_width, channels]).astype(np.float32) + boxes = np.random.uniform(shape=[num_boxes, 4]).astype(np.float32) + box_index = np.random.uniform(shape=[num_boxes], low=0, high=batch_size).astype(np.int32) + crop_size = np.array([24, 24]).astype(np.int32) + net = Net(crop_size=Tensor(crop_size)) + output = net(Tensor(image), Tensor(boxes), Tensor(box_index)) + print(output.asnumpy()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_ctc_loss.py b/tests/st/ops/ascend/test_aicpu_ops/test_ctc_loss.py new file mode 100644 index 00000000000..67949bf767c --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_ctc_loss.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.ctc_loss = P.CTCLoss() + + @ms_function + def construct(self, inputs, labels_indices, labels_values, sequence_length): + return self.ctc_loss(inputs, labels_indices, labels_values, sequence_length) + + +def test_net_float32(): + x = np.rand.randn(2, 2, 3).astype(np.float32) + labels_indices = np.array([[0, 0], [1, 0]]).astype(np.int64) + labels_values = np.array([2, 2]).astype(np.int32) + sequence_length = np.array([2, 2]).astype(np.int32) + net = Net() + output = net(Tensor(x), Tensor(labels_indices), Tensor(labels_values), Tensor(sequence_length)) + print(output.asnumpy()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_reverse_sequence.py b/tests/st/ops/ascend/test_aicpu_ops/test_reverse_sequence.py new file mode 100644 index 00000000000..5927b62560f --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_reverse_sequence.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, seq_dim, batch_dim): + super(Net, self).__init__() + self.reverse_sequence = P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim) + + @ms_function + def construct(self, x, seq_lengths): + return self.reverse_sequence(x, seq_lengths) + + +def test_net_int8(): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int8) + seq_lengths = np.array([1, 2, 3]).astype(np.int32) + seq_dim = 0 + batch_dim = 1 + net = Net(seq_dim, batch_dim) + output = net(Tensor(x), Tensor(seq_lengths)) + expected = np.array([1, 5, 9], [4, 2, 6], [7, 8, 3]).astype(np.int8) + assert np.array_equal(output.asnumpy(), expected) + + +def test_net_int32(): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int32) + seq_lengths = np.array([1, 2, 3]).astype(np.int64) + seq_dim = 1 + batch_dim = 0 + net = Net(seq_dim, batch_dim) + output = net(Tensor(x), Tensor(seq_lengths)) + expected = np.array([1, 2, 3], [5, 4, 6], [9, 8, 7]).astype(np.int32) + assert np.array_equal(output.asnumpy(), expected) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 5927d97c50b..5b5fd57aa9f 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1594,6 +1594,11 @@ test_case_array_ops = [ Tensor(np.arange(16).reshape(2, 4, 2).astype(np.float32))], 'skip': ['backward'], }), + ('ReverseSequence', { + 'block': P.ReverseSequence(1, 0), + 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), + Tensor(np.array([1, 2, 3]).astype(np.int32))], + 'desc_bprop': [[3, 3]]}), ] test_case_other_ops = [