[mlir][ods] Extend the EnumAttr tablegen class to support BitEnum attributes

This diff allows the EnumAttr class to be used for bit enum attributes (in
addition to previously supported integer enum attributes). While integer
and bit enum attributes share many common implementation aspects, parsing
bit enum values requires a separate implementation. This is accomplished
by creating empty parser and printer strings in the EnumAttrInfo record,
and having derived classes (specific to bit and integer enums) override with
an appropriate parser/printer string.

To support existing bit enums that may use a vertical bar separator, the
parser is modified to support the | token.

Tests were added for bit enums alongside integer enums.

Future diffs for fastmath attributes in the arithmetic dialect will use these
changes.

(resubmission of earlier abaondoned diff, updated to reflect subsequent changes
in the repository)

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D123880
This commit is contained in:
Jeremy Furtek 2022-04-25 18:51:09 +00:00 committed by Mogball
parent 4e5dee2f30
commit a266a21000
8 changed files with 167 additions and 19 deletions

View File

@ -53,7 +53,7 @@ class BitEnumAttrCaseBase<I intType, string sym, int val, string str = sym> :
// A bit enum case stored with a 32-bit IntegerAttr. `val` here is *not* the
// ordinal number of a bit that is set. It is a 32-bit integer value with bits
// set to match the case.
// set to match the case.
class I32BitEnumAttrCase<string sym, int val, string str = sym>
: BitEnumAttrCaseBase<I32, sym, val, str>;
@ -182,6 +182,14 @@ class EnumAttrInfo<
cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)",
baseAttrClass.constBuilderCall);
let valueType = baseAttrClass.valueType;
// C++ type wrapped by attribute
string cppType = cppNamespace # "::" # className;
// Parser and printer code used by the EnumParameter class, to be provided by
// derived classes
string parameterParser = ?;
string parameterPrinter = ?;
}
// An enum attribute backed by IntegerAttr.
@ -202,7 +210,25 @@ class IntEnumAttr<I intType, string name, string summary,
IntEnumAttrBase<intType, cases,
!if(!empty(summary), "allowed " # intType.summary # " cases: " #
!interleave(!foreach(case, cases, case.value), ", "),
summary)>>;
summary)>> {
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.
let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
auto loc = $_parser.getCurrentLocation();
::llvm::StringRef enumKeyword;
if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
return ::mlir::failure();
auto maybeEnum = }] # cppNamespace # "::" #
stringToSymbolFnName # [{(enumKeyword);
if (maybeEnum)
return *maybeEnum;
return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] #
cppType # [{ to be one of: }] #
!interleave(!foreach(enum, enumerants, enum.str), ", ") # [{")};
}()}];
// Print the enum by calling `symbolToString`.
let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
}
class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> :
IntEnumAttr<I32, name, summary, cases> {
@ -244,6 +270,35 @@ class BitEnumAttr<I intType, string name, string summary,
// The delimiter used to separate bit enum cases in strings.
string separator = "|";
// Parsing function that corresponds to the enum separator. Only
// "," and "|" are supported by this definition.
string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar",
"parseOptionalComma");
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.
let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
}] # cppType # [{ flags = {};
auto loc = $_parser.getCurrentLocation();
::llvm::StringRef enumKeyword;
do {
if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
return ::mlir::failure();
auto maybeEnum = }] # cppNamespace # "::" #
stringToSymbolFnName # [{(enumKeyword);
if (!maybeEnum) {
return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] #
cppType # [{ to be one of: }] #
!interleave(!foreach(enum, enumerants, enum.str),
", ") # [{")};
}
flags = flags | *maybeEnum;
} while(::mlir::succeeded($_parser.}] # parseSeparatorFn # [{()));
return flags;
}()}];
// Print the enum by calling `symbolToString`.
let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
// Print the "primary group" only for bits that are members of case groups
// that have all bits present. When the value is 0, printing will display both
// both individual bit case names AND the names for all groups that the bit is
@ -272,23 +327,8 @@ class I64BitEnumAttr<string name, string summary,
class EnumParameter<EnumAttrInfo enumInfo>
: AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
"an enum of type " # enumInfo.className> {
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.
let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
auto loc = $_parser.getCurrentLocation();
::llvm::StringRef enumKeyword;
if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
return ::mlir::failure();
auto maybeEnum = }] # enumInfo.cppNamespace # "::" #
enumInfo.stringToSymbolFnName # [{(enumKeyword);
if (maybeEnum)
return *maybeEnum;
return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] #
cppType # [{ to be one of: }] #
!interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")};
}()}];
// Print the enum by calling `symbolToString`.
let printer = "$_printer << " # enumInfo.symbolToStringFnName # "($_self)";
let parser = enumInfo.parameterParser;
let printer = enumInfo.parameterPrinter;
}
// An attribute backed by a C++ enum. The attribute contains a single

View File

