[mlir][Linalg] Add comprehensive bufferization support for linalg::InitTensor and tensor::CastOp (11/n)

Also add an integration test that connects all the dots end to end, including with cast to unranked tensor for external library calls.

Differential Revision: https://reviews.llvm.org/D105106
This commit is contained in:
Nicolas Vasilache 2021-06-29 12:53:09 +00:00
parent 01b846674d
commit 231b9dd9de
3 changed files with 185 additions and 13 deletions

View File

@ -357,7 +357,9 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
return
// clang-format off
isa<CallOpInterface,
tensor::CastOp,
scf::ForOp,
InitTensorOp,
LinalgOp,
ReturnOp,
ExtractSliceOp,
@ -418,6 +420,14 @@ static OpResult getInplaceableOpResult(InsertSliceOp op, OpOperand &opOperand) {
return op->getResult(0);
}
/// Return the OpResult that may bufferize into the same buffer as `opOperand`
/// when the op is bufferized inplace.
/// Return null if no such result exists.
static OpResult getInplaceableOpResult(tensor::CastOp op,
OpOperand &opOperand) {
return op->getResult(0);
}
/// Return the OpResult that may bufferize into the same buffer as `opOperand`
/// when the op is bufferized inplace.
/// The inplace analysis uses this information along with interfering read
@ -428,7 +438,8 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
// clang-format off
// Ops that perform destructive updates on operand(s) to produce
// result(s).
.Case<scf::ForOp,
.Case<tensor::CastOp,
scf::ForOp,
LinalgOp,
InsertSliceOp,
VectorTransferOpInterface>(
@ -455,6 +466,7 @@ static Optional<OpOperand *> getAliasingOpOperand(OpResult result) {
if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp()))
return None;
return TypeSwitch<Operation *, OpOperand *>(result.getDefiningOp())
.Case([&](tensor::CastOp op) { return &op->getOpOperand(0); })
.Case([&](LinalgOp op) {
return op.getOutputTensorOperands()[result.getResultNumber()];
})
@ -1559,6 +1571,35 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
return success();
}
/// tensor::CastOp bufferizes to memref::CastOp.
static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
Type sourceType = lookup(bvm, castOp.source()).getType();
auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
assert(rankedMemRefType || unrankedMemRefType);
unsigned memorySpace = rankedMemRefType
? rankedMemRefType.getMemorySpaceAsInt()
: unrankedMemRefType.getMemorySpaceAsInt();
TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
ArrayRef<AffineMap> affineMaps =
rankedMemRefType && tensorType.isa<RankedTensorType>()
? rankedMemRefType.getAffineMaps()
: ArrayRef<AffineMap>{};
Type memRefType = getContiguousOrUnrankedMemRefType(
castOp.getResult().getType(), {}, memorySpace);
Value res = b.create<memref::CastOp>(castOp.getLoc(), memRefType,
lookup(bvm, castOp.source()));
aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
map(bvm, castOp.getResult(), res);
return success();
}
/// DimOp tensor operand is modified inplace. This allows leaving dead
/// tensors behind that will get DCE'd.
static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
@ -1635,6 +1676,21 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
return success();
}
/// InitTensor always allocates.
/// TODO: consider hoisting across function boundaries prior to bufferization.
static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
map(bvm, initTensorOp.result(), alloc);
return success();
}
/// ReturnOp always creates memref::TensorLoadOp.
static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
BlockAndValueMapping &bvm,
@ -2070,16 +2126,18 @@ static LogicalResult bufferizeFuncOpInternals(
// Since walk has to be PreOrder, we need to erase ops that require it
// separately: this is the case for CallOp
SmallVector<Operation *> toErase;
WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op)
-> WalkResult {
// clang-format off
WalkResult result =
funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
// clang-format off
WalkResult result =
TypeSwitch<Operation *, LogicalResult>(op)
// Skip BufferCast and TensorLoad ops.
.Case<memref::BufferCastOp,
memref::TensorLoadOp>([&](auto) { return success(); })
.Case<tensor::DimOp,
.Case<tensor::CastOp,
tensor::DimOp,
scf::ForOp,
InitTensorOp,
LinalgOp,
ReturnOp,
ExtractSliceOp,
@ -2100,16 +2158,16 @@ static LogicalResult bufferizeFuncOpInternals(
return failure();
return success();
});
// clang-format on
// clang-format on
// Register post-walk erasure, if necessary.
if (isa<CallOpInterface>(op))
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
llvm::any_of(op->getResultTypes(), isaTensor))
toErase.push_back(op);
// Register post-walk erasure, if necessary.
if (isa<CallOpInterface>(op))
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
llvm::any_of(op->getResultTypes(), isaTensor))
toErase.push_back(op);
return result;
});
return result;
});
LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
for (Operation *op : toErase)

