forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Move arith interface impl to new build target
This makes ComprehensiveBufferize entirely independent of the arith dialect. Differential Revision: https://reviews.llvm.org/D114219
This commit is contained in:
parent
7bd87a03fd
commit
d3bb4fec2a
|
@ -0,0 +1,27 @@
|
|||
//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace arith_ext {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
|
||||
} // namespace arith_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
|
|
@ -0,0 +1,73 @@
|
|||
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Transforms/BufferUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace arith_ext {
|
||||
|
||||
struct ConstantOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
||||
arith::ConstantOp> {
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto constantOp = cast<arith::ConstantOp>(op);
|
||||
if (!constantOp.getResult().getType().isa<TensorType>())
|
||||
return success();
|
||||
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
||||
"not a constant ranked tensor");
|
||||
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp) {
|
||||
return constantOp.emitError(
|
||||
"cannot bufferize constants not within builtin.module op");
|
||||
}
|
||||
GlobalCreator globalCreator(moduleOp);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(constantOp);
|
||||
|
||||
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
||||
Value memref = b.create<memref::GetGlobalOp>(
|
||||
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
||||
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
||||
state.mapBuffer(constantOp, memref);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value) const {
|
||||
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
||||
assert(value.isa<OpResult>());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace arith_ext
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::arith_ext::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
ArithInterfaceImpl.cpp
|
||||
BufferizableOpInterface.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
LinalgInterfaceImpl.cpp
|
||||
|
@ -17,6 +18,17 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
|
|||
MLIRMemRef
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
|
||||
ArithInterfaceImpl.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithmetic
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRStandardOpsTransforms
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
||||
LinalgInterfaceImpl.cpp
|
||||
|
||||
|
|
|
@ -116,16 +116,17 @@
|
|||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/BufferUtils.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#define DEBUG_TYPE "comprehensive-module-bufferize"
|
||||
|
@ -1287,52 +1288,6 @@ BufferizationOptions::BufferizationOptions()
|
|||
namespace mlir {
|
||||
namespace linalg {
|
||||
namespace comprehensive_bufferize {
|
||||
namespace arith_ext {
|
||||
|
||||
struct ConstantOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
||||
arith::ConstantOp> {
|
||||
SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
|
||||
OpResult opResult) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, OpBuilder &b,
|
||||
BufferizationState &state) const {
|
||||
auto constantOp = cast<arith::ConstantOp>(op);
|
||||
if (!isaTensor(constantOp.getResult().getType()))
|
||||
return success();
|
||||
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
|
||||
"not a constant ranked tensor");
|
||||
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp) {
|
||||
return constantOp.emitError(
|
||||
"cannot bufferize constants not within builtin.module op");
|
||||
}
|
||||
GlobalCreator globalCreator(moduleOp);
|
||||
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(constantOp);
|
||||
|
||||
auto globalMemref = globalCreator.getGlobalFor(constantOp);
|
||||
Value memref = b.create<memref::GetGlobalOp>(
|
||||
constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
|
||||
state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
|
||||
state.mapBuffer(constantOp, memref);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
bool isWritable(Operation *op, Value value) const {
|
||||
// Memory locations returned by memref::GetGlobalOp may not be written to.
|
||||
assert(value.isa<OpResult>());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace arith_ext
|
||||
|
||||
namespace scf_ext {
|
||||
|
||||
struct ExecuteRegionOpInterface
|
||||
|
@ -1813,7 +1768,6 @@ struct ReturnOpInterface
|
|||
} // namespace std_ext
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
|
||||
registry.addOpInterface<scf::ExecuteRegionOp,
|
||||
scf_ext::ExecuteRegionOpInterface>();
|
||||
registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
|
||||
|
|
|
@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
MLIRAffine
|
||||
MLIRAffineUtils
|
||||
MLIRAnalysis
|
||||
MLIRArithBufferizableOpInterfaceImpl
|
||||
MLIRArithmetic
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRComplex
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||
|
@ -39,6 +40,7 @@ struct LinalgComprehensiveModuleBufferize
|
|||
tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
|
||||
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
||||
registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
|
|
@ -6306,6 +6306,26 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ArithBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
":BufferizableOpInterface",
|
||||
":IR",
|
||||
":MemRefDialect",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "LinalgBufferizableOpInterfaceImpl",
|
||||
srcs = [
|
||||
|
@ -6563,6 +6583,7 @@ cc_library(
|
|||
":Affine",
|
||||
":AffineUtils",
|
||||
":Analysis",
|
||||
":ArithBufferizableOpInterfaceImpl",
|
||||
":ArithmeticDialect",
|
||||
":BufferizableOpInterface",
|
||||
":ComplexDialect",
|
||||
|
@ -6604,7 +6625,6 @@ cc_library(
|
|||
includes = ["include"],
|
||||
deps = [
|
||||
":Affine",
|
||||
":ArithmeticDialect",
|
||||
":BufferizableOpInterface",
|
||||
":DialectUtils",
|
||||
":IR",
|
||||
|
@ -6614,7 +6634,6 @@ cc_library(
|
|||
":SCFDialect",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TransformUtils",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue