forked from mindspore-Ecosystem/mindspore
fix SpaceToBatchND and BatchToSpace.
This commit is contained in:
parent
55751c6c33
commit
832a9d5fbb
|
@ -59,6 +59,7 @@ class ReduceLogSumExp(Cell):
|
|||
>>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
|
||||
>>> op = nn.ReduceLogSumExp(keep_dims=True)
|
||||
>>> output = op(input_x, 1)
|
||||
>>> output.shape
|
||||
(3, 1, 5, 6)
|
||||
"""
|
||||
|
||||
|
|
|
@ -776,7 +776,7 @@ def get_bprop_batch_to_space(self):
|
|||
@bprop_getters.register(P.SpaceToBatchND)
|
||||
def get_bprop_space_to_batch_nd(self):
|
||||
"""Generate bprop for SpaceToBatchND"""
|
||||
space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
|
||||
space_to_batch_nd_grad = P.BatchToSpaceND(self.ori_block_shape, self.ori_paddings)
|
||||
def bprop(x, out, dout):
|
||||
dx = space_to_batch_nd_grad(dout)
|
||||
return (dx,)
|
||||
|
@ -786,7 +786,7 @@ def get_bprop_space_to_batch_nd(self):
|
|||
@bprop_getters.register(P.BatchToSpaceND)
|
||||
def get_bprop_batch_to_space_nd(self):
|
||||
"""Generate bprop for BatchToSpaceND"""
|
||||
batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
|
||||
batch_to_space_nd_grad = P.SpaceToBatchND(self.ori_block_shape, self.ori_crops)
|
||||
def bprop(x, out, dout):
|
||||
dx = batch_to_space_nd_grad(dout)
|
||||
return (dx,)
|
||||
|
|
|
@ -3259,8 +3259,8 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value greater than 1.
|
||||
The length of `block_shape` is M correspoding to the number of spatial dimensions.
|
||||
paddings (list): The padding values for H and W dimension, containing M subtraction list.
|
||||
The length of `block_shape` is M correspoding to the number of spatial dimensions. M must be 2.
|
||||
paddings (list): The padding values for H and W dimension, containing 2 subtraction list.
|
||||
Each contains 2 integer value. All values must be greater than 0.
|
||||
`paddings[i]` specifies the paddings for the spatial dimension i,
|
||||
which corresponds to the input dimension i+2.
|
||||
|
@ -3294,21 +3294,28 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, block_shape, paddings):
|
||||
"""Init SpaceToBatchND"""
|
||||
self.ori_block_shape = block_shape
|
||||
self.ori_paddings = paddings
|
||||
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
|
||||
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
|
||||
block_rank = len(block_shape)
|
||||
|
||||
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
|
||||
for elem in block_shape:
|
||||
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||
validator.check_value_type('block_shape element', elem, [int], self.name)
|
||||
|
||||
self.block_shape = block_shape
|
||||
|
||||
validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
|
||||
validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
|
||||
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||
for elem in itertools.chain(*paddings):
|
||||
validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
|
||||
validator.check_value_type('paddings element', elem, [int], self.name)
|
||||
self.paddings = paddings
|
||||
block_shape_append = [1] + list(self.block_shape)
|
||||
self.add_prim_attr("block_shape", block_shape_append)
|
||||
paddings_append = [[0, 0]] + list(self.paddings)
|
||||
self.add_prim_attr("paddings", paddings_append)
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
|
||||
|
@ -3321,7 +3328,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
|
||||
block_shape_prod = 1
|
||||
offset = 2
|
||||
if x_rank < 4:
|
||||
if x_rank <= 4:
|
||||
offset = 1
|
||||
for i in range(len(self.block_shape)):
|
||||
padded = out_shape[i + offset] + self.paddings[i][0] + \
|
||||
|
@ -3345,7 +3352,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
|
||||
The length of block_shape is M correspoding to the number of spatial dimensions.
|
||||
The length of block_shape is M correspoding to the number of spatial dimensions. M must be 2.
|
||||
crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list,
|
||||
each containing 2 int value.
|
||||
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
|
||||
|
@ -3380,22 +3387,28 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, block_shape, crops):
|
||||
"""Init BatchToSpaceND"""
|
||||
self.ori_block_shape = block_shape
|
||||
self.ori_crops = crops
|
||||
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
|
||||
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
|
||||
block_rank = len(block_shape)
|
||||
|
||||
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
|
||||
for elem in block_shape:
|
||||
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||
validator.check_value_type('block_shape element', elem, [int], self.name)
|
||||
|
||||
self.block_shape = block_shape
|
||||
|
||||
validator.check_value_type('crops type', crops, [list, tuple], self.name)
|
||||
validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name)
|
||||
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||
for elem in itertools.chain(*crops):
|
||||
validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
|
||||
validator.check_value_type('crops element', elem, [int], self.name)
|
||||
self.crops = crops
|
||||
block_shape_append = [1] + list(self.block_shape)
|
||||
self.add_prim_attr("block_shape", block_shape_append)
|
||||
crops_append = [[0, 0]] + list(self.crops)
|
||||
self.add_prim_attr("crops", crops_append)
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
|
||||
|
@ -3408,7 +3421,7 @@ class BatchToSpaceND(PrimitiveWithInfer):
|
|||
|
||||
block_shape_prod = 1
|
||||
offset = 2
|
||||
if x_rank < 4:
|
||||
if x_rank <= 4:
|
||||
offset = 1
|
||||
for i in range(len(self.block_shape)):
|
||||
block_shape_prod = block_shape_prod * self.block_shape[i]
|
||||
|
@ -3591,12 +3604,12 @@ class EditDistance(PrimitiveWithInfer):
|
|||
The shape of tensor is :math:`(N, R)`.
|
||||
- **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor.
|
||||
Must be 1-D vector with length of N.
|
||||
- **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor.
|
||||
- **hypothesis_shape** (Tensor) - The shape of the hypothesis list SparseTensor.
|
||||
Must be R-length vector with int64 data type. Only constant value is allowed.
|
||||
- **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type.
|
||||
The shape of tensor is :math:`(M, R)`.
|
||||
- **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.
|
||||
- **truth_shape** (Tensor) - The values of the truth list SparseTensor.
|
||||
- **truth_shape** (Tensor) - The shape of the truth list SparseTensor.
|
||||
Must be R-length vector with int64 data type. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
|
|
Loading…
Reference in New Issue