forked from OSchip/llvm-project
TableGen: untie Attr from Type
In TableGen definitions, the "Type" class has been used for types of things that can be stored in Attributes, but not necessarily present in the MLIR type system. As a consequence, records like "String" or "DerviedAttrBody" were of class "Type", which can be confusing. Furthermore, the "builderCall" field of the "Type" class serves only for attribute construction. Some TableGen "Type" subclasses that correspond to MLIR kinds of types do not have a canonical way of construction only from the data available in TableGen, e.g. MemRefType would require the list of affine maps. This leads to a conclusion that the entities that describe types of objects appearing in Attributes should be independent of "Type": they have some properties "Type"s don't and vice versa. Do not parameterize Tablegen "Attr" class by an instance of "Type". Instead, provide a "constBuilderCall" field that can be used to build an attribute from a constant value stored in TableGen instead of indirectly going through Attribute.Type.builderCall. Some attributes still don't have a "constBuilderCall" because they used to depend on types without a "builderCall". Drop definitions of class "Type" that don't correspond to MLIR Types. Provide infrastructure to define type-dependent attributes and string-backed attributes for convenience. PiperOrigin-RevId: 229570087
This commit is contained in:
parent
590012772d
commit
bd161ae5bc
|
@ -84,12 +84,20 @@ class TypeConstraint<CPred condition, string descr = ""> {
|
||||||
string description = descr;
|
string description = descr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A specific type that can be constructed. Also carries type constraints, but
|
// A type, carries type constraints, but accepts any type by default.
|
||||||
// accepts any type by default.
|
class Type<CPred condition = CPred<"true">, string descr = "">
|
||||||
class Type<CPred condition = CPred<"true">, string descr = ""> : TypeConstraint<condition, descr> {
|
: TypeConstraint<condition, descr>;
|
||||||
// The builder call to invoke (if specified) to construct the Type.
|
|
||||||
|
// A type that can be constructed using MLIR::Builder.
|
||||||
|
// Note that this does not "inherit" from Type because it would require
|
||||||
|
// duplicating Type subclasses for buildable and non-buildable cases to avoid
|
||||||
|
// diamond "inheritance".
|
||||||
|
// TODO(zinenko): we may extend this to a more general 'Buildable' trait,
|
||||||
|
// making some Types and some Attrs buildable.
|
||||||
|
class BuildableType<code builder> {
|
||||||
|
// The builder call to invoke (if specified) to construct the BuildableType.
|
||||||
// Format: this will be affixed to the builder.
|
// Format: this will be affixed to the builder.
|
||||||
code builderCall = ?;
|
code builderCall = builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Integer types.
|
// Integer types.
|
||||||
|
@ -103,9 +111,9 @@ def Index : IntegerBase<CPred<"{0}.isa<IndexType>()">, "index">;
|
||||||
|
|
||||||
// Integer type of a specific width.
|
// Integer type of a specific width.
|
||||||
class I<int width>
|
class I<int width>
|
||||||
: IntegerBase<CPred<"{0}.isInteger(" # width # ")">, "i" # width> {
|
: IntegerBase<CPred<"{0}.isInteger(" # width # ")">, "i" # width>,
|
||||||
|
BuildableType<"getIntegerType(" # width # ")"> {
|
||||||
int bitwidth = width;
|
int bitwidth = width;
|
||||||
let builderCall = "getIntegerType(" # bitwidth # ")";
|
|
||||||
}
|
}
|
||||||
def I1 : I<1>;
|
def I1 : I<1>;
|
||||||
def I32 : I<32>;
|
def I32 : I<32>;
|
||||||
|
@ -118,9 +126,9 @@ def Float : FloatBase<CPred<"{0}.isa<FloatType>()">, "floating point">;
|
||||||
|
|
||||||
// Float type of a specific width.
|
// Float type of a specific width.
|
||||||
class F<int width>
|
class F<int width>
|
||||||
: FloatBase<CPred<"{0}.isF" # width # "()">, "f" # width> {
|
: FloatBase<CPred<"{0}.isF" # width # "()">, "f" # width>,
|
||||||
|
BuildableType<"getF" # width # "Type()"> {
|
||||||
int bitwidth = width;
|
int bitwidth = width;
|
||||||
let builderCall = "getF" # width # "Type()";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def F32 : F<32>;
|
def F32 : F<32>;
|
||||||
|
@ -180,12 +188,6 @@ class TypedTensor<Type t>
|
||||||
|
|
||||||
def F32Tensor : TypedTensor<F32>;
|
def F32Tensor : TypedTensor<F32>;
|
||||||
|
|
||||||
// String type.
|
|
||||||
def String : Type;
|
|
||||||
|
|
||||||
// Type corresponding to derived attribute.
|
|
||||||
def DerivedAttrBody : Type;
|
|
||||||
|
|
||||||
// Type constraint for integer-like types: integers, indices, vectors of
|
// Type constraint for integer-like types: integers, indices, vectors of
|
||||||
// integers, tensors of integers.
|
// integers, tensors of integers.
|
||||||
def IntegerLike : TypeConstraint<AnyOf<[Integer.predicate, Index.predicate,
|
def IntegerLike : TypeConstraint<AnyOf<[Integer.predicate, Index.predicate,
|
||||||
|
@ -202,9 +204,7 @@ def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Base class for all attributes.
|
// Base class for all attributes.
|
||||||
class Attr<Type t> {
|
class Attr {
|
||||||
Type type = t;
|
|
||||||
|
|
||||||
code storageType = ?; // The backing mlir::Attribute type
|
code storageType = ?; // The backing mlir::Attribute type
|
||||||
code returnType = ?; // The underlying C++ value type
|
code returnType = ?; // The underlying C++ value type
|
||||||
|
|
||||||
|
@ -216,36 +216,65 @@ class Attr<Type t> {
|
||||||
// '{0}.getValue().convertToFloat()' for 'FloatAttr val' will expand to
|
// '{0}.getValue().convertToFloat()' for 'FloatAttr val' will expand to
|
||||||
// 'getAttrOfType<FloatAttr>("val").getValue().convertToFloat()'.
|
// 'getAttrOfType<FloatAttr>("val").getValue().convertToFloat()'.
|
||||||
code convertFromStorage = "{0}.getValue()";
|
code convertFromStorage = "{0}.getValue()";
|
||||||
|
|
||||||
|
// The call expression that builds an attribute from a constant value.
|
||||||
|
//
|
||||||
|
// Format: {0} will be expanded to an instance of mlir::Builder, {1} will be
|
||||||
|
// expanded to the constant value of the attribute. For example,
|
||||||
|
// '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to
|
||||||
|
// 'builder.getStringAttr("foo")'.
|
||||||
|
code constBuilderCall = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
def BoolAttr : Attr<I1> {
|
// A generic attribute that must be constructed around a specific type.
|
||||||
|
// Backed by a C++ class "attrName".
|
||||||
|
class TypeBasedAttr<BuildableType t, string attrName> : Attr {
|
||||||
|
let constBuilderCall =
|
||||||
|
"{0}.get" # attrName # "({0}." # t.builderCall # ", {1})";
|
||||||
|
let storageType = attrName;
|
||||||
|
}
|
||||||
|
|
||||||
|
// An attribute backed by a string type.
|
||||||
|
class StringBasedAttr : Attr {
|
||||||
|
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
|
||||||
|
let storageType = [{ StringAttr }];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base class for instantiating float attributes of fixed width.
|
||||||
|
class FloatAttrBase<BuildableType t> : TypeBasedAttr<t, "FloatAttr">;
|
||||||
|
|
||||||
|
// Base class for instantiating integer attributes of fixed width.
|
||||||
|
class IntegerAttrBase<BuildableType t> : TypeBasedAttr<t, "IntegerAttr">;
|
||||||
|
|
||||||
|
def BoolAttr : Attr {
|
||||||
let storageType = [{ BoolAttr }];
|
let storageType = [{ BoolAttr }];
|
||||||
let returnType = [{ bool }];
|
let returnType = [{ bool }];
|
||||||
|
let constBuilderCall = [{ {0}.getBoolAttr({1})" }];
|
||||||
}
|
}
|
||||||
def ElementsAttr : Attr<?> {
|
def ElementsAttr : Attr {
|
||||||
let storageType = [{ ElementsAttr }];
|
let storageType = [{ ElementsAttr }];
|
||||||
let returnType = [{ ElementsAttr }];
|
let returnType = [{ ElementsAttr }];
|
||||||
code convertFromStorage = "{0}";
|
let convertFromStorage = "{0}";
|
||||||
}
|
}
|
||||||
def F32Attr : Attr<F32> {
|
def F32Attr : FloatAttrBase<F32> {
|
||||||
let storageType = [{ FloatAttr }];
|
|
||||||
let returnType = [{ float }];
|
let returnType = [{ float }];
|
||||||
let convertFromStorage = [{ {0}.getValue().convertToFloat() }];
|
let convertFromStorage = [{ {0}.getValue().convertToFloat() }];
|
||||||
}
|
}
|
||||||
def I32Attr : Attr<I32> {
|
def I32Attr : IntegerAttrBase<I32> {
|
||||||
let storageType = [{ IntegerAttr }];
|
let storageType = [{ IntegerAttr }];
|
||||||
let returnType = [{ int }];
|
let returnType = [{ int }];
|
||||||
let convertFromStorage = [{ {0}.getValue().getSExtValue() }];
|
let convertFromStorage = [{ {0}.getValue().getSExtValue() }];
|
||||||
}
|
}
|
||||||
def StrAttr : Attr<String> {
|
def StrAttr : StringBasedAttr {
|
||||||
let storageType = [{ StringAttr }];
|
let storageType = [{ StringAttr }];
|
||||||
let returnType = [{ StringRef }];
|
let returnType = [{ StringRef }];
|
||||||
|
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
|
||||||
}
|
}
|
||||||
|
|
||||||
// DerivedAttr are attributes whose value is computed from properties
|
// DerivedAttr are attributes whose value is computed from properties
|
||||||
// of the operation. They do not require additional storage and are
|
// of the operation. They do not require additional storage and are
|
||||||
// materialized as needed.
|
// materialized as needed.
|
||||||
class DerivedAttr<code ReturnType, code Body> : Attr<DerivedAttrBody> {
|
class DerivedAttr<code ReturnType, code Body> : Attr {
|
||||||
let returnType = ReturnType;
|
let returnType = ReturnType;
|
||||||
code body = Body;
|
code body = Body;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,6 @@
|
||||||
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||||
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/TableGen/Type.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
|
@ -48,9 +47,6 @@ public:
|
||||||
// of `DrivedAttr`).
|
// of `DrivedAttr`).
|
||||||
bool isDerivedAttr() const;
|
bool isDerivedAttr() const;
|
||||||
|
|
||||||
// Returns the type of this attribute.
|
|
||||||
Type getType() const;
|
|
||||||
|
|
||||||
// Returns true if this attribute has storage type set.
|
// Returns true if this attribute has storage type set.
|
||||||
bool hasStorageType() const;
|
bool hasStorageType() const;
|
||||||
|
|
||||||
|
@ -66,6 +62,17 @@ public:
|
||||||
// The call will contain a `{0}` which will be expanded to this attribute.
|
// The call will contain a `{0}` which will be expanded to this attribute.
|
||||||
StringRef getConvertFromStorageCall() const;
|
StringRef getConvertFromStorageCall() const;
|
||||||
|
|
||||||
|
// Returns true if this attribute can be built from a constant value.
|
||||||
|
bool isConstBuildable() const;
|
||||||
|
|
||||||
|
// Returns the template that can be used to produce an instance of the
|
||||||
|
// attribute.
|
||||||
|
// Syntax: {0} should be replaced with a builder, {1} should be replaced with
|
||||||
|
// the constant value.
|
||||||
|
StringRef getConstBuilderTemplate() const;
|
||||||
|
|
||||||
|
StringRef getTableGenDefName() const;
|
||||||
|
|
||||||
// Returns the code body for derived attribute. Aborts if this is not a
|
// Returns the code body for derived attribute. Aborts if this is not a
|
||||||
// derived attribute.
|
// derived attribute.
|
||||||
StringRef getDerivedCodeBody() const;
|
StringRef getDerivedCodeBody() const;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/TableGen/Attribute.h"
|
#include "mlir/TableGen/Attribute.h"
|
||||||
|
#include "mlir/TableGen/Type.h"
|
||||||
#include "llvm/ADT/PointerUnion.h"
|
#include "llvm/ADT/PointerUnion.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
|
@ -70,11 +70,6 @@ public:
|
||||||
|
|
||||||
// Returns the TableGen def name for this type.
|
// Returns the TableGen def name for this type.
|
||||||
StringRef getTableGenDefName() const;
|
StringRef getTableGenDefName() const;
|
||||||
|
|
||||||
// Returns the method call to invoke upon a MLIR pattern rewriter to
|
|
||||||
// construct this type. Returns an empty StringRef if the method call
|
|
||||||
// is undefined or unset.
|
|
||||||
StringRef getBuilderCall() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace tblgen
|
} // end namespace tblgen
|
||||||
|
|
|
@ -47,10 +47,6 @@ bool tblgen::Attribute::isDerivedAttr() const {
|
||||||
return def.isSubClassOf("DerivedAttr");
|
return def.isSubClassOf("DerivedAttr");
|
||||||
}
|
}
|
||||||
|
|
||||||
tblgen::Type tblgen::Attribute::getType() const {
|
|
||||||
return Type(def.getValueAsDef("type"));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool tblgen::Attribute::hasStorageType() const {
|
bool tblgen::Attribute::hasStorageType() const {
|
||||||
const auto *init = def.getValueInit("storageType");
|
const auto *init = def.getValueInit("storageType");
|
||||||
return !getValueAsString(init).empty();
|
return !getValueAsString(init).empty();
|
||||||
|
@ -74,6 +70,20 @@ StringRef tblgen::Attribute::getConvertFromStorageCall() const {
|
||||||
return getValueAsString(init);
|
return getValueAsString(init);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool tblgen::Attribute::isConstBuildable() const {
|
||||||
|
const auto *init = def.getValueInit("constBuilderCall");
|
||||||
|
return !getValueAsString(init).empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef tblgen::Attribute::getConstBuilderTemplate() const {
|
||||||
|
const auto *init = def.getValueInit("constBuilderCall");
|
||||||
|
return getValueAsString(init);
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef tblgen::Attribute::getTableGenDefName() const {
|
||||||
|
return def.getName();
|
||||||
|
}
|
||||||
|
|
||||||
StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
||||||
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
||||||
return def.getValueAsString("body");
|
return def.getValueAsString("body");
|
||||||
|
|
|
@ -62,13 +62,3 @@ tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) {
|
||||||
tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {}
|
tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {}
|
||||||
|
|
||||||
StringRef tblgen::Type::getTableGenDefName() const { return def.getName(); }
|
StringRef tblgen::Type::getTableGenDefName() const { return def.getName(); }
|
||||||
|
|
||||||
StringRef tblgen::Type::getBuilderCall() const {
|
|
||||||
const auto *val = def.getValue("builderCall");
|
|
||||||
assert(val && "TableGen 'Type' class should have 'builderCall' field");
|
|
||||||
|
|
||||||
if (const auto *builder = dyn_cast<llvm::CodeInit>(val->getValue()))
|
|
||||||
return builder->getValue();
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,8 @@
|
||||||
include "mlir/IR/op_base.td"
|
include "mlir/IR/op_base.td"
|
||||||
|
|
||||||
// Create a Type and Attribute.
|
// Create a Type and Attribute.
|
||||||
def YT : Type {
|
def YT : BuildableType<"buildYT">;
|
||||||
let builderCall = "buildYT()";
|
def Y_Attr : TypeBasedAttr<YT, "Attribute">;
|
||||||
}
|
|
||||||
def Y_Attr : Attr<YT>;
|
|
||||||
def Y_Const_Attr {
|
def Y_Const_Attr {
|
||||||
Attr attr = Y_Attr;
|
Attr attr = Y_Attr;
|
||||||
string value = "attrValue";
|
string value = "attrValue";
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/TableGen/Attribute.h"
|
||||||
#include "mlir/TableGen/GenInfo.h"
|
#include "mlir/TableGen/GenInfo.h"
|
||||||
#include "mlir/TableGen/Operator.h"
|
#include "mlir/TableGen/Operator.h"
|
||||||
#include "mlir/TableGen/Predicate.h"
|
#include "mlir/TableGen/Predicate.h"
|
||||||
|
@ -93,24 +94,14 @@ private:
|
||||||
void Pattern::emitAttributeValue(Record *constAttr) {
|
void Pattern::emitAttributeValue(Record *constAttr) {
|
||||||
Attribute attr(constAttr->getValueAsDef("attr"));
|
Attribute attr(constAttr->getValueAsDef("attr"));
|
||||||
auto value = constAttr->getValue("value");
|
auto value = constAttr->getValue("value");
|
||||||
Type type = attr.getType();
|
|
||||||
auto storageType = attr.getStorageType();
|
|
||||||
|
|
||||||
// For attributes stored as strings we do not need to query builder etc.
|
if (!attr.isConstBuildable())
|
||||||
if (storageType == "StringAttr") {
|
|
||||||
os << formatv("rewriter.getStringAttr({0})",
|
|
||||||
value->getValue()->getAsString());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto builder = type.getBuilderCall();
|
|
||||||
if (builder.empty())
|
|
||||||
PrintFatalError(pattern->getLoc(),
|
PrintFatalError(pattern->getLoc(),
|
||||||
"no builder specified for " + type.getTableGenDefName());
|
"Attribute " + attr.getTableGenDefName() +
|
||||||
|
" does not have the 'constBuilderCall' field");
|
||||||
|
|
||||||
// Construct the attribute based on storage type and builder.
|
|
||||||
// TODO(jpienaar): Verify the constants here
|
// TODO(jpienaar): Verify the constants here
|
||||||
os << formatv("{0}::get(rewriter.{1}, {2})", storageType, builder,
|
os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||||
value->getValue()->getAsUnquotedString());
|
value->getValue()->getAsUnquotedString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue