[mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops

This change provides `BufferizableOpInterface` implementations for ops from the Bufferization dialects. These ops are needed at the bufferization boundaries for partial bufferization.

Differential Revision: https://reviews.llvm.org/D114618
This commit is contained in:
Matthias Springer 2021-12-03 16:25:08 +09:00
parent 1c16b0db9d
commit d30fcadf07
9 changed files with 246 additions and 26 deletions

View File

@ -0,0 +1,27 @@
//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace bufferization_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace bufferization_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H

View File

@ -416,10 +416,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
BufferizationState &state) {
OpBuilder b(op->getContext());
// Skip ToMemrefOp and ToTensorOp.
if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
return success();
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);

View File

@ -0,0 +1,101 @@
//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
using namespace linalg;
using namespace comprehensive_bufferize;
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace bufferization_ext {
// TODO: These ops should implement BufferizableOpInterface directly when moved
// to the Bufferization dialect.
// TODO: These implementations are conservative and will likely have to be
// loosened for partial bufferization.
/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
/// location of the incoming tensor once it will be bufferized. In the anlysis,
/// the incoming tensor is assumed to bufferize to a memory read and to an
/// inplace memory write, since it is unknown what will happen to the resulting
/// memref.
struct ToMemrefOpInterface
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
bufferization::ToMemrefOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
// It is unknown whether the resulting MemRef will be read or not.
return true;
}
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
return success();
}
};
/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
/// not lower any further, and they should have disappeared by the time the
/// input is fully bufferized.
///
/// The analysis has no information about the memref that is loaded from by the
/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
/// potentially alias with any other bufferized tensor. Since ToTensorOp and
/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
/// directly in the analysis. However, declaring ToTensorOp results as not
/// writable also enforces a buffer copy and has the same effect.
struct ToTensorOpInterface
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
bufferization::ToTensorOp> {
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {};
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
return success();
}
bool isWritable(Operation *op, Value value) const {
// It is unknown whether the MemRef operand is writable or not.
return false;
}
};
} // namespace bufferization_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::bufferization_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<bufferization::ToMemrefOp,
bufferization_ext::ToMemrefOpInterface>();
registry.addOpInterface<bufferization::ToTensorOp,
bufferization_ext::ToTensorOpInterface>();
}

View File

