Tablegen helpers for accessing properties of shaped types

Tablegen's lack of functions continues to be annoying

PiperOrigin-RevId: 271680947
This commit is contained in:
Geoffrey Martin-Noble 2019-09-27 17:34:56 -07:00 committed by A. Unique TensorFlower
parent 5f8dff936b
commit e7c3ca92f8
2 changed files with 39 additions and 15 deletions

View File

@ -27,13 +27,32 @@
// Common utilities for defining TableGen mechanisms // Common utilities for defining TableGen mechanisms
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Concatenates a list of strings with a separator (default ", ") // A workaround for the inability to define functions in Tablegen.
class StrJoin<list<string> strings, string sep = ", "> { //
string result = // The template parameter defines a string that can be extracted from an
!if(!empty(strings), "", // instance of this class by accessing the "result" member. Subclasses can take
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur)); // their own template parameters as function "arguments" and use them to
// populate result.
// For example, if it didn't already exist, a concat function could be defined
// like:
//
// class StrConcat<list<string> strings> :
// StrFunc<!foldl("", strings, prev, cur, prev # cur)>
//
// and then called like
//
// StrConcat<["a", "b", "c"]>.result
//
// to get the string "abc"
class StrFunc<string r> {
string result = r;
} }
// Concatenates a list of strings with a separator (default ", ")
class StrJoin<list<string> strings, string sep = ", "> :
StrFunc<!if(!empty(strings), "",
!foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur))>;
// Concatenates a list of integers into a string with a separator (default ", ") // Concatenates a list of integers into a string with a separator (default ", ")
class StrJoinInt<list<int> integers, string sep = ", "> : class StrJoinInt<list<int> integers, string sep = ", "> :
StrJoin<!foreach(i, integers, !cast<string>(i)), sep>; StrJoin<!foreach(i, integers, !cast<string>(i)), sep>;
@ -1437,6 +1456,14 @@ def HasNoUseOf: Constraint<
// TODO(b/135033717): Improve the autogenerated error messages. // TODO(b/135033717): Improve the autogenerated error messages.
class Rank<string name> :
StrFunc<"$" # name # ".getType().cast<ShapedType>().getRank()">;
class ElementCount<string name> :
StrFunc<"$" # name # ".getType().cast<ShapedType>().getNumElements()">;
class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">;
class AllMatchPred<list<string> values> : class AllMatchPred<list<string> values> :
CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">; CPred<"llvm::is_splat(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
@ -1454,17 +1481,15 @@ class AllMatchSameOperatorTrait<list<string> names, string operator,
AllMatchSameOperatorPred<names, operator>>; AllMatchSameOperatorPred<names, operator>>;
class AllElementCountsMatch<list<string> names> : class AllElementCountsMatch<list<string> names> :
AllMatchSameOperatorTrait< AllMatchSameOperatorTrait<names, ElementCount<"_self">.result,
names, "$_self.getType().cast<ShapedType>().getNumElements()",
"element count">; "element count">;
class AllElementTypesMatch<list<string> names> : class AllElementTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, AllMatchSameOperatorTrait<names, ElementType<"_self">.result,
"getElementTypeOrSelf($_self)", "element type">; "element type">;
class AllRanksMatch<list<string> names> : class AllRanksMatch<list<string> names> :
AllMatchSameOperatorTrait< AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">;
names, "$_self.getType().cast<ShapedType>().getRank()", "rank">;
class AllTypesMatch<list<string> names> : class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">; AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;

View File

@ -293,9 +293,8 @@ def FourEqualsFive :
def OperandRankEqualsResultSize : def OperandRankEqualsResultSize :
TEST_Op<"operand_rank_equals_result_size", TEST_Op<"operand_rank_equals_result_size",
[AllMatch<["$operand.getType().cast<ShapedType>().getRank()", [AllMatch<[Rank<"operand">.result, ElementCount<"result">.result],
"$result.getType().cast<ShapedType>().getNumElements()" "operand rank equals result size">]> {
], "operand rank equals result size">]> {
let arguments = (ins AnyTensor:$operand); let arguments = (ins AnyTensor:$operand);
let results = (outs AnyTensor:$result); let results = (outs AnyTensor:$result);
} }