@ -464,6 +464,12 @@ public:
/// Parse a '*' token if present.
virtual ParseResult parseOptionalStar() = 0;
/// Parse a '|' token.
virtual ParseResult parseVerticalBar() = 0;
/// Parse a '|' token if present.
virtual ParseResult parseOptionalVerticalBar() = 0;
/// Parse a quoted string token.
ParseResult parseString(std::string *string) {
auto loc = getCurrentLocation();

View File

@ -221,6 +221,16 @@ public:
return success(parser.consumeIf(Token::plus));
}
/// Parse a '|' token.
virtual ParseResult parseVerticalBar() override {
return parser.parseToken(Token::vertical_bar, "expected '|'");
}
/// Parse a '|' token if present.
virtual ParseResult parseOptionalVerticalBar() override {
return success(parser.consumeIf(Token::vertical_bar));
}
/// Parses a quoted string token if present.
ParseResult parseOptionalString(std::string *string) override {
if (!parser.getToken().is(Token::string))

View File

@ -127,6 +127,9 @@ Token Lexer::lexToken() {
case '?':
return formToken(Token::question, tokStart);
case '|':
return formToken(Token::vertical_bar, tokStart);
case '/':
if (*curPtr == '/') {
skipComment();

View File

@ -70,6 +70,7 @@ TOK_PUNCTUATION(r_brace, "}")
TOK_PUNCTUATION(r_paren, ")")
TOK_PUNCTUATION(r_square, "]")
TOK_PUNCTUATION(star, "*")
TOK_PUNCTUATION(vertical_bar, "|")
// Keywords. These turn "foo" into Token::kw_foo enums.

View File

@ -407,6 +407,42 @@ func.func @disallowed_case7_fail() {
// -----
//===----------------------------------------------------------------------===//
// Test BitEnumAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @allowed_cases_pass
func @allowed_cases_pass() {
// CHECK: test.op_with_bit_enum <read,write>
"test.op_with_bit_enum"() {value = #test.bit_enum<read, write>} : () -> ()
// CHECK: test.op_with_bit_enum <read,execute>
test.op_with_bit_enum <read,execute>
return
}
// -----
// CHECK-LABEL: func @allowed_cases_pass
func @allowed_cases_pass() {
// CHECK: test.op_with_bit_enum_vbar <user|group>
"test.op_with_bit_enum_vbar"() {
value = #test.bit_enum_vbar<user|group>
} : () -> ()
// CHECK: test.op_with_bit_enum_vbar <user|group|other>
test.op_with_bit_enum_vbar <user | group | other>
return
}
// -----
func @disallowed_case_sticky_fail() {
// expected-error@+2 {{expected test::TestBitEnum to be one of: read, write, execute}}
// expected-error@+1 {{failed to parse TestBitEnumAttr}}
"test.op_with_bit_enum"() {value = #test.bit_enum<sticky>} : () -> ()
}
// -----
//===----------------------------------------------------------------------===//
// Test FloatElementsAttr
//===----------------------------------------------------------------------===//

View File

@ -22,6 +22,7 @@
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
// Include this before the using namespace lines below to

View File

@ -311,6 +311,57 @@ def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
"::test::TestEnum::Second">,
ConstantAttr<I32Attr, "1">)>;
//===----------------------------------------------------------------------===//
// Test Bit Enum Attributes
//===----------------------------------------------------------------------===//
// Define the C++ enum.
def TestBitEnum
: I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
I32BitEnumAttrCaseBit<"Read", 0, "read">,
I32BitEnumAttrCaseBit<"Write", 1, "write">,
I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
let separator = ",";
}
// Define the enum attribute.
def TestBitEnumAttr : EnumAttr<Test_Dialect, TestBitEnum, "bit_enum"> {
let assemblyFormat = "`<` $value `>`";
}
// Define an op that contains the enum attribute.
def OpWithBitEnum : TEST_Op<"op_with_bit_enum"> {
let arguments = (ins TestBitEnumAttr:$value, OptionalAttr<AnyAttr>:$tag);
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
}
// Define an enum with a different separator
def TestBitEnumVerticalBar
: I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
I32BitEnumAttrCaseBit<"User", 0, "user">,
I32BitEnumAttrCaseBit<"Group", 1, "group">,
I32BitEnumAttrCaseBit<"Other", 2, "other">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
let separator = "|";
}
def TestBitEnumVerticalBarAttr
: EnumAttr<Test_Dialect, TestBitEnumVerticalBar, "bit_enum_vbar"> {
let assemblyFormat = "`<` $value `>`";
}
// Define an op that contains the enum attribute.
def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
let arguments = (ins TestBitEnumVerticalBarAttr:$value,
OptionalAttr<AnyAttr>:$tag);
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
}
//===----------------------------------------------------------------------===//
// Test Attribute Constraints
//===----------------------------------------------------------------------===//