[mlir][tosa] Add tosa.slice to std.subtensor lowering

Lowering to subtensor is added for tosa.slice operator.

Differential Revision: https://reviews.llvm.org/D98825
This commit is contained in:
Rob Suderman 2021-03-17 15:53:18 -07:00
parent d672d5219a
commit f4bb076a44
3 changed files with 30 additions and 2 deletions

View File

@ -32,9 +32,28 @@ public:
}
};
class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
public:
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const final {
Value input = sliceOp.input();
SmallVector<int64_t> strides;
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
rewriter.replaceOpWithNewOp<SubTensorOp>(
sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
ValueRange({}), sliceOp.start(), sliceOp.size(),
rewriter.getI64ArrayAttr(strides));
return success();
}
};
} // namespace
void mlir::tosa::populateTosaToStandardConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<ConstOpConverter>(context);
patterns->insert<ConstOpConverter, SliceOpConverter>(context);
}

View File

@ -32,7 +32,8 @@ public:
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
target.addLegalOp<ConstantOp>();
target.addIllegalOp<tosa::SliceOp>();
target.addLegalDialect<StandardOpsDialect>();
auto *op = getOperation();
mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),

View File

@ -8,3 +8,11 @@ func @const_test() -> (tensor<i32>) {
// CHECK: return [[C3]]
return %0 : tensor<i32>
}
// ----
func @slice(%arg0: tensor<6xf32>) ->() {
// CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1]
%0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
return
}