From bda669beea2c6de2e832533bf79fa0add6425d0c Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 27 Jun 2019 09:12:19 -0700 Subject: [PATCH] Allow attaching a type to StringAttr. Some dialects allow for string types, and this allows for reusing StringAttr for constants of these types. PiperOrigin-RevId: 255413948 --- mlir/g3doc/LangRef.md | 2 +- mlir/include/mlir/IR/Attributes.h | 4 ++++ mlir/include/mlir/IR/Builders.h | 1 + mlir/lib/IR/AttributeDetail.h | 11 +++++++---- mlir/lib/IR/Attributes.cpp | 10 ++++++++-- mlir/lib/IR/Builders.cpp | 4 ++++ mlir/lib/Parser/Parser.cpp | 9 ++++++++- mlir/lib/SPIRV/SPIRVOps.cpp | 4 +++- mlir/test/IR/parser.mlir | 3 +++ mlir/test/SPIRV/ops.mlir | 4 ++-- 10 files changed, 41 insertions(+), 11 deletions(-) diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 7e7c025c25e8..5e5d00c07594 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -871,7 +871,7 @@ the given function. Syntax: ``` {.ebnf} -string-attribute ::= string-literal +string-attribute ::= string-literal (`:` type)? ``` A string attribute is an attribute that represents a string literal value. diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 215dff12d34d..5b9bfca35adb 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -405,8 +405,12 @@ public: using Base::Base; using ValueType = StringRef; + /// Get an instance of a StringAttr with the given string. static StringAttr get(StringRef bytes, MLIRContext *context); + /// Get an instance of a StringAttr with the given string and Type. + static StringAttr get(StringRef bytes, Type type); + StringRef getValue() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index b99f091a2ab3..27d0c28b7701 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -107,6 +107,7 @@ public: FloatAttr getFloatAttr(Type type, double value); FloatAttr getFloatAttr(Type type, const APFloat &value); StringAttr getStringAttr(StringRef bytes); + StringAttr getStringAttr(StringRef bytes, Type type); ArrayAttr getArrayAttr(ArrayRef value); AffineMapAttr getAffineMapAttr(AffineMap map); IntegerSetAttr getIntegerSetAttr(IntegerSet set); diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 0fe07a979169..8e757364ff3d 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -297,18 +297,21 @@ struct OpaqueAttributeStorage : public AttributeStorage { /// An attribute representing a string value. struct StringAttributeStorage : public AttributeStorage { - using KeyTy = StringRef; + using KeyTy = std::pair; - StringAttributeStorage(StringRef value) : value(value) {} + StringAttributeStorage(StringRef value, Type type) + : AttributeStorage(type), value(value) {} /// Key equality function. - bool operator==(const KeyTy &key) const { return key == value; } + bool operator==(const KeyTy &key) const { + return key == KeyTy(value, getType()); + } /// Construct a new storage instance. static StringAttributeStorage *construct(AttributeStorageAllocator &allocator, const KeyTy &key) { return new (allocator.allocate()) - StringAttributeStorage(allocator.copyInto(key)); + StringAttributeStorage(allocator.copyInto(key.first), key.second); } StringRef value; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 37ed96b1c278..01f9a060bd9e 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -255,7 +255,8 @@ FunctionAttr FunctionAttr::get(Function *value) { } FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) { - return Base::get(ctx, StandardAttributes::Function, value); + return Base::get(ctx, StandardAttributes::Function, value, + NoneType::get(ctx)); } StringRef FunctionAttr::getValue() const { return getImpl()->value; } @@ -332,7 +333,12 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants( //===----------------------------------------------------------------------===// StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { - return Base::get(context, StandardAttributes::String, bytes); + return get(bytes, NoneType::get(context)); +} + +/// Get an instance of a StringAttr with the given string and Type. +StringAttr StringAttr::get(StringRef bytes, Type type) { + return Base::get(type.getContext(), StandardAttributes::String, bytes, type); } StringRef StringAttr::getValue() const { return getImpl()->value; } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 43e6a44fa7a3..9b30205abdbc 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -159,6 +159,10 @@ StringAttr Builder::getStringAttr(StringRef bytes) { return StringAttr::get(bytes, context); } +StringAttr Builder::getStringAttr(StringRef bytes, Type type) { + return StringAttr::get(bytes, type); +} + ArrayAttr Builder::getArrayAttr(ArrayRef value) { return ArrayAttr::get(value, context); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index d2443082711b..3e0f4e86e858 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -926,7 +926,7 @@ ParseResult Parser::parseXInDimensionList() { /// | bool-literal /// | integer-literal (`:` (index-type | integer-type))? /// | float-literal (`:` float-type)? -/// | string-literal +/// | string-literal (`:` type)? /// | type /// | `[` (attribute-value (`,` attribute-value)*)? `]` /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` @@ -1034,6 +1034,13 @@ Attribute Parser::parseAttribute(Type type) { case Token::string: { auto val = getToken().getStringValue(); consumeToken(Token::string); + + // Parse the optional trailing colon type. + if (!type && consumeIf(Token::colon)) { + Type stringType = parseType(); + return stringType ? StringAttr::get(val, stringType) : Attribute(); + } + return builder.getStringAttr(val); } diff --git a/mlir/lib/SPIRV/SPIRVOps.cpp b/mlir/lib/SPIRV/SPIRVOps.cpp index a32af8419c69..1fd27c49aefb 100644 --- a/mlir/lib/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/SPIRV/SPIRVOps.cpp @@ -43,7 +43,9 @@ static ParseResult parseStorageClassAttribute(spirv::StorageClass &storageClass, Attribute storageClassAttr; SmallVector storageAttr; auto loc = parser->getCurrentLocation(); - if (parser->parseAttribute(storageClassAttr, "storage_class", storageAttr)) { + if (parser->parseAttribute(storageClassAttr, + parser->getBuilder().getNoneType(), + "storage_class", storageAttr)) { return failure(); } if (!storageClassAttr.isa()) { diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 8a81be2b8e22..c3ed9daaafe2 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -512,6 +512,9 @@ func @stringquote() -> () { ^bb0: // CHECK: "foo"() {bar = "a\22quoted\22string"} : () -> () "foo"(){bar = "a\"quoted\"string"} : () -> () + + // CHECK-NEXT: "typed_string" : !foo.string + "foo"(){bar = "typed_string" : !foo.string} : () -> () return } diff --git a/mlir/test/SPIRV/ops.mlir b/mlir/test/SPIRV/ops.mlir index 947622d84ba0..407ce4f4b4e5 100644 --- a/mlir/test/SPIRV/ops.mlir +++ b/mlir/test/SPIRV/ops.mlir @@ -111,7 +111,7 @@ func @volatile_load_missing_lbrace() -> () { func @volatile_load_missing_rbrace() -> () { %0 = spv.Variable : !spv.ptr // expected-error @+1 {{expected ']'}} - %1 = spv.Load "Function" %0 ["Volatile" : f32 + %1 = spv.Load "Function" %0 ["Volatile"} : f32 return } @@ -247,7 +247,7 @@ func @volatile_store_missing_lbrace(%arg0 : f32) -> () { func @volatile_store_missing_rbrace(%arg0 : f32) -> () { %0 = spv.Variable : !spv.ptr // expected-error @+1 {{expected ']'}} - spv.Store "Function" %0, %arg0 ["Volatile" : f32 + spv.Store "Function" %0, %arg0 ["Volatile"} : f32 return }