forked from OSchip/llvm-project
[mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect
Also reimplement `std-bufferize` in terms of BufferizableOpInterface-based bufferization. The old `std.select` bufferization pattern is no longer needed and deleted. Differential Revision: https://reviews.llvm.org/D118559
This commit is contained in:
parent
8f12175fed
commit
e448c793c6
|
@ -1,27 +0,0 @@
|
|||
//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace std_ext {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace std_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
|
|
@ -0,0 +1,18 @@
|
|||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
|
|
@ -23,10 +23,6 @@ class BufferizeTypeConverter;
|
|||
|
||||
class RewritePatternSet;
|
||||
|
||||
void populateStdBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Creates an instance of std bufferization pass.
|
||||
std::unique_ptr<Pass> createStdBufferizePass();
|
||||
|
||||
|
|
|
@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
|
|||
def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
|
||||
let summary = "Bufferize the std dialect";
|
||||
let constructor = "mlir::createStdBufferizePass()";
|
||||
let dependentDialects = ["bufferization::BufferizationDialect",
|
||||
"memref::MemRefDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
|
||||
|
|
|
@ -25,14 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
|||
MLIRTensor
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
|
||||
StdInterfaceImpl.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBufferization
|
||||
MLIRStandard
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
|
||||
VectorInterfaceImpl.cpp
|
||||
|
||||
|
|
|
@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRSCF
|
||||
MLIRSCFTransforms
|
||||
MLIRSCFUtils
|
||||
MLIRStdBufferizableOpInterfaceImpl
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -56,7 +56,7 @@ struct LinalgComprehensiveModuleBufferize
|
|||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerModuleBufferizationExternalModels(registry);
|
||||
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
|
||||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,19 +6,18 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace std_ext {
|
||||
namespace {
|
||||
|
||||
/// Bufferization of std.select. Just replace the operands.
|
||||
struct SelectOpInterface
|
||||
|
@ -69,12 +68,10 @@ struct SelectOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace std_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::std_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
|
||||
void mlir::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<SelectOp, SelectOpInterface>();
|
||||
}
|
|
@ -12,64 +12,34 @@
|
|||
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(SelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.getCondition().getType().isa<IntegerType>())
|
||||
return rewriter.notifyMatchFailure(op, "requires scalar condition");
|
||||
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, adaptor.getCondition(),
|
||||
adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateStdBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeSelectOp>(typeConverter, patterns.getContext());
|
||||
}
|
||||
using namespace mlir::bufferization;
|
||||
|
||||
namespace {
|
||||
struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
auto *context = &getContext();
|
||||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
std::unique_ptr<BufferizationOptions> options =
|
||||
getPartialBufferizationOptions();
|
||||
options->addToDialectFilter<StandardOpsDialect>();
|
||||
|
||||
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
|
||||
memref::MemRefDialect>();
|
||||
|
||||
populateStdBufferizePatterns(typeConverter, patterns);
|
||||
// We only bufferize the case of tensor selected type and scalar condition,
|
||||
// as that boils down to a select over memref descriptors (don't need to
|
||||
// touch the data).
|
||||
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
|
||||
return typeConverter.isLegal(op.getType()) ||
|
||||
!op.getCondition().getType().isa<IntegerType>();
|
||||
});
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
if (failed(bufferizeOp(getOperation(), *options)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect, scf::SCFDialect>();
|
||||
mlir::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_dialect_library(MLIRStandardOpsTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
DecomposeCallGraphTypes.cpp
|
||||
FuncBufferize.cpp
|
||||
|
@ -13,6 +14,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRArithmeticTransforms
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
|
|
|
@ -27,8 +27,8 @@ add_mlir_library(MLIRLinalgTestPasses
|
|||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRSCFTransforms
|
||||
MLIRStdBufferizableOpInterfaceImpl
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRTensor
|
||||
MLIRTensorTransforms
|
||||
MLIRTransformUtils
|
||||
|
|
|
@ -18,12 +18,12 @@
|
|||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -62,7 +62,7 @@ struct TestComprehensiveFunctionBufferize
|
|||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
|
|
|
@ -6658,24 +6658,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "StdBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":BufferizationDialect",
|
||||
":IR",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "VectorBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
|
@ -6916,7 +6898,6 @@ cc_library(
|
|||
":SCFUtils",
|
||||
":StandardOps",
|
||||
":StandardOpsTransforms",
|
||||
":StdBufferizableOpInterfaceImpl",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":TensorTransforms",
|
||||
|
|
|
@ -403,7 +403,7 @@ cc_library(
|
|||
"//mlir:SCFDialect",
|
||||
"//mlir:SCFTransforms",
|
||||
"//mlir:StandardOps",
|
||||
"//mlir:StdBufferizableOpInterfaceImpl",
|
||||
"//mlir:StandardOpsTransforms",
|
||||
"//mlir:TensorDialect",
|
||||
"//mlir:TensorTransforms",
|
||||
"//mlir:TransformUtils",
|
||||
|
|
Loading…
Reference in New Issue