[mlir][bufferize] Move arith BufferizableOpInterface impl to arith dialect

Also switch the implementation of `-arith-bufferize` to BufferizableOpInterface.

Differential Revision: https://reviews.llvm.org/D118325
This commit is contained in:
Matthias Springer 2022-01-28 01:11:22 +09:00
parent ccce1a03c9
commit 075e3fdda1
15 changed files with 62 additions and 147 deletions

View File

@ -0,0 +1,21 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace arith {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H

View File

@ -12,17 +12,8 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // namespace bufferization
namespace arith { namespace arith {
/// Add patterns to bufferize Arithmetic ops.
void populateArithmeticBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Create a pass to bufferize Arithmetic ops. /// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass(); std::unique_ptr<Pass> createArithmeticBufferizePass();

View File

@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
def ArithmeticBufferize : Pass<"arith-bufferize", "FuncOp"> { def ArithmeticBufferize : Pass<"arith-bufferize", "FuncOp"> {
let summary = "Bufferize Arithmetic dialect ops."; let summary = "Bufferize Arithmetic dialect ops.";
let constructor = "mlir::arith::createArithmeticBufferizePass()"; let constructor = "mlir::arith::createArithmeticBufferizePass()";
let dependentDialects = ["bufferization::BufferizationDialect",
"memref::MemRefDialect"];
} }
def ArithmeticExpandOps : Pass<"arith-expand", "FuncOp"> { def ArithmeticExpandOps : Pass<"arith-expand", "FuncOp"> {

View File

@ -1,27 +0,0 @@
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H

View File

@ -1,4 +1,4 @@
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===// //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
// //
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information. // See https://llvm.org/LICENSE.txt for license information.
@ -6,8 +6,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
@ -18,9 +17,8 @@
using namespace mlir::bufferization; using namespace mlir::bufferization;
namespace mlir { namespace mlir {
namespace linalg { namespace arith {
namespace comprehensive_bufferize { namespace {
namespace arith_ext {
/// Bufferization of arith.constant. Replace with memref.get_global. /// Bufferization of arith.constant. Replace with memref.get_global.
struct ConstantOpInterface struct ConstantOpInterface
@ -100,14 +98,13 @@ struct IndexCastOpInterface
return success(); return success();
} }
}; };
} // namespace arith_ext
} // namespace comprehensive_bufferize } // namespace
} // namespace linalg } // namespace arith
} // namespace mlir } // namespace mlir
void mlir::linalg::comprehensive_bufferize::arith_ext:: void mlir::arith::registerBufferizableOpInterfaceExternalModels(
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) { DialectRegistry &registry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>(); registry.addOpInterface<ConstantOp, ConstantOpInterface>();
registry registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
.addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
} }

View File

@ -8,61 +8,37 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir; using namespace mlir;
using namespace bufferization;
namespace { namespace {
/// Bufferize arith.index_cast.
struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tensorType = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, adaptor.getIn(),
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
return success();
}
};
/// Pass to bufferize Arithmetic ops. /// Pass to bufferize Arithmetic ops.
struct ArithmeticBufferizePass struct ArithmeticBufferizePass
: public ArithmeticBufferizeBase<ArithmeticBufferizePass> { : public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
void runOnOperation() override { void runOnOperation() override {
bufferization::BufferizeTypeConverter typeConverter; std::unique_ptr<BufferizationOptions> options =
RewritePatternSet patterns(&getContext()); getPartialBufferizationOptions();
ConversionTarget target(getContext()); options->addToDialectFilter<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>(); if (failed(bufferizeOp(getOperation(), *options)))
arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
target.addDynamicallyLegalOp<arith::IndexCastOp>(
[&](arith::IndexCastOp op) {
return typeConverter.isLegal(op.getType());
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure(); signalPassFailure();
} }
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
arith::ArithmeticDialect>();
arith::registerBufferizableOpInterfaceExternalModels(registry);
}
}; };
} // namespace } // namespace
void mlir::arith::populateArithmeticBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
}
std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() { std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
return std::make_unique<ArithmeticBufferizePass>(); return std::make_unique<ArithmeticBufferizePass>();
} }

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRArithmeticTransforms add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp Bufferize.cpp
ExpandOps.cpp ExpandOps.cpp
@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRArithmetic MLIRArithmetic
MLIRBufferization
MLIRBufferizationTransforms MLIRBufferizationTransforms
MLIRIR MLIRIR
MLIRMemRef MLIRMemRef

View File

@ -1,6 +1,5 @@
set(LLVM_OPTIONAL_SOURCES set(LLVM_OPTIONAL_SOURCES
AffineInterfaceImpl.cpp AffineInterfaceImpl.cpp
ArithInterfaceImpl.cpp
LinalgInterfaceImpl.cpp LinalgInterfaceImpl.cpp
ModuleBufferization.cpp ModuleBufferization.cpp
SCFInterfaceImpl.cpp SCFInterfaceImpl.cpp
@ -16,17 +15,6 @@ add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
MLIRBufferization MLIRBufferization
) )
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
ArithInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferization
MLIRIR
MLIRMemRef
MLIRStandardOpsTransforms
)
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
LinalgInterfaceImpl.cpp LinalgInterfaceImpl.cpp

