[mlir] Add alignment option to constant tensor bufferization pass

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D111364
This commit is contained in:
Eugene Zhulenev 2021-10-07 15:52:17 -07:00
parent 4b46a41343
commit e2a37bb540
5 changed files with 33 additions and 7 deletions

View File

@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
RewritePatternSet &patterns);
/// Creates an instance of tensor constant bufferization pass.
std::unique_ptr<Pass> createTensorConstantBufferizePass();
std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
/// Creates an instance of the StdExpand pass that legalizes Std
/// dialect ops to be convertible to LLVM. For example,

View File

@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
}];
let constructor = "mlir::createTensorConstantBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
let options = [
Option<"alignment", "alignment", "unsigned", /*default=*/"0",
"Create global memrefs with a specified alignment">,
];
}
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

View File

@ -125,11 +125,13 @@ class GlobalOp;
// names. Duplicates are avoided.
class GlobalCreator {
public:
explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
GlobalCreator(ModuleOp module, unsigned alignment = 0)
: moduleOp(module), alignment(alignment) {}
memref::GlobalOp getGlobalFor(ConstantOp constantOp);
private:
ModuleOp moduleOp;
unsigned alignment;
// This could use memref::GlobalOp key but we avoid introducing a new
// dependence to the memref dialect for this.
DenseMap<Attribute, Operation *> globals;

View File

@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
// Add an optional alignment to the global memref.
IntegerAttr memrefAlignment =
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
: IntegerAttr();
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/typeConverter.convertType(type).cast<MemRefType>(),
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
/*constant=*/true,
/*alignment=*/IntegerAttr());
/*alignment=*/memrefAlignment);
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
}
namespace {
struct TensorConstantBufferizePass
class TensorConstantBufferizePass
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
public:
explicit TensorConstantBufferizePass(unsigned alignment) {
if (alignment)
this->alignment = alignment;
}
void runOnOperation() override {
auto module = getOperation();
GlobalCreator globals(module);
GlobalCreator globals(module, alignment);
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
};
} // namespace
std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
return std::make_unique<TensorConstantBufferizePass>();
std::unique_ptr<Pass>
mlir::createTensorConstantBufferizePass(unsigned alignment) {
return std::make_unique<TensorConstantBufferizePass>(alignment);
}

View File

@ -1,9 +1,17 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
// CHECK-LABEL: module {
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
// CHECK-NOT: alignment
// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
// ALIGNED-SAME: {alignment = 64 : i64}
// CHECK: @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>