fix SpaceToBatchND and BatchToSpace.

This commit is contained in:
liuxiao93 2020-09-19 16:11:27 +08:00
parent 55751c6c33
commit 832a9d5fbb
3 changed files with 27 additions and 13 deletions

View File

@ -59,6 +59,7 @@ class ReduceLogSumExp(Cell):
>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
>>> op = nn.ReduceLogSumExp(keep_dims=True)
>>> output = op(input_x, 1)
>>> output.shape
(3, 1, 5, 6)
"""

View File

@ -776,7 +776,7 @@ def get_bprop_batch_to_space(self):
@bprop_getters.register(P.SpaceToBatchND)
def get_bprop_space_to_batch_nd(self):
"""Generate bprop for SpaceToBatchND"""
space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
space_to_batch_nd_grad = P.BatchToSpaceND(self.ori_block_shape, self.ori_paddings)
def bprop(x, out, dout):
dx = space_to_batch_nd_grad(dout)
return (dx,)
@ -786,7 +786,7 @@ def get_bprop_space_to_batch_nd(self):
@bprop_getters.register(P.BatchToSpaceND)
def get_bprop_batch_to_space_nd(self):
"""Generate bprop for BatchToSpaceND"""
batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
batch_to_space_nd_grad = P.SpaceToBatchND(self.ori_block_shape, self.ori_crops)
def bprop(x, out, dout):
dx = batch_to_space_nd_grad(dout)
return (dx,)

View File

@ -3259,8 +3259,8 @@ class SpaceToBatchND(PrimitiveWithInfer):
Args:
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value greater than 1.
The length of `block_shape` is M correspoding to the number of spatial dimensions.
paddings (list): The padding values for H and W dimension, containing M subtraction list.
The length of `block_shape` is M correspoding to the number of spatial dimensions. M must be 2.
paddings (list): The padding values for H and W dimension, containing 2 subtraction list.
Each contains 2 integer value. All values must be greater than 0.
`paddings[i]` specifies the paddings for the spatial dimension i,
which corresponds to the input dimension i+2.
@ -3294,21 +3294,28 @@ class SpaceToBatchND(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, block_shape, paddings):
"""Init SpaceToBatchND"""
self.ori_block_shape = block_shape
self.ori_paddings = paddings
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
block_rank = len(block_shape)
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
for elem in block_shape:
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
validator.check_value_type('block_shape element', elem, [int], self.name)
self.block_shape = block_shape
validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
for elem in itertools.chain(*paddings):
validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
validator.check_value_type('paddings element', elem, [int], self.name)
self.paddings = paddings
block_shape_append = [1] + list(self.block_shape)
self.add_prim_attr("block_shape", block_shape_append)
paddings_append = [[0, 0]] + list(self.paddings)
self.add_prim_attr("paddings", paddings_append)
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
@ -3321,7 +3328,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
block_shape_prod = 1
offset = 2
if x_rank < 4:
if x_rank <= 4:
offset = 1
for i in range(len(self.block_shape)):
padded = out_shape[i + offset] + self.paddings[i][0] + \
@ -3345,7 +3352,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
Args:
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
The length of block_shape is M correspoding to the number of spatial dimensions.
The length of block_shape is M correspoding to the number of spatial dimensions. M must be 2.
crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list,
each containing 2 int value.
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
@ -3380,22 +3387,28 @@ class BatchToSpaceND(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, block_shape, crops):
"""Init BatchToSpaceND"""
self.ori_block_shape = block_shape
self.ori_crops = crops
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
block_rank = len(block_shape)
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
for elem in block_shape:
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
validator.check_value_type('block_shape element', elem, [int], self.name)
self.block_shape = block_shape
validator.check_value_type('crops type', crops, [list, tuple], self.name)
validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name)
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
for elem in itertools.chain(*crops):
validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
validator.check_value_type('crops element', elem, [int], self.name)
self.crops = crops
block_shape_append = [1] + list(self.block_shape)
self.add_prim_attr("block_shape", block_shape_append)
crops_append = [[0, 0]] + list(self.crops)
self.add_prim_attr("crops", crops_append)
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
@ -3408,7 +3421,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
block_shape_prod = 1
offset = 2
if x_rank < 4:
if x_rank <= 4:
offset = 1
for i in range(len(self.block_shape)):
block_shape_prod = block_shape_prod * self.block_shape[i]
@ -3591,12 +3604,12 @@ class EditDistance(PrimitiveWithInfer):
The shape of tensor is :math:`(N, R)`.
- **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor.
Must be 1-D vector with length of N.
- **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor.
- **hypothesis_shape** (Tensor) - The shape of the hypothesis list SparseTensor.
Must be R-length vector with int64 data type. Only constant value is allowed.
- **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type.
The shape of tensor is :math:`(M, R)`.
- **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.
- **truth_shape** (Tensor) - The values of the truth list SparseTensor.
- **truth_shape** (Tensor) - The shape of the truth list SparseTensor.
Must be R-length vector with int64 data type. Only constant value is allowed.
Outputs: