From ec77704198688c8d9ff621ee8538967cb365ce77 Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 4 Nov 2020 16:37:28 +0800 Subject: [PATCH] fix BNTrainingUpdate shape check about inputs. --- mindspore/ops/operations/nn_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index fb34e62dbf6..38e6bc8a18c 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -836,12 +836,12 @@ class BNTrainingUpdate(PrimitiveWithInfer): validator.check_equal_int(len(b), 1, "b rank", self.name) validator.check_equal_int(len(mean), 1, "mean rank", self.name) validator.check_equal_int(len(variance), 1, "variance rank", self.name) - validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name) + validator.check("sum shape", sum[0], "x_shape[1]", x[1], Rel.EQ, self.name) validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name) - validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name) - validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name) - validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name) - validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name) + validator.check("scale shape", scale[0], "x_shape[1]", x[1], Rel.EQ, self.name) + validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name) + validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name) + validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name) return (x, variance, variance, variance, variance) def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance): @@ -5436,7 +5436,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer): `num_labels` indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. Data type must be float32 or float64. - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`). - The type must be int32. Each value in the tensor must not greater than `max_time`. + The type must be int32. Each value in the tensor must be equal to or less than `max_time`. Outputs: - **decoded_indices** (Tensor) - A tensor with shape of (`total_decoded_outputs`, 2).