add operator SpaceToBatch and BatchToSpace for ge

This commit is contained in:
zhaozhenlong 2020-03-31 15:15:08 +08:00
parent d84bf8d357
commit cf40305bf0
8 changed files with 228 additions and 2 deletions

View File

@ -180,6 +180,8 @@ const char kNamePrint[] = "Print";
const char kNameApplyFtrl[] = "ApplyFtrl"; const char kNameApplyFtrl[] = "ApplyFtrl";
const char kNameDiag[] = "Diag"; const char kNameDiag[] = "Diag";
const char kNameDiagPart[] = "DiagPart"; const char kNameDiagPart[] = "DiagPart";
const char kNameSpaceToBatch[] = "SpaceToBatch";
const char kNameBatchToSpace[] = "BatchToSpace";
// -----------------OpAdapter initialization-------------- // -----------------OpAdapter initialization--------------
std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() { std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_map() {
@ -361,7 +363,9 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameRound), ADPT_DESC(Round)}, {string(kNameRound), ADPT_DESC(Round)},
{string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}, {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)},
{string(kNameDiag), ADPT_DESC(Diag)}, {string(kNameDiag), ADPT_DESC(Diag)},
{string(kNameDiagPart), ADPT_DESC(DiagPart)}}; {string(kNameDiagPart), ADPT_DESC(DiagPart)},
{string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)},
{string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
adpt_map[string(kNamePrint)] = ADPT_DESC(Print); adpt_map[string(kNamePrint)] = ADPT_DESC(Print);
#endif #endif

View File

@ -744,6 +744,28 @@ class OpAdapter : public BaseOpAdapter {
return list; return list;
} }
static std::vector<int64_t> ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<std::vector<int64_t>>>,
const AnyTraits<std::vector<int64_t>>) {
MS_EXCEPTION_IF_NULL(value);
MS_LOG(DEBUG) << "Value: " << value->type_name();
if (!value->isa<ValueList>()) {
MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name();
}
auto vec = value->cast<ValueListPtr>();
std::vector<int64_t> list;
for (auto& it : vec->value()) {
MS_EXCEPTION_IF_NULL(it);
if (!it->isa<ValueList>()) {
MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name();
}
auto sub_vector = it->cast<ValueListPtr>();
for (auto& item : sub_vector->value()) {
list.push_back(static_cast<int64_t>(GetValue<int>(item)));
}
}
return list;
}
static std::vector<int64_t> ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<int64_t>>, static std::vector<int64_t> ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<int64_t>>,
const AnyTraits<std::vector<int64_t>>) { const AnyTraits<std::vector<int64_t>>) {
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);

View File

