[mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect

Also reimplement `std-bufferize` in terms of BufferizableOpInterface-based bufferization. The old `std.select` bufferization pattern is no longer needed and deleted.

Differential Revision: https://reviews.llvm.org/D118559
This commit is contained in:
Matthias Springer 2022-01-30 22:02:22 +09:00
parent 8f12175fed
commit e448c793c6
14 changed files with 47 additions and 121 deletions

View File

@ -1,27 +0,0 @@
//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H

View File

@ -0,0 +1,18 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H

View File

@ -23,10 +23,6 @@ class BufferizeTypeConverter;
class RewritePatternSet; class RewritePatternSet;
void populateStdBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Creates an instance of std bufferization pass. /// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdBufferizePass(); std::unique_ptr<Pass> createStdBufferizePass();

View File

@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
def StdBufferize : Pass<"std-bufferize", "FuncOp"> { def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
let summary = "Bufferize the std dialect"; let summary = "Bufferize the std dialect";
let constructor = "mlir::createStdBufferizePass()"; let constructor = "mlir::createStdBufferizePass()";
let dependentDialects = ["bufferization::BufferizationDialect",
"memref::MemRefDialect", "scf::SCFDialect"];
} }
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {

View File

@ -25,14 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
MLIRTensor MLIRTensor
) )
add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
StdInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRBufferization
MLIRStandard
)
add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
VectorInterfaceImpl.cpp VectorInterfaceImpl.cpp

View File

@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRSCF MLIRSCF
MLIRSCFTransforms MLIRSCFTransforms
MLIRSCFUtils MLIRSCFUtils
MLIRStdBufferizableOpInterfaceImpl
MLIRPass MLIRPass
MLIRStandard MLIRStandard
MLIRStandardOpsTransforms MLIRStandardOpsTransforms

View File

@ -15,10 +15,10 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
@ -56,7 +56,7 @@ struct LinalgComprehensiveModuleBufferize
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry); mlir::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
} }

View File

@ -1,4 +1,4 @@
//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===// //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -6,19 +6,18 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
using namespace mlir;
using namespace mlir::bufferization; using namespace mlir::bufferization;
namespace mlir { namespace mlir {
namespace linalg { namespace {
namespace comprehensive_bufferize {
namespace std_ext {
/// Bufferization of std.select. Just replace the operands. /// Bufferization of std.select. Just replace the operands.
struct SelectOpInterface struct SelectOpInterface
@ -69,12 +68,10 @@ struct SelectOpInterface
} }
}; };
} // namespace std_ext } // namespace
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir } // namespace mlir
void mlir::linalg::comprehensive_bufferize::std_ext:: void mlir::registerBufferizableOpInterfaceExternalModels(
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) { DialectRegistry &registry) {
registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>(); registry.addOpInterface<SelectOp, SelectOpInterface>();
} }

View File

@ -12,64 +12,34 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir; using namespace mlir;
using namespace mlir::bufferization;
namespace {
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getCondition().getType().isa<IntegerType>())
return rewriter.notifyMatchFailure(op, "requires scalar condition");
rewriter.replaceOpWithNewOp<SelectOp>(op, adaptor.getCondition(),
adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
} // namespace
void mlir::populateStdBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeSelectOp>(typeConverter, patterns.getContext());
}
namespace { namespace {
struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> { struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnOperation() override { void runOnOperation() override {
auto *context = &getContext(); std::unique_ptr<BufferizationOptions> options =
bufferization::BufferizeTypeConverter typeConverter; getPartialBufferizationOptions();
RewritePatternSet patterns(context); options->addToDialectFilter<StandardOpsDialect>();
ConversionTarget target(*context);
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect, if (failed(bufferizeOp(getOperation(), *options)))
memref::MemRefDialect>();
populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
return typeConverter.isLegal(op.getType()) ||
!op.getCondition().getType().isa<IntegerType>();
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure(); signalPassFailure();
} }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
StandardOpsDialect, scf::SCFDialect>();
mlir::registerBufferizableOpInterfaceExternalModels(registry);
}
}; };
} // namespace } // namespace

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms add_mlir_dialect_library(MLIRStandardOpsTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp Bufferize.cpp
DecomposeCallGraphTypes.cpp DecomposeCallGraphTypes.cpp
FuncBufferize.cpp FuncBufferize.cpp
@ -13,6 +14,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRAffine MLIRAffine
MLIRArithmeticTransforms MLIRArithmeticTransforms
MLIRBufferization
MLIRBufferizationTransforms MLIRBufferizationTransforms
MLIRIR MLIRIR
MLIRMemRef MLIRMemRef

View File

@ -27,8 +27,8 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRPass MLIRPass
MLIRSCF MLIRSCF
MLIRSCFTransforms MLIRSCFTransforms
MLIRStdBufferizableOpInterfaceImpl
MLIRStandard MLIRStandard
MLIRStandardOpsTransforms
MLIRTensor MLIRTensor
MLIRTensorTransforms MLIRTensorTransforms
MLIRTransformUtils MLIRTransformUtils

View File

@ -18,12 +18,12 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
@ -62,7 +62,7 @@ struct TestComprehensiveFunctionBufferize
arith::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry); mlir::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
} }

View File

@ -6658,24 +6658,6 @@ cc_library(
], ],
) )
cc_library(
name = "StdBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
],
includes = ["include"],
deps = [
":BufferizationDialect",
":IR",
":StandardOps",
":Support",
"//llvm:Support",
],
)
cc_library( cc_library(
name = "VectorBufferizableOpInterfaceImpl", name = "VectorBufferizableOpInterfaceImpl",
srcs = [ srcs = [
@ -6916,7 +6898,6 @@ cc_library(
":SCFUtils", ":SCFUtils",
":StandardOps", ":StandardOps",
":StandardOpsTransforms", ":StandardOpsTransforms",
":StdBufferizableOpInterfaceImpl",
":Support", ":Support",
":TensorDialect", ":TensorDialect",
":TensorTransforms", ":TensorTransforms",

View File

@ -403,7 +403,7 @@ cc_library(
"//mlir:SCFDialect", "//mlir:SCFDialect",
"//mlir:SCFTransforms", "//mlir:SCFTransforms",
"//mlir:StandardOps", "//mlir:StandardOps",
"//mlir:StdBufferizableOpInterfaceImpl", "//mlir:StandardOpsTransforms",
"//mlir:TensorDialect", "//mlir:TensorDialect",
"//mlir:TensorTransforms", "//mlir:TensorTransforms",
"//mlir:TransformUtils", "//mlir:TransformUtils",