forked from OSchip/llvm-project
[mlir] Add alignment option to constant tensor bufferization pass
Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D111364
This commit is contained in:
parent
4b46a41343
commit
e2a37bb540
|
@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
|
|||
RewritePatternSet &patterns);
|
||||
|
||||
/// Creates an instance of tensor constant bufferization pass.
|
||||
std::unique_ptr<Pass> createTensorConstantBufferizePass();
|
||||
std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
|
||||
|
||||
/// Creates an instance of the StdExpand pass that legalizes Std
|
||||
/// dialect ops to be convertible to LLVM. For example,
|
||||
|
|
|
@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
|
|||
}];
|
||||
let constructor = "mlir::createTensorConstantBufferizePass()";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
let options = [
|
||||
Option<"alignment", "alignment", "unsigned", /*default=*/"0",
|
||||
"Create global memrefs with a specified alignment">,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
||||
|
|
|
@ -125,11 +125,13 @@ class GlobalOp;
|
|||
// names. Duplicates are avoided.
|
||||
class GlobalCreator {
|
||||
public:
|
||||
explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
|
||||
GlobalCreator(ModuleOp module, unsigned alignment = 0)
|
||||
: moduleOp(module), alignment(alignment) {}
|
||||
memref::GlobalOp getGlobalFor(ConstantOp constantOp);
|
||||
|
||||
private:
|
||||
ModuleOp moduleOp;
|
||||
unsigned alignment;
|
||||
// This could use memref::GlobalOp key but we avoid introducing a new
|
||||
// dependence to the memref dialect for this.
|
||||
DenseMap<Attribute, Operation *> globals;
|
||||
|
|
|
@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
|
|||
interleave(type.getShape(), os, "x");
|
||||
os << "x" << type.getElementType();
|
||||
|
||||
// Add an optional alignment to the global memref.
|
||||
IntegerAttr memrefAlignment =
|
||||
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
|
||||
: IntegerAttr();
|
||||
|
||||
auto global = globalBuilder.create<memref::GlobalOp>(
|
||||
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
|
||||
/*type=*/typeConverter.convertType(type).cast<MemRefType>(),
|
||||
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
|
||||
/*constant=*/true,
|
||||
/*alignment=*/IntegerAttr());
|
||||
/*alignment=*/memrefAlignment);
|
||||
symbolTable.insert(global);
|
||||
// The symbol table inserts at the end of the module, but globals are a bit
|
||||
// nicer if they are at the beginning.
|
||||
|
@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
|
|||
}
|
||||
|
||||
namespace {
|
||||
struct TensorConstantBufferizePass
|
||||
class TensorConstantBufferizePass
|
||||
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
|
||||
public:
|
||||
explicit TensorConstantBufferizePass(unsigned alignment) {
|
||||
if (alignment)
|
||||
this->alignment = alignment;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
GlobalCreator globals(module);
|
||||
GlobalCreator globals(module, alignment);
|
||||
|
||||
auto *context = &getContext();
|
||||
BufferizeTypeConverter typeConverter;
|
||||
|
@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
|
|||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
|
||||
return std::make_unique<TensorConstantBufferizePass>();
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createTensorConstantBufferizePass(unsigned alignment) {
|
||||
return std::make_unique<TensorConstantBufferizePass>(alignment);
|
||||
}
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
|
||||
// We check the debug name too since we put some effort into making that readable.
|
||||
// The name isn't load-bearing though.
|
||||
|
||||
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
|
||||
// CHECK-NOT: alignment
|
||||
|
||||
// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
|
||||
// ALIGNED-SAME: {alignment = 64 : i64}
|
||||
|
||||
// CHECK: @basic
|
||||
func @basic() -> tensor<3x4xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>
|
||||
|
|
Loading…
Reference in New Issue