forked from mindspore-Ecosystem/mindspore
fix BNTrainingUpdate shape check about inputs.
This commit is contained in:
parent
57a8911eb0
commit
ec77704198
|
@ -836,12 +836,12 @@ class BNTrainingUpdate(PrimitiveWithInfer):
|
|||
validator.check_equal_int(len(b), 1, "b rank", self.name)
|
||||
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
|
||||
validator.check_equal_int(len(variance), 1, "variance rank", self.name)
|
||||
validator.check("sum shape", sum, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("sum shape", sum[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("square_sum shape", square_sum, "sum", sum, Rel.EQ, self.name)
|
||||
validator.check("scale shape", scale, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("offset shape", b, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("variance shape", variance, "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("scale shape", scale[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("offset shape", b[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
validator.check("variance shape", variance[0], "x_shape[1]", x[1], Rel.EQ, self.name)
|
||||
return (x, variance, variance, variance, variance)
|
||||
|
||||
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
|
@ -5436,7 +5436,7 @@ class CTCGreedyDecoder(PrimitiveWithInfer):
|
|||
`num_labels` indicates the number of actual labels. Blank labels are reserved.
|
||||
Default blank label is `num_classes - 1`. Data type must be float32 or float64.
|
||||
- **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of (`batch_size`).
|
||||
The type must be int32. Each value in the tensor must not greater than `max_time`.
|
||||
The type must be int32. Each value in the tensor must be equal to or less than `max_time`.
|
||||
|
||||
Outputs:
|
||||
- **decoded_indices** (Tensor) - A tensor with shape of (`total_decoded_outputs`, 2).
|
||||
|
|
Loading…
Reference in New Issue