[mlir][linalg][bufferize][NFC] Move arith interface impl to new build target

This makes ComprehensiveBufferize entirely independent of the arith dialect.

Differential Revision: https://reviews.llvm.org/D114219
This commit is contained in:
Matthias Springer 2021-11-25 10:06:16 +09:00
parent 7bd87a03fd
commit d3bb4fec2a
7 changed files with 138 additions and 50 deletions

View File

@ -0,0 +1,27 @@
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H

View File

@ -0,0 +1,73 @@
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/BufferUtils.h"
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {};
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
if (!constantOp.getResult().getType().isa<TensorType>())
return success();
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp) {
return constantOp.emitError(
"cannot bufferize constants not within builtin.module op");
}
GlobalCreator globalCreator(moduleOp);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(constantOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
state.mapBuffer(constantOp, memref);
return success();
}
bool isWritable(Operation *op, Value value) const {
// Memory locations returned by memref::GetGlobalOp may not be written to.
assert(value.isa<OpResult>());
return false;
}
};
} // namespace arith_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
void mlir::linalg::comprehensive_bufferize::arith_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
}

View File

@ -1,4 +1,5 @@
set(LLVM_OPTIONAL_SOURCES
ArithInterfaceImpl.cpp
BufferizableOpInterface.cpp
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
@ -17,6 +18,17 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
MLIRMemRef
)
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
ArithInterfaceImpl.cpp
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferizableOpInterface
MLIRIR
MLIRMemRef
MLIRStandardOpsTransforms
)
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
LinalgInterfaceImpl.cpp

View File

@ -116,16 +116,17 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/BufferUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "comprehensive-module-bufferize"
@ -1287,52 +1288,6 @@ BufferizationOptions::BufferizationOptions()
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
OpResult opResult) const {
return {};
}
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
if (!isaTensor(constantOp.getResult().getType()))
return success();
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp) {
return constantOp.emitError(
"cannot bufferize constants not within builtin.module op");
}
GlobalCreator globalCreator(moduleOp);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(constantOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
Value memref = b.create<memref::GetGlobalOp>(
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
state.mapBuffer(constantOp, memref);
return success();
}
bool isWritable(Operation *op, Value value) const {
// Memory locations returned by memref::GetGlobalOp may not be written to.
assert(value.isa<OpResult>());
return false;
}
};
} // namespace arith_ext
namespace scf_ext {
struct ExecuteRegionOpInterface
@ -1813,7 +1768,6 @@ struct ReturnOpInterface
} // namespace std_ext
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
registry.addOpInterface<scf::ExecuteRegionOp,
scf_ext::ExecuteRegionOpInterface>();
registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();

View File

@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAffine
MLIRAffineUtils
MLIRAnalysis
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic
MLIRBufferizableOpInterface
MLIRComplex

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
@ -39,6 +40,7 @@ struct LinalgComprehensiveModuleBufferize
tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);

View File

@ -6306,6 +6306,26 @@ cc_library(
],
)
cc_library(
name = "ArithBufferizableOpInterfaceImpl",
srcs = [
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
],
hdrs = [
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
],
includes = ["include"],
deps = [
":ArithmeticDialect",
":BufferizableOpInterface",
":IR",
":MemRefDialect",
":Support",
":TransformUtils",
"//llvm:Support",
],
)
cc_library(
name = "LinalgBufferizableOpInterfaceImpl",
srcs = [
@ -6563,6 +6583,7 @@ cc_library(
":Affine",
":AffineUtils",
":Analysis",
":ArithBufferizableOpInterfaceImpl",
":ArithmeticDialect",
":BufferizableOpInterface",
":ComplexDialect",
@ -6604,7 +6625,6 @@ cc_library(
includes = ["include"],
deps = [
":Affine",
":ArithmeticDialect",
":BufferizableOpInterface",
":DialectUtils",
":IR",
@ -6614,7 +6634,6 @@ cc_library(
":SCFDialect",
":StandardOps",
":Support",
":TransformUtils",
"//llvm:Support",
],
)