[mlir][linalg][bufferize][NFC] Move helper function to op interface

This is in preparation of changing the op traversal during bufferization.

Differential Revision: https://reviews.llvm.org/D114040
This commit is contained in:
Matthias Springer 2021-11-23 11:20:27 +09:00
parent 8d0994ed21
commit 26c0dd83ab
6 changed files with 33 additions and 27 deletions

View File

@ -297,6 +297,11 @@ struct BufferizationState {
/// bufferization is necessary.
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
/// function returns immediately. Otherwise, it calls the `bufferize` interface
/// method of `BufferizableOpInterface`.
LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used
/// implement custom dialect-specific optimizations.

View File

@ -24,9 +24,6 @@ static constexpr int64_t kBufferAlignments = 128;
/// Return default allocation callbacks.
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
/// Bufferize one particular op.
LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
/// Register external models implemented for the `BufferizableOpInterface`.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
@ -390,6 +391,31 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
return operandBuffer;
}
LogicalResult
mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
BufferizationState &state) {
OpBuilder b(op->getContext());
// Skip BufferCast and TensorLoad ops.
if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
return success();
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
if (!hasTensorResult && !hasTensorOperand)
return success();
// Bufferize using `BufferizableOpInterface`.
b.setInsertionPoint(op);
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.bufferize(b, state);
// Other op with tensors. No bufferization method specified.
return op->emitError() << "unsupported op with tensors";
}
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//

View File

@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
LINK_LIBS PUBLIC
MLIRIR
MLIRMemRef
)
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl

View File

@ -927,30 +927,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
LogicalResult
mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
BufferizationState &state) {
OpBuilder b(op->getContext());
// Skip BufferCast and TensorLoad ops.
if (isa<memref::BufferCastOp, memref::TensorLoadOp>(op))
return success();
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
if (!hasTensorResult && !hasTensorOperand)
return success();
// Bufferize using `BufferizableOpInterface`.
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.bufferize(b, state);
// Other op with tensors. No bufferization method specified.
return op->emitError() << "unsupported op with tensors";
}
static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
BufferizationState &state) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");

View File

@ -6299,6 +6299,7 @@ cc_library(
deps = [
":BufferizableOpInterfaceIncGen",
":IR",
":MemRefDialect",
":Support",
"//llvm:Support",
],