forked from OSchip/llvm-project
[mlir][ods] Added EnumAttr, an AttrDef implementation of enum attributes
`EnumAttr` is a pure TableGen implementation of enum attributes using `AttrDef`. This is meant as a drop-in replacement for `StrEnumAttr`, which is soon to be deprecated. `StrEnumAttr` is often used over `IntEnumAttr` because its more readable in MLIR assembly formats. However, storing and manipulating strings is not efficient. Defining `StrEnumAttr` can also be awkward and relies on a lot of special logic in `EnumsGen`, and has some hidden sharp edges. Also, `EnumAttr` stores the enum directly, removing the need to convert to/from integers when calling attribute getters on ops. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D115181
This commit is contained in:
parent
c50a4b3f97
commit
319d8cf685
|
@ -0,0 +1,96 @@
|
|||
//===-- EnumAttr.td - Enum attributes ----------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ENUM_ATTR
|
||||
#define ENUM_ATTR
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// A C++ enum as an attribute parameter. The parameter implements a parser and
|
||||
// printer for the enum by dispatching calls to `stringToSymbol` and
|
||||
// `symbolToString`.
|
||||
class EnumParameter<EnumAttrInfo enumInfo>
|
||||
: AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
|
||||
"an enum of type " # enumInfo.className> {
|
||||
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
|
||||
// symbol is not valid.
|
||||
let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
|
||||
auto loc = $_parser.getCurrentLocation();
|
||||
::llvm::StringRef enumKeyword;
|
||||
if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
|
||||
return ::mlir::failure();
|
||||
auto maybeEnum = }] # enumInfo.cppNamespace # "::" #
|
||||
enumInfo.stringToSymbolFnName # [{(enumKeyword);
|
||||
if (maybeEnum)
|
||||
return *maybeEnum;
|
||||
return {$_parser.emitError(loc, "expected }] #
|
||||
cppType # [{ to be one of: }] #
|
||||
!interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")};
|
||||
}()}];
|
||||
// Print the enum by calling `symbolToString`.
|
||||
let printer = "$_printer << " # enumInfo.symbolToStringFnName # "($_self)";
|
||||
}
|
||||
|
||||
// An attribute backed by a C++ enum. The attribute contains a single
|
||||
// parameter `value` whose type is the C++ enum class.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ```
|
||||
// def MyEnum : I32EnumAttr<"MyEnum", "a simple enum", [
|
||||
// I32EnumAttrCase<"First", 0, "first">,
|
||||
// I32EnumAttrCase<"Second", 1, "second>]> {
|
||||
// let genSpecializedAttr = 0;
|
||||
// }
|
||||
//
|
||||
// def MyEnumAttr : EnumAttr<MyDialect, MyEnum, "enum">;
|
||||
// ```
|
||||
//
|
||||
// By default, the assembly format of the attribute works best with operation
|
||||
// assembly formats. For example:
|
||||
//
|
||||
// ```
|
||||
// def MyOp : Op<MyDialect, "my_op"> {
|
||||
// let arguments = (ins MyEnumAttr:$enum);
|
||||
// let assemblyFormat = "$enum attr-dict";
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// The op will appear in the IR as `my_dialect.my_op first`. However, the
|
||||
// generic format of the attribute will be `#my_dialect<"enum first">`. Override
|
||||
// the attribute's assembly format as required.
|
||||
class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
|
||||
list <Trait> traits = []>
|
||||
: AttrDef<dialect, enumInfo.className, traits> {
|
||||
let summary = enumInfo.summary;
|
||||
|
||||
// Inherit the C++ namespace from the enum.
|
||||
let cppNamespace = enumInfo.cppNamespace;
|
||||
|
||||
// Define a constant builder for the attribute to convert from C++ enums.
|
||||
let constBuilderCall = cppNamespace # "::" # cppClassName #
|
||||
"::get($_builder.getContext(), $0)";
|
||||
|
||||
// Op attribute getters should return the underlying C++ enum type.
|
||||
let returnType = enumInfo.cppNamespace # "::" # enumInfo.className;
|
||||
|
||||
// Convert from attribute to the underlying C++ type in op getters.
|
||||
let convertFromStorage = "$_self.getValue()";
|
||||
|
||||
// The enum attribute has one parameter: the C++ enum value.
|
||||
let parameters = (ins EnumParameter<enumInfo>:$value);
|
||||
|
||||
// If a mnemonic was provided, use it to generate a custom assembly format.
|
||||
let mnemonic = name;
|
||||
|
||||
// The default assembly format for enum attributes. Selected to best work with
|
||||
// operation assembly formats.
|
||||
let assemblyFormat = "$value";
|
||||
}
|
||||
|
||||
#endif // ENUM_ATTR
|
|
@ -0,0 +1,30 @@
|
|||
// RUN: mlir-opt -verify-diagnostics -split-input-file %s
|
||||
|
||||
func @test_invalid_enum_case() -> () {
|
||||
// expected-error@+2 {{expected test::TestEnum to be one of: first, second, third}}
|
||||
// expected-error@+1 {{failed to parse TestEnumAttr}}
|
||||
test.op_with_enum #test<"enum fourth">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_invalid_enum_case() -> () {
|
||||
// expected-error@+1 {{expected test::TestEnum to be one of: first, second, third}}
|
||||
test.op_with_enum fourth
|
||||
// expected-error@+1 {{failed to parse TestEnumAttr}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_invalid_attr() -> () {
|
||||
// expected-error@+1 {{op attribute 'value' failed to satisfy constraint: a test enum}}
|
||||
"test.op_with_enum"() {value = 1 : index} : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_parse_invalid_attr() -> () {
|
||||
// expected-error@+2 {{expected valid keyword}}
|
||||
// expected-error@+1 {{failed to parse TestEnumAttr parameter 'value'}}
|
||||
test.op_with_enum 1 : index
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_enum_attr_roundtrip
|
||||
func @test_enum_attr_roundtrip() -> () {
|
||||
// CHECK: value = #test<"enum first">
|
||||
"test.op"() {value = #test<"enum first">} : () -> ()
|
||||
// CHECK: value = #test<"enum second">
|
||||
"test.op"() {value = #test<"enum second">} : () -> ()
|
||||
// CHECK: value = #test<"enum third">
|
||||
"test.op"() {value = #test<"enum third">} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_op_with_enum
|
||||
func @test_op_with_enum() -> () {
|
||||
// CHECK: test.op_with_enum third
|
||||
test.op_with_enum third
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_match_op_with_enum
|
||||
func @test_match_op_with_enum() -> () {
|
||||
// CHECK: test.op_with_enum third tag 0 : i32
|
||||
test.op_with_enum third tag 0 : i32
|
||||
// CHECK: test.op_with_enum second tag 1 : i32
|
||||
test.op_with_enum first tag 0 : i32
|
||||
return
|
||||
}
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/IR/DialectImplementation.h"
|
||||
|
||||
#include "TestAttrInterfaces.h.inc"
|
||||
#include "TestOpEnums.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "TestAttrDefs.h.inc"
|
||||
|
|
|
@ -39,7 +39,6 @@ class DLTIDialect;
|
|||
class RewritePatternSet;
|
||||
} // namespace mlir
|
||||
|
||||
#include "TestOpEnums.h.inc"
|
||||
#include "TestOpInterfaces.h.inc"
|
||||
#include "TestOpStructs.h.inc"
|
||||
#include "TestOpsDialect.h.inc"
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
include "TestDialect.td"
|
||||
include "mlir/Dialect/DLTI/DLTIBase.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/RegionKindInterface.td"
|
||||
|
@ -287,6 +288,38 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
|
|||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Enum Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Define the C++ enum.
|
||||
def TestEnum
|
||||
: I32EnumAttr<"TestEnum", "a test enum", [
|
||||
I32EnumAttrCase<"First", 0, "first">,
|
||||
I32EnumAttrCase<"Second", 1, "second">,
|
||||
I32EnumAttrCase<"Third", 2, "third">,
|
||||
]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "test";
|
||||
}
|
||||
|
||||
// Define the enum attribute.
|
||||
def TestEnumAttr : EnumAttr<Test_Dialect, TestEnum, "enum">;
|
||||
|
||||
// Define an op that contains the enum attribute.
|
||||
def OpWithEnum : TEST_Op<"op_with_enum"> {
|
||||
let arguments = (ins TestEnumAttr:$value, OptionalAttr<AnyAttr>:$tag);
|
||||
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
|
||||
}
|
||||
|
||||
// Define a pattern that matches and creates an enum attribute.
|
||||
def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
|
||||
"::test::TestEnum::First">:$value,
|
||||
ConstantAttr<I32Attr, "0">:$tag),
|
||||
(OpWithEnum ConstantAttr<TestEnumAttr,
|
||||
"::test::TestEnum::Second">,
|
||||
ConstantAttr<I32Attr, "1">)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Attribute Constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -839,6 +839,7 @@ cc_binary(
|
|||
td_library(
|
||||
name = "OpBaseTdFiles",
|
||||
srcs = [
|
||||
"include/mlir/IR/EnumAttr.td",
|
||||
"include/mlir/IR/OpAsmInterface.td",
|
||||
"include/mlir/IR/OpBase.td",
|
||||
"include/mlir/IR/RegionKindInterface.td",
|
||||
|
|
Loading…
Reference in New Issue