@ -1183,6 +1183,19 @@ INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}};
ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP; ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP;
OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}};
// SpaceToBatchD
INPUT_MAP(SpaceToBatchD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(SpaceToBatchD) = {
{"block_size", ATTR_DESC(block_size, AnyTraits<int64_t>())},
{"paddings", ATTR_DESC(paddings, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}};
// BatchToSpaceD
INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}};
ATTR_MAP(BatchToSpaceD) = {
{"block_size", ATTR_DESC(block_size, AnyTraits<int64_t>())},
{"crops", ATTR_DESC(crops, AnyTraits<std::vector<std::vector<int64_t>>>(), AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}};
#ifdef ENABLE_GE #ifdef ENABLE_GE
// Print // Print
INPUT_MAP(Print) = EMPTY_INPUT_MAP; INPUT_MAP(Print) = EMPTY_INPUT_MAP;

View File

@ -439,6 +439,10 @@ DECLARE_OP_ADAPTER(Diag)
DECLARE_OP_USE_OUTPUT(Diag) DECLARE_OP_USE_OUTPUT(Diag)
DECLARE_OP_ADAPTER(DiagPart) DECLARE_OP_ADAPTER(DiagPart)
DECLARE_OP_USE_OUTPUT(DiagPart) DECLARE_OP_USE_OUTPUT(DiagPart)
DECLARE_OP_ADAPTER(SpaceToBatchD)
DECLARE_OP_USE_OUTPUT(SpaceToBatchD)
DECLARE_OP_ADAPTER(BatchToSpaceD)
DECLARE_OP_USE_OUTPUT(BatchToSpaceD)
#ifdef ENABLE_GE #ifdef ENABLE_GE
DECLARE_OP_ADAPTER(Print) DECLARE_OP_ADAPTER(Print)
DECLARE_OP_USE_DYN_INPUT(Print) DECLARE_OP_USE_DYN_INPUT(Print)

View File

@ -430,3 +430,23 @@ def get_bprop_diag_part(self):
return (op(dout),) return (op(dout),)
return bprop return bprop
@bprop_getters.register(P.SpaceToBatch)
def get_bprop_space_to_batch(self):
"""Generate bprop for SpaceToBatch"""
space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
def bprop(x, out, dout):
dx = space_to_batch_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(P.BatchToSpace)
def get_bprop_batch_to_space(self):
"""Generate bprop for BatchToSpace"""
batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
def bprop(x, out, dout):
dx = batch_to_space_grad(dout)
return (dx,)
return bprop

View File

@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
Shape, Size, Slice, Split, Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, Squeeze, StridedSlice, Tile,
Transpose, TruncatedNormal, TupleToArray, Transpose, TruncatedNormal, TupleToArray,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace) UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset, _MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice) _VirtualDiv, _GetTensorSlice)
@ -225,6 +225,8 @@ __all__ = [
"LARSUpdate", "LARSUpdate",
"Round", "Round",
"ApplyFtrl", "ApplyFtrl",
"SpaceToBatch",
"BatchToSpace"
] ]
__all__.sort() __all__.sort()

View File

@ -20,6 +20,7 @@
import copy import copy
import functools import functools
import itertools
import numbers import numbers
import numpy as np import numpy as np
@ -2020,3 +2021,143 @@ class DepthToSpace(PrimitiveWithInfer):
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) validator.check_subclass("x_dtype", x_dtype, mstype.tensor)
return x_dtype return x_dtype
class SpaceToBatch(PrimitiveWithInfer):
r"""
Divide spatial dimensions into blocks and combine the block size with the original batch.
This operation will divide spatial dimensions (H, W) into blocks with block_size, the output tensor's H and W
dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
product of the original batch and the square of block_size. Prior to division into blocks, the spatial dimensions
of the input are zero padded according to paddings if necessary.
Args:
block_size (int): The block size of dividing block with value >= 1.
paddings (list): The padding value for H and W dimension, containing 2 sub list, each containing 2 int value.
All values must be >= 0. paddings[i] specifies the paddings for spatial dimension i, which corresponds to
input dimension i+2. It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible
by block_size.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, the output tensor with the same type as input. Assume input shape is :math:`(n, c, h, w)` with
:math:`block\_size` and :math:`padddings`. The output tensor shape will be :math:`(n', c', h', w')`, where
:math:`n' = n*(block\_size*block\_size)`
:math:`c' = c`
:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_size`
:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_size`
Examples:
>>> block_size = 2
>>> paddings = [[0, 0], [0, 0]]
>>> space_to_batch = P.SpaceToBatch(block_size, paddings)
>>> x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mstype.float32)
>>> space_to_batch(x)
[[[[1.]]], [[[2.]]], [[[3.]]], [[[4.]]]]
"""
@prim_attr_register
def __init__(self, block_size, paddings):
"""Init SpaceToBatch"""
validator.check_type('block_size', block_size, [int])
validator.check('block_size', block_size, '', 1, Rel.GT)
self.block_size = block_size
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2))
for elem in itertools.chain(*paddings):
validator.check_type('paddings element', elem, [int])
self.paddings = paddings
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor)
validator.check_typename('input_x', x_dtype, mstype.number_type)
return x_dtype
def infer_shape(self, x_shape):
validator.check('rank of input_x', len(x_shape), '', 4)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
padded = out_shape[i+2] + self.paddings[i][0] + \
self.paddings[i][1]
if padded % self.block_size != 0:
raise ValueError(f'padded[{i}] {padded} should be divisible by '
f'block_size {self.block_size}')
out_shape[i+2] = padded // self.block_size
out_shape[0] *= self.block_size * self.block_size
return out_shape
class BatchToSpace(PrimitiveWithInfer):
r"""
Divide batch dimension with blocks and interleaves these blocks back into spatial dimensions.
This operation will divide batch dimension N into blocks with block_size, the output tensor's N dimension
is the corresponding number of blocks after division. The output tensor's H, W dimension is product of original H, W
dimension and block_size with given amount to crop from dimension, respectively.
Args:
block_size (int): The block size of dividing block with value >= 1.
crops (list): The crop value for H and W dimension, containing 2 sub list, each containing 2 int value.
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
input dimension i+2. It is required that input_shape[i+2]*block_size >= crops[i][0]+crops[i][1].
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_size
and crops. The output shape will be (n', c', h', w'), where
:math:`n' = n//(block\_size*block\_size)`
:math:`c' = c`
:math:`h' = h*block\_size-crops[0][0]-crops[0][1]`
:math:`w' = w*block\_size-crops[1][0]-crops[1][1]`
Examples:
>>> block_size = 2
>>> crops = [[0, 0], [0, 0]]
>>> op = P.BatchToSpace(block_size, crops)
>>> x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mstype.float32)
>>> output = op(x)
[[[[1., 2.], [3., 4.]]]]
"""
@prim_attr_register
def __init__(self, block_size, crops):
"""Init BatchToSpace"""
validator.check_type('block_size', block_size, [int])
validator.check('block_size', block_size, '', 1, Rel.GT)
self.block_size = block_size
validator.check('crops shape', np.array(crops).shape, '', (2, 2))
for elem in itertools.chain(*crops):
validator.check_type('crops element', elem, [int])
self.crops = crops
def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor)
validator.check_typename('input_x', x_dtype, mstype.number_type)
return x_dtype
def infer_shape(self, x_shape):
validator.check('rank of input_x', len(x_shape), '', 4)
out_shape = copy.deepcopy(x_shape)
for i in range(2):
x_block_prod = out_shape[i+2] * self.block_size
crops_sum = self.crops[i][0] + self.crops[i][1]
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT)
out_shape[i+2] = x_block_prod - crops_sum
block_size_prod = self.block_size * self.block_size
if out_shape[0] % block_size_prod != 0:
raise ValueError(f'input_x dimension 0 {out_shape[0]} should be divisible by '
f'block_size_prod {block_size_prod}')
out_shape[0] = out_shape[0] // block_size_prod
return out_shape

View File

@ -952,6 +952,26 @@ test_case_array_ops = [
'desc_inputs': [[4, 4]], 'desc_inputs': [[4, 4]],
'desc_bprop': [[4]], 'desc_bprop': [[4]],
}), }),
('SpaceToBatch_1', {
'block': P.SpaceToBatch(2, [[0, 0], [0, 0]]),
'desc_inputs': [[1, 3, 2, 2]],
'desc_bprop': [[4, 3, 1, 1]],
}),
('SpaceToBatch_2', {
'block': P.SpaceToBatch(2, [[1, 1], [0, 4]]),
'desc_inputs': [[1, 3, 2, 2]],
'desc_bprop': [[4, 3, 2, 4]],
}),
('BatchToSpace_1', {
'block': P.BatchToSpace(2, [[0, 0], [0, 0]]),
'desc_inputs': [[4, 3, 1, 1]],
'desc_bprop': [[1, 3, 2, 2]],
}),
('BatchToSpace_2', {
'block': P.BatchToSpace(2, [[0, 0], [0, 1]]),
'desc_inputs': [[4, 3, 1, 1]],
'desc_bprop': [[1, 3, 2, 1]],
}),
] ]
test_case_other_ops = [ test_case_other_ops = [