[MLIR] Introduce std.global_memref and std.get_global_memref operations.

- Add standard dialect operations to define global variables with memref types and to
  retrieve the memref for to a named global variable
- Extend unit tests to test verification for these operations.

Differential Revision: https://reviews.llvm.org/D90337
This commit is contained in:
Rahul Joshi 2020-11-02 11:21:29 -08:00
parent 934b27a9da
commit c254b0bb69
8 changed files with 336 additions and 15 deletions

View File

@ -2005,6 +2005,97 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
//===----------------------------------------------------------------------===//
// GlobalMemrefOp
//===----------------------------------------------------------------------===//
def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> {
let summary = "declare or define a global memref variable";
let description = [{
The `global_memref` operation declares or defines a named global variable.
The backing memory for the variable is allocated statically and is described
by the type of the variable (which should be a statically shaped memref
type). The operation is a declaration if no `inital_value` is specified,
else it is a definition. The `initial_value` can either be a unit attribute
to represent a definition of an uninitialized global variable, or an
elements attribute to represent the definition of a global variable with an
initial value. The global variable can also be marked constant using the
`constant` unit attribute. Writing to such constant global variables is
undefined.
The global variable can be accessed by using the `get_global_memref` to
retrieve the memref for the global variable. Note that the memref
for such global variable itself is immutable (i.e., get_global_memref for a
given global variable will always return the same memref descriptor).
Example:
```mlir
// Private variable with an initial value.
global_memref @x : memref<2xf32> { sym_visibility = "private",
initial_value = dense<0.0,2.0> : tensor<2xf32> }
// External variable.
global_memref @y : memref<4xi32> { sym_visibility = "public" }
// Uninitialized externally visible variable.
global_memref @z : memref<3xf16> { sym_visibility = "public",
initial_value }
```
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
OptionalAttr<StrAttr>:$sym_visibility,
TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value,
UnitAttr:$constant
);
let assemblyFormat = [{
($sym_visibility^)?
(`constant` $constant^)?
$sym_name `:`
custom<GlobalMemrefOpTypeAndInitialValue>($type, $initial_value)
attr-dict
}];
let extraClassDeclaration = [{
bool isExternal() { return !initial_value(); }
bool isUnitialized() {
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
}
}];
}
//===----------------------------------------------------------------------===//
// GetGlobalMemrefOp
//===----------------------------------------------------------------------===//
def GetGlobalMemrefOp : Std_Op<"get_global_memref",
[NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "get the memref pointing to a global variable";
let description = [{
The `get_global_memref` operation retrieves the memref pointing to a
named global variable. If the global variable is marked constant, writing
to the result memref (such as through a `std.store` operation) is
undefined.
Example:
```mlir
%x = get_global_memref @foo : memref<2xf32>
```
}];
let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs AnyStaticShapeMemRef:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
// `GetGlobalMemrefOp` is fully verified by its traits.
let verifier = ?;
}
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//

View File

@ -395,7 +395,7 @@ public:
// Parse any kind of attribute.
Attribute attr;
if (parseAttribute(attr))
if (parseAttribute(attr, type))
return failure();
// Check for the right kind of attribute.
@ -436,6 +436,10 @@ public:
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified

View File

@ -245,6 +245,18 @@ static bool areVectorCastSimpleCompatible(
return false;
}
//===----------------------------------------------------------------------===//
// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
//===----------------------------------------------------------------------===//
static Type getTensorTypeFromMemRefType(Type type) {
if (auto memref = type.dyn_cast<MemRefType>())
return RankedTensorType::get(memref.getShape(), memref.getElementType());
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
return UnrankedTensorType::get(memref.getElementType());
return NoneType::get(type.getContext());
}
//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
@ -2140,6 +2152,106 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}
//===----------------------------------------------------------------------===//
// GlobalMemrefOp
//===----------------------------------------------------------------------===//
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
GlobalMemrefOp op,
TypeAttr type,
Attribute initialValue) {
p << type;
if (!op.isExternal()) {
p << " = ";
if (op.isUnitialized())
p << "uninitialized";
else
p.printAttributeWithoutType(initialValue);
}
}
static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &initialValue) {
Type type;
if (parser.parseType(type))
return failure();
auto memrefType = type.dyn_cast<MemRefType>();
if (!memrefType || !memrefType.hasStaticShape())
return parser.emitError(parser.getNameLoc())
<< "type should be static shaped memref, but got " << type;
typeAttr = TypeAttr::get(type);
if (parser.parseOptionalEqual())
return success();
if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
initialValue = UnitAttr::get(parser.getBuilder().getContext());
return success();
}
Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (parser.parseAttribute(initialValue, tensorType))
return failure();
if (!initialValue.isa<ElementsAttr>())
return parser.emitError(parser.getNameLoc())
<< "initial value should be a unit or elements attribute";
return success();
}
static LogicalResult verify(GlobalMemrefOp op) {
auto memrefType = op.type().dyn_cast<MemRefType>();
if (!memrefType || !memrefType.hasStaticShape())
return op.emitOpError("type should be static shaped memref, but got ")
<< op.type();
// Verify that the initial value, if present, is either a unit attribute or
// an elements attribute.
if (op.initial_value().hasValue()) {
Attribute initValue = op.initial_value().getValue();
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
return op.emitOpError("initial value should be a unit or elements "
"attribute, but got ")
<< initValue;
// Check that the type of the initial value is compatible with the type of
// the global variable.
if (initValue.isa<ElementsAttr>()) {
Type initType = initValue.getType();
Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (initType != tensorType)
return op.emitOpError("initial value expected to be of type ")
<< tensorType << ", but was of type " << initType;
}
}
// TODO: verify visibility for declarations.
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalMemrefOp
//===----------------------------------------------------------------------===//
LogicalResult
GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Verify that the result type is same as the type of the referenced
// global_memref op.
auto global =
symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
if (!global)
return emitOpError("'")
<< name() << "' does not reference a valid global memref";
Type resultType = result().getType();
if (global.type() != resultType)
return emitOpError("result type ")
<< resultType << " does not match type " << global.type()
<< " of the global memref @" << name();
return success();
}
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
@ -3891,18 +4003,6 @@ void TensorCastOp::getCanonicalizationPatterns(
results.insert<ChainedTensorCast>(context);
}
//===----------------------------------------------------------------------===//
// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
//===----------------------------------------------------------------------===//
static Type getTensorTypeFromMemRefType(Type type) {
if (auto memref = type.dyn_cast<MemRefType>())
return RankedTensorType::get(memref.getShape(), memref.getElementType());
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
return UnrankedTensorType::get(memref.getElementType());
return NoneType::get(type.getContext());
}
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//

