From e448c793c66521ee48d0107c33b80a2ff1baaaaf Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 30 Jan 2022 22:02:22 +0900 Subject: [PATCH] [mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect Also reimplement `std-bufferize` in terms of BufferizableOpInterface-based bufferization. The old `std.select` bufferization pattern is no longer needed and deleted. Differential Revision: https://reviews.llvm.org/D118559 --- .../ComprehensiveBufferize/StdInterfaceImpl.h | 27 --------- .../Transforms/BufferizableOpInterfaceImpl.h | 18 ++++++ .../Dialect/StandardOps/Transforms/Passes.h | 4 -- .../Dialect/StandardOps/Transforms/Passes.td | 2 - .../ComprehensiveBufferize/CMakeLists.txt | 8 --- .../Dialect/Linalg/Transforms/CMakeLists.txt | 1 - .../Transforms/ComprehensiveBufferizePass.cpp | 4 +- .../BufferizableOpInterfaceImpl.cpp} | 19 +++---- .../StandardOps/Transforms/Bufferize.cpp | 56 +++++-------------- .../StandardOps/Transforms/CMakeLists.txt | 2 + mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 2 +- .../Linalg/TestComprehensiveBufferize.cpp | 4 +- .../llvm-project-overlay/mlir/BUILD.bazel | 19 ------- .../mlir/test/BUILD.bazel | 2 +- 14 files changed, 47 insertions(+), 121 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h create mode 100644 mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h rename mlir/lib/Dialect/{Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp => StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp} (83%) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h deleted file mode 100644 index ae3b3db23e64..000000000000 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H - -namespace mlir { - -class DialectRegistry; - -namespace linalg { -namespace comprehensive_bufferize { -namespace std_ext { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace std_ext -} // namespace comprehensive_bufferize -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000000..a85acbbb195a --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,18 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h index 52bbea000d1f..d6b8d2028e0e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -23,10 +23,6 @@ class BufferizeTypeConverter; class RewritePatternSet; -void populateStdBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); - /// Creates an instance of std bufferization pass. std::unique_ptr createStdBufferizePass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td index 6bd83938346e..3e08865c6f71 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td" def StdBufferize : Pass<"std-bufferize", "FuncOp"> { let summary = "Bufferize the std dialect"; let constructor = "mlir::createStdBufferizePass()"; - let dependentDialects = ["bufferization::BufferizationDialect", - "memref::MemRefDialect", "scf::SCFDialect"]; } def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt index 10d0f72ebcfe..5733f88c953a 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -25,14 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl MLIRTensor ) -add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl - StdInterfaceImpl.cpp - - LINK_LIBS PUBLIC - MLIRBufferization - MLIRStandard -) - add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl VectorInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 5a42474d49da..d1418c40f035 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRSCF MLIRSCFTransforms MLIRSCFUtils - MLIRStdBufferizableOpInterfaceImpl MLIRPass MLIRStandard MLIRStandardOpsTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp index 9d2e6a539a18..f809bf35dc6f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -15,10 +15,10 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -56,7 +56,7 @@ struct LinalgComprehensiveModuleBufferize linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); - std_ext::registerBufferizableOpInterfaceExternalModels(registry); + mlir::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp similarity index 83% rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp rename to mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp index 7941c979b09e..b89a5372a48b 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===// +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,19 +6,18 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir; using namespace mlir::bufferization; namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { -namespace std_ext { +namespace { /// Bufferization of std.select. Just replace the operands. struct SelectOpInterface @@ -69,12 +68,10 @@ struct SelectOpInterface } }; -} // namespace std_ext -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace } // namespace mlir -void mlir::linalg::comprehensive_bufferize::std_ext:: - registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); +void mlir::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index 9851784d4951..64f9d040a71c 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -12,64 +12,34 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/Transforms/DialectConversion.h" using namespace mlir; - -namespace { -class BufferizeSelectOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getCondition().getType().isa()) - return rewriter.notifyMatchFailure(op, "requires scalar condition"); - - rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - return success(); - } -}; -} // namespace - -void mlir::populateStdBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); -} +using namespace mlir::bufferization; namespace { struct StdBufferizePass : public StdBufferizeBase { void runOnOperation() override { - auto *context = &getContext(); - bufferization::BufferizeTypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); + std::unique_ptr options = + getPartialBufferizationOptions(); + options->addToDialectFilter(); - target.addLegalDialect(); - - populateStdBufferizePatterns(typeConverter, patterns); - // We only bufferize the case of tensor selected type and scalar condition, - // as that boils down to a select over memref descriptors (don't need to - // touch the data). - target.addDynamicallyLegalOp([&](SelectOp op) { - return typeConverter.isLegal(op.getType()) || - !op.getCondition().getType().isa(); - }); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), *options))) signalPassFailure(); } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + mlir::registerBufferizableOpInterfaceExternalModels(registry); + } }; } // namespace diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt index d5869ce207cf..7db425fdc361 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms + BufferizableOpInterfaceImpl.cpp Bufferize.cpp DecomposeCallGraphTypes.cpp FuncBufferize.cpp @@ -13,6 +14,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms LINK_LIBS PUBLIC MLIRAffine MLIRArithmeticTransforms + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRMemRef diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt index 7ad34948bceb..4fa607092180 100644 --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -27,8 +27,8 @@ add_mlir_library(MLIRLinalgTestPasses MLIRPass MLIRSCF MLIRSCFTransforms - MLIRStdBufferizableOpInterfaceImpl MLIRStandard + MLIRStandardOpsTransforms MLIRTensor MLIRTensorTransforms MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp index 5cde7cf2ac09..b074043b7a3a 100644 --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -18,12 +18,12 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/PassManager.h" @@ -62,7 +62,7 @@ struct TestComprehensiveFunctionBufferize arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); - std_ext::registerBufferizableOpInterfaceExternalModels(registry); + mlir::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 185998a0bf94..e786e82d4f5e 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6658,24 +6658,6 @@ cc_library( ], ) -cc_library( - name = "StdBufferizableOpInterfaceImpl", - srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h", - ], - includes = ["include"], - deps = [ - ":BufferizationDialect", - ":IR", - ":StandardOps", - ":Support", - "//llvm:Support", - ], -) - cc_library( name = "VectorBufferizableOpInterfaceImpl", srcs = [ @@ -6916,7 +6898,6 @@ cc_library( ":SCFUtils", ":StandardOps", ":StandardOpsTransforms", - ":StdBufferizableOpInterfaceImpl", ":Support", ":TensorDialect", ":TensorTransforms", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index d23ca654fd09..4292c258c051 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -403,7 +403,7 @@ cc_library( "//mlir:SCFDialect", "//mlir:SCFTransforms", "//mlir:StandardOps", - "//mlir:StdBufferizableOpInterfaceImpl", + "//mlir:StandardOpsTransforms", "//mlir:TensorDialect", "//mlir:TensorTransforms", "//mlir:TransformUtils",