forked from mindspore-Ecosystem/mindspore
!2843 Add TransShape operator
Merge pull request !2843 from fanglei/trans_shape
This commit is contained in:
commit
17319d8dfd
|
@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
|
|||
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
|
||||
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
|
||||
const char kNameReshape[] = "Reshape";
|
||||
const char kNameTransShape[] = "TransShape";
|
||||
const char kNameRealDiv[] = "RealDiv";
|
||||
const char kNameTile[] = "Tile";
|
||||
const char kNameCos[] = "Cos";
|
||||
|
@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
|
|||
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
|
||||
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
|
||||
{string(kNameReshape), ADPT_DESC(Reshape)},
|
||||
{string(kNameTransShape), ADPT_DESC(TransShape)},
|
||||
{string(kNameFlattenGrad), ADPT_DESC(Reshape)},
|
||||
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
|
||||
{string(kNameAddN), ADPT_DESC(AddN)},
|
||||
|
|
|
@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
|
|||
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// TransShape
|
||||
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
|
||||
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
|
||||
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};
|
||||
|
||||
// BiasAdd
|
||||
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
|
||||
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
|
||||
|
|
|
@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
|
|||
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
|
||||
DECLARE_OP_ADAPTER(Reshape)
|
||||
DECLARE_OP_USE_OUTPUT(Reshape)
|
||||
DECLARE_OP_ADAPTER(TransShape)
|
||||
DECLARE_OP_USE_INPUT_ATTR(TransShape)
|
||||
DECLARE_OP_USE_OUTPUT(TransShape)
|
||||
DECLARE_OP_ADAPTER(Iou)
|
||||
DECLARE_OP_USE_OUTPUT(Iou)
|
||||
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
|
||||
|
|
|
@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
|
|||
dx = reverse_sequence_grad(dout, seq_lengths)
|
||||
return dx, zeros_like(seq_lengths)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TransShape)
|
||||
def get_bprop_trans_shape(self):
|
||||
"""Generate bprop for TransShape"""
|
||||
op = P.TransShape()
|
||||
def bprop(x, shape, out, dout):
|
||||
dx = op(dout, shape_op(x))
|
||||
return (dx, zeros_like(shape))
|
||||
return bprop
|
||||
|
|
|
@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split,
|
||||
Shape, Size, Slice, Split, TransShape,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
|
|
|
@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
|
|||
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
|
||||
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
|
||||
return x
|
||||
|
||||
|
||||
class TransShape(PrimitiveWithInfer):
|
||||
"""
|
||||
Transform the shape of input tensor to target shape.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - A input tensor.
|
||||
- **out_shape** (tuple[int]) - The shape of output data.
|
||||
|
||||
Outputs:
|
||||
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
def __infer__(self, x, shape):
|
||||
shp = shape['value']
|
||||
dtype = x['dtype']
|
||||
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
|
||||
self.add_prim_attr('out_shape', tuple(shp))
|
||||
return {'shape': shp,
|
||||
'dtype': dtype,
|
||||
'value': None}
|
||||
|
|
|
@ -1865,6 +1865,12 @@ test_case_array_ops = [
|
|||
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
('TransShape', {
|
||||
'block': P.TransShape(),
|
||||
'desc_const': [(1, 12, 24, 24)],
|
||||
'desc_inputs': [[1, 3, 24, 24]],
|
||||
'desc_bprop': [[1, 12, 24, 24]],
|
||||
}),
|
||||
]
|
||||
|
||||
test_case_other_ops = [
|
||||
|
|
Loading…
Reference in New Issue