From d1060690b22ca785d6f954a58336f6fb5aaaec55 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 30 Jun 2020 10:29:39 +0800 Subject: [PATCH] fixed SpaceToBatchND, Conv2DBackpropFilter --- .../ops/_op_impl/tbe/batch_to_space_nd.py | 4 ++-- .../_op_impl/tbe/conv2d_backprop_filter.py | 2 ++ .../ops/_op_impl/tbe/space_to_batch_nd.py | 4 ++-- mindspore/ops/operations/_grad_ops.py | 1 + mindspore/ops/operations/array_ops.py | 20 +++++++++++++------ 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py b/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py index ad5060e7c1..f942a3836d 100644 --- a/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +++ b/mindspore/ops/_op_impl/tbe/batch_to_space_nd.py @@ -25,8 +25,8 @@ batch_to_space_nd_op_info = TBERegOp("BatchToSpaceND") \ .partial_flag(True) \ .attr("block_shape", "required", "listInt", "all") \ .attr("crops", "required", "listListInt", "all") \ - .input(0, "x", False, "required", "all") \ - .output(0, "y", False, "required", "all") \ + .input(0, "x", False, "required", "all", reshape_type="NH") \ + .output(0, "y", False, "required", "all", reshape_type="NH") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py b/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py index 04b55bb2a3..c309a4f2ab 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +++ b/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py @@ -27,6 +27,8 @@ conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \ .attr("stride", "required", "listInt", "all") \ .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ + .attr("groups", "optional", "int", "all") \ + .attr("data_format", "optional", "str", "all") \ .input(0, "out_backprop", False, "required", "all") \ .input(1, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py b/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py index 3a50b56a24..c1094cb55c 100644 --- a/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +++ b/mindspore/ops/_op_impl/tbe/space_to_batch_nd.py @@ -25,8 +25,8 @@ space_to_batch_nd_op_info = TBERegOp("SpaceToBatchND") \ .partial_flag(True) \ .attr("block_shape", "required", "listInt", "all") \ .attr("paddings", "required", "listListInt", "all") \ - .input(0, "x", False, "required", "all") \ - .output(0, "y", False, "required", "all") \ + .input(0, "x", False, "required", "all", reshape_type="NH") \ + .output(0, "y", False, "required", "all", reshape_type="NH") \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 211c30d143..94ba2f1bd9 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -237,6 +237,7 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): self.add_prim_attr('stride', self.stride) self.dilation = dilation self.group = group + self.add_prim_attr('groups', group) self.add_prim_attr('data_format', "NCHW") def __infer__(self, doutput, x, w_size): diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index f59a3a37e5..c85e87b93a 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2635,16 +2635,20 @@ class SpaceToBatchND(PrimitiveWithInfer): def infer_shape(self, x_shape): x_rank = len(x_shape) + validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) out_shape = copy.deepcopy(x_shape) block_shape_prod = 1 - for i in range(x_rank - 2): - padded = out_shape[i + 2] + self.paddings[i][0] + \ + offset = 2 + if x_rank < 4: + offset = 1 + for i in range(len(self.block_shape)): + padded = out_shape[i + offset] + self.paddings[i][0] + \ self.paddings[i][1] if padded % self.block_shape[i] != 0: raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' f'block_shape[{i}] {self.block_shape[i]}') - out_shape[i + 2] = padded // self.block_shape[i] + out_shape[i + offset] = padded // self.block_shape[i] block_shape_prod = block_shape_prod * self.block_shape[i] out_shape[0] *= block_shape_prod return out_shape @@ -2715,15 +2719,19 @@ class BatchToSpaceND(PrimitiveWithInfer): def infer_shape(self, x_shape): x_rank = len(x_shape) + validator.check_integer('x_shape rank', x_rank, 4, Rel.EQ, self.name) out_shape = copy.deepcopy(x_shape) block_shape_prod = 1 - for i in range(x_rank - 2): + offset = 2 + if x_rank < 4: + offset = 1 + for i in range(len(self.block_shape)): block_shape_prod = block_shape_prod * self.block_shape[i] - x_block_prod = out_shape[i + 2] * self.block_shape[i] + x_block_prod = out_shape[i + offset] * self.block_shape[i] crops_sum = self.crops[i][0] + self.crops[i][1] validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name) - out_shape[i + 2] = x_block_prod - crops_sum + out_shape[i + offset] = x_block_prod - crops_sum if out_shape[0] % block_shape_prod != 0: raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '