[MLIR][TOSA] Lower tosa.transpose to linalg.generic

Lowers the transpose operation to a generic linalg op when permutations
is a constant value.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D97508
This commit is contained in:
Rob Suderman 2021-03-01 11:00:34 -08:00
parent 2fcc3f4b18
commit 087bc20fe4
2 changed files with 63 additions and 1 deletions

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -438,6 +439,48 @@ public:
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, args[0], reassociationMap);
return success();
}
};
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
public:
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
DenseIntElementsAttr perms;
if (!matchPattern(op.perms(), m_Constant(&perms))) {
return failure();
}
auto resultTy = op.getType().cast<ShapedType>();
if (!resultTy.hasStaticShape())
return failure();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
inputExprs[permutation.value().getZExtValue()] =
rewriter.getAffineDimExpr(permutation.index());
}
auto initTensor = rewriter.create<linalg::InitTensorOp>(
op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
resultTy.getElementType());
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
return success();
}
};
@ -478,5 +521,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>, ReshapeOpConverter>(context);
IdentityNConverter<tosa::IdentityNOp>,
ReshapeOpConverter, TransposeConverter>(context);
}

View File

@ -317,3 +317,21 @@ func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32
// CHECK: return %arg0, %arg1
return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32>
}
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: @test_transpose
// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
%0 = constant dense<[1, 2, 0]> : tensor<3xi32>
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
// CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
// CHECK: linalg.yield [[ARG1]]
// CHECK: }
%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
return
}