View File

@ -58,3 +58,73 @@ func @bar(
// CHECK-NEXT: return
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
}
// -----
// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
// CHECK: func @init_and_dot(
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %[[C0:.*]] = constant 0{{.*}} : f32
%v0 = constant 0.0 : f32
// CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
%d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
// CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
%e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
outs(%d: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return
return %e : tensor<f32>
}
// CHECK: func @main()
func @main() {
// CHECK-DAG: %[[C0:.*]] = constant 0{{.*}} : f32
// CHECK-DAG: %[[C1:.*]] = constant 1{{.*}} : f32
// CHECK-DAG: %[[C2:.*]] = constant 2{{.*}} : f32
%v0 = constant 0.0 : f32
%v1 = constant 1.0 : f32
%v2 = constant 2.0 : f32
// CHECK-NEXT: %[[A:.*]] = memref.alloc() : memref<64xf32>
// CHECK-NEXT: %[[B:.*]] = memref.alloc() : memref<64xf32>
// CHECK-NEXT: %[[C:.*]] = memref.alloc() : memref<f32>
%A = linalg.init_tensor [64] : tensor<64xf32>
%B = linalg.init_tensor [64] : tensor<64xf32>
%C = linalg.init_tensor [] : tensor<f32>
// CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
// CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
// CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
%AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
%BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
%CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
// CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
// CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
// CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
%res = call @init_and_dot(%AA, %BB, %CC) :
(tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
%res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
// CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
// CHECK-DAG: memref.dealloc %[[A]] : memref<64xf32>
// CHECK-DAG: memref.dealloc %[[B]] : memref<64xf32>
// CHECK-DAG: memref.dealloc %[[C]] : memref<f32>
// CHECK-NEXT: return
return
}
// CHECK: func private @print_memref_f32(memref<*xf32>)
func private @print_memref_f32(tensor<*xf32>)

View File

@ -0,0 +1,44 @@
// RUN: mlir-opt %s -canonicalize -cse -linalg-comprehensive-module-bufferize |\
// RUN: mlir-opt -convert-vector-to-scf -lower-affine -convert-linalg-to-loops |\
// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\
// RUN: FileCheck %s
func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
%v0 = constant 0.0 : f32
%d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
%e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
outs(%d: tensor<f32>) -> tensor<f32>
return %e : tensor<f32>
}
func @main() {
%v0 = constant 0.0 : f32
%v1 = constant 1.0 : f32
%v2 = constant 2.0 : f32
%A = linalg.init_tensor [64] : tensor<64xf32>
%B = linalg.init_tensor [64] : tensor<64xf32>
%C = linalg.init_tensor [] : tensor<f32>
%AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
%BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
%CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
%res = call @init_and_dot(%AA, %BB, %CC) :
(tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
%res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
// CHECK: Unranked Memref base@ = {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data =
// CHECK-NEXT: [128]
call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
return
}
func private @print_memref_f32(tensor<*xf32>) attributes { llvm.emit_c_interface }