forked from OSchip/llvm-project
Tablegen helpers for accessing properties of shaped types
Tablegen's lack of functions continues to be annoying PiperOrigin-RevId: 271680947
This commit is contained in:
parent
5f8dff936b
commit
e7c3ca92f8
|
@ -27,13 +27,32 @@
|
||||||
// Common utilities for defining TableGen mechanisms
|
// Common utilities for defining TableGen mechanisms
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Concatenates a list of strings with a separator (default ", ")
|
// A workaround for the inability to define functions in Tablegen.
|
||||||
class StrJoin<list<string> strings, string sep = ", "> {
|
//
|
||||||
string result =
|
// The template parameter defines a string that can be extracted from an
|
||||||
!if(!empty(strings), "",
|
// instance of this class by accessing the "result" member. Subclasses can take
|
||||||
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur));
|
// their own template parameters as function "arguments" and use them to
|
||||||
|
// populate result.
|
||||||
|
// For example, if it didn't already exist, a concat function could be defined
|
||||||
|
// like:
|
||||||
|
//
|
||||||
|
// class StrConcat<list<string> strings> :
|
||||||
|
// StrFunc<!foldl("", strings, prev, cur, prev # cur)>
|
||||||
|
//
|
||||||
|
// and then called like
|
||||||
|
//
|
||||||
|
// StrConcat<["a", "b", "c"]>.result
|
||||||
|
//
|
||||||
|
// to get the string "abc"
|
||||||
|
class StrFunc<string r> {
|
||||||
|
string result = r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Concatenates a list of strings with a separator (default ", ")
|
||||||
|
class StrJoin<list<string> strings, string sep = ", "> :
|
||||||
|
StrFunc<!if(!empty(strings), "",
|
||||||
|
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur))>;
|
||||||
|
|
||||||
// Concatenates a list of integers into a string with a separator (default ", ")
|
// Concatenates a list of integers into a string with a separator (default ", ")
|
||||||
class StrJoinInt<list<int> integers, string sep = ", "> :
|
class StrJoinInt<list<int> integers, string sep = ", "> :
|
||||||
StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
|
StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
|
||||||
|
@ -1437,6 +1456,14 @@ def HasNoUseOf: Constraint<
|
||||||
|
|
||||||
// TODO(b/135033717): Improve the autogenerated error messages.
|
// TODO(b/135033717): Improve the autogenerated error messages.
|
||||||
|
|
||||||
|
class Rank<string name> :
|
||||||
|
StrFunc<"$" # name # ".getType().cast<ShapedType>().getRank()">;
|
||||||
|
|
||||||
|
class ElementCount<string name> :
|
||||||
|
StrFunc<"$" # name # ".getType().cast<ShapedType>().getNumElements()">;
|
||||||
|
|
||||||
|
class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">;
|
||||||
|
|
||||||
class AllMatchPred<list<string> values> :
|
class AllMatchPred<list<string> values> :
|
||||||
CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
|
||||||
|
|
||||||
|
@ -1454,17 +1481,15 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
|
||||||
AllMatchSameOperatorPred<names, operator>>;
|
AllMatchSameOperatorPred<names, operator>>;
|
||||||
|
|
||||||
class AllElementCountsMatch<list<string> names> :
|
class AllElementCountsMatch<list<string> names> :
|
||||||
AllMatchSameOperatorTrait<
|
AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
|
||||||
names, "$_self.getType().cast<ShapedType>().getNumElements()",
|
|
||||||
"element count">;
|
"element count">;
|
||||||
|
|
||||||
class AllElementTypesMatch<list<string> names> :
|
class AllElementTypesMatch<list<string> names> :
|
||||||
AllMatchSameOperatorTrait<names,
|
AllMatchSameOperatorTrait<names, ElementType<"_self">.result,
|
||||||
"getElementTypeOrSelf($_self)", "element type">;
|
"element type">;
|
||||||
|
|
||||||
class AllRanksMatch<list<string> names> :
|
class AllRanksMatch<list<string> names> :
|
||||||
AllMatchSameOperatorTrait<
|
AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">;
|
||||||
names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
|
|
||||||
|
|
||||||
class AllTypesMatch<list<string> names> :
|
class AllTypesMatch<list<string> names> :
|
||||||
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
|
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
|
||||||
|
|
|
@ -293,9 +293,8 @@ def FourEqualsFive :
|
||||||
|
|
||||||
def OperandRankEqualsResultSize :
|
def OperandRankEqualsResultSize :
|
||||||
TEST_Op<"operand_rank_equals_result_size",
|
TEST_Op<"operand_rank_equals_result_size",
|
||||||
[AllMatch<["$operand.getType().cast<ShapedType>().getRank()",
|
[AllMatch<[Rank<"operand">.result, ElementCount<"result">.result],
|
||||||
"$result.getType().cast<ShapedType>().getNumElements()"
|
"operand rank equals result size">]> {
|
||||||
], "operand rank equals result size">]> {
|
|
||||||
let arguments = (ins AnyTensor:$operand);
|
let arguments = (ins AnyTensor:$operand);
|
||||||
let results = (outs AnyTensor:$result);
|
let results = (outs AnyTensor:$result);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue