!5805 Fix bug about nn.CentralCrop calulation result.
Merge pull request !5805 from liuxiao93/central-crop
This commit is contained in:
commit
5389348148
|
@ -386,15 +386,15 @@ def _raise_dims_rank_error(input_shape, param_name, func_name):
|
|||
raise ValueError(f"{func_name} {param_name} should be 3d or 4d, but got shape {input_shape}")
|
||||
|
||||
@constexpr
|
||||
def _get_bbox(rank, shape, central_fraction):
|
||||
def _get_bbox(rank, shape, size_h, size_w):
|
||||
"""get bbox start and size for slice"""
|
||||
if rank == 3:
|
||||
c, h, w = shape
|
||||
else:
|
||||
n, c, h, w = shape
|
||||
|
||||
bbox_h_start = int(np.round((float(h) - float(h) * central_fraction) / 2))
|
||||
bbox_w_start = int(np.round((float(w) - float(w) * central_fraction) / 2))
|
||||
bbox_h_start = int((float(h) - size_h) / 2)
|
||||
bbox_w_start = int((float(w) - size_w) / 2)
|
||||
bbox_h_size = h - bbox_h_start * 2
|
||||
bbox_w_size = w - bbox_w_start * 2
|
||||
|
||||
|
@ -436,12 +436,15 @@ class CentralCrop(Cell):
|
|||
def construct(self, image):
|
||||
image_shape = F.shape(image)
|
||||
rank = len(image_shape)
|
||||
h, w = image_shape[-2], image_shape[-1]
|
||||
if not rank in (3, 4):
|
||||
return _raise_dims_rank_error(image_shape, "image", self.cls_name)
|
||||
if self.central_fraction == 1.0:
|
||||
return image
|
||||
|
||||
bbox_begin, bbox_size = _get_bbox(rank, image_shape, self.central_fraction)
|
||||
size_h = self.central_fraction * h
|
||||
size_w = self.central_fraction * w
|
||||
bbox_begin, bbox_size = _get_bbox(rank, image_shape, size_h, size_w)
|
||||
image = self.slice(image, bbox_begin, bbox_size)
|
||||
|
||||
return image
|
||||
|
|
|
@ -5298,7 +5298,7 @@ class CTCLoss(PrimitiveWithInfer):
|
|||
- **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is
|
||||
:math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels`
|
||||
indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`.
|
||||
Data type must be float32 or float64.
|
||||
Data type must be float16, float32 or float64.
|
||||
- **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]`
|
||||
stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2.
|
||||
- **labels_values** (Tensor) - A `1-D` input tensor. The values are associated with the given batch and time.
|
||||
|
@ -5348,7 +5348,8 @@ class CTCLoss(PrimitiveWithInfer):
|
|||
return batch_size, inputs
|
||||
|
||||
def infer_dtype(self, inputs, labels_indices, labels_values, sequence_length):
|
||||
validator.check_tensor_type_same({"inputs_dtype": inputs}, [mstype.float32, mstype.double], self.name)
|
||||
valid_dtype = [mstype.float16, mstype.float32, mstype.double]
|
||||
validator.check_tensor_type_same({"inputs_dtype": inputs}, valid_dtype, self.name)
|
||||
validator.check_tensor_type_same({"labels_indices_dtype": labels_indices}, [mstype.int64], self.name)
|
||||
validator.check_tensor_type_same({"labels_values_dtype": labels_values}, [mstype.int32], self.name)
|
||||
validator.check_tensor_type_same({"sequence_length_dtype": sequence_length}, [mstype.int32], self.name)
|
||||
|
|
Loading…
Reference in New Issue