[mlir] Add alignment option to constant tensor bufferization pass

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D111364
This commit is contained in:
Eugene Zhulenev 2021-10-07 15:52:17 -07:00
parent 4b46a41343
commit e2a37bb540
5 changed files with 33 additions and 7 deletions

View File

@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
RewritePatternSet &patterns); RewritePatternSet &patterns);
/// Creates an instance of tensor constant bufferization pass. /// Creates an instance of tensor constant bufferization pass.
std::unique_ptr<Pass> createTensorConstantBufferizePass(); std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
/// Creates an instance of the StdExpand pass that legalizes Std /// Creates an instance of the StdExpand pass that legalizes Std
/// dialect ops to be convertible to LLVM. For example, /// dialect ops to be convertible to LLVM. For example,

View File

@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
}]; }];
let constructor = "mlir::createTensorConstantBufferizePass()"; let constructor = "mlir::createTensorConstantBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"]; let dependentDialects = ["memref::MemRefDialect"];
let options = [
Option<"alignment", "alignment", "unsigned", /*default=*/"0",
"Create global memrefs with a specified alignment">,
];
} }
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

View File

@ -125,11 +125,13 @@ class GlobalOp;
// names. Duplicates are avoided. // names. Duplicates are avoided.
class GlobalCreator { class GlobalCreator {
public: public:
explicit GlobalCreator(ModuleOp module) : moduleOp(module) {} GlobalCreator(ModuleOp module, unsigned alignment = 0)
: moduleOp(module), alignment(alignment) {}
memref::GlobalOp getGlobalFor(ConstantOp constantOp); memref::GlobalOp getGlobalFor(ConstantOp constantOp);
private: private:
ModuleOp moduleOp; ModuleOp moduleOp;
unsigned alignment;
// This could use memref::GlobalOp key but we avoid introducing a new // This could use memref::GlobalOp key but we avoid introducing a new
// dependence to the memref dialect for this. // dependence to the memref dialect for this.
DenseMap<Attribute, Operation *> globals; DenseMap<Attribute, Operation *> globals;

View File

@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
interleave(type.getShape(), os, "x"); interleave(type.getShape(), os, "x");
os << "x" << type.getElementType(); os << "x" << type.getElementType();
// Add an optional alignment to the global memref.
IntegerAttr memrefAlignment =
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
: IntegerAttr();
auto global = globalBuilder.create<memref::GlobalOp>( auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"), /*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/typeConverter.convertType(type).cast<MemRefType>(), /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(), /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
/*constant=*/true, /*constant=*/true,
/*alignment=*/IntegerAttr()); /*alignment=*/memrefAlignment);
symbolTable.insert(global); symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit // The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning. // nicer if they are at the beginning.
@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
} }
namespace { namespace {
struct TensorConstantBufferizePass class TensorConstantBufferizePass
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> { : public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
public:
explicit TensorConstantBufferizePass(unsigned alignment) {
if (alignment)
this->alignment = alignment;
}
void runOnOperation() override { void runOnOperation() override {
auto module = getOperation(); auto module = getOperation();
GlobalCreator globals(module); GlobalCreator globals(module, alignment);
auto *context = &getContext(); auto *context = &getContext();
BufferizeTypeConverter typeConverter; BufferizeTypeConverter typeConverter;
@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
}; };
} // namespace } // namespace
std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() { std::unique_ptr<Pass>
return std::make_unique<TensorConstantBufferizePass>(); mlir::createTensorConstantBufferizePass(unsigned alignment) {
return std::make_unique<TensorConstantBufferizePass>(alignment);
} }

View File

@ -1,9 +1,17 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s // RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
// CHECK-LABEL: module { // CHECK-LABEL: module {
// We check the debug name too since we put some effort into making that readable. // We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though. // The name isn't load-bearing though.
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00> // CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
// CHECK-NOT: alignment
// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
// ALIGNED-SAME: {alignment = 64 : i64}
// CHECK: @basic // CHECK: @basic
func @basic() -> tensor<3x4xf32> { func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32> // CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>