forked from OSchip/llvm-project
[MLIR] Introduce std.global_memref and std.get_global_memref operations.
- Add standard dialect operations to define global variables with memref types and to retrieve the memref for to a named global variable - Extend unit tests to test verification for these operations. Differential Revision: https://reviews.llvm.org/D90337
This commit is contained in:
parent
934b27a9da
commit
c254b0bb69
|
@ -2005,6 +2005,97 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
|
|||
let hasFolder = 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> {
|
||||
let summary = "declare or define a global memref variable";
|
||||
let description = [{
|
||||
The `global_memref` operation declares or defines a named global variable.
|
||||
The backing memory for the variable is allocated statically and is described
|
||||
by the type of the variable (which should be a statically shaped memref
|
||||
type). The operation is a declaration if no `inital_value` is specified,
|
||||
else it is a definition. The `initial_value` can either be a unit attribute
|
||||
to represent a definition of an uninitialized global variable, or an
|
||||
elements attribute to represent the definition of a global variable with an
|
||||
initial value. The global variable can also be marked constant using the
|
||||
`constant` unit attribute. Writing to such constant global variables is
|
||||
undefined.
|
||||
|
||||
The global variable can be accessed by using the `get_global_memref` to
|
||||
retrieve the memref for the global variable. Note that the memref
|
||||
for such global variable itself is immutable (i.e., get_global_memref for a
|
||||
given global variable will always return the same memref descriptor).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Private variable with an initial value.
|
||||
global_memref @x : memref<2xf32> { sym_visibility = "private",
|
||||
initial_value = dense<0.0,2.0> : tensor<2xf32> }
|
||||
|
||||
// External variable.
|
||||
global_memref @y : memref<4xi32> { sym_visibility = "public" }
|
||||
|
||||
// Uninitialized externally visible variable.
|
||||
global_memref @z : memref<3xf16> { sym_visibility = "public",
|
||||
initial_value }
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolNameAttr:$sym_name,
|
||||
OptionalAttr<StrAttr>:$sym_visibility,
|
||||
TypeAttr:$type,
|
||||
OptionalAttr<AnyAttr>:$initial_value,
|
||||
UnitAttr:$constant
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
($sym_visibility^)?
|
||||
(`constant` $constant^)?
|
||||
$sym_name `:`
|
||||
custom<GlobalMemrefOpTypeAndInitialValue>($type, $initial_value)
|
||||
attr-dict
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isExternal() { return !initial_value(); }
|
||||
bool isUnitialized() {
|
||||
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def GetGlobalMemrefOp : Std_Op<"get_global_memref",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
let summary = "get the memref pointing to a global variable";
|
||||
let description = [{
|
||||
The `get_global_memref` operation retrieves the memref pointing to a
|
||||
named global variable. If the global variable is marked constant, writing
|
||||
to the result memref (such as through a `std.store` operation) is
|
||||
undefined.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%x = get_global_memref @foo : memref<2xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins FlatSymbolRefAttr:$name);
|
||||
let results = (outs AnyStaticShapeMemRef:$result);
|
||||
let assemblyFormat = "$name `:` type($result) attr-dict";
|
||||
|
||||
// `GetGlobalMemrefOp` is fully verified by its traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ImOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -395,7 +395,7 @@ public:
|
|||
|
||||
// Parse any kind of attribute.
|
||||
Attribute attr;
|
||||
if (parseAttribute(attr))
|
||||
if (parseAttribute(attr, type))
|
||||
return failure();
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
|
@ -436,6 +436,10 @@ public:
|
|||
Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) = 0;
|
||||
virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
|
||||
Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) = 0;
|
||||
|
||||
/// Parse an arbitrary attribute of a given type and return it in result. This
|
||||
/// also adds the attribute to the specified attribute list with the specified
|
||||
|
|
|
@ -245,6 +245,18 @@ static bool areVectorCastSimpleCompatible(
|
|||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Type getTensorTypeFromMemRefType(Type type) {
|
||||
if (auto memref = type.dyn_cast<MemRefType>())
|
||||
return RankedTensorType::get(memref.getShape(), memref.getElementType());
|
||||
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
|
||||
return UnrankedTensorType::get(memref.getElementType());
|
||||
return NoneType::get(type.getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2140,6 +2152,106 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
|
|||
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
|
||||
GlobalMemrefOp op,
|
||||
TypeAttr type,
|
||||
Attribute initialValue) {
|
||||
p << type;
|
||||
if (!op.isExternal()) {
|
||||
p << " = ";
|
||||
if (op.isUnitialized())
|
||||
p << "uninitialized";
|
||||
else
|
||||
p.printAttributeWithoutType(initialValue);
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult
|
||||
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
|
||||
Attribute &initialValue) {
|
||||
Type type;
|
||||
if (parser.parseType(type))
|
||||
return failure();
|
||||
|
||||
auto memrefType = type.dyn_cast<MemRefType>();
|
||||
if (!memrefType || !memrefType.hasStaticShape())
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< "type should be static shaped memref, but got " << type;
|
||||
typeAttr = TypeAttr::get(type);
|
||||
|
||||
if (parser.parseOptionalEqual())
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
|
||||
initialValue = UnitAttr::get(parser.getBuilder().getContext());
|
||||
return success();
|
||||
}
|
||||
|
||||
Type tensorType = getTensorTypeFromMemRefType(memrefType);
|
||||
if (parser.parseAttribute(initialValue, tensorType))
|
||||
return failure();
|
||||
if (!initialValue.isa<ElementsAttr>())
|
||||
return parser.emitError(parser.getNameLoc())
|
||||
<< "initial value should be a unit or elements attribute";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(GlobalMemrefOp op) {
|
||||
auto memrefType = op.type().dyn_cast<MemRefType>();
|
||||
if (!memrefType || !memrefType.hasStaticShape())
|
||||
return op.emitOpError("type should be static shaped memref, but got ")
|
||||
<< op.type();
|
||||
|
||||
// Verify that the initial value, if present, is either a unit attribute or
|
||||
// an elements attribute.
|
||||
if (op.initial_value().hasValue()) {
|
||||
Attribute initValue = op.initial_value().getValue();
|
||||
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
|
||||
return op.emitOpError("initial value should be a unit or elements "
|
||||
"attribute, but got ")
|
||||
<< initValue;
|
||||
|
||||
// Check that the type of the initial value is compatible with the type of
|
||||
// the global variable.
|
||||
if (initValue.isa<ElementsAttr>()) {
|
||||
Type initType = initValue.getType();
|
||||
Type tensorType = getTensorTypeFromMemRefType(memrefType);
|
||||
if (initType != tensorType)
|
||||
return op.emitOpError("initial value expected to be of type ")
|
||||
<< tensorType << ", but was of type " << initType;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: verify visibility for declarations.
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
// Verify that the result type is same as the type of the referenced
|
||||
// global_memref op.
|
||||
auto global =
|
||||
symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
|
||||
if (!global)
|
||||
return emitOpError("'")
|
||||
<< name() << "' does not reference a valid global memref";
|
||||
|
||||
Type resultType = result().getType();
|
||||
if (global.type() != resultType)
|
||||
return emitOpError("result type ")
|
||||
<< resultType << " does not match type " << global.type()
|
||||
<< " of the global memref @" << name();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IndexCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3891,18 +4003,6 @@ void TensorCastOp::getCanonicalizationPatterns(
|
|||
results.insert<ChainedTensorCast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Type getTensorTypeFromMemRefType(Type type) {
|
||||
if (auto memref = type.dyn_cast<MemRefType>())
|
||||
return RankedTensorType::get(memref.getShape(), memref.getElementType());
|
||||
if (auto memref = type.dyn_cast<UnrankedMemRefType>())
|
||||
return UnrankedTensorType::get(memref.getElementType());
|
||||
return NoneType::get(type.getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -226,6 +226,10 @@ OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
|
|||
Type type) {
|
||||
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
|
||||
}
|
||||
OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
|
||||
Type type) {
|
||||
return parseOptionalAttributeWithToken(Token::string, attribute, type);
|
||||
}
|
||||
|
||||
/// Attribute dictionary.
|
||||
///
|
||||
|
@ -807,6 +811,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
|
|||
|
||||
/// Parse a dense elements attribute.
|
||||
Attribute Parser::parseDenseElementsAttr(Type attrType) {
|
||||
auto attribLoc = getToken().getLoc();
|
||||
consumeToken(Token::kw_dense);
|
||||
if (parseToken(Token::less, "expected '<' after 'dense'"))
|
||||
return nullptr;
|
||||
|
@ -819,11 +824,14 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto typeLoc = getToken().getLoc();
|
||||
// If the type is specified `parseElementsLiteralType` will not parse a type.
|
||||
// Use the attribute location as the location for error reporting in that
|
||||
// case.
|
||||
auto loc = attrType ? attribLoc : getToken().getLoc();
|
||||
auto type = parseElementsLiteralType(attrType);
|
||||
if (!type)
|
||||
return nullptr;
|
||||
return literalParser.getAttr(typeLoc, type);
|
||||
return literalParser.getAttr(loc, type);
|
||||
}
|
||||
|
||||
/// Parse an opaque elements attribute.
|
||||
|
|
|
@ -1065,6 +1065,11 @@ public:
|
|||
NamedAttrList &attrs) override {
|
||||
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
|
||||
}
|
||||
OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) override {
|
||||
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
|
||||
}
|
||||
|
||||
/// Parse a named dictionary into 'result' if it is present.
|
||||
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
|
||||
|
|
|
@ -188,6 +188,7 @@ public:
|
|||
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
|
||||
Type type = {});
|
||||
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
|
||||
OptionalParseResult parseOptionalAttribute(StringAttr &attribute, Type type);
|
||||
|
||||
/// Parse an optional attribute that is demarcated by a specific token.
|
||||
template <typename AttributeT>
|
||||
|
|
|
@ -231,3 +231,84 @@ func @memref_reshape_result_affine_map_is_not_identity(
|
|||
memref_reshape %buf(%shape)
|
||||
: (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{type should be static shaped memref}}
|
||||
global_memref @foo : i32
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{type should be static shaped memref}}
|
||||
global_memref @foo : i32 = 5
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{type should be static shaped memref}}
|
||||
global_memref @foo : memref<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{type should be static shaped memref}}
|
||||
global_memref @foo : memref<?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{initial value should be a unit or elements attribute}}
|
||||
global_memref @foo : memref<2x2xf32> = "foo"
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{inferred shape of elements literal ([2]) does not match type ([2, 2])}}
|
||||
global_memref @foo : memref<2x2xf32> = dense<[0.0, 1.0]>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
|
||||
global_memref "private" "public" @foo : memref<2x2xf32> = "foo"
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
|
||||
global_memref constant external @foo : memref<2x2xf32> = "foo"
|
||||
|
||||
// -----
|
||||
|
||||
// constant qualifier must be after visibility.
|
||||
// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
|
||||
global_memref constant "private" @foo : memref<2x2xf32> = "foo"
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{op visibility expected to be one of ["public", "private", "nested"], but got "priate"}}
|
||||
global_memref "priate" constant @memref5 : memref<2xf32> = uninitialized
|
||||
|
||||
// -----
|
||||
|
||||
func @nonexistent_global_memref() {
|
||||
// expected-error @+1 {{'gv' does not reference a valid global memref}}
|
||||
%0 = get_global_memref @gv : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo()
|
||||
|
||||
func @nonexistent_global_memref() {
|
||||
// expected-error @+1 {{'foo' does not reference a valid global memref}}
|
||||
%0 = get_global_memref @foo : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
global_memref @gv : memref<3xi32>
|
||||
|
||||
func @mismatched_types() {
|
||||
// expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}}
|
||||
%0 = get_global_memref @gv : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -77,3 +77,34 @@ func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
|||
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
return %new_unranked : memref<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: global_memref @memref0 : memref<2xf32>
|
||||
global_memref @memref0 : memref<2xf32>
|
||||
|
||||
// CHECK-LABEL: global_memref constant @memref1 : memref<2xf32> = dense<[0.000000e+00, 1.000000e+00]>
|
||||
global_memref constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]>
|
||||
|
||||
// CHECK-LABEL: global_memref @memref2 : memref<2xf32> = uninitialized
|
||||
global_memref @memref2 : memref<2xf32> = uninitialized
|
||||
|
||||
// CHECK-LABEL: global_memref "private" @memref3 : memref<2xf32> = uninitialized
|
||||
global_memref "private" @memref3 : memref<2xf32> = uninitialized
|
||||
|
||||
// CHECK-LABEL: global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
|
||||
global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
|
||||
|
||||
// CHECK-LABEL: func @write_global_memref
|
||||
func @write_global_memref() {
|
||||
%0 = get_global_memref @memref0 : memref<2xf32>
|
||||
%1 = constant dense<[1.0, 2.0]> : tensor<2xf32>
|
||||
tensor_store %1, %0 : memref<2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @read_global_memref
|
||||
func @read_global_memref() {
|
||||
%0 = get_global_memref @memref0 : memref<2xf32>
|
||||
%1 = tensor_load %0 : memref<2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue