From f4bb076a4419767cf35a17e3c08f392505a5acd2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 17 Mar 2021 15:53:18 -0700 Subject: [PATCH] [mlir][tosa] Add tosa.slice to std.subtensor lowering Lowering to subtensor is added for tosa.slice operator. Differential Revision: https://reviews.llvm.org/D98825 --- .../TosaToStandard/TosaToStandard.cpp | 21 ++++++++++++++++++- .../TosaToStandard/TosaToStandardPass.cpp | 3 ++- .../TosaToStandard/tosa-to-standard.mlir | 8 +++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp index 21a8da291aee..6e5411dd5ecb 100644 --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -32,9 +32,28 @@ public: } }; +class SliceOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const final { + Value input = sliceOp.input(); + SmallVector strides; + strides.resize(sliceOp.getType().template cast().getRank(), 1); + + rewriter.replaceOpWithNewOp( + sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}), + ValueRange({}), sliceOp.start(), sliceOp.size(), + rewriter.getI64ArrayAttr(strides)); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToStandardConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + patterns->insert(context); } diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp index 225855e78bda..78a0e65da81b 100644 --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -32,7 +32,8 @@ public: OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addIllegalOp(); - target.addLegalOp(); + target.addIllegalOp(); + target.addLegalDialect(); auto *op = getOperation(); mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(), diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir index 86304dcba862..94925aec15c7 100644 --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -8,3 +8,11 @@ func @const_test() -> (tensor) { // CHECK: return [[C3]] return %0 : tensor } + +// ---- + +func @slice(%arg0: tensor<6xf32>) ->() { + // CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1] + %0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>) + return +}