[mlir][bufferize] Move arith BufferizableOpInterface impl to arith dialect

Also switch the implementation of `-arith-bufferize` to BufferizableOpInterface.

Differential Revision: https://reviews.llvm.org/D118325
This commit is contained in:
Matthias Springer 2022-01-28 01:11:22 +09:00
parent ccce1a03c9
commit 075e3fdda1
15 changed files with 62 additions and 147 deletions

View File

@ -0,0 +1,21 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace arith {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H

View File

@ -12,17 +12,8 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // namespace bufferization
namespace arith {
/// Add patterns to bufferize Arithmetic ops.
void populateArithmeticBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass();

View File

@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
def ArithmeticBufferize : Pass<"arith-bufferize", "FuncOp"> {
let summary = "Bufferize Arithmetic dialect ops.";
let constructor = "mlir::arith::createArithmeticBufferizePass()";
let dependentDialects = ["bufferization::BufferizationDialect",
"memref::MemRefDialect"];
}
def ArithmeticExpandOps : Pass<"arith-expand", "FuncOp"> {

View File

@ -1,27 +0,0 @@
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H

View File

@ -1,4 +1,4 @@
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,8 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
@ -18,9 +17,8 @@
using namespace mlir::bufferization;
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
namespace arith {
namespace {
/// Bufferization of arith.constant. Replace with memref.get_global.
struct ConstantOpInterface
@ -100,14 +98,13 @@ struct IndexCastOpInterface
return success();
}
};
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace
} // namespace arith
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::arith_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
registry
.addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
void mlir::arith::registerBufferizableOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addOpInterface<ConstantOp, ConstantOpInterface>();
registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
}

View File

@ -8,61 +8,37 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
using namespace bufferization;
namespace {
/// Bufferize arith.index_cast.
struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tensorType = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, adaptor.getIn(),
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
return success();
}
};
/// Pass to bufferize Arithmetic ops.
struct ArithmeticBufferizePass
: public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
void runOnOperation() override {
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
std::unique_ptr<BufferizationOptions> options =
getPartialBufferizationOptions();
options->addToDialectFilter<arith::ArithmeticDialect>();
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
target.addDynamicallyLegalOp<arith::IndexCastOp>(
[&](arith::IndexCastOp op) {
return typeConverter.isLegal(op.getType());
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
if (failed(bufferizeOp(getOperation(), *options)))
signalPassFailure();
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
arith::ArithmeticDialect>();
arith::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // namespace
void mlir::arith::populateArithmeticBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
}
std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
return std::make_unique<ArithmeticBufferizePass>();
}

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ExpandOps.cpp
@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferization
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef

View File

@ -1,6 +1,5 @@
set(LLVM_OPTIONAL_SOURCES
AffineInterfaceImpl.cpp
ArithInterfaceImpl.cpp
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
SCFInterfaceImpl.cpp
@ -16,17 +15,6 @@ add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
MLIRBufferization
)
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
ArithInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferization
MLIRIR
MLIRMemRef
MLIRStandardOpsTransforms
)
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
LinalgInterfaceImpl.cpp

View File

@ -34,8 +34,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAffineBufferizableOpInterfaceImpl
MLIRAffineUtils
MLIRAnalysis
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic
MLIRArithmeticTransforms
MLIRBufferization
MLIRComplex
MLIRInferTypeOpInterface

View File

@ -8,11 +8,11 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
@ -52,7 +52,7 @@ struct LinalgComprehensiveModuleBufferize
vector::VectorDialect, scf::SCFDialect,
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry);

View File

@ -96,19 +96,3 @@ func @rank_reducing(
}
return %5: tensor<?x1x6x8xf32>
}
// -----
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-LABEL: func @index_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
%index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
%index_scalar = arith.index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
// CHECK: return %[[INDEX_TENSOR]]

View File

@ -14,8 +14,8 @@ add_mlir_library(MLIRLinalgTestPasses
LINK_LIBS PUBLIC
MLIRAffine
MLIRAffineBufferizableOpInterfaceImpl
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic
MLIRArithmeticTransforms
MLIRBufferization
MLIRBufferizationTransforms
MLIRGPUTransforms

View File

@ -12,11 +12,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
@ -59,7 +59,7 @@ struct TestComprehensiveFunctionBufferize
vector::VectorDialect, scf::SCFDialect, StandardOpsDialect,
arith::ArithmeticDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);

View File

@ -6580,27 +6580,6 @@ cc_library(
],
)
cc_library(
name = "ArithBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
],
includes = ["include"],
deps = [
":ArithmeticDialect",
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":MemRefDialect",
":Support",
":TransformUtils",
"//llvm:Support",
],
)
cc_library(
name = "LinalgBufferizableOpInterfaceImpl",
srcs = [
@ -6876,8 +6855,8 @@ cc_library(
":AffineBufferizableOpInterfaceImpl",
":AffineUtils",
":Analysis",
":ArithBufferizableOpInterfaceImpl",
":ArithmeticDialect",
":ArithmeticTransforms",
":BufferizationDialect",
":BufferizationTransforms",
":ComplexDialect",
@ -7566,7 +7545,10 @@ cc_library(
"lib/Dialect/Arithmetic/Transforms/*.cpp",
"lib/Dialect/Arithmetic/Transforms/*.h",
]),
hdrs = ["include/mlir/Dialect/Arithmetic/Transforms/Passes.h"],
hdrs = [
"include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h",
"include/mlir/Dialect/Arithmetic/Transforms/Passes.h",
],
includes = ["include"],
deps = [
":ArithmeticDialect",
@ -7577,7 +7559,10 @@ cc_library(
":MemRefDialect",
":Pass",
":StandardOps",
":Support",
":TransformUtils",
":Transforms",
"//llvm:Support",
],
)

View File

@ -389,8 +389,8 @@ cc_library(
"//llvm:Support",
"//mlir:Affine",
"//mlir:AffineBufferizableOpInterfaceImpl",
"//mlir:ArithBufferizableOpInterfaceImpl",
"//mlir:ArithmeticDialect",
"//mlir:ArithmeticTransforms",
"//mlir:BufferizationDialect",
"//mlir:BufferizationTransforms",
"//mlir:GPUDialect",