forked from OSchip/llvm-project
[mlir][tosa] Add tosa.reverse lowering to linalg.generic
Reverse lowers to a linalg.generic op by reversing the read order in the index map. Differential Revision: https://reviews.llvm.org/D98997
This commit is contained in:
parent
6c9cac5da1
commit
e990fa2170
|
@ -585,7 +585,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
||||
|
||||
|
@ -727,7 +727,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class RescaleOpConverter : public OpRewritePattern<tosa::RescaleOp> {
|
||||
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
|
||||
|
||||
|
@ -889,7 +889,7 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
struct ConcatOpConversion : public OpConversionPattern<tosa::ConcatOp> {
|
||||
struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
|
||||
using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
|
@ -936,6 +936,56 @@ struct ConcatOpConversion : public OpConversionPattern<tosa::ConcatOp> {
|
|||
}
|
||||
};
|
||||
|
||||
class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::ReverseOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
auto inputTy = input.getType().template cast<ShapedType>();
|
||||
auto resultTy = op.getType().template cast<ShapedType>();
|
||||
auto rank = resultTy.getRank();
|
||||
auto axis = op.axis();
|
||||
|
||||
if (!inputTy.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "No initial value found for reduction operation");
|
||||
|
||||
// First fill the output buffer with the init value.
|
||||
auto initTensor = rewriter
|
||||
.create<linalg::InitTensorOp>(
|
||||
loc, ArrayRef<Value>({}), inputTy.getShape(),
|
||||
inputTy.getElementType())
|
||||
.result();
|
||||
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.resize(resultTy.getRank());
|
||||
|
||||
for (int i = 0; i < rank; i++)
|
||||
inputExprs[i] = rewriter.getAffineDimExpr(i);
|
||||
|
||||
inputExprs[axis] =
|
||||
rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) -
|
||||
inputExprs[axis];
|
||||
|
||||
SmallVector<AffineMap, 2> affineMaps = {
|
||||
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
|
||||
rewriter.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
||||
|
||||
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||||
op, resultTy, op.input(), ValueRange{initTensor}, affineMaps,
|
||||
getNParallelLoopsAttrs(resultTy.getRank()),
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
||||
|
@ -963,6 +1013,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|||
IdentityNConverter<tosa::IdentityOp>,
|
||||
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
|
||||
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
|
||||
ReduceConverter<tosa::ReduceProdOp>, ConcatOpConversion,
|
||||
ReshapeOpConverter, TransposeConverter, RescaleOpConverter>(context);
|
||||
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
|
||||
RescaleConverter, ReverseConverter, TransposeConverter>(context);
|
||||
}
|
||||
|
|
|
@ -598,3 +598,26 @@ func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) {
|
|||
%0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>)
|
||||
return %0 : tensor<1xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)>
|
||||
|
||||
// CHECK-LABEL: @reverse
|
||||
func @reverse(%arg0: tensor<5x4xi32>) -> () {
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4]
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) {
|
||||
// CHECK: ^bb0(%arg1: i32, %arg2: i32):
|
||||
// CHECK: linalg.yield %arg1 : i32
|
||||
%0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
|
||||
|
||||
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4]
|
||||
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) {
|
||||
// CHECK: ^bb0(%arg1: i32, %arg2: i32):
|
||||
// CHECK: linalg.yield %arg1 : i32
|
||||
%1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue