forked from OSSInnovation/mindspore
!2741 fix BatchToSpaceND
Merge pull request !2741 from jiangjinsheng/issue_fix4
This commit is contained in:
commit
cf1628a3d9
|
@ -25,8 +25,8 @@ batch_to_space_nd_op_info = TBERegOp("BatchToSpaceND") \
|
|||
.partial_flag(True) \
|
||||
.attr("block_shape", "required", "listInt", "all") \
|
||||
.attr("crops", "required", "listListInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.input(0, "x", False, "required", "all", reshape_type="NH") \
|
||||
.output(0, "y", False, "required", "all", reshape_type="NH") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -27,6 +27,8 @@ conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \
|
|||
.attr("stride", "required", "listInt", "all") \
|
||||
.attr("pad_list", "required", "listInt", "all") \
|
||||
.attr("dilation", "required", "listInt", "all") \
|
||||
.attr("groups", "optional", "int", "all") \
|
||||
.attr("data_format", "optional", "str", "all") \
|
||||
.input(0, "out_backprop", False, "required", "all") \
|
||||
.input(1, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
|
|
|
@ -25,8 +25,8 @@ space_to_batch_nd_op_info = TBERegOp("SpaceToBatchND") \
|
|||
.partial_flag(True) \
|
||||
.attr("block_shape", "required", "listInt", "all") \
|
||||
.attr("paddings", "required", "listListInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.input(0, "x", False, "required", "all", reshape_type="NH") \
|
||||
.output(0, "y", False, "required", "all", reshape_type="NH") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -237,6 +237,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
|
|||
self.add_prim_attr('stride', self.stride)
|
||||
self.dilation = dilation
|
||||
self.group = group
|
||||
self.add_prim_attr('groups', group)
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def __infer__(self, doutput, x, w_size):
|
||||
|
|
|
@ -2636,16 +2636,20 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape):
|
||||
x_rank = len(x_shape)
|
||||
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
|
||||
block_shape_prod = 1
|
||||
for i in range(x_rank - 2):
|
||||
padded = out_shape[i + 2] + self.paddings[i][0] + \
|
||||
offset = 2
|
||||
if x_rank < 4:
|
||||
offset = 1
|
||||
for i in range(len(self.block_shape)):
|
||||
padded = out_shape[i + offset] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_shape[i] != 0:
|
||||
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
||||
f'block_shape[{i}] {self.block_shape[i]}')
|
||||
out_shape[i + 2] = padded // self.block_shape[i]
|
||||
out_shape[i + offset] = padded // self.block_shape[i]
|
||||
block_shape_prod = block_shape_prod * self.block_shape[i]
|
||||
out_shape[0] *= block_shape_prod
|
||||
return out_shape
|
||||
|
@ -2716,15 +2720,19 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape):
|
||||
x_rank = len(x_shape)
|
||||
validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
|
||||
block_shape_prod = 1
|
||||
for i in range(x_rank - 2):
|
||||
offset = 2
|
||||
if x_rank < 4:
|
||||
offset = 1
|
||||
for i in range(len(self.block_shape)):
|
||||
block_shape_prod = block_shape_prod * self.block_shape[i]
|
||||
x_block_prod = out_shape[i + 2] * self.block_shape[i]
|
||||
x_block_prod = out_shape[i + offset] * self.block_shape[i]
|
||||
crops_sum = self.crops[i][0] + self.crops[i][1]
|
||||
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
|
||||
out_shape[i + 2] = x_block_prod - crops_sum
|
||||
out_shape[i + offset] = x_block_prod - crops_sum
|
||||
|
||||
if out_shape[0] % block_shape_prod != 0:
|
||||
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
|
||||
|
|
Loading…
Reference in New Issue