forked from OSchip/llvm-project
[mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops
This change provides `BufferizableOpInterface` implementations for ops from the Bufferization dialects. These ops are needed at the bufferization boundaries for partial bufferization. Differential Revision: https://reviews.llvm.org/D114618
This commit is contained in:
parent
1c16b0db9d
commit
d30fcadf07
|
@ -0,0 +1,27 @@
|
|||
//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace bufferization_ext {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace bufferization_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
|
|
@ -416,10 +416,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
|
|||
BufferizationState &state) {
|
||||
OpBuilder b(op->getContext());
|
||||
|
||||
// Skip ToMemrefOp and ToTensorOp.
|
||||
if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
|
||||
return success();
|
||||
|
||||
// Check if op has tensor results or operands.
|
||||
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
|
||||
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace linalg;
|
||||
using namespace comprehensive_bufferize;
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace bufferization_ext {
|
||||
|
||||
// TODO: These ops should implement BufferizableOpInterface directly when moved
|
||||
// to the Bufferization dialect.
|
||||
|
||||
// TODO: These implementations are conservative and will likely have to be
|
||||
// loosened for partial bufferization.
|
||||
|
||||
/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
|
||||
/// location of the incoming tensor once it will be bufferized. In the anlysis,
|
||||
/// the incoming tensor is assumed to bufferize to a memory read and to an
|
||||
/// inplace memory write, since it is unknown what will happen to the resulting
|
||||
/// memref.
|
||||
struct ToMemrefOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
|
||||
bufferization::ToMemrefOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
|
||||
// It is unknown whether the resulting MemRef will be read or not.
|
||||
return true;
|
||||
}
|
||||
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
|
||||
return OpResult();
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
|
||||
/// not lower any further, and they should have disappeared by the time the
|
||||
/// input is fully bufferized.
|
||||
///
|
||||
/// The analysis has no information about the memref that is loaded from by the
|
||||
/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
|
||||
/// potentially alias with any other bufferized tensor. Since ToTensorOp and
|
||||
/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
|
||||
/// directly in the analysis. However, declaring ToTensorOp results as not
|
||||
/// writable also enforces a buffer copy and has the same effect.
|
||||
struct ToTensorOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
|
||||
bufferization::ToTensorOp> {
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
|
||||
state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
|
||||
return success();
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value) const {
|
||||
// It is unknown whether the MemRef operand is writable or not.
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace bufferization_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::bufferization_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<bufferization::ToMemrefOp,
|
||||
bufferization_ext::ToMemrefOpInterface>();
|
||||
registry.addOpInterface<bufferization::ToTensorOp,
|
||||
bufferization_ext::ToTensorOpInterface>();
|
||||
}
|
|
@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
|
|||
AffineInterfaceImpl.cpp
|
||||
ArithInterfaceImpl.cpp
|
||||
BufferizableOpInterface.cpp
|
||||
BufferizationInterfaceImpl.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
LinalgInterfaceImpl.cpp
|
||||
ModuleBufferization.cpp
|
||||
|
@ -80,6 +81,7 @@ add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
|
|||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRComprehensiveBufferize
|
||||
BufferizationInterfaceImpl.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
ModuleBufferization.cpp
|
||||
|
||||
|
|
|
@ -239,6 +239,12 @@ static std::string printValueInfo(Value value, bool prefix) {
|
|||
/// Return true if opOperand has been decided to bufferize in-place.
|
||||
static bool isInplaceMemoryWrite(OpOperand &opOperand,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
// The analysis does not know what happens to the result of a ToMemrefOp, so
|
||||
// we assume that it is written to.
|
||||
// TODO: This is a conservative implementation. This rule will have to be
|
||||
// relaxed for partial bufferization.
|
||||
if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
|
||||
return true;
|
||||
// OpOperands without an aliasing OpResult do not write.
|
||||
OpResult opResult = getAliasingOpResult(opOperand);
|
||||
if (!opResult)
|
||||
|
@ -453,14 +459,23 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
|||
/// If `checkConsistencyOnly` is true, this function checks if there is a
|
||||
/// read-after-write conflict without bufferizing `operand` inplace. This would
|
||||
/// indicate a problem with the current inplace bufferization decisions.
|
||||
///
|
||||
/// Note: If `checkConsistencyOnly`, this function may be called with a null
|
||||
/// OpResult. In that case, only the consistency of bufferization decisions
|
||||
/// involving aliases of the given OpOperand are checked.
|
||||
bool wouldCreateReadAfterWriteInterference(
|
||||
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
|
||||
const BufferizationAliasInfo &aliasInfo,
|
||||
bool checkConsistencyOnly = false) {
|
||||
#ifndef NDEBUG
|
||||
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
|
||||
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
|
||||
"operand and result do not match");
|
||||
if (result) {
|
||||
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
|
||||
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
|
||||
"operand and result do not match");
|
||||
} else {
|
||||
assert(checkConsistencyOnly &&
|
||||
"result not provided, can only check consistency");
|
||||
}
|
||||
#endif // NDEBUG
|
||||
|
||||
// Helper function to iterate on aliases of `root` and capture the reads.
|
||||
|
@ -486,9 +501,11 @@ bool wouldCreateReadAfterWriteInterference(
|
|||
// Collect reads and writes of all aliases of OpOperand and OpResult.
|
||||
DenseSet<OpOperand *> usesRead, usesWrite;
|
||||
getAliasingReads(usesRead, operand.get());
|
||||
getAliasingReads(usesRead, result);
|
||||
if (result)
|
||||
getAliasingReads(usesRead, result);
|
||||
getAliasingInplaceWrites(usesWrite, operand.get());
|
||||
getAliasingInplaceWrites(usesWrite, result);
|
||||
if (result)
|
||||
getAliasingInplaceWrites(usesWrite, result);
|
||||
if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
|
||||
usesWrite.insert(&operand);
|
||||
|
||||
|
@ -673,25 +690,38 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
|
|||
return res;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
/// Assert that the current bufferization decisions are consistent.
|
||||
static void checkAliasInfoConsistency(FuncOp funcOp,
|
||||
const DominanceInfo &domInfo,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
funcOp.walk([&](Operation *op) {
|
||||
static LogicalResult
|
||||
checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
|
||||
const BufferizationAliasInfo &aliasInfo) {
|
||||
Operation *inconsistentOp = nullptr;
|
||||
WalkResult walkResult = funcOp.walk([&](Operation *op) {
|
||||
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
|
||||
for (OpOperand &opOperand : op->getOpOperands())
|
||||
if (opOperand.get().getType().isa<TensorType>())
|
||||
if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
|
||||
// If this assertion fails, there is probably an inconsistent
|
||||
// combination of "mustBufferizeInPlace" decisions.
|
||||
assert(!wouldCreateReadAfterWriteInterference(
|
||||
opOperand, opResult, domInfo, aliasInfo,
|
||||
/*checkConsistencyOnly=*/true) &&
|
||||
"found read after write conflict before running analysis");
|
||||
if (opOperand.get().getType().isa<TensorType>()) {
|
||||
OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
|
||||
if (wouldCreateReadAfterWriteInterference(
|
||||
opOperand, opResult, domInfo, aliasInfo,
|
||||
/*checkConsistencyOnly=*/true)) {
|
||||
// This error can happen for two reasons. Either the input IR
|
||||
// already has a read-after-write conflict. Or certain
|
||||
// "mustBufferizeInPlace" interface methods are implemented
|
||||
// incorrectly.
|
||||
inconsistentOp = op;
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (walkResult.wasInterrupted())
|
||||
// This can currently happen in one situation: When a tensor is passed into
|
||||
// a ToMemrefOp and read by another op consecutively. ToMemrefOps are
|
||||
// currently handled conservatively. Once a tensor is passed into a
|
||||
// ToMemrefOp, it may longer be read.
|
||||
return inconsistentOp->emitError("input IR has RaW conflict");
|
||||
return success();
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Annotate the IR with the result of the analysis. For testing/debugging only.
|
||||
static void
|
||||
|
@ -720,9 +750,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
if (funcOp.body().empty())
|
||||
return success();
|
||||
|
||||
#ifndef NDEBUG
|
||||
checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
|
||||
#endif // NDEBUG
|
||||
if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
|
||||
return failure();
|
||||
|
||||
// If the analysis fails, just return.
|
||||
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
|
||||
|
@ -47,6 +48,7 @@ struct LinalgComprehensiveModuleBufferize
|
|||
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
||||
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
|
|
@ -1492,3 +1492,44 @@ func @main_func(%A : tensor<?xf32> {linalg.inplaceable = true},
|
|||
%0 = call @some_use(%A, %v) : (tensor<?xf32>, vector<5xf32>) -> (tensor<?xf32>)
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @to_tensor_op_not_writable
|
||||
func @to_tensor_op_not_writable(%m: memref<?xf32>, %v: vector<5xf32>,
|
||||
%idx1: index, %idx2: index)
|
||||
-> vector<10xf32> {
|
||||
%0 = bufferization.to_tensor %m : memref<?xf32>
|
||||
|
||||
// Write to the tensor. Cannot be inplace due to tensor_load.
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%w = vector.transfer_write %v, %0[%idx1] : vector<5xf32>, tensor<?xf32>
|
||||
|
||||
// Read from the tensor and return result.
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r = vector.transfer_read %w[%idx2], %cst : tensor<?xf32>, vector<10xf32>
|
||||
return %r : vector<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @to_memref_op_is_reading
|
||||
func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%idx1: index, %idx2: index, %idx3: index,
|
||||
%v1: vector<5xf32>)
|
||||
-> (vector<5xf32>, vector<5xf32>) {
|
||||
// Write + read to/from tensor.
|
||||
// CHECK: vector.transfer_write
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]
|
||||
%1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor<?xf32>
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r1 = vector.transfer_read %1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
|
||||
|
||||
// Write + read to/from same memref.
|
||||
%0 = bufferization.to_memref %t1 : memref<?xf32>
|
||||
vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref<?xf32>
|
||||
%r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
|
||||
|
||||
return %r1, %r2 : vector<5xf32>, vector<5xf32>
|
||||
}
|
||||
|
|
|
@ -167,3 +167,23 @@ func @main() -> tensor<4xi32> {
|
|||
}
|
||||
return %r: tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @to_memref_op_is_writing(
|
||||
%t1: tensor<?xf32> {linalg.inplaceable = true}, %idx1: index,
|
||||
%idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
|
||||
// This is a RaW conflict because to_memref is an inplace write and %t1 is
|
||||
// read further down. This will likely have to change with partial
|
||||
// bufferization.
|
||||
|
||||
// expected-error @+1 {{input IR has RaW conflict}}
|
||||
%0 = bufferization.to_memref %t1 : memref<?xf32>
|
||||
|
||||
// Read from both.
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
|
||||
%r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
|
||||
|
||||
return %r1, %r2 : vector<5xf32>, vector<5xf32>
|
||||
}
|
||||
|
|
|
@ -6673,10 +6673,12 @@ cc_library(
|
|||
cc_library(
|
||||
name = "ComprehensiveBufferize",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp",
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h",
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h",
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue