fix BNTrainingUpdate shape check about inputs.

This commit is contained in:
liuxiao93 2020-11-04 16:37:28 +08:00
parent 57a8911eb0
commit ec77704198
1 changed files with 6 additions and 6 deletions

View File

@ -836,12 +836,12 @@ class BNTrainingUpdate(PrimitiveWithInfer):
validator.check_equal_int(len(b), 1, "b rank", self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check_equal_int(len(variance), 1, "variance rank", self.name)
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("sum shape", sum[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("scale shape", scale[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name)
validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name)
return (x, variance, variance, variance, variance)
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
@ -5436,7 +5436,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
`num_labels` indicates the number of actual labels. Blank labels are reserved.
Default blank label is `num_classes - 1`. Data type must be float32 or float64.
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
The type must be int32. Each value in the tensor must not greater than `max_time`.
The type must be int32. Each value in the tensor must be equal to or less than `max_time`.
Outputs:
- **decoded_indices** (Tensor) - A tensor with shape of (`total_decoded_outputs`, 2).