!28863 axis debug
Merge pull request !28863 from zong_shuai/unstack_debug
This commit is contained in:
commit
528470d537
|
@ -2790,19 +2790,20 @@ class Unstack(PrimitiveWithInfer):
|
|||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
||||
x_shape = list(x['shape'])
|
||||
dim = len(x_shape)
|
||||
validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||
if self.axis < 0:
|
||||
self.axis = self.axis + dim
|
||||
output_num = x_shape[self.axis]
|
||||
axis = self.axis
|
||||
validator.check_int_range(axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name)
|
||||
if axis < 0:
|
||||
axis = axis + dim
|
||||
output_num = x_shape[axis]
|
||||
validator.check_value_type("num", output_num, [int], self.name)
|
||||
validator.check_positive_int(output_num, "output_num", self.name)
|
||||
self.add_prim_attr('num', output_num)
|
||||
output_valid_check = x_shape[self.axis] - output_num
|
||||
output_valid_check = x_shape[axis] - output_num
|
||||
validator.check_int(output_valid_check, 0, Rel.EQ,
|
||||
"The dimension which to unstack divides output_num", self.name)
|
||||
out_shapes = []
|
||||
out_dtypes = []
|
||||
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
|
||||
out_shape = x_shape[:axis] + x_shape[axis + 1:]
|
||||
for _ in range(output_num):
|
||||
out_shapes.append(tuple(out_shape))
|
||||
out_dtypes.append(x['dtype'])
|
||||
|
|
Loading…
Reference in New Issue