[mlir][ods] Added EnumAttr, an AttrDef implementation of enum attributes

`EnumAttr` is a pure TableGen implementation of enum attributes using `AttrDef`. This is meant as a drop-in replacement for `StrEnumAttr`, which is soon to be deprecated. `StrEnumAttr` is often used over `IntEnumAttr` because its more readable in MLIR assembly formats. However, storing and manipulating strings is not efficient. Defining `StrEnumAttr` can also be awkward and relies on a lot of special logic in `EnumsGen`, and has some hidden sharp edges.

Also, `EnumAttr` stores the enum directly,  removing the need to convert to/from integers when calling attribute getters on ops.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D115181
This commit is contained in:
Mogball 2021-12-17 02:44:56 +00:00
parent c50a4b3f97
commit 319d8cf685
7 changed files with 189 additions and 1 deletions

View File

@ -0,0 +1,96 @@
//===-- EnumAttr.td - Enum attributes ----------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ENUM_ATTR
#define ENUM_ATTR
include "mlir/IR/OpBase.td"
// A C++ enum as an attribute parameter. The parameter implements a parser and
// printer for the enum by dispatching calls to `stringToSymbol` and
// `symbolToString`.
class EnumParameter<EnumAttrInfo enumInfo>
: AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
"an enum of type " # enumInfo.className> {
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.
let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
auto loc = $_parser.getCurrentLocation();
::llvm::StringRef enumKeyword;
if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
return ::mlir::failure();
auto maybeEnum = }] # enumInfo.cppNamespace # "::" #
enumInfo.stringToSymbolFnName # [{(enumKeyword);
if (maybeEnum)
return *maybeEnum;
return {$_parser.emitError(loc, "expected }] #
cppType # [{ to be one of: }] #
!interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")};
}()}];
// Print the enum by calling `symbolToString`.
let printer = "$_printer << " # enumInfo.symbolToStringFnName # "($_self)";
}
// An attribute backed by a C++ enum. The attribute contains a single
// parameter `value` whose type is the C++ enum class.
//
// Example:
//
// ```
// def MyEnum : I32EnumAttr<"MyEnum", "a simple enum", [
// I32EnumAttrCase<"First", 0, "first">,
// I32EnumAttrCase<"Second", 1, "second>]> {
// let genSpecializedAttr = 0;
// }
//
// def MyEnumAttr : EnumAttr<MyDialect, MyEnum, "enum">;
// ```
//
// By default, the assembly format of the attribute works best with operation
// assembly formats. For example:
//
// ```
// def MyOp : Op<MyDialect, "my_op"> {
// let arguments = (ins MyEnumAttr:$enum);
// let assemblyFormat = "$enum attr-dict";
// }
// ```
//
// The op will appear in the IR as `my_dialect.my_op first`. However, the
// generic format of the attribute will be `#my_dialect<"enum first">`. Override
// the attribute's assembly format as required.
class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
list <Trait> traits = []>
: AttrDef<dialect, enumInfo.className, traits> {
let summary = enumInfo.summary;
// Inherit the C++ namespace from the enum.
let cppNamespace = enumInfo.cppNamespace;
// Define a constant builder for the attribute to convert from C++ enums.
let constBuilderCall = cppNamespace # "::" # cppClassName #
"::get($_builder.getContext(), $0)";
// Op attribute getters should return the underlying C++ enum type.
let returnType = enumInfo.cppNamespace # "::" # enumInfo.className;
// Convert from attribute to the underlying C++ type in op getters.
let convertFromStorage = "$_self.getValue()";
// The enum attribute has one parameter: the C++ enum value.
let parameters = (ins EnumParameter<enumInfo>:$value);
// If a mnemonic was provided, use it to generate a custom assembly format.
let mnemonic = name;
// The default assembly format for enum attributes. Selected to best work with
// operation assembly formats.
let assemblyFormat = "$value";
}
#endif // ENUM_ATTR

View File

