From 3045af13cb3f42f1eb063277b6c1387e3a5ab2e6 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Sat, 5 Sep 2020 17:37:53 +0800 Subject: [PATCH] fix nn.CentralCrop calulation result. --- mindspore/nn/layer/image.py | 11 +++++++---- mindspore/ops/operations/nn_ops.py | 5 +++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 0c7cf88231e..af7e729b9bc 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -386,15 +386,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name): raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}") @constexpr -def _get_bbox(rank, shape, central_fraction): +def _get_bbox(rank, shape, size_h, size_w): """get bbox start and size for slice""" if rank == 3: c, h, w = shape else: n, c, h, w = shape - bbox_h_start = int(np.round((float(h) - float(h) * central_fraction) / 2)) - bbox_w_start = int(np.round((float(w) - float(w) * central_fraction) / 2)) + bbox_h_start = int((float(h) - size_h) / 2) + bbox_w_start = int((float(w) - size_w) / 2) bbox_h_size = h - bbox_h_start * 2 bbox_w_size = w - bbox_w_start * 2 @@ -436,12 +436,15 @@ class CentralCrop(Cell): def construct(self, image): image_shape = F.shape(image) rank = len(image_shape) + h, w = image_shape[-2], image_shape[-1] if not rank in (3, 4): return _raise_dims_rank_error(image_shape, "image", self.cls_name) if self.central_fraction == 1.0: return image - bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction) + size_h = self.central_fraction * h + size_w = self.central_fraction * w + bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w) image = self.slice(image, bbox_begin, bbox_size) return image diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f3ab97b96af..cf092f78b02 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5298,7 +5298,7 @@ class CTCLoss(PrimitiveWithInfer): - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is :math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels` indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. - Data type must be float32 or float64. + Data type must be float16, float32 or float64. - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. - **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time. @@ -5348,7 +5348,8 @@ class CTCLoss(PrimitiveWithInfer): return batch_size, inputs def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length): - validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name) + valid_dtype = [mstype.float16, mstype.float32, mstype.double] + validator.check_tensor_type_same({"inputs_dtype": inputs}, valid_dtype, self.name) validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name) validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name) validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)