From bdcc607b1a87c00a43f17e2ab5ca88a69d7aa710 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Mon, 13 Jul 2020 11:23:43 +0800 Subject: [PATCH] fix ParallelConcat --- mindspore/ops/operations/array_ops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index d68fc79a0eb..5ea52785f65 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1532,7 +1532,8 @@ class ParallelConcat(PrimitiveWithInfer): The input tensors are all required to have size 1 in the first dimension. Inputs: - - **values** (tuple, list) - Tuple or list of input tensors. + - **values** (tuple, list) - Tuple or list of input tensors. The data type and shape of these + tensors must be same. Outputs: Tensor, data type same as `values`. @@ -1542,6 +1543,7 @@ class ParallelConcat(PrimitiveWithInfer): >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) >>> op = P.ParallelConcat() >>> output = op((data1, data2)) + [[0, 1], [2, 1]] """ @prim_attr_register @@ -1553,14 +1555,15 @@ class ParallelConcat(PrimitiveWithInfer): x_type = values['dtype'] validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) + + args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + first_elem = x_shp[0] - args = {} for i, elem in enumerate(x_shp[1:]): j = i + 1 - args[f'x_type[{j}]'] = x_type[j] validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) - validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) ret_shp = x_shp[0].copy() ret_shp[0] = len(x_shp)