[mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect

Also reimplement `std-bufferize` in terms of BufferizableOpInterface-based bufferization. The old `std.select` bufferization pattern is no longer needed and deleted.

Differential Revision: https://reviews.llvm.org/D118559
This commit is contained in:
Matthias Springer 2022-01-30 22:02:22 +09:00
parent 8f12175fed
commit e448c793c6
14 changed files with 47 additions and 121 deletions
mlir
include/mlir/Dialect
Linalg/ComprehensiveBufferize
StandardOps/Transforms
lib/Dialect
test/lib/Dialect/Linalg
utils/bazel/llvm-project-overlay/mlir

View File

@ -1,27 +0,0 @@
//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H

View File

@ -0,0 +1,18 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H

View File

@ -23,10 +23,6 @@ class BufferizeTypeConverter;
class RewritePatternSet;
void populateStdBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdBufferizePass();

View File

@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
let summary = "Bufferize the std dialect";
let constructor = "mlir::createStdBufferizePass()";
let dependentDialects = ["bufferization::BufferizationDialect",
"memref::MemRefDialect", "scf::SCFDialect"];
}
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {

View File

@ -25,14 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
MLIRTensor
)
add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
StdInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRBufferization
MLIRStandard
)
add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
VectorInterfaceImpl.cpp

View File

@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRSCF
MLIRSCFTransforms
MLIRSCFUtils
MLIRStdBufferizableOpInterfaceImpl
MLIRPass
MLIRStandard
MLIRStandardOpsTransforms

View File

@ -15,10 +15,10 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -56,7 +56,7 @@ struct LinalgComprehensiveModuleBufferize
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
mlir::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}

View File

@ -1,4 +1,4 @@
//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,19 +6,18 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
using namespace mlir::bufferization;
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
namespace {
/// Bufferization of std.select. Just replace the operands.
struct SelectOpInterface
@ -69,12 +68,10 @@ struct SelectOpInterface
}
};
} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::std_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
void mlir::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addOpInterface<SelectOp, SelectOpInterface>();
}

View File

@ -12,64 +12,34 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getCondition().getType().isa<IntegerType>())
return rewriter.notifyMatchFailure(op, "requires scalar condition");
rewriter.replaceOpWithNewOp<SelectOp>(op, adaptor.getCondition(),
adaptor.getTrueValue(),
adaptor.getFalseValue());
return success();
}
};
} // namespace
void mlir::populateStdBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeSelectOp>(typeConverter, patterns.getContext());
}
using namespace mlir::bufferization;
namespace {
struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnOperation() override {
auto *context = &getContext();
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
std::unique_ptr<BufferizationOptions> options =
getPartialBufferizationOptions();
options->addToDialectFilter<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
memref::MemRefDialect>();
populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
return typeConverter.isLegal(op.getType()) ||
!op.getCondition().getType().isa<IntegerType>();
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
if (failed(bufferizeOp(getOperation(), *options)))
signalPassFailure();
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
StandardOpsDialect, scf::SCFDialect>();
mlir::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // namespace

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
DecomposeCallGraphTypes.cpp
FuncBufferize.cpp
@ -13,6 +14,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
LINK_LIBS PUBLIC
MLIRAffine
MLIRArithmeticTransforms
MLIRBufferization
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef

View File

@ -27,8 +27,8 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRPass
MLIRSCF
MLIRSCFTransforms
MLIRStdBufferizableOpInterfaceImpl
MLIRStandard
MLIRStandardOpsTransforms
MLIRTensor
MLIRTensorTransforms
MLIRTransformUtils

View File

@ -18,12 +18,12 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/PassManager.h"
@ -62,7 +62,7 @@ struct TestComprehensiveFunctionBufferize
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
mlir::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}

View File

@ -6658,24 +6658,6 @@ cc_library(
],
)
cc_library(
name = "StdBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
],
includes = ["include"],
deps = [
":BufferizationDialect",
":IR",
":StandardOps",
":Support",
"//llvm:Support",
],
)
cc_library(
name = "VectorBufferizableOpInterfaceImpl",
srcs = [
@ -6916,7 +6898,6 @@ cc_library(
":SCFUtils",
":StandardOps",
":StandardOpsTransforms",
":StdBufferizableOpInterfaceImpl",
":Support",
":TensorDialect",
":TensorTransforms",

View File

@ -403,7 +403,7 @@ cc_library(
"//mlir:SCFDialect",
"//mlir:SCFTransforms",
"//mlir:StandardOps",
"//mlir:StdBufferizableOpInterfaceImpl",
"//mlir:StandardOpsTransforms",
"//mlir:TensorDialect",
"//mlir:TensorTransforms",
"//mlir:TransformUtils",