forked from OSchip/llvm-project
[mlir][bufferize] Move arith BufferizableOpInterface impl to arith dialect
Also switch the implementation of `-arith-bufferize` to BufferizableOpInterface. Differential Revision: https://reviews.llvm.org/D118325
This commit is contained in:
parent
ccce1a03c9
commit
075e3fdda1
|
@ -0,0 +1,21 @@
|
|||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace arith {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace arith
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
|
|
@ -12,17 +12,8 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace bufferization {
|
||||
class BufferizeTypeConverter;
|
||||
} // namespace bufferization
|
||||
|
||||
namespace arith {
|
||||
|
||||
/// Add patterns to bufferize Arithmetic ops.
|
||||
void populateArithmeticBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Create a pass to bufferize Arithmetic ops.
|
||||
std::unique_ptr<Pass> createArithmeticBufferizePass();
|
||||
|
||||
|
|
|
@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
|
|||
def ArithmeticBufferize : Pass<"arith-bufferize", "FuncOp"> {
|
||||
let summary = "Bufferize Arithmetic dialect ops.";
|
||||
let constructor = "mlir::arith::createArithmeticBufferizePass()";
|
||||
let dependentDialects = ["bufferization::BufferizationDialect",
|
||||
"memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def ArithmeticExpandOps : Pass<"arith-expand", "FuncOp"> {
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace arith_ext {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace arith_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H
|
|
@ -1,4 +1,4 @@
|
|||
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
|
||||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,8 +6,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
|
@ -18,9 +17,8 @@
|
|||
using namespace mlir::bufferization;
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace arith_ext {
|
||||
namespace arith {
|
||||
namespace {
|
||||
|
||||
/// Bufferization of arith.constant. Replace with memref.get_global.
|
||||
struct ConstantOpInterface
|
||||
|
@ -100,14 +98,13 @@ struct IndexCastOpInterface
|
|||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace arith_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
|
||||
} // namespace
|
||||
} // namespace arith
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::arith_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
|
||||
registry
|
||||
.addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
|
||||
void mlir::arith::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<ConstantOp, ConstantOpInterface>();
|
||||
registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
|
||||
}
|
|
@ -8,61 +8,37 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Bufferize arith.index_cast.
|
||||
struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
||||
op, adaptor.getIn(),
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pass to bufferize Arithmetic ops.
|
||||
struct ArithmeticBufferizePass
|
||||
: public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
bufferization::BufferizeTypeConverter typeConverter;
|
||||
RewritePatternSet patterns(&getContext());
|
||||
ConversionTarget target(getContext());
|
||||
std::unique_ptr<BufferizationOptions> options =
|
||||
getPartialBufferizationOptions();
|
||||
options->addToDialectFilter<arith::ArithmeticDialect>();
|
||||
|
||||
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
|
||||
|
||||
arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
|
||||
|
||||
target.addDynamicallyLegalOp<arith::IndexCastOp>(
|
||||
[&](arith::IndexCastOp op) {
|
||||
return typeConverter.isLegal(op.getType());
|
||||
});
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
if (failed(bufferizeOp(getOperation(), *options)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::arith::populateArithmeticBufferizePatterns(
|
||||
bufferization::BufferizeTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
|
||||
return std::make_unique<ArithmeticBufferizePass>();
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_dialect_library(MLIRArithmeticTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
Bufferize.cpp
|
||||
ExpandOps.cpp
|
||||
|
||||
|
@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
AffineInterfaceImpl.cpp
|
||||
ArithInterfaceImpl.cpp
|
||||
LinalgInterfaceImpl.cpp
|
||||
ModuleBufferization.cpp
|
||||
SCFInterfaceImpl.cpp
|
||||
|
@ -16,17 +15,6 @@ add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
|
|||
MLIRBufferization
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
|
||||
ArithInterfaceImpl.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRBufferization
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRStandardOpsTransforms
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
||||
LinalgInterfaceImpl.cpp
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRAffineBufferizableOpInterfaceImpl
|
||||
MLIRAffineUtils
|
||||
MLIRAnalysis
|
||||
MLIRArithBufferizableOpInterfaceImpl
|
||||
MLIRArithmetic
|
||||
MLIRArithmeticTransforms
|
||||
MLIRBufferization
|
||||
MLIRComplex
|
||||
MLIRInferTypeOpInterface
|
||||
|
|
|
@ -8,11 +8,11 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
|
||||
|
@ -52,7 +52,7 @@ struct LinalgComprehensiveModuleBufferize
|
|||
vector::VectorDialect, scf::SCFDialect,
|
||||
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
||||
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerModuleBufferizationExternalModels(registry);
|
||||
|
|
|
@ -96,19 +96,3 @@ func @rank_reducing(
|
|||
}
|
||||
return %5: tensor<?x1x6x8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
|
||||
// CHECK-LABEL: func @index_cast(
|
||||
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
|
||||
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
|
||||
%index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
|
||||
%index_scalar = arith.index_cast %scalar : i32 to index
|
||||
return %index_tensor, %index_scalar : tensor<index>, index
|
||||
}
|
||||
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
|
||||
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
|
||||
// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
|
||||
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
|
||||
// CHECK: return %[[INDEX_TENSOR]]
|
||||
|
|
|
@ -14,8 +14,8 @@ add_mlir_library(MLIRLinalgTestPasses
|
|||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRAffineBufferizableOpInterfaceImpl
|
||||
MLIRArithBufferizableOpInterfaceImpl
|
||||
MLIRArithmetic
|
||||
MLIRArithmeticTransforms
|
||||
MLIRBufferization
|
||||
MLIRBufferizationTransforms
|
||||
MLIRGPUTransforms
|
||||
|
|
|
@ -12,11 +12,11 @@
|
|||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
|
||||
|
@ -59,7 +59,7 @@ struct TestComprehensiveFunctionBufferize
|
|||
vector::VectorDialect, scf::SCFDialect, StandardOpsDialect,
|
||||
arith::ArithmeticDialect, AffineDialect>();
|
||||
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
|
|
@ -6580,27 +6580,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ArithBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
":BufferizationDialect",
|
||||
":BufferizationTransforms",
|
||||
":IR",
|
||||
":MemRefDialect",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "LinalgBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
|
@ -6876,8 +6855,8 @@ cc_library(
|
|||
":AffineBufferizableOpInterfaceImpl",
|
||||
":AffineUtils",
|
||||
":Analysis",
|
||||
":ArithBufferizableOpInterfaceImpl",
|
||||
":ArithmeticDialect",
|
||||
":ArithmeticTransforms",
|
||||
":BufferizationDialect",
|
||||
":BufferizationTransforms",
|
||||
":ComplexDialect",
|
||||
|
@ -7566,7 +7545,10 @@ cc_library(
|
|||
"lib/Dialect/Arithmetic/Transforms/*.cpp",
|
||||
"lib/Dialect/Arithmetic/Transforms/*.h",
|
||||
]),
|
||||
hdrs = ["include/mlir/Dialect/Arithmetic/Transforms/Passes.h"],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h",
|
||||
"include/mlir/Dialect/Arithmetic/Transforms/Passes.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
|
@ -7577,7 +7559,10 @@ cc_library(
|
|||
":MemRefDialect",
|
||||
":Pass",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
":Transforms",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -389,8 +389,8 @@ cc_library(
|
|||
"//llvm:Support",
|
||||
"//mlir:Affine",
|
||||
"//mlir:AffineBufferizableOpInterfaceImpl",
|
||||
"//mlir:ArithBufferizableOpInterfaceImpl",
|
||||
"//mlir:ArithmeticDialect",
|
||||
"//mlir:ArithmeticTransforms",
|
||||
"//mlir:BufferizationDialect",
|
||||
"//mlir:BufferizationTransforms",
|
||||
"//mlir:GPUDialect",
|
||||
|
|
Loading…
Reference in New Issue