forked from OSchip/llvm-project
[mlir] Add alignment option to constant tensor bufferization pass
Reviewed By: bkramer Differential Revision: https://reviews.llvm.org/D111364
This commit is contained in:
parent
4b46a41343
commit
e2a37bb540
|
@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
|
||||||
RewritePatternSet &patterns);
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
/// Creates an instance of tensor constant bufferization pass.
|
/// Creates an instance of tensor constant bufferization pass.
|
||||||
std::unique_ptr<Pass> createTensorConstantBufferizePass();
|
std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
|
||||||
|
|
||||||
/// Creates an instance of the StdExpand pass that legalizes Std
|
/// Creates an instance of the StdExpand pass that legalizes Std
|
||||||
/// dialect ops to be convertible to LLVM. For example,
|
/// dialect ops to be convertible to LLVM. For example,
|
||||||
|
|
|
@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::createTensorConstantBufferizePass()";
|
let constructor = "mlir::createTensorConstantBufferizePass()";
|
||||||
let dependentDialects = ["memref::MemRefDialect"];
|
let dependentDialects = ["memref::MemRefDialect"];
|
||||||
|
let options = [
|
||||||
|
Option<"alignment", "alignment", "unsigned", /*default=*/"0",
|
||||||
|
"Create global memrefs with a specified alignment">,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
|
||||||
|
|
|
@ -125,11 +125,13 @@ class GlobalOp;
|
||||||
// names. Duplicates are avoided.
|
// names. Duplicates are avoided.
|
||||||
class GlobalCreator {
|
class GlobalCreator {
|
||||||
public:
|
public:
|
||||||
explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
|
GlobalCreator(ModuleOp module, unsigned alignment = 0)
|
||||||
|
: moduleOp(module), alignment(alignment) {}
|
||||||
memref::GlobalOp getGlobalFor(ConstantOp constantOp);
|
memref::GlobalOp getGlobalFor(ConstantOp constantOp);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ModuleOp moduleOp;
|
ModuleOp moduleOp;
|
||||||
|
unsigned alignment;
|
||||||
// This could use memref::GlobalOp key but we avoid introducing a new
|
// This could use memref::GlobalOp key but we avoid introducing a new
|
||||||
// dependence to the memref dialect for this.
|
// dependence to the memref dialect for this.
|
||||||
DenseMap<Attribute, Operation *> globals;
|
DenseMap<Attribute, Operation *> globals;
|
||||||
|
|
|
@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
|
||||||
interleave(type.getShape(), os, "x");
|
interleave(type.getShape(), os, "x");
|
||||||
os << "x" << type.getElementType();
|
os << "x" << type.getElementType();
|
||||||
|
|
||||||
|
// Add an optional alignment to the global memref.
|
||||||
|
IntegerAttr memrefAlignment =
|
||||||
|
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
|
||||||
|
: IntegerAttr();
|
||||||
|
|
||||||
auto global = globalBuilder.create<memref::GlobalOp>(
|
auto global = globalBuilder.create<memref::GlobalOp>(
|
||||||
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
|
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||||
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
|
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
|
||||||
/*type=*/typeConverter.convertType(type).cast<MemRefType>(),
|
/*type=*/typeConverter.convertType(type).cast<MemRefType>(),
|
||||||
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
|
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
|
||||||
/*constant=*/true,
|
/*constant=*/true,
|
||||||
/*alignment=*/IntegerAttr());
|
/*alignment=*/memrefAlignment);
|
||||||
symbolTable.insert(global);
|
symbolTable.insert(global);
|
||||||
// The symbol table inserts at the end of the module, but globals are a bit
|
// The symbol table inserts at the end of the module, but globals are a bit
|
||||||
// nicer if they are at the beginning.
|
// nicer if they are at the beginning.
|
||||||
|
@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct TensorConstantBufferizePass
|
class TensorConstantBufferizePass
|
||||||
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
|
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
|
||||||
|
public:
|
||||||
|
explicit TensorConstantBufferizePass(unsigned alignment) {
|
||||||
|
if (alignment)
|
||||||
|
this->alignment = alignment;
|
||||||
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
GlobalCreator globals(module);
|
GlobalCreator globals(module, alignment);
|
||||||
|
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
BufferizeTypeConverter typeConverter;
|
BufferizeTypeConverter typeConverter;
|
||||||
|
@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
|
std::unique_ptr<Pass>
|
||||||
return std::make_unique<TensorConstantBufferizePass>();
|
mlir::createTensorConstantBufferizePass(unsigned alignment) {
|
||||||
|
return std::make_unique<TensorConstantBufferizePass>(alignment);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
|
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
|
||||||
|
// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
|
||||||
|
|
||||||
// CHECK-LABEL: module {
|
// CHECK-LABEL: module {
|
||||||
|
|
||||||
// We check the debug name too since we put some effort into making that readable.
|
// We check the debug name too since we put some effort into making that readable.
|
||||||
// The name isn't load-bearing though.
|
// The name isn't load-bearing though.
|
||||||
|
|
||||||
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
|
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
|
||||||
|
// CHECK-NOT: alignment
|
||||||
|
|
||||||
|
// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
|
||||||
|
// ALIGNED-SAME: {alignment = 64 : i64}
|
||||||
|
|
||||||
// CHECK: @basic
|
// CHECK: @basic
|
||||||
func @basic() -> tensor<3x4xf32> {
|
func @basic() -> tensor<3x4xf32> {
|
||||||
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>
|
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>
|
||||||
|
|
Loading…
Reference in New Issue