View File

@ -226,6 +226,10 @@ OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
Type type) {
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
}
OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
Type type) {
return parseOptionalAttributeWithToken(Token::string, attribute, type);
}
/// Attribute dictionary.
///
@ -807,6 +811,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto attribLoc = getToken().getLoc();
consumeToken(Token::kw_dense);
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
@ -819,11 +824,14 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return nullptr;
}
auto typeLoc = getToken().getLoc();
// If the type is specified `parseElementsLiteralType` will not parse a type.
// Use the attribute location as the location for error reporting in that
// case.
auto loc = attrType ? attribLoc : getToken().getLoc();
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
return literalParser.getAttr(typeLoc, type);
return literalParser.getAttr(loc, type);
}
/// Parse an opaque elements attribute.

View File

@ -1065,6 +1065,11 @@ public:
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {

View File

@ -188,6 +188,7 @@ public:
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {});
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
OptionalParseResult parseOptionalAttribute(StringAttr &attribute, Type type);
/// Parse an optional attribute that is demarcated by a specific token.
template <typename AttributeT>

View File

@ -231,3 +231,84 @@ func @memref_reshape_result_affine_map_is_not_identity(
memref_reshape %buf(%shape)
: (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]>
}
// -----
// expected-error @+1 {{type should be static shaped memref}}
global_memref @foo : i32
// -----
// expected-error @+1 {{type should be static shaped memref}}
global_memref @foo : i32 = 5
// -----
// expected-error @+1 {{type should be static shaped memref}}
global_memref @foo : memref<*xf32>
// -----
// expected-error @+1 {{type should be static shaped memref}}
global_memref @foo : memref<?x?xf32>
// -----
// expected-error @+1 {{initial value should be a unit or elements attribute}}
global_memref @foo : memref<2x2xf32> = "foo"
// -----
// expected-error @+1 {{inferred shape of elements literal ([2]) does not match type ([2, 2])}}
global_memref @foo : memref<2x2xf32> = dense<[0.0, 1.0]>
// -----
// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
global_memref "private" "public" @foo : memref<2x2xf32> = "foo"
// -----
// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
global_memref constant external @foo : memref<2x2xf32> = "foo"
// -----
// constant qualifier must be after visibility.
// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
global_memref constant "private" @foo : memref<2x2xf32> = "foo"
// -----
// expected-error @+1 {{op visibility expected to be one of ["public", "private", "nested"], but got "priate"}}
global_memref "priate" constant @memref5 : memref<2xf32> = uninitialized
// -----
func @nonexistent_global_memref() {
// expected-error @+1 {{'gv' does not reference a valid global memref}}
%0 = get_global_memref @gv : memref<3xf32>
return
}
// -----
func @foo()
func @nonexistent_global_memref() {
// expected-error @+1 {{'foo' does not reference a valid global memref}}
%0 = get_global_memref @foo : memref<3xf32>
return
}
// -----
global_memref @gv : memref<3xi32>
func @mismatched_types() {
// expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}}
%0 = get_global_memref @gv : memref<3xf32>
return
}

View File

@ -77,3 +77,34 @@ func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return %new_unranked : memref<*xf32>
}
// CHECK-LABEL: global_memref @memref0 : memref<2xf32>
global_memref @memref0 : memref<2xf32>
// CHECK-LABEL: global_memref constant @memref1 : memref<2xf32> = dense<[0.000000e+00, 1.000000e+00]>
global_memref constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]>
// CHECK-LABEL: global_memref @memref2 : memref<2xf32> = uninitialized
global_memref @memref2 : memref<2xf32> = uninitialized
// CHECK-LABEL: global_memref "private" @memref3 : memref<2xf32> = uninitialized
global_memref "private" @memref3 : memref<2xf32> = uninitialized
// CHECK-LABEL: global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
// CHECK-LABEL: func @write_global_memref
func @write_global_memref() {
%0 = get_global_memref @memref0 : memref<2xf32>
%1 = constant dense<[1.0, 2.0]> : tensor<2xf32>
tensor_store %1, %0 : memref<2xf32>
return
}
// CHECK-LABEL: func @read_global_memref
func @read_global_memref() {
%0 = get_global_memref @memref0 : memref<2xf32>
%1 = tensor_load %0 : memref<2xf32>
return
}