!3029 fix validator for ParallelConcat
Merge pull request !3029 from jiangjinsheng/issue_fix4
This commit is contained in:
commit
ece99192e8
|
@ -1532,7 +1532,8 @@ class ParallelConcat(PrimitiveWithInfer):
|
||||||
The input tensors are all required to have size 1 in the first dimension.
|
The input tensors are all required to have size 1 in the first dimension.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **values** (tuple, list) - Tuple or list of input tensors.
|
- **values** (tuple, list) - Tuple or list of input tensors. The data type and shape of these
|
||||||
|
tensors must be same.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, data type same as `values`.
|
Tensor, data type same as `values`.
|
||||||
|
@ -1542,6 +1543,7 @@ class ParallelConcat(PrimitiveWithInfer):
|
||||||
>>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
|
>>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32))
|
||||||
>>> op = P.ParallelConcat()
|
>>> op = P.ParallelConcat()
|
||||||
>>> output = op((data1, data2))
|
>>> output = op((data1, data2))
|
||||||
|
[[0, 1], [2, 1]]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
|
@ -1553,14 +1555,15 @@ class ParallelConcat(PrimitiveWithInfer):
|
||||||
x_type = values['dtype']
|
x_type = values['dtype']
|
||||||
|
|
||||||
validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)
|
validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name)
|
||||||
|
|
||||||
|
args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)}
|
||||||
|
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
||||||
|
|
||||||
first_elem = x_shp[0]
|
first_elem = x_shp[0]
|
||||||
args = {}
|
|
||||||
for i, elem in enumerate(x_shp[1:]):
|
for i, elem in enumerate(x_shp[1:]):
|
||||||
j = i + 1
|
j = i + 1
|
||||||
args[f'x_type[{j}]'] = x_type[j]
|
|
||||||
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
|
validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name)
|
||||||
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
|
validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name)
|
||||||
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
|
|
||||||
|
|
||||||
ret_shp = x_shp[0].copy()
|
ret_shp = x_shp[0].copy()
|
||||||
ret_shp[0] = len(x_shp)
|
ret_shp[0] = len(x_shp)
|
||||||
|
|
Loading…
Reference in New Issue