[mlir][ods] Allow specifying return types of builders

This patch allows custom attribute and type builders to return
something other than the C++ type of the attribute or type.

This is useful for attributes or types that may perform extra work during
construction (e.g. canonicalization) that could result in a different
kind of attribute or type being returned.

Reviewed By: rriddle, lattner

Differential Revision: https://reviews.llvm.org/D129792
This commit is contained in:
Jeff Niu 2022-07-14 10:31:38 -07:00
parent a7789d6315
commit 7fe2294e47
11 changed files with 108 additions and 33 deletions

View File

@ -347,6 +347,7 @@ def MyType : ... {
// its arguments.
return Base::get(typeParam.getContext(), ...);
}]>,
TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
];
}
```
@ -461,6 +462,28 @@ the builder used `TypeBuilderWithInferredContext` implies that the context
parameter is not necessary as it can be inferred from the arguments to the
builder.
The fifth builder will generate the declaration of a builder method with a
custom return type, like:
```tablegen
let builders = [
TypeBuilder<(ins "int":$intParam), [{}], "IntegerType">,
]
```
```c++
class MyType : /*...*/ {
/*...*/
static IntegerType get(::mlir::MLIRContext *context, int intParam);
};
```
This generates a builder declaration the same as the first three examples, but
the return type of the builder is user-specified instead of the attribute or
type class. This is useful for defining builders of attributes and types that
may fold or canonicalize on construction.
### Parsing and Printing
If a mnemonic was specified, the `hasCustomAssemblyFormat` and `assemblyFormat`

View File

@ -96,30 +96,38 @@ class PredTypeTrait<string descr, Pred pred> : PredTrait<descr, pred>;
// This is necessary because the `body` is also used to generate `getChecked`
// methods, which have a different underlying `Base::get*` call.
//
class AttrOrTypeBuilder<dag parameters, code bodyCode = ""> {
class AttrOrTypeBuilder<dag parameters, code bodyCode = "",
string returnTypeStr = ""> {
dag dagParams = parameters;
code body = bodyCode;
// Change the return type of the builder. By default, it is the type of the
// attribute or type.
string returnType = returnTypeStr;
// The context parameter can be inferred from one of the other parameters and
// is not implicitly added to the parameter list.
bit hasInferredContextParam = 0;
}
class AttrBuilder<dag parameters, code bodyCode = "">
: AttrOrTypeBuilder<parameters, bodyCode>;
class TypeBuilder<dag parameters, code bodyCode = "">
: AttrOrTypeBuilder<parameters, bodyCode>;
class AttrBuilder<dag parameters, code bodyCode = "", string returnType = "">
: AttrOrTypeBuilder<parameters, bodyCode, returnType>;
class TypeBuilder<dag parameters, code bodyCode = "", string returnType = "">
: AttrOrTypeBuilder<parameters, bodyCode, returnType>;
// A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter
// from one of the other builder parameters. Instances of this builder do not
// have `MLIRContext *` implicitly added to the parameter list.
class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
: TypeBuilder<parameters, bodyCode> {
class AttrOrTypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
string returnType = "">
: TypeBuilder<parameters, bodyCode, returnType> {
let hasInferredContextParam = 1;
}
class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "">
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode>;
class AttrBuilderWithInferredContext<dag parameters, code bodyCode = "",
string returnType = "">
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
string returnType = "">
: AttrOrTypeBuilderWithInferredContext<parameters, bodyCode, returnType>;
//===----------------------------------------------------------------------===//
// Definitions

View File

@ -792,14 +792,14 @@ public:
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
T getChecked(SMLoc loc, ParamsT &&...params) {
auto getChecked(SMLoc loc, ParamsT &&...params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT>
T getChecked(ParamsT &&...params) {
auto getChecked(ParamsT &&...params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}

View File

@ -37,6 +37,9 @@ class AttrOrTypeBuilder : public Builder {
public:
using Builder::Builder;
/// Returns an optional builder return type.
Optional<StringRef> getReturnType() const;
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool hasInferredContextParameter() const;
};

View File

@ -20,7 +20,11 @@ using namespace mlir::tblgen;
// AttrOrTypeBuilder
//===----------------------------------------------------------------------===//
/// Returns true if this builder is able to infer the MLIRContext parameter.
Optional<StringRef> AttrOrTypeBuilder::getReturnType() const {
Optional<StringRef> type = def->getValueAsOptionalString("returnType");
return type && !type->empty() ? type : llvm::None;
}
bool AttrOrTypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
}

View File

@ -223,6 +223,19 @@ def TestAttrSelfTypeParameterFormat
let assemblyFormat = "`<` $a `>`";
}
// Test overridding attribute builders with a custom builder.
def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
let mnemonic = "override_builder";
let parameters = (ins "int":$a);
let assemblyFormat = "`<` $a `>`";
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
let builders = [AttrBuilder<(ins "int":$a), [{
return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
}], "::mlir::Attribute">];
}
// Test simple extern 1D vector using ElementsAttrInterface.
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
ElementsAttrInterface

View File

@ -55,8 +55,8 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
// ATTR: if (odsParser.parseRParen())
// ATTR: return {};
// ATTR: return TestAAttr::get(odsParser.getContext(),
// ATTR: (*_result_value),
// ATTR: (*_result_complex));
// ATTR: IntegerAttr((*_result_value)),
// ATTR: TestParamA((*_result_complex)));
// ATTR: }
// ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@ -114,8 +114,8 @@ def AttrA : TestAttr<"TestA"> {
// ATTR: return {};
// ATTR: }
// ATTR: return TestBAttr::get(odsParser.getContext(),
// ATTR: (*_result_v0),
// ATTR: (*_result_v1));
// ATTR: TestParamA((*_result_v0)),
// ATTR: TestParamB((*_result_v1)));
// ATTR: }
// ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
@ -151,8 +151,8 @@ def AttrB : TestAttr<"TestB"> {
// ATTR: if (::mlir::failed(_result_v1))
// ATTR: return {};
// ATTR: return TestFAttr::get(odsParser.getContext(),
// ATTR: (*_result_v0),
// ATTR: (*_result_v1));
// ATTR: int((*_result_v0)),
// ATTR: int((*_result_v1)));
// ATTR: }
def AttrC : TestAttr<"TestF"> {
@ -278,10 +278,10 @@ def TypeA : TestType<"TestC"> {
// TYPE: if (::mlir::failed(_result_v3))
// TYPE: return {};
// TYPE: return TestDType::get(odsParser.getContext(),
// TYPE: (*_result_v0),
// TYPE: (*_result_v1),
// TYPE: (*_result_v2),
// TYPE: (*_result_v3));
// TYPE: TestParamC((*_result_v0)),
// TYPE: TestParamD((*_result_v1)),
// TYPE: TestParamC((*_result_v2)),
// TYPE: TestParamD((*_result_v3)));
// TYPE: }
// TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
@ -369,10 +369,10 @@ def TypeB : TestType<"TestD"> {
// TYPE: return {};
// TYPE: }
// TYPE: return TestEType::get(odsParser.getContext(),
// TYPE: (*_result_v0),
// TYPE: (*_result_v1),
// TYPE: (*_result_v2),
// TYPE: (*_result_v3));
// TYPE: IntegerAttr((*_result_v0)),
// TYPE: IntegerAttr((*_result_v1)),
// TYPE: IntegerAttr((*_result_v2)),
// TYPE: IntegerAttr((*_result_v3)));
// TYPE: }
// TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {

View File

@ -31,9 +31,9 @@ include "mlir/IR/OpBase.td"
// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
// DEF-NEXT: value = ::test::IndexAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
// DEF: .Default([&](llvm::StringRef keyword,
// DEF: .Default([&](llvm::StringRef keyword,
// DEF-NEXT: *mnemonic = keyword;
// DEF-NEXT: return llvm::None;
// DEF-NEXT: return llvm::None;
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
@ -148,3 +148,13 @@ def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> {
// DEF: ParamWithAccessorTypeAttrStorage
// DEF: ParamWithAccessorTypeAttrStorage(std::string param)
// DEF: StringRef ParamWithAccessorTypeAttr::getParam()
def G_BuilderWithReturnTypeAttr : TestAttr<"BuilderWithReturnType"> {
let parameters = (ins "int":$a);
let genVerifyDecl = 1;
let builders = [AttrBuilder<(ins), [{ return {}; }], "::mlir::Attribute">];
}
// DECL-LABEL: class BuilderWithReturnTypeAttr
// DECL: ::mlir::Attribute get(
// DECL: ::mlir::Attribute getChecked(

View File

@ -13,3 +13,9 @@ func.func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [
// CHECK-LABEL: @qualifiedAttr()
// CHECK-SAME: #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>
func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual<i #test.cmpnd_nested_inner<42 <1, !test.smpla, [5, 6]>>>}
// CHECK-LABEL: @overriddenAttr
// CHECK-SAME: foo = 5 : index
func.func private @overriddenAttr() attributes {
foo = #test.override_builder<5>
}

View File

@ -348,7 +348,10 @@ getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
// Don't emit a body if there isn't one.
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
Method *m = defCls.addMethod(def.getCppClassName(), "get", props,
StringRef returnType = def.getCppClassName();
if (Optional<StringRef> builderReturnType = builder.getReturnType())
returnType = *builderReturnType;
Method *m = defCls.addMethod(returnType, "get", props,
getCustomBuilderParams({}, builder));
if (!builder.getBody())
return;
@ -373,8 +376,11 @@ static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
// Don't emit a body if there isn't one.
auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
StringRef returnType = def.getCppClassName();
if (Optional<StringRef> builderReturnType = builder.getReturnType())
returnType = *builderReturnType;
Method *m = defCls.addMethod(
def.getCppClassName(), "getChecked", props,
returnType, "getChecked", props,
getCustomBuilderParams(
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
builder));

View File

@ -311,7 +311,9 @@ void DefFormat::genParser(MethodBody &os) {
} else {
selfOs << formatv("(*_result_{0})", param.getName());
}
os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()));
os << param.getCppType() << "("
<< tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
<< ")";
}
os << ");";
}