forked from OSchip/llvm-project
[mlir][tosa] Add tosa.tile to linalg.generic lowering
Tiling operations are generic operations with modified indexing. Updated to to linalg lowerings to perform this lowering. Differential Revision: https://reviews.llvm.org/D99113
This commit is contained in:
parent
1bc33eb6a3
commit
2d72b675d5
|
@ -702,6 +702,11 @@ public:
|
|||
ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
|
||||
if (operandTy == resultTy) {
|
||||
rewriter.replaceOp(reshape, args[0]);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
|
@ -1086,6 +1091,70 @@ public:
|
|||
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This converter translate a tile operation to a reshape, broadcast, reshape.
|
||||
// The first reshape minimally expands each tiled dimension to include a
|
||||
// proceding size-1 dim. This dim is then broadcasted to the appropriate
|
||||
// multiple.
|
||||
struct TileConverter : public OpConversionPattern<tosa::TileOp> {
|
||||
using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tosa::TileOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto input = op.input1();
|
||||
auto inputTy = input.getType().cast<ShapedType>();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto resultTy = op.getType().cast<ShapedType>();
|
||||
auto elementTy = inputTy.getElementType();
|
||||
int64_t rank = inputTy.getRank();
|
||||
|
||||
if (!inputTy.hasStaticShape() || !resultTy.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> multiples;
|
||||
getValuesFromIntArrayAttribute(op.multiples(), multiples);
|
||||
|
||||
llvm::SmallVector<int64_t, 4> reshapeShape;
|
||||
reshapeShape.reserve(rank * 2);
|
||||
for (int i = 0; i < rank; i++) {
|
||||
reshapeShape.push_back(1);
|
||||
reshapeShape.push_back(inputShape[i]);
|
||||
}
|
||||
|
||||
ShapedType reshapeTy = RankedTensorType::get(reshapeShape, elementTy);
|
||||
Value reshape = rewriter.create<tosa::ReshapeOp>(
|
||||
loc, reshapeTy, input, rewriter.getI64ArrayAttr(reshapeTy.getShape()));
|
||||
|
||||
// Broadcast the newly added dimensions to their appropriate multiple.
|
||||
SmallVector<int64_t, 2> genericShape;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
genericShape.push_back(multiples[i]);
|
||||
genericShape.push_back(inputShape[i]);
|
||||
}
|
||||
|
||||
auto initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
op.getLoc(), ArrayRef<Value>({}), genericShape, elementTy);
|
||||
|
||||
SmallVector<AffineMap, 2> affineMaps = {
|
||||
createAffineMapForType(reshapeTy, rewriter),
|
||||
rewriter.getMultiDimIdentityMap(genericShape.size())};
|
||||
|
||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, RankedTensorType::get(genericShape, elementTy), reshape,
|
||||
ValueRange{initTensor}, affineMaps,
|
||||
getNParallelLoopsAttrs(genericShape.size()),
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
|
||||
});
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, resultTy, genericOp.getResult(0),
|
||||
rewriter.getI64ArrayAttr(resultTy.getShape()));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1119,6 +1188,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|||
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
|
||||
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
|
||||
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
|
||||
RescaleConverter, ReverseConverter, TransposeConverter, MatMulConverter,
|
||||
FullyConnectedConverter>(patterns->getContext());
|
||||
RescaleConverter, ReverseConverter, TileConverter, TransposeConverter,
|
||||
MatMulConverter, FullyConnectedConverter>(patterns->getContext());
|
||||
}
|
||||
|
|
|
@ -636,6 +636,40 @@ func @reverse(%arg0: tensor<5x4xi32>) -> () {
|
|||
// CHECK: ^bb0(%arg1: i32, %arg2: i32):
|
||||
// CHECK: linalg.yield %arg1 : i32
|
||||
%1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
|
||||
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>
|
||||
// CHECK: #[[$MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK: #[[$MAP4:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
|
||||
// CHECK: #[[$MAP5:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
|
||||
// CHECK-LABEL: @tile
|
||||
func @tile(%arg0 : tensor<2x3xi8>) -> () {
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8>
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 2, 1, 3]
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<2x2x1x3xi8>)
|
||||
// CHECK: linalg.yield %arg1 : i8
|
||||
// CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP0]], #[[$MAP1]]]
|
||||
%0 = "tosa.tile"(%arg0) {multiples = [2, 1]} : (tensor<2x3xi8>) -> (tensor<4x3xi8>)
|
||||
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8>
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 2, 2, 3]
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>)
|
||||
// CHECK: linalg.yield %arg1 : i8
|
||||
// CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]]
|
||||
%1 = "tosa.tile"(%arg0) {multiples = [1, 2]} : (tensor<2x3xi8>) -> (tensor<2x6xi8>)
|
||||
|
||||
// CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]]] : tensor<2x3xi8> into tensor<1x2x1x3xi8>
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2, 7, 3]
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[RESHAPE]] : tensor<1x2x1x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>)
|
||||
// CHECK: linalg.yield %arg1 : i8
|
||||
// CHECK: linalg.tensor_reshape [[GENERIC]] [#[[$MAP4]], #[[$MAP5]]]
|
||||
%2 = "tosa.tile"(%arg0) {multiples = [5, 7]} : (tensor<2x3xi8>) -> (tensor<10x21xi8>)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue