forked from mindspore-Ecosystem/mindspore
!2843 Add TransShape operator
Merge pull request !2843 from fanglei/trans_shape
This commit is contained in:
commit
17319d8dfd
|
@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
|
||||||
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
|
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
|
||||||
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
|
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
|
||||||
const char kNameReshape[] = "Reshape";
|
const char kNameReshape[] = "Reshape";
|
||||||
|
const char kNameTransShape[] = "TransShape";
|
||||||
const char kNameRealDiv[] = "RealDiv";
|
const char kNameRealDiv[] = "RealDiv";
|
||||||
const char kNameTile[] = "Tile";
|
const char kNameTile[] = "Tile";
|
||||||
const char kNameCos[] = "Cos";
|
const char kNameCos[] = "Cos";
|
||||||
|
@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
||||||
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
|
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
|
||||||
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
|
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
|
||||||
{string(kNameReshape), ADPT_DESC(Reshape)},
|
{string(kNameReshape), ADPT_DESC(Reshape)},
|
||||||
|
{string(kNameTransShape), ADPT_DESC(TransShape)},
|
||||||
{string(kNameFlattenGrad), ADPT_DESC(Reshape)},
|
{string(kNameFlattenGrad), ADPT_DESC(Reshape)},
|
||||||
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
|
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
|
||||||
{string(kNameAddN), ADPT_DESC(AddN)},
|
{string(kNameAddN), ADPT_DESC(AddN)},
|
||||||
|
|
|
@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
||||||
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
|
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
|
||||||
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
|
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
|
// TransShape
|
||||||
|
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
|
||||||
|
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
|
||||||
|
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
|
||||||
|
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};
|
||||||
|
|
||||||
// BiasAdd
|
// BiasAdd
|
||||||
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
|
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
|
||||||
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||||
|
|
|
@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
|
||||||
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
|
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
|
||||||
DECLARE_OP_ADAPTER(Reshape)
|
DECLARE_OP_ADAPTER(Reshape)
|
||||||
DECLARE_OP_USE_OUTPUT(Reshape)
|
DECLARE_OP_USE_OUTPUT(Reshape)
|
||||||
|
DECLARE_OP_ADAPTER(TransShape)
|
||||||
|
DECLARE_OP_USE_INPUT_ATTR(TransShape)
|
||||||
|
DECLARE_OP_USE_OUTPUT(TransShape)
|
||||||
DECLARE_OP_ADAPTER(Iou)
|
DECLARE_OP_ADAPTER(Iou)
|
||||||
DECLARE_OP_USE_OUTPUT(Iou)
|
DECLARE_OP_USE_OUTPUT(Iou)
|
||||||
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
|
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
|
||||||
|
|
|
@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
|
||||||
dx = reverse_sequence_grad(dout, seq_lengths)
|
dx = reverse_sequence_grad(dout, seq_lengths)
|
||||||
return dx, zeros_like(seq_lengths)
|
return dx, zeros_like(seq_lengths)
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.TransShape)
|
||||||
|
def get_bprop_trans_shape(self):
|
||||||
|
"""Generate bprop for TransShape"""
|
||||||
|
op = P.TransShape()
|
||||||
|
def bprop(x, shape, out, dout):
|
||||||
|
dx = op(dout, shape_op(x))
|
||||||
|
return (dx, zeros_like(shape))
|
||||||
|
return bprop
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split,
|
Shape, Size, Slice, Split, TransShape,
|
||||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
|
|
|
@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
|
||||||
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
|
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
|
||||||
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
|
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransShape(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Transform the shape of input tensor to target shape.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - A input tensor.
|
||||||
|
- **out_shape** (tuple[int]) - The shape of output data.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
self.__setattr_flag__ = True
|
||||||
|
|
||||||
|
def __infer__(self, x, shape):
|
||||||
|
shp = shape['value']
|
||||||
|
dtype = x['dtype']
|
||||||
|
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
|
||||||
|
self.add_prim_attr('out_shape', tuple(shp))
|
||||||
|
return {'shape': shp,
|
||||||
|
'dtype': dtype,
|
||||||
|
'value': None}
|
||||||
|
|
|
@ -1865,6 +1865,12 @@ test_case_array_ops = [
|
||||||
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
|
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
|
||||||
'skip': ['backward'],
|
'skip': ['backward'],
|
||||||
}),
|
}),
|
||||||
|
('TransShape', {
|
||||||
|
'block': P.TransShape(),
|
||||||
|
'desc_const': [(1, 12, 24, 24)],
|
||||||
|
'desc_inputs': [[1, 3, 24, 24]],
|
||||||
|
'desc_bprop': [[1, 12, 24, 24]],
|
||||||
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
test_case_other_ops = [
|
test_case_other_ops = [
|
||||||
|
|
Loading…
Reference in New Issue