From f68939d3d91c3e1b57fba5450fa9146c3dcf5fdc Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 15 Sep 2021 16:58:19 +0530 Subject: [PATCH] [MLIR] Tighten type constraint on memref.global op def Tighten the def of memref.global op to use the right kind of TypeAttr (of MemRefType). Differential Revision: https://reviews.llvm.org/D109822 --- .../mlir/Dialect/MemRef/IR/MemRefOps.td | 24 ++++++++++++------- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 2 +- .../Transforms/TensorConstantBufferize.cpp | 2 +- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index c6c3ac8e1266..e0cb3816efaf 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -18,6 +18,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/SymbolInterfaces.td" +/// A TypeAttr for memref types. +def MemRefTypeAttr + : TypeAttrBase<"::mlir::MemRefType", "memref type attribute"> { + let constBuilderCall = "::mlir::TypeAttr::get($0)"; +} + class MemRef_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; @@ -597,14 +603,14 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global", def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> { let summary = "declare or define a global memref variable"; let description = [{ - The `memref.global` operation declares or defines a named global variable. - The backing memory for the variable is allocated statically and is described - by the type of the variable (which should be a statically shaped memref - type). The operation is a declaration if no `inital_value` is specified, - else it is a definition. The `initial_value` can either be a unit attribute - to represent a definition of an uninitialized global variable, or an - elements attribute to represent the definition of a global variable with an - initial value. The global variable can also be marked constant using the + The `memref.global` operation declares or defines a named global memref + variable. The backing memory for the variable is allocated statically and is + described by the type of the variable (which should be a statically shaped + memref type). The operation is a declaration if no `inital_value` is + specified, else it is a definition. The `initial_value` can either be a unit + attribute to represent a definition of an uninitialized global variable, or + an elements attribute to represent the definition of a global variable with + an initial value. The global variable can also be marked constant using the `constant` unit attribute. Writing to such constant global variables is undefined. @@ -633,7 +639,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> { let arguments = (ins SymbolNameAttr:$sym_name, OptionalAttr:$sym_visibility, - TypeAttr:$type, + MemRefTypeAttr:$type, OptionalAttr:$initial_value, UnitAttr:$constant ); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index ea3c9943f4b4..ebca204ab848 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -434,7 +434,7 @@ struct GlobalMemrefOpLowering LogicalResult matchAndRewrite(memref::GlobalOp global, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - MemRefType type = global.type().cast(); + MemRefType type = global.type(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp index 518405aabb49..c916a73e16d9 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -46,7 +46,7 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) { auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), - /*type=*/typeConverter.convertType(type), + /*type=*/typeConverter.convertType(type).cast(), /*initial_value=*/constantOp.getValue().cast(), /*constant=*/true); symbolTable.insert(global);