forked from mindspore-Ecosystem/mindspore
!2424 Add aicpu op: CTCLoss, ReverseSequence and CropAndResize
Merge pull request !2424 from xutianchun/ctcloss
This commit is contained in:
commit
0327d7e79b
|
@ -682,3 +682,14 @@ def get_bprop_broadcast_to(self):
|
|||
dx = reshape(reduced_grad, x_shape)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ReverseSequence)
|
||||
def get_bprop_reverse_sequence(self):
|
||||
"""Generate bprop for ReverseSequence"""
|
||||
reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
|
||||
|
||||
def bprop(x, seq_lengths, out, dout):
|
||||
dx = reverse_sequence_grad(dout, seq_lengths)
|
||||
return dx, zeros_like(seq_lengths)
|
||||
return bprop
|
||||
|
|
|
@ -26,3 +26,6 @@ from .expand_dims import _expand_dims_aicpu
|
|||
from .random_choice_with_mask import _random_choice_with_mask_aicpu
|
||||
from .pack import _pack_aicpu
|
||||
from .normal import _normal_aicpu
|
||||
from .ctcloss import _ctcloss_aicpu
|
||||
from .reverse_sequence import _reverse_sequence_aicpu
|
||||
from .crop_and_resize import _crop_and_resize_aicpu
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CropAndResize op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
crop_and_resize_op_info = AiCPURegOp("CropAndResize") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "image", "required") \
|
||||
.input(1, "boxes", "required") \
|
||||
.input(2, "box_index", "required") \
|
||||
.input(3, "crop_size", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("method", "str") \
|
||||
.attr("extrapolation_value", "float") \
|
||||
.dtype_format(DataType.I8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.I16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.I32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.I64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F64_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.U8_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.U16_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, DataType.I32_NHWC,
|
||||
DataType.F32_NHWC) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(crop_and_resize_op_info)
|
||||
def _crop_and_resize_aicpu():
|
||||
"""CropAndResize AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""CTCLoss op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
ctcloss_op_info = AiCPURegOp("CTCLoss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "inputs", "required") \
|
||||
.input(1, "labels_indices", "required") \
|
||||
.input(2, "labels_values", "required") \
|
||||
.input(3, "sequence_length", "required") \
|
||||
.output(0, "loss", "required") \
|
||||
.output(1, "gradient", "required") \
|
||||
.attr("preprocess_collapse_repeated", "bool") \
|
||||
.attr("ctc_merge_repeated", "bool") \
|
||||
.attr("ignore_longer_outputs_than_inputs", "bool") \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW,
|
||||
DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW,
|
||||
DataType.F64_NCHW, DataType.F64_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(ctcloss_op_info)
|
||||
def _ctcloss_aicpu():
|
||||
"""CTCLoss AiCPU register"""
|
||||
return
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReverseSequence op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
reverse_sequence_op_info = AiCPURegOp("ReverseSequence") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x", "required") \
|
||||
.input(1, "seq_lengths", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.attr("seq_dim", "int") \
|
||||
.attr("batch_dim", "int") \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_NCHW, DataType.I32_NCHW, DataType.BOOL_NCHW) \
|
||||
.dtype_format(DataType.I8_NCHW, DataType.I32_NCHW, DataType.I8_NCHW) \
|
||||
.dtype_format(DataType.I16_NCHW, DataType.I32_NCHW, DataType.I16_NCHW) \
|
||||
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \
|
||||
.dtype_format(DataType.I64_NCHW, DataType.I32_NCHW, DataType.I64_NCHW) \
|
||||
.dtype_format(DataType.U8_NCHW, DataType.I32_NCHW, DataType.U8_NCHW) \
|
||||
.dtype_format(DataType.U16_NCHW, DataType.I32_NCHW, DataType.U16_NCHW) \
|
||||
.dtype_format(DataType.U32_NCHW, DataType.I32_NCHW, DataType.U32_NCHW) \
|
||||
.dtype_format(DataType.U64_NCHW, DataType.I32_NCHW, DataType.U64_NCHW) \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.I32_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F64_NCHW, DataType.I32_NCHW, DataType.F64_NCHW) \
|
||||
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.BOOL_NCHW, DataType.I64_NCHW, DataType.BOOL_NCHW) \
|
||||
.dtype_format(DataType.I8_NCHW, DataType.I64_NCHW, DataType.I8_NCHW) \
|
||||
.dtype_format(DataType.I16_NCHW, DataType.I64_NCHW, DataType.I16_NCHW) \
|
||||
.dtype_format(DataType.I32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW) \
|
||||
.dtype_format(DataType.I64_NCHW, DataType.I64_NCHW, DataType.I64_NCHW) \
|
||||
.dtype_format(DataType.U8_NCHW, DataType.I64_NCHW, DataType.U8_NCHW) \
|
||||
.dtype_format(DataType.U16_NCHW, DataType.I64_NCHW, DataType.U16_NCHW) \
|
||||
.dtype_format(DataType.U32_NCHW, DataType.I64_NCHW, DataType.U32_NCHW) \
|
||||
.dtype_format(DataType.U64_NCHW, DataType.I64_NCHW, DataType.U64_NCHW) \
|
||||
.dtype_format(DataType.F16_NCHW, DataType.I64_NCHW, DataType.F16_NCHW) \
|
||||
.dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW) \
|
||||
.dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.F64_NCHW) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(reverse_sequence_op_info)
|
||||
def _reverse_sequence_aicpu():
|
||||
"""ReverseSequence AiCPU register"""
|
||||
return
|
|
@ -19,6 +19,7 @@ Primitive operator classes.
|
|||
A collection of operators to build nerual networks or computing functions.
|
||||
"""
|
||||
|
||||
from .image_ops import (CropAndResize)
|
||||
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||
Diag, DiagPart, DType, ExpandDims, Eye,
|
||||
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
|
||||
|
@ -30,7 +31,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate)
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
|
@ -79,6 +80,8 @@ from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
|||
from .thor_ops import *
|
||||
|
||||
__all__ = [
|
||||
'ReverseSequence',
|
||||
'CropAndResize',
|
||||
'TensorAdd',
|
||||
'Argmax',
|
||||
'Argmin',
|
||||
|
|
|
@ -2841,3 +2841,52 @@ class InplaceUpdate(PrimitiveWithInfer):
|
|||
Rel.EQ, self.name)
|
||||
|
||||
return x_shape
|
||||
|
||||
|
||||
class ReverseSequence(PrimitiveWithInfer):
|
||||
"""
|
||||
Reverses variable length slices.
|
||||
|
||||
Args:
|
||||
seq_dim (int): The dimension along which reversal is performed. Required.
|
||||
batch_dim (int): The input is sliced along this dimmension. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input to reverse, support all number types including bool.
|
||||
- **seq_lengths** (Tensor) - Must be 1-D vector with types: int32, int64.
|
||||
|
||||
Outputs:
|
||||
Reversed tensor with the same shape and data type as input.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> reverse_sequence = P.ReverseSequence(seq_dim=1)
|
||||
>>> output = reverse_sequence(x, seq_lengths)
|
||||
[[1 2 3]
|
||||
[5 4 6]
|
||||
[9 8 7]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, seq_dim, batch_dim=0):
|
||||
"""init ReverseSequence"""
|
||||
self.init_prim_io_names(inputs=['x', 'seq_lengths'], outputs=['y'])
|
||||
validator.check_value_type("seq_dim", seq_dim, [int], self.name)
|
||||
self.seq_dim_ = seq_dim
|
||||
validator.check_value_type("batch_dim", batch_dim, [int], self.name)
|
||||
self.batch_dim_ = batch_dim
|
||||
|
||||
def infer_shape(self, x, seq_lengths):
|
||||
validator.check("seq_dim", self.seq_dim_, "x rank", len(x), Rel.LE, self.name)
|
||||
validator.check("batch_dim", self.batch_dim_, "x rank", len(x), Rel.LE, self.name)
|
||||
validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
|
||||
validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("seq_lengths vector size", seq_lengths[0],
|
||||
"input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, seq_lengths):
|
||||
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
|
||||
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""image_ops"""
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||
|
||||
|
||||
class CropAndResize(PrimitiveWithInfer):
|
||||
"""
|
||||
Extracts crops from the input image tensor and resizes them.
|
||||
|
||||
Note:
|
||||
In case that the output shape depends on crop_size, the crop_size should be constant.
|
||||
|
||||
Args:
|
||||
method (str): An optional string specifying the sampling method for resizing.
|
||||
It can be either "bilinear" or "nearest" and default to "bilinear"
|
||||
extrapolation_value (float): An optional float defaults to 0. Value used for extrapolation, when applicable.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input image must be a 4-D tensor of shape [batch, image_height, image_width, depth].
|
||||
Types allowed: int8, int16, int32, int64, float16, float32, float64, uint8, uint16.
|
||||
- **boxes** (Tensor) - A 2-D tensor of shape [num_boxes, 4].
|
||||
The i-th row of the tensor specifies the coordinates of a box in the box_ind[i] image
|
||||
and is specified in normalized coordinates [y1, x1, y2, x2]. A normalized coordinate value of y is mapped to
|
||||
the image coordinate at y * (image_height - 1), so as the [0, 1] interval of normalized image height is
|
||||
mapped to [0, image_height - 1] in image height coordinates. We do allow y1 > y2, in which case the sampled
|
||||
crop is an up-down flipped version of the original image. The width dimension is treated similarly.
|
||||
Normalized coordinates outside the [0, 1] range are allowed, in which case we use extrapolation_value to
|
||||
extrapolate the input image values. Types allowd: float32.
|
||||
- **box_index** (Tensor) - A 1-D tensor of shape [num_boxes] with int32 values in [0, batch).
|
||||
The value of box_ind[i] specifies the image that the i-th box refers to. Types allowd: int32.
|
||||
- **crop_size** (Tensor) - Only constant value is allowd. Types allowed: int32.
|
||||
A 1-D tensor of 2 elements, size = [crop_height, crop_width].
|
||||
All cropped image patches are resized to this size. The aspect ratio of the image content is not preserved.
|
||||
Both crop_height and crop_width need to be positive.
|
||||
Outputs:
|
||||
A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] with type: float32.
|
||||
|
||||
Examples:
|
||||
>>> class CropAndResizeNet(nn.Cell):
|
||||
>>> def __init__(self, crop_size):
|
||||
>>> super(CropAndResizeNet, self).__init__()
|
||||
>>> self.crop_and_resize = P.CropAndResize()
|
||||
>>> self.crop_size = crop_size
|
||||
>>> @ms_function
|
||||
>>> def construct(self, x, boxes, box_index):
|
||||
>>> return self.crop_and_resize(x, boxes, box_index, self.crop_size)
|
||||
>>>
|
||||
>>> BATCH_SIZE = 1
|
||||
>>> NUM_BOXES = 5
|
||||
>>> IMAGE_HEIGHT = 256
|
||||
>>> IMAGE_WIDTH = 256
|
||||
>>> CHANNELS = 3
|
||||
>>> image = np.random.normal(size=[BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, CHANNELS]).astype(np.float32)
|
||||
>>> boxes = np.random.uniform(shape=[NUM_BOXES, 4]).astype(np.float32)
|
||||
>>> box_index = np.random.uniform(shape=[NUM_BOXES], low=0, high=BATCH_SIZE).astype(np.int32)
|
||||
>>> crop_size = np.array([24, 24]).astype(np.int32)
|
||||
>>> crop_and_resize = CropAndResizeNet(crop_size=Tensor(crop_size))
|
||||
>>> output = crop_and_resize(Tensor(image), Tensor(boxes), Tensor(box_index))
|
||||
>>> print(output.asnumpy())
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, method="bilinear", extrapolation_value=0.0):
|
||||
"""init CropAndResize"""
|
||||
self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y'])
|
||||
validator.check_value_type("method", method, [str], self.name)
|
||||
validator.check_string("method", method, ["bilinear", "nearest"], self.name)
|
||||
self.method = method
|
||||
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name)
|
||||
self.extrapolation_value = extrapolation_value
|
||||
|
||||
def __infer__(self, x, boxes, box_index, crop_size):
|
||||
# get shape
|
||||
x_shape = list(x['shape'])
|
||||
boxes_shape = list(boxes['shape'])
|
||||
box_index_shape = list(box_index['shape'])
|
||||
crop_size_shape = list(crop_size['shape'])
|
||||
# get value
|
||||
if crop_size['value'] is None:
|
||||
raise ValueError(f"For {self.name}, crop_size must be const.")
|
||||
crop_size_value = crop_size['value'].asnumpy()
|
||||
# get dtype
|
||||
x_dtype = x['dtype']
|
||||
boxes_dtype = boxes['dtype']
|
||||
box_index_dtype = box_index['dtype']
|
||||
crop_size_dtype = crop_size['dtype']
|
||||
# check dytpe
|
||||
validator.check_tensor_type_same({"x": x_dtype},
|
||||
[mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.float16,
|
||||
mstype.float32, mstype.float64, mstype.uint8, mstype.uint16], self.name)
|
||||
validator.check_tensor_type_same({"boxes": boxes_dtype}, [mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"box_index": box_index_dtype}, [mstype.int32], self.name)
|
||||
validator.check_tensor_type_same({"crop_size": crop_size_dtype}, [mstype.int32], self.name)
|
||||
# check input shape rank
|
||||
validator.check("x rank", len(x_shape), "expected", 4, Rel.EQ, self.name)
|
||||
validator.check("boxes rank", len(boxes_shape), "expected", 2, Rel.EQ, self.name)
|
||||
validator.check("box_index rank", len(box_index_shape), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("crop_size rank", len(crop_size_shape), "expected", 1, Rel.EQ, self.name)
|
||||
|
||||
validator.check("boxes dim_0", boxes_shape[0], "box_index dim_0", box_index_shape[0], Rel.EQ, self.name)
|
||||
validator.check("boxes dim_1", boxes_shape[1], "expected", 4, Rel.EQ, self.name)
|
||||
|
||||
num_boxes = boxes_shape[0]
|
||||
crop_height = crop_size_value[0]
|
||||
crop_width = crop_size_value[1]
|
||||
depth = x_shape[3]
|
||||
return {'shape': (num_boxes, crop_height, crop_width, depth),
|
||||
'dtype': mstype.float32,
|
||||
'value': None}
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, crop_size):
|
||||
super(Net, self).__init__()
|
||||
self.crop_and_resize = P.CropAndResize()
|
||||
self.crop_size = crop_size
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, boxes, box_index):
|
||||
return self.crop_and_resize(x, boxes, box_index, self.crop_size)
|
||||
|
||||
|
||||
def test_net_float32():
|
||||
batch_size = 1
|
||||
num_boxes = 5
|
||||
image_height = 256
|
||||
image_width = 256
|
||||
channels = 3
|
||||
image = np.random.normal(size=[batch_size, image_height, image_width, channels]).astype(np.float32)
|
||||
boxes = np.random.uniform(shape=[num_boxes, 4]).astype(np.float32)
|
||||
box_index = np.random.uniform(shape=[num_boxes], low=0, high=batch_size).astype(np.int32)
|
||||
crop_size = np.array([24, 24]).astype(np.int32)
|
||||
net = Net(crop_size=Tensor(crop_size))
|
||||
output = net(Tensor(image), Tensor(boxes), Tensor(box_index))
|
||||
print(output.asnumpy())
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.ctc_loss = P.CTCLoss()
|
||||
|
||||
@ms_function
|
||||
def construct(self, inputs, labels_indices, labels_values, sequence_length):
|
||||
return self.ctc_loss(inputs, labels_indices, labels_values, sequence_length)
|
||||
|
||||
|
||||
def test_net_float32():
|
||||
x = np.rand.randn(2, 2, 3).astype(np.float32)
|
||||
labels_indices = np.array([[0, 0], [1, 0]]).astype(np.int64)
|
||||
labels_values = np.array([2, 2]).astype(np.int32)
|
||||
sequence_length = np.array([2, 2]).astype(np.int32)
|
||||
net = Net()
|
||||
output = net(Tensor(x), Tensor(labels_indices), Tensor(labels_values), Tensor(sequence_length))
|
||||
print(output.asnumpy())
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, seq_dim, batch_dim):
|
||||
super(Net, self).__init__()
|
||||
self.reverse_sequence = P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, seq_lengths):
|
||||
return self.reverse_sequence(x, seq_lengths)
|
||||
|
||||
|
||||
def test_net_int8():
|
||||
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int8)
|
||||
seq_lengths = np.array([1, 2, 3]).astype(np.int32)
|
||||
seq_dim = 0
|
||||
batch_dim = 1
|
||||
net = Net(seq_dim, batch_dim)
|
||||
output = net(Tensor(x), Tensor(seq_lengths))
|
||||
expected = np.array([1, 5, 9], [4, 2, 6], [7, 8, 3]).astype(np.int8)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
||||
|
||||
|
||||
def test_net_int32():
|
||||
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.int32)
|
||||
seq_lengths = np.array([1, 2, 3]).astype(np.int64)
|
||||
seq_dim = 1
|
||||
batch_dim = 0
|
||||
net = Net(seq_dim, batch_dim)
|
||||
output = net(Tensor(x), Tensor(seq_lengths))
|
||||
expected = np.array([1, 2, 3], [5, 4, 6], [9, 8, 7]).astype(np.int32)
|
||||
assert np.array_equal(output.asnumpy(), expected)
|
|
@ -1594,6 +1594,11 @@ test_case_array_ops = [
|
|||
Tensor(np.arange(16).reshape(2, 4, 2).astype(np.float32))],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
('ReverseSequence', {
|
||||
'block': P.ReverseSequence(1, 0),
|
||||
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
|
||||
Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 3]]}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue