forked from OSchip/llvm-project
[mlir][Linalg] Fix tensor.extract_slice(linalg.init_tensor) canonicalization for rank-reducing extract
Differential Revision: https://reviews.llvm.org/D105636
This commit is contained in:
parent
8c7ff9da90
commit
4747e1b83b
|
@ -772,11 +772,11 @@ struct FoldInitTensorWithExtractSliceOp
|
|||
PatternRewriter &rewriter) const override {
|
||||
if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
|
||||
return failure();
|
||||
// ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
|
||||
// as well as its result type.
|
||||
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
|
||||
sliceOp, sliceOp.sizes(),
|
||||
llvm::to_vector<4>(llvm::map_range(
|
||||
sliceOp.static_sizes(),
|
||||
[](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
|
||||
sliceOp.result().getType().cast<RankedTensorType>().getShape(),
|
||||
sliceOp.getSourceType().getElementType());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -890,3 +890,15 @@ func @init_canonicalize(%i : index) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rank_reducing_init_extract
|
||||
func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
|
||||
// CHECK: linalg.init_tensor [2] : tensor<2xf32>
|
||||
%a = linalg.init_tensor [%sz, 2] : tensor<?x2xf32>
|
||||
|
||||
// CHECK-NOT: extract
|
||||
%r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
|
||||
return %r: tensor<2xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue