forked from OSchip/llvm-project
[mlir][bufferize][NFC] Move scf BufferizableOpInterface impl to scf dialect
Differential Revision: https://reviews.llvm.org/D118557
This commit is contained in:
parent
7a9765e8a8
commit
19efe141f7
|
@ -1,4 +1,4 @@
|
|||
//===- SCFInterfaceImpl.h - SCF Impl. of BufferizableOpInterface ----------===//
|
||||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,19 +6,15 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||
#ifndef MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace scf_ext {
|
||||
|
||||
namespace scf {
|
||||
/// Assert that yielded values of an scf.for op are aliasing their corresponding
|
||||
/// bbArgs. This is required because the i-th OpResult of an scf.for op is
|
||||
/// currently assumed to alias with the i-th iter_arg (in the absence of
|
||||
|
@ -30,10 +26,7 @@ struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
|
|||
};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace scf_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||
#endif // MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
|
|
@ -2,7 +2,6 @@ set(LLVM_OPTIONAL_SOURCES
|
|||
AffineInterfaceImpl.cpp
|
||||
LinalgInterfaceImpl.cpp
|
||||
ModuleBufferization.cpp
|
||||
SCFInterfaceImpl.cpp
|
||||
StdInterfaceImpl.cpp
|
||||
VectorInterfaceImpl.cpp
|
||||
)
|
||||
|
@ -26,16 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
|||
MLIRTensor
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
|
||||
SCFInterfaceImpl.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
MLIRSCF
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
|
||||
StdInterfaceImpl.cpp
|
||||
|
||||
|
|
|
@ -47,7 +47,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRLinalgUtils
|
||||
MLIRModuleBufferization
|
||||
MLIRSCF
|
||||
MLIRSCFBufferizableOpInterfaceImpl
|
||||
MLIRSCFTransforms
|
||||
MLIRSCFUtils
|
||||
MLIRStdBufferizableOpInterfaceImpl
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -54,7 +54,7 @@ struct LinalgComprehensiveModuleBufferize
|
|||
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerModuleBufferizationExternalModels(registry);
|
||||
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
@ -132,7 +132,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|||
}
|
||||
|
||||
// Only certain scf.for ops are supported by the analysis.
|
||||
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
|
||||
options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
|
||||
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- SCFInterfaceImpl.cpp - SCF Impl. of BufferizableOpInterface --------===//
|
||||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,7 +6,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
|
||||
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -14,12 +15,13 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
using namespace mlir::scf;
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace scf_ext {
|
||||
namespace scf {
|
||||
namespace {
|
||||
|
||||
// bufferization.to_memref is not allowed to change the rank.
|
||||
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
|
||||
|
@ -384,42 +386,6 @@ struct ForOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
|
||||
run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
|
||||
LogicalResult status = success();
|
||||
|
||||
op->walk([&](scf::ForOp forOp) {
|
||||
auto yieldOp =
|
||||
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
// Note: This is overly strict. We should check for aliasing bufferized
|
||||
// values. But we don't have a "must-alias" analysis yet.
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
status =
|
||||
yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to a buffer that is aliasing the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
|
||||
/// this is for analysis only.
|
||||
struct YieldOpInterface
|
||||
|
@ -462,18 +428,51 @@ struct YieldOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace scf_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::scf_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<scf::ExecuteRegionOp,
|
||||
scf_ext::ExecuteRegionOpInterface>();
|
||||
registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
|
||||
registry.addOpInterface<scf::IfOp, scf_ext::IfOpInterface>();
|
||||
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
|
||||
registry.addOpInterface<scf::ParallelOp,
|
||||
AllocationHoistingBarrierOnly<scf::ParallelOp>>();
|
||||
LogicalResult mlir::scf::AssertScfForAliasingProperties::run(
|
||||
Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
LogicalResult status = success();
|
||||
|
||||
op->walk([&](scf::ForOp forOp) {
|
||||
auto yieldOp =
|
||||
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
// Note: This is overly strict. We should check for aliasing bufferized
|
||||
// values. But we don't have a "must-alias" analysis yet.
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
status =
|
||||
yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to a buffer that is aliasing the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
|
||||
registry.addOpInterface<ForOp, ForOpInterface>();
|
||||
registry.addOpInterface<IfOp, IfOpInterface>();
|
||||
registry.addOpInterface<YieldOp, YieldOpInterface>();
|
||||
registry
|
||||
.addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>();
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_dialect_library(MLIRSCFTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
ForToWhile.cpp
|
||||
LoopCanonicalization.cpp
|
||||
|
@ -20,6 +21,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
|
|||
MLIRAffine
|
||||
MLIRAffineAnalysis
|
||||
MLIRArithmetic
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
|
|
|
@ -26,7 +26,7 @@ add_mlir_library(MLIRLinalgTestPasses
|
|||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRSCFBufferizableOpInterfaceImpl
|
||||
MLIRSCFTransforms
|
||||
MLIRStdBufferizableOpInterfaceImpl
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
|
|
|
@ -18,11 +18,11 @@
|
|||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
|
@ -61,7 +61,7 @@ struct TestComprehensiveFunctionBufferize
|
|||
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
@ -106,7 +106,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
|
|||
auto options = std::make_unique<AnalysisBufferizationOptions>();
|
||||
|
||||
if (!allowReturnMemref)
|
||||
options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
|
||||
options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
|
||||
|
||||
options->allowReturnMemref = allowReturnMemref;
|
||||
options->allowUnknownOps = allowUnknownOps;
|
||||
|
|
|
@ -1782,6 +1782,7 @@ cc_library(
|
|||
"lib/Dialect/SCF/Transforms/*.h",
|
||||
]),
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
|
||||
"include/mlir/Dialect/SCF/Passes.h",
|
||||
"include/mlir/Dialect/SCF/Transforms.h",
|
||||
],
|
||||
|
@ -2435,6 +2436,7 @@ cc_library(
|
|||
"include/mlir/Dialect/SCF/*.h",
|
||||
],
|
||||
exclude = [
|
||||
"include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
|
||||
"include/mlir/Dialect/SCF/Transforms.h",
|
||||
],
|
||||
),
|
||||
|
@ -6656,25 +6658,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "SCFBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":BufferizationDialect",
|
||||
":BufferizationTransforms",
|
||||
":IR",
|
||||
":SCFDialect",
|
||||
":Support",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "StdBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
|
@ -6928,7 +6911,6 @@ cc_library(
|
|||
":MemRefDialect",
|
||||
":ModuleBufferization",
|
||||
":Pass",
|
||||
":SCFBufferizableOpInterfaceImpl",
|
||||
":SCFDialect",
|
||||
":SCFTransforms",
|
||||
":SCFUtils",
|
||||
|
|
|
@ -400,7 +400,6 @@ cc_library(
|
|||
"//mlir:LinalgTransforms",
|
||||
"//mlir:MemRefDialect",
|
||||
"//mlir:Pass",
|
||||
"//mlir:SCFBufferizableOpInterfaceImpl",
|
||||
"//mlir:SCFDialect",
|
||||
"//mlir:SCFTransforms",
|
||||
"//mlir:StandardOps",
|
||||
|
|
Loading…
Reference in New Issue