[mlir] Add support for tensor.extract to comprehensive bufferization

Differential Revision: https://reviews.llvm.org/D105870
This commit is contained in:
thomasraoux 2021-07-13 09:34:48 -07:00
parent 46e8970817
commit ae4cea38f1
2 changed files with 30 additions and 0 deletions

View File

@ -2115,6 +2115,22 @@ static LogicalResult bufferize(OpBuilder &b, linalg::YieldOp yieldOp,
return success();
llvm_unreachable("unexpected yieldOp");
}
/// Bufferization for tensor::ExtractOp just translate to memref.load, it only
/// reads the tensor.
static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(extractOp);
Location loc = extractOp.getLoc();
Value srcMemref = lookup(bvm, extractOp.tensor());
Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
extractOp.replaceAllUsesWith(l);
return success();
}
//===----------------------------------------------------------------------===//
// Bufferization analyses.
//===----------------------------------------------------------------------===//
@ -2310,6 +2326,7 @@ static LogicalResult bufferizeFuncOpInternals(
scf::ForOp,
InitTensorOp,
InsertSliceOp,
tensor::ExtractOp,
LinalgOp,
ReturnOp,
TiledLoopOp,

View File

@ -34,6 +34,19 @@ func @fill_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf
// -----
// CHECK-LABEL: func @tensor_extract(%{{.*}}: memref<?xf32, #{{.*}}>) -> f32 {
func @tensor_extract(%A : tensor<?xf32>) -> (f32) {
%c0 = constant 0 : index
// CHECK: %[[RES:.*]] = memref.load {{.*}} : memref<?xf32, #{{.*}}>
%0 = tensor.extract %A[%c0] : tensor<?xf32>
// CHECK: return %[[RES]] : f32
return %0 : f32
}
// -----
// CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
/// No linalg.inplaceable flag, must allocate.