[mlir][bufferize][NFC] Move scf BufferizableOpInterface impl to scf dialect

Differential Revision: https://reviews.llvm.org/D118557
This commit is contained in:
Matthias Springer 2022-01-30 21:53:02 +09:00
parent 7a9765e8a8
commit 19efe141f7
10 changed files with 69 additions and 106 deletions

View File

@ -1,4 +1,4 @@
//===- SCFInterfaceImpl.h - SCF Impl. of BufferizableOpInterface ----------===//
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,19 +6,15 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
#ifndef MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {
namespace scf {
/// Assert that yielded values of an scf.for op are aliasing their corresponding
/// bbArgs. This is required because the i-th OpResult of an scf.for op is
/// currently assumed to alias with the i-th iter_arg (in the absence of
@ -30,10 +26,7 @@ struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace scf_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace scf
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
#endif // MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H

View File

@ -2,7 +2,6 @@ set(LLVM_OPTIONAL_SOURCES
AffineInterfaceImpl.cpp
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
SCFInterfaceImpl.cpp
StdInterfaceImpl.cpp
VectorInterfaceImpl.cpp
)
@ -26,16 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
MLIRTensor
)
add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
SCFInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRBufferization
MLIRBufferizationTransforms
MLIRIR
MLIRSCF
)
add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
StdInterfaceImpl.cpp

View File

@ -47,7 +47,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRLinalgUtils
MLIRModuleBufferization
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
MLIRSCFTransforms
MLIRSCFUtils
MLIRStdBufferizableOpInterfaceImpl

View File

@ -15,10 +15,10 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -54,7 +54,7 @@ struct LinalgComprehensiveModuleBufferize
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
@ -132,7 +132,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
}
// Only certain scf.for ops are supported by the analysis.
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);

View File

@ -1,4 +1,4 @@
//===- SCFInterfaceImpl.cpp - SCF Impl. of BufferizableOpInterface --------===//
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -14,12 +15,13 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::scf;
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {
namespace scf {
namespace {
// bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
@ -384,42 +386,6 @@ struct ForOpInterface
}
};
LogicalResult
mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
LogicalResult status = success();
op->walk([&](scf::ForOp forOp) {
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
// Note: This is overly strict. We should check for aliasing bufferized
// values. But we don't have a "must-alias" analysis yet.
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to a buffer that is aliasing the matching"
<< " enclosing scf::for operand";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return status;
}
/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
/// this is for analysis only.
struct YieldOpInterface
@ -462,18 +428,51 @@ struct YieldOpInterface
}
};
} // namespace scf_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace
} // namespace scf
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::scf_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<scf::ExecuteRegionOp,
scf_ext::ExecuteRegionOpInterface>();
registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
registry.addOpInterface<scf::IfOp, scf_ext::IfOpInterface>();
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
registry.addOpInterface<scf::ParallelOp,
AllocationHoistingBarrierOnly<scf::ParallelOp>>();
LogicalResult mlir::scf::AssertScfForAliasingProperties::run(
Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
op->walk([&](scf::ForOp forOp) {
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
// Note: This is overly strict. We should check for aliasing bufferized
// values. But we don't have a "must-alias" analysis yet.
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to a buffer that is aliasing the matching"
<< " enclosing scf::for operand";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return status;
}
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
registry.addOpInterface<ForOp, ForOpInterface>();
registry.addOpInterface<IfOp, IfOpInterface>();
registry.addOpInterface<YieldOp, YieldOpInterface>();
registry
.addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>();
}

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRSCFTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
@ -20,6 +21,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
MLIRAffine
MLIRAffineAnalysis
MLIRArithmetic
MLIRBufferization
MLIRBufferizationTransforms
MLIRDialectUtils
MLIRIR

View File

@ -26,7 +26,7 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRMemRef
MLIRPass
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
MLIRSCFTransforms
MLIRStdBufferizableOpInterfaceImpl
MLIRStandard
MLIRTensor

View File

@ -18,11 +18,11 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@ -61,7 +61,7 @@ struct TestComprehensiveFunctionBufferize
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
@ -106,7 +106,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
auto options = std::make_unique<AnalysisBufferizationOptions>();
if (!allowReturnMemref)
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
options->allowReturnMemref = allowReturnMemref;
options->allowUnknownOps = allowUnknownOps;

View File

@ -1782,6 +1782,7 @@ cc_library(
"lib/Dialect/SCF/Transforms/*.h",
]),
hdrs = [
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/SCF/Passes.h",
"include/mlir/Dialect/SCF/Transforms.h",
],
@ -2435,6 +2436,7 @@ cc_library(
"include/mlir/Dialect/SCF/*.h",
],
exclude = [
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/SCF/Transforms.h",
],
),
@ -6656,25 +6658,6 @@ cc_library(
],
)
cc_library(
name = "SCFBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h",
],
includes = ["include"],
deps = [
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":SCFDialect",
":Support",
"//llvm:Support",
],
)
cc_library(
name = "StdBufferizableOpInterfaceImpl",
srcs = [
@ -6928,7 +6911,6 @@ cc_library(
":MemRefDialect",
":ModuleBufferization",
":Pass",
":SCFBufferizableOpInterfaceImpl",
":SCFDialect",
":SCFTransforms",
":SCFUtils",

View File

@ -400,7 +400,6 @@ cc_library(
"//mlir:LinalgTransforms",
"//mlir:MemRefDialect",
"//mlir:Pass",
"//mlir:SCFBufferizableOpInterfaceImpl",
"//mlir:SCFDialect",
"//mlir:SCFTransforms",
"//mlir:StandardOps",