@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
AffineInterfaceImpl.cpp
ArithInterfaceImpl.cpp
BufferizableOpInterface.cpp
BufferizationInterfaceImpl.cpp
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
@ -80,6 +81,7 @@ add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
)
add_mlir_dialect_library(MLIRComprehensiveBufferize
BufferizationInterfaceImpl.cpp
ComprehensiveBufferize.cpp
ModuleBufferization.cpp

View File

@ -239,6 +239,12 @@ static std::string printValueInfo(Value value, bool prefix) {
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo) {
// The analysis does not know what happens to the result of a ToMemrefOp, so
// we assume that it is written to.
// TODO: This is a conservative implementation. This rule will have to be
// relaxed for partial bufferization.
if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
return true;
// OpOperands without an aliasing OpResult do not write.
OpResult opResult = getAliasingOpResult(opOperand);
if (!opResult)
@ -453,14 +459,23 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
/// If `checkConsistencyOnly` is true, this function checks if there is a
/// read-after-write conflict without bufferizing `operand` inplace. This would
/// indicate a problem with the current inplace bufferization decisions.
///
/// Note: If `checkConsistencyOnly`, this function may be called with a null
/// OpResult. In that case, only the consistency of bufferization decisions
/// involving aliases of the given OpOperand are checked.
bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
#ifndef NDEBUG
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
"operand and result do not match");
if (result) {
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
"operand and result do not match");
} else {
assert(checkConsistencyOnly &&
"result not provided, can only check consistency");
}
#endif // NDEBUG
// Helper function to iterate on aliases of `root` and capture the reads.
@ -486,9 +501,11 @@ bool wouldCreateReadAfterWriteInterference(
// Collect reads and writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesRead, usesWrite;
getAliasingReads(usesRead, operand.get());
getAliasingReads(usesRead, result);
if (result)
getAliasingReads(usesRead, result);
getAliasingInplaceWrites(usesWrite, operand.get());
getAliasingInplaceWrites(usesWrite, result);
if (result)
getAliasingInplaceWrites(usesWrite, result);
if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
@ -673,25 +690,38 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
return res;
}
#ifndef NDEBUG
/// Assert that the current bufferization decisions are consistent.
static void checkAliasInfoConsistency(FuncOp funcOp,
const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
funcOp.walk([&](Operation *op) {
static LogicalResult
checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo) {
Operation *inconsistentOp = nullptr;
WalkResult walkResult = funcOp.walk([&](Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>())
if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
// If this assertion fails, there is probably an inconsistent
// combination of "mustBufferizeInPlace" decisions.
assert(!wouldCreateReadAfterWriteInterference(
opOperand, opResult, domInfo, aliasInfo,
/*checkConsistencyOnly=*/true) &&
"found read after write conflict before running analysis");
if (opOperand.get().getType().isa<TensorType>()) {
OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
if (wouldCreateReadAfterWriteInterference(
opOperand, opResult, domInfo, aliasInfo,
/*checkConsistencyOnly=*/true)) {
// This error can happen for two reasons. Either the input IR
// already has a read-after-write conflict. Or certain
// "mustBufferizeInPlace" interface methods are implemented
// incorrectly.
inconsistentOp = op;
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
// This can currently happen in one situation: When a tensor is passed into
// a ToMemrefOp and read by another op consecutively. ToMemrefOps are
// currently handled conservatively. Once a tensor is passed into a
// ToMemrefOp, it may longer be read.
return inconsistentOp->emitError("input IR has RaW conflict");
return success();
}
#endif
/// Annotate the IR with the result of the analysis. For testing/debugging only.
static void
@ -720,9 +750,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (funcOp.body().empty())
return success();
#ifndef NDEBUG
checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
#endif // NDEBUG
if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
return failure();
// If the analysis fails, just return.
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
@ -47,6 +48,7 @@ struct LinalgComprehensiveModuleBufferize
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);

View File

@ -1492,3 +1492,44 @@ func @main_func(%A : tensor<?xf32> {linalg.inplaceable = true},
%0 = call @some_use(%A, %v) : (tensor<?xf32>, vector<5xf32>) -> (tensor<?xf32>)
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @to_tensor_op_not_writable
func @to_tensor_op_not_writable(%m: memref<?xf32>, %v: vector<5xf32>,
%idx1: index, %idx2: index)
-> vector<10xf32> {
%0 = bufferization.to_tensor %m : memref<?xf32>
// Write to the tensor. Cannot be inplace due to tensor_load.
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%w = vector.transfer_write %v, %0[%idx1] : vector<5xf32>, tensor<?xf32>
// Read from the tensor and return result.
%cst = arith.constant 0.0 : f32
%r = vector.transfer_read %w[%idx2], %cst : tensor<?xf32>, vector<10xf32>
return %r : vector<10xf32>
}
// -----
// CHECK-LABEL: func @to_memref_op_is_reading
func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
%idx1: index, %idx2: index, %idx3: index,
%v1: vector<5xf32>)
-> (vector<5xf32>, vector<5xf32>) {
// Write + read to/from tensor.
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
%1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor<?xf32>
%cst = arith.constant 0.0 : f32
%r1 = vector.transfer_read %1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
// Write + read to/from same memref.
%0 = bufferization.to_memref %t1 : memref<?xf32>
vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref<?xf32>
%r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
return %r1, %r2 : vector<5xf32>, vector<5xf32>
}

View File

@ -167,3 +167,23 @@ func @main() -> tensor<4xi32> {
}
return %r: tensor<4xi32>
}
// -----
func @to_memref_op_is_writing(
%t1: tensor<?xf32> {linalg.inplaceable = true}, %idx1: index,
%idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
// This is a RaW conflict because to_memref is an inplace write and %t1 is
// read further down. This will likely have to change with partial
// bufferization.
// expected-error @+1 {{input IR has RaW conflict}}
%0 = bufferization.to_memref %t1 : memref<?xf32>
// Read from both.
%cst = arith.constant 0.0 : f32
%r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
%r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
return %r1, %r2 : vector<5xf32>, vector<5xf32>
}

View File

@ -6673,10 +6673,12 @@ cc_library(
cc_library(
name = "ComprehensiveBufferize",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp",
"lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
"lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h",
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h",
],