!2843 Add TransShape operator

Merge pull request !2843 from fanglei/trans_shape
This commit is contained in:
mindspore-ci-bot 2020-07-06 09:14:13 +08:00 committed by Gitee
commit 17319d8dfd
7 changed files with 53 additions and 1 deletions

View File

@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
const char kNameReshape[] = "Reshape"; const char kNameReshape[] = "Reshape";
const char kNameTransShape[] = "TransShape";
const char kNameRealDiv[] = "RealDiv"; const char kNameRealDiv[] = "RealDiv";
const char kNameTile[] = "Tile"; const char kNameTile[] = "Tile";
const char kNameCos[] = "Cos"; const char kNameCos[] = "Cos";
@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, {string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
{string(kNameReshape), ADPT_DESC(Reshape)}, {string(kNameReshape), ADPT_DESC(Reshape)},
{string(kNameTransShape), ADPT_DESC(TransShape)},
{string(kNameFlattenGrad), ADPT_DESC(Reshape)}, {string(kNameFlattenGrad), ADPT_DESC(Reshape)},
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
{string(kNameAddN), ADPT_DESC(AddN)}, {string(kNameAddN), ADPT_DESC(AddN)},

View File

@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}}; OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
// TransShape
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};
// BiasAdd // BiasAdd
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}}; INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}}; ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};

View File

@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
DECLARE_OP_ADAPTER(Reshape) DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape) DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(TransShape)
DECLARE_OP_USE_INPUT_ATTR(TransShape)
DECLARE_OP_USE_OUTPUT(TransShape)
DECLARE_OP_ADAPTER(Iou) DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou) DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)

View File

@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
dx = reverse_sequence_grad(dout, seq_lengths) dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths) return dx, zeros_like(seq_lengths)
return bprop return bprop
@bprop_getters.register(P.TransShape)
def get_bprop_trans_shape(self):
"""Generate bprop for TransShape"""
op = P.TransShape()
def bprop(x, shape, out, dout):
dx = op(dout, shape_op(x))
return (dx, zeros_like(shape))
return bprop

View File

@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, Shape, Size, Slice, Split, TransShape,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,

View File

@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name) validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name) validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
return x return x
class TransShape(PrimitiveWithInfer):
"""
Transform the shape of input tensor to target shape.
Inputs:
- **input_x** (Tensor) - A input tensor.
- **out_shape** (tuple[int]) - The shape of output data.
Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
"""
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
'value': None}

View File

@ -1865,6 +1865,12 @@ test_case_array_ops = [
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)], Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'], 'skip': ['backward'],
}), }),
('TransShape', {
'block': P.TransShape(),
'desc_const': [(1, 12, 24, 24)],
'desc_inputs': [[1, 3, 24, 24]],
'desc_bprop': [[1, 12, 24, 24]],
}),
] ]
test_case_other_ops = [ test_case_other_ops = [