forked from mindspore-Ecosystem/mindspore
!1477 support vm for SpaceToBatchND and BatchToSpaceND
Merge pull request !1477 from jiangjinsheng/BatchToSpaceND
This commit is contained in:
commit
b94949ea99
|
@ -82,6 +82,8 @@ static std::map<string, string> tbe_func_adapter_map = {
|
||||||
{"argmax", "arg_max_d"},
|
{"argmax", "arg_max_d"},
|
||||||
{"space_to_batch", "space_to_batch_d"},
|
{"space_to_batch", "space_to_batch_d"},
|
||||||
{"batch_to_space", "batch_to_space_d"},
|
{"batch_to_space", "batch_to_space_d"},
|
||||||
|
{"space_to_batch_nd", "space_to_batch_nd_d"},
|
||||||
|
{"batch_to_space_nd", "batch_to_space_nd_d"},
|
||||||
{"resize_bilinear", "resize_bilinear_v2_d"},
|
{"resize_bilinear", "resize_bilinear_v2_d"},
|
||||||
{"resize_bilinear_grad", "resize_bilinear_v2_grad"},
|
{"resize_bilinear_grad", "resize_bilinear_v2_grad"},
|
||||||
{"adam", "apply_adam"},
|
{"adam", "apply_adam"},
|
||||||
|
|
|
@ -536,3 +536,23 @@ def get_bprop_batch_to_space(self):
|
||||||
dx = batch_to_space_grad(dout)
|
dx = batch_to_space_grad(dout)
|
||||||
return (dx,)
|
return (dx,)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.SpaceToBatchND)
|
||||||
|
def get_bprop_space_to_batch_nd(self):
|
||||||
|
"""Generate bprop for SpaceToBatchND"""
|
||||||
|
space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = space_to_batch_nd_grad(dout)
|
||||||
|
return (dx,)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.BatchToSpaceND)
|
||||||
|
def get_bprop_batch_to_space_nd(self):
|
||||||
|
"""Generate bprop for BatchToSpaceND"""
|
||||||
|
batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
|
||||||
|
def bprop(x, out, dout):
|
||||||
|
dx = batch_to_space_nd_grad(dout)
|
||||||
|
return (dx,)
|
||||||
|
return bprop
|
||||||
|
|
|
@ -200,3 +200,5 @@ from .reduce_prod import _reduce_prod_tbe
|
||||||
from .flatten_grad import _flatten_grad_tbe
|
from .flatten_grad import _flatten_grad_tbe
|
||||||
from .scatter_add import _scatter_add_tbe
|
from .scatter_add import _scatter_add_tbe
|
||||||
from .atan2 import _atan2_tbe
|
from .atan2 import _atan2_tbe
|
||||||
|
from .batch_to_space_nd import _batch_to_space_nd_tbe
|
||||||
|
from .space_to_batch_nd import _space_to_batch_nd_tbe
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""BatchToSpaceND op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
batch_to_space_nd_op_info = TBERegOp("BatchToSpaceND") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("batch_to_space_nd_d.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("batch_to_space_nd_d") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("block_shape", "required", "listInt", "all") \
|
||||||
|
.attr("crops", "required", "listListInt", "all") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(batch_to_space_nd_op_info)
|
||||||
|
def _batch_to_space_nd_tbe():
|
||||||
|
"""BatchToSpaceND TBE register"""
|
||||||
|
return
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""SpaceToBatchND op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
space_to_batch_nd_op_info = TBERegOp("SpaceToBatchND") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("space_to_batch_nd_d.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("space_to_batch_nd_d") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.attr("block_shape", "required", "listInt", "all") \
|
||||||
|
.attr("paddings", "required", "listListInt", "all") \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(space_to_batch_nd_op_info)
|
||||||
|
def _space_to_batch_nd_tbe():
|
||||||
|
"""SpaceToBatchND TBE register"""
|
||||||
|
return
|
|
@ -29,7 +29,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Shape, Size, Slice, Split,
|
Shape, Size, Slice, Split,
|
||||||
Squeeze, StridedSlice, Tile,
|
Squeeze, StridedSlice, Tile,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace)
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
|
SpaceToBatchND, BatchToSpaceND)
|
||||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||||
_VirtualDiv, _GetTensorSlice)
|
_VirtualDiv, _GetTensorSlice)
|
||||||
|
@ -260,6 +261,8 @@ __all__ = [
|
||||||
"Atan2",
|
"Atan2",
|
||||||
"ApplyRMSProp",
|
"ApplyRMSProp",
|
||||||
"ApplyCenteredRMSProp",
|
"ApplyCenteredRMSProp",
|
||||||
|
"SpaceToBatchND",
|
||||||
|
"BatchToSpaceND",
|
||||||
"SquareSumAll"
|
"SquareSumAll"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -980,7 +980,7 @@ class InvertPermutation(PrimitiveWithInfer):
|
||||||
validator.check_value_type("shape", x_shp, [tuple, list], self.name)
|
validator.check_value_type("shape", x_shp, [tuple, list], self.name)
|
||||||
if mstype.issubclass_(x['dtype'], mstype.tensor):
|
if mstype.issubclass_(x['dtype'], mstype.tensor):
|
||||||
validator.check('x dimension', len(x_shp), '', 1, Rel.EQ, self.name)
|
validator.check('x dimension', len(x_shp), '', 1, Rel.EQ, self.name)
|
||||||
validator.check_type_same({'x dtype': x['dtype']}, mstype.int_type, self.name)
|
validator.check_tensor_type_same({'x dtype': x['dtype']}, mstype.int_type, self.name)
|
||||||
x_value = [int(i) for i in x_value.asnumpy()]
|
x_value = [int(i) for i in x_value.asnumpy()]
|
||||||
z = [x_value[i] for i in range(len(x_value))]
|
z = [x_value[i] for i in range(len(x_value))]
|
||||||
z.sort()
|
z.sort()
|
||||||
|
@ -2491,3 +2491,163 @@ class BatchToSpace(PrimitiveWithInfer):
|
||||||
f'block_size_prod {block_size_prod}')
|
f'block_size_prod {block_size_prod}')
|
||||||
out_shape[0] = out_shape[0] // block_size_prod
|
out_shape[0] = out_shape[0] // block_size_prod
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToBatchND(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Divide spatial dimensions into blocks and combine the block size with the original batch.
|
||||||
|
|
||||||
|
This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor's H and W
|
||||||
|
dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
|
||||||
|
product of the original batch and the product of block_shape. Prior to division into blocks, the spatial dimensions
|
||||||
|
of the input are zero padded according to paddings if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
|
||||||
|
The length of block_shape is M correspoding to the number of spatial dimensions.
|
||||||
|
paddings (list): The padding value for H and W dimension, containing M sub list, each containing 2 int value.
|
||||||
|
All values must be >= 0. paddings[i] specifies the paddings for spatial dimension i, which corresponds to
|
||||||
|
input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible
|
||||||
|
by block_shape[i].
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the output tensor with the same type as input. Assume input shape is :math:`(n, c, h, w)` with
|
||||||
|
:math:`block\_shape` and :math:`padddings`. The output tensor shape will be :math:`(n', c', h', w')`, where
|
||||||
|
|
||||||
|
:math:`n' = n*(block\_shape[0]*block\_shape[1])`
|
||||||
|
|
||||||
|
:math:`c' = c`
|
||||||
|
|
||||||
|
:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`
|
||||||
|
|
||||||
|
:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> block_shape = [2, 2]
|
||||||
|
>>> paddings = [[0, 0], [0, 0]]
|
||||||
|
>>> space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings)
|
||||||
|
>>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
|
||||||
|
>>> space_to_batch_nd(input_x)
|
||||||
|
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, block_shape, paddings):
|
||||||
|
"""Init SpaceToBatchND"""
|
||||||
|
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
|
||||||
|
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
|
||||||
|
block_rank = len(block_shape)
|
||||||
|
|
||||||
|
for elem in block_shape:
|
||||||
|
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||||
|
self.block_shape = block_shape
|
||||||
|
|
||||||
|
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||||
|
for elem in itertools.chain(*paddings):
|
||||||
|
validator.check_integer('paddings element', elem, 0, Rel.GE, self.name)
|
||||||
|
validator.check_value_type('paddings element', elem, [int], self.name)
|
||||||
|
self.paddings = paddings
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
x_rank = len(x_shape)
|
||||||
|
out_shape = copy.deepcopy(x_shape)
|
||||||
|
|
||||||
|
block_shape_prod = 1
|
||||||
|
for i in range(x_rank - 2):
|
||||||
|
padded = out_shape[i + 2] + self.paddings[i][0] + \
|
||||||
|
self.paddings[i][1]
|
||||||
|
if padded % self.block_shape[i] != 0:
|
||||||
|
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
|
||||||
|
f'block_shape[{i}] {self.block_shape[i]}')
|
||||||
|
out_shape[i + 2] = padded // self.block_shape[i]
|
||||||
|
block_shape_prod = block_shape_prod * self.block_shape[i]
|
||||||
|
out_shape[0] *= block_shape_prod
|
||||||
|
return out_shape
|
||||||
|
|
||||||
|
|
||||||
|
class BatchToSpaceND(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Divide batch dimension with blocks and interleaves these blocks back into spatial dimensions.
|
||||||
|
|
||||||
|
This operation will divide batch dimension N into blocks with block_shape, the output tensor's N dimension
|
||||||
|
is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W
|
||||||
|
dimension and block_shape with given amount to crop from dimension, respectively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_shape (Union[list(int), tuple(int)]): The block shape of dividing block with all value >= 1.
|
||||||
|
The length of block_shape is M correspoding to the number of spatial dimensions.
|
||||||
|
crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
|
||||||
|
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
|
||||||
|
input dimension i+2. It is required that input_shape[i+2]*block_size[i] >= crops[i][0]+crops[i][1].
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape
|
||||||
|
and crops. The output shape will be (n', c', h', w'), where
|
||||||
|
|
||||||
|
:math:`n' = n//(block\_shape[0]*block\_shape[1])`
|
||||||
|
|
||||||
|
:math:`c' = c`
|
||||||
|
|
||||||
|
:math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
|
||||||
|
|
||||||
|
:math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> block_shape = [2, 2]
|
||||||
|
>>> crops = [[0, 0], [0, 0]]
|
||||||
|
>>> batch_to_space_nd = P.BatchToSpaceND(block_shape, crops)
|
||||||
|
>>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
|
||||||
|
>>> output = batch_to_space_nd(input_x)
|
||||||
|
[[[[1., 2.], [3., 4.]]]]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, block_shape, crops):
|
||||||
|
"""Init BatchToSpaceND"""
|
||||||
|
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
|
||||||
|
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
|
||||||
|
block_rank = len(block_shape)
|
||||||
|
|
||||||
|
for elem in block_shape:
|
||||||
|
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||||
|
self.block_shape = block_shape
|
||||||
|
|
||||||
|
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||||
|
for elem in itertools.chain(*crops):
|
||||||
|
validator.check_integer('crops element', elem, 0, Rel.GE, self.name)
|
||||||
|
validator.check_value_type('crops element', elem, [int], self.name)
|
||||||
|
self.crops = crops
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
|
||||||
|
return x_dtype
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
x_rank = len(x_shape)
|
||||||
|
out_shape = copy.deepcopy(x_shape)
|
||||||
|
|
||||||
|
block_shape_prod = 1
|
||||||
|
for i in range(x_rank - 2):
|
||||||
|
block_shape_prod = block_shape_prod * self.block_shape[i]
|
||||||
|
x_block_prod = out_shape[i + 2] * self.block_shape[i]
|
||||||
|
crops_sum = self.crops[i][0] + self.crops[i][1]
|
||||||
|
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
|
||||||
|
out_shape[i + 2] = x_block_prod - crops_sum
|
||||||
|
|
||||||
|
if out_shape[0] % block_shape_prod != 0:
|
||||||
|
raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
|
||||||
|
f'block_shape_prod {block_shape_prod}')
|
||||||
|
out_shape[0] = out_shape[0] // block_shape_prod
|
||||||
|
return out_shape
|
||||||
|
|
|
@ -264,6 +264,27 @@ class DepthToSpaceNet(Cell):
|
||||||
return self.depth_to_space(x)
|
return self.depth_to_space(x)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchToSpaceNDNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(BatchToSpaceNDNet, self).__init__()
|
||||||
|
block_shape = [2, 2]
|
||||||
|
crops = [[0, 0], [0, 0]]
|
||||||
|
self.batch_to_space_nd = P.BatchToSpaceND(block_shape, crops)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.batch_to_space_nd(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToBatchNDNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SpaceToBatchNDNet, self).__init__()
|
||||||
|
block_shape = [2, 2]
|
||||||
|
paddings = [[0, 0], [0, 0]]
|
||||||
|
self.space_to_batch_nd = P.SpaceToBatchND(block_shape, paddings)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.space_to_batch_nd(x)
|
||||||
|
|
||||||
test_case_array_ops = [
|
test_case_array_ops = [
|
||||||
('CustNet1', {
|
('CustNet1', {
|
||||||
'block': CustNet1(),
|
'block': CustNet1(),
|
||||||
|
@ -298,6 +319,12 @@ test_case_array_ops = [
|
||||||
('DepthToSpaceNet', {
|
('DepthToSpaceNet', {
|
||||||
'block': DepthToSpaceNet(),
|
'block': DepthToSpaceNet(),
|
||||||
'desc_inputs': [Tensor(np.random.rand(1,12,1,1).astype(np.float16))]}),
|
'desc_inputs': [Tensor(np.random.rand(1,12,1,1).astype(np.float16))]}),
|
||||||
|
('SpaceToBatchNDNet', {
|
||||||
|
'block': SpaceToBatchNDNet(),
|
||||||
|
'desc_inputs': [Tensor(np.random.rand(1,1,2,2).astype(np.float16))]}),
|
||||||
|
('BatchToSpaceNDNet', {
|
||||||
|
'block': BatchToSpaceNDNet(),
|
||||||
|
'desc_inputs': [Tensor(np.random.rand(4,1,1,1).astype(np.float16))]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_lists = [test_case_array_ops]
|
test_case_lists = [test_case_array_ops]
|
||||||
|
|
Loading…
Reference in New Issue