@ -0,0 +1,30 @@
// RUN: mlir-opt -verify-diagnostics -split-input-file %s
func @test_invalid_enum_case() -> () {
// expected-error@+2 {{expected test::TestEnum to be one of: first, second, third}}
// expected-error@+1 {{failed to parse TestEnumAttr}}
test.op_with_enum #test<"enum fourth">
}
// -----
func @test_invalid_enum_case() -> () {
// expected-error@+1 {{expected test::TestEnum to be one of: first, second, third}}
test.op_with_enum fourth
// expected-error@+1 {{failed to parse TestEnumAttr}}
}
// -----
func @test_invalid_attr() -> () {
// expected-error@+1 {{op attribute 'value' failed to satisfy constraint: a test enum}}
"test.op_with_enum"() {value = 1 : index} : () -> ()
}
// -----
func @test_parse_invalid_attr() -> () {
// expected-error@+2 {{expected valid keyword}}
// expected-error@+1 {{failed to parse TestEnumAttr parameter 'value'}}
test.op_with_enum 1 : index
}

View File

@ -0,0 +1,28 @@
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
// CHECK-LABEL: @test_enum_attr_roundtrip
func @test_enum_attr_roundtrip() -> () {
// CHECK: value = #test<"enum first">
"test.op"() {value = #test<"enum first">} : () -> ()
// CHECK: value = #test<"enum second">
"test.op"() {value = #test<"enum second">} : () -> ()
// CHECK: value = #test<"enum third">
"test.op"() {value = #test<"enum third">} : () -> ()
return
}
// CHECK-LABEL: @test_op_with_enum
func @test_op_with_enum() -> () {
// CHECK: test.op_with_enum third
test.op_with_enum third
return
}
// CHECK-LABEL: @test_match_op_with_enum
func @test_match_op_with_enum() -> () {
// CHECK: test.op_with_enum third tag 0 : i32
test.op_with_enum third tag 0 : i32
// CHECK: test.op_with_enum second tag 1 : i32
test.op_with_enum first tag 0 : i32
return
}

View File

@ -23,6 +23,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "TestAttrInterfaces.h.inc"
#include "TestOpEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.h.inc"

View File

@ -39,7 +39,6 @@ class DLTIDialect;
class RewritePatternSet;
} // namespace mlir
#include "TestOpEnums.h.inc"
#include "TestOpInterfaces.h.inc"
#include "TestOpStructs.h.inc"
#include "TestOpsDialect.h.inc"

View File

@ -11,6 +11,7 @@
include "TestDialect.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/RegionKindInterface.td"
@ -287,6 +288,38 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
);
}
//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//
// Define the C++ enum.
def TestEnum
: I32EnumAttr<"TestEnum", "a test enum", [
I32EnumAttrCase<"First", 0, "first">,
I32EnumAttrCase<"Second", 1, "second">,
I32EnumAttrCase<"Third", 2, "third">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
}
// Define the enum attribute.
def TestEnumAttr : EnumAttr<Test_Dialect, TestEnum, "enum">;
// Define an op that contains the enum attribute.
def OpWithEnum : TEST_Op<"op_with_enum"> {
let arguments = (ins TestEnumAttr:$value, OptionalAttr<AnyAttr>:$tag);
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
}
// Define a pattern that matches and creates an enum attribute.
def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
"::test::TestEnum::First">:$value,
ConstantAttr<I32Attr, "0">:$tag),
(OpWithEnum ConstantAttr<TestEnumAttr,
"::test::TestEnum::Second">,
ConstantAttr<I32Attr, "1">)>;
//===----------------------------------------------------------------------===//
// Test Attribute Constraints
//===----------------------------------------------------------------------===//

View File

@ -839,6 +839,7 @@ cc_binary(
td_library(
name = "OpBaseTdFiles",
srcs = [
"include/mlir/IR/EnumAttr.td",
"include/mlir/IR/OpAsmInterface.td",
"include/mlir/IR/OpBase.td",
"include/mlir/IR/RegionKindInterface.td",