View File

@ -34,8 +34,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAffineBufferizableOpInterfaceImpl MLIRAffineBufferizableOpInterfaceImpl
MLIRAffineUtils MLIRAffineUtils
MLIRAnalysis MLIRAnalysis
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic MLIRArithmetic
MLIRArithmeticTransforms
MLIRBufferization MLIRBufferization
MLIRComplex MLIRComplex
MLIRInferTypeOpInterface MLIRInferTypeOpInterface

View File

@ -8,11 +8,11 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
@ -52,7 +52,7 @@ struct LinalgComprehensiveModuleBufferize
vector::VectorDialect, scf::SCFDialect, vector::VectorDialect, scf::SCFDialect,
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry); affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry);

View File

@ -96,19 +96,3 @@ func @rank_reducing(
} }
return %5: tensor<?x1x6x8xf32> return %5: tensor<?x1x6x8xf32>
} }
// -----
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-LABEL: func @index_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
%index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
// CHECK: return %[[INDEX_TENSOR]]

View File

@ -14,8 +14,8 @@ add_mlir_library(MLIRLinalgTestPasses
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRAffine MLIRAffine
MLIRAffineBufferizableOpInterfaceImpl MLIRAffineBufferizableOpInterfaceImpl
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic MLIRArithmetic
MLIRArithmeticTransforms
MLIRBufferization MLIRBufferization
MLIRBufferizationTransforms MLIRBufferizationTransforms
MLIRGPUTransforms MLIRGPUTransforms

View File

@ -12,11 +12,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
@ -59,7 +59,7 @@ struct TestComprehensiveFunctionBufferize
vector::VectorDialect, scf::SCFDialect, StandardOpsDialect, vector::VectorDialect, scf::SCFDialect, StandardOpsDialect,
arith::ArithmeticDialect, AffineDialect>(); arith::ArithmeticDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry); affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry);

View File

@ -6580,27 +6580,6 @@ cc_library(
], ],
) )
cc_library(
name = "ArithBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
],
includes = ["include"],
deps = [
":ArithmeticDialect",
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":MemRefDialect",
":Support",
":TransformUtils",
"//llvm:Support",
],
)
cc_library( cc_library(
name = "LinalgBufferizableOpInterfaceImpl", name = "LinalgBufferizableOpInterfaceImpl",
srcs = [ srcs = [
@ -6876,8 +6855,8 @@ cc_library(
":AffineBufferizableOpInterfaceImpl", ":AffineBufferizableOpInterfaceImpl",
":AffineUtils", ":AffineUtils",
":Analysis", ":Analysis",
":ArithBufferizableOpInterfaceImpl",
":ArithmeticDialect", ":ArithmeticDialect",
":ArithmeticTransforms",
":BufferizationDialect", ":BufferizationDialect",
":BufferizationTransforms", ":BufferizationTransforms",
":ComplexDialect", ":ComplexDialect",
@ -7566,7 +7545,10 @@ cc_library(
"lib/Dialect/Arithmetic/Transforms/*.cpp", "lib/Dialect/Arithmetic/Transforms/*.cpp",
"lib/Dialect/Arithmetic/Transforms/*.h", "lib/Dialect/Arithmetic/Transforms/*.h",
]), ]),
hdrs = ["include/mlir/Dialect/Arithmetic/Transforms/Passes.h"], hdrs = [
"include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/Arithmetic/Transforms/Passes.h",
],
includes = ["include"], includes = ["include"],
deps = [ deps = [
":ArithmeticDialect", ":ArithmeticDialect",
@ -7577,7 +7559,10 @@ cc_library(
":MemRefDialect", ":MemRefDialect",
":Pass", ":Pass",
":StandardOps", ":StandardOps",
":Support",
":TransformUtils",
":Transforms", ":Transforms",
"//llvm:Support",
], ],
) )

View File

@ -389,8 +389,8 @@ cc_library(
"//llvm:Support", "//llvm:Support",
"//mlir:Affine", "//mlir:Affine",
"//mlir:AffineBufferizableOpInterfaceImpl", "//mlir:AffineBufferizableOpInterfaceImpl",
"//mlir:ArithBufferizableOpInterfaceImpl",
"//mlir:ArithmeticDialect", "//mlir:ArithmeticDialect",
"//mlir:ArithmeticTransforms",
"//mlir:BufferizationDialect", "//mlir:BufferizationDialect",
"//mlir:BufferizationTransforms", "//mlir:BufferizationTransforms",
"//mlir:GPUDialect", "//mlir:GPUDialect",