Merge pull request !28863 from zong_shuai/unstack_debug
This commit is contained in:
i-robot 2022-01-12 03:15:00 +00:00 committed by Gitee
commit 528470d537
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 7 additions and 6 deletions

View File

@ -2790,19 +2790,20 @@ class Unstack(PrimitiveWithInfer):
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
if self.axis < 0:
self.axis = self.axis + dim
output_num = x_shape[self.axis]
axis = self.axis
validator.check_int_range(axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
if axis < 0:
axis = axis + dim
output_num = x_shape[axis]
validator.check_value_type("num", output_num, [int], self.name)
validator.check_positive_int(output_num, "output_num", self.name)
self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num
output_valid_check = x_shape[axis] - output_num
validator.check_int(output_valid_check, 0, Rel.EQ,
"The dimension which to unstack divides output_num", self.name)
out_shapes = []
out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
out_shape = x_shape[:axis] + x_shape[axis + 1:]
for _ in range(output_num):
out_shapes.append(tuple(out_shape))
out_dtypes.append(x['dtype'])