forked from OSchip/llvm-project
[MLIR][TOSA] Lower tosa.transpose to linalg.generic
Lowers the transpose operation to a generic linalg op when permutations is a constant value. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D97508
This commit is contained in:
parent
2fcc3f4b18
commit
087bc20fe4
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
@ -438,6 +439,48 @@ public:
|
|||
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshape, resultTy, args[0], reassociationMap);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
DenseIntElementsAttr perms;
|
||||
if (!matchPattern(op.perms(), m_Constant(&perms))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto resultTy = op.getType().cast<ShapedType>();
|
||||
if (!resultTy.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultTy.getRank());
|
||||
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
|
||||
inputExprs[permutation.value().getZExtValue()] =
|
||||
rewriter.getAffineDimExpr(permutation.index());
|
||||
}
|
||||
|
||||
auto initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
|
||||
resultTy.getElementType());
|
||||
|
||||
SmallVector<AffineMap, 2> affineMaps = {
|
||||
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
|
||||
rewriter.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
||||
|
||||
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||||
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
|
||||
getNParallelLoopsAttrs(resultTy.getRank()),
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -478,5 +521,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|||
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
|
||||
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
|
||||
IdentityNConverter<tosa::IdentityOp>,
|
||||
IdentityNConverter<tosa::IdentityNOp>, ReshapeOpConverter>(context);
|
||||
IdentityNConverter<tosa::IdentityNOp>,
|
||||
ReshapeOpConverter, TransposeConverter>(context);
|
||||
}
|
||||
|
|
|
@ -317,3 +317,21 @@ func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32
|
|||
// CHECK: return %arg0, %arg1
|
||||
return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
|
||||
// CHECK-LABEL: @test_transpose
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
|
||||
func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
|
||||
%0 = constant dense<[1, 2, 0]> : tensor<3xi32>
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1]
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
|
||||
// CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
|
||||
// CHECK: linalg.yield [[ARG1]]
|
||||
// CHECK: }
|
||||
%1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue