forked from OSchip/llvm-project
[mlir][tosa] Add tosa.slice to std.subtensor lowering
Lowering to subtensor is added for tosa.slice operator. Differential Revision: https://reviews.llvm.org/D98825
This commit is contained in:
parent
d672d5219a
commit
f4bb076a44
|
@ -32,9 +32,28 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class SliceOpConverter : public OpRewritePattern<tosa::SliceOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Value input = sliceOp.input();
|
||||
SmallVector<int64_t> strides;
|
||||
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
|
||||
|
||||
rewriter.replaceOpWithNewOp<SubTensorOp>(
|
||||
sliceOp, sliceOp.getType(), input, ValueRange({}), ValueRange({}),
|
||||
ValueRange({}), sliceOp.start(), sliceOp.size(),
|
||||
rewriter.getI64ArrayAttr(strides));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaToStandardConversionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<ConstOpConverter>(context);
|
||||
patterns->insert<ConstOpConverter, SliceOpConverter>(context);
|
||||
}
|
||||
|
|
|
@ -32,7 +32,8 @@ public:
|
|||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addIllegalOp<tosa::ConstOp>();
|
||||
target.addLegalOp<ConstantOp>();
|
||||
target.addIllegalOp<tosa::SliceOp>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
|
||||
auto *op = getOperation();
|
||||
mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
|
||||
|
|
|
@ -8,3 +8,11 @@ func @const_test() -> (tensor<i32>) {
|
|||
// CHECK: return [[C3]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
func @slice(%arg0: tensor<6xf32>) ->() {
|
||||
// CHECK: [[SLICE:%.+]] = subtensor %arg0[2] [1] [1]
|
||||
%0 = "tosa.slice"(%arg0) {start = [2], size = [1]} : (tensor<6xf32>) -> (tensor<1xf32>)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue