forked from OSchip/llvm-project
[mlir][ods] Use lambda in element type check pred rather than repeated casts
Avoids multiple cast & getElementType calls. Just a local change for ShapedType containers but reduces one model case from 24.7 to 24.04s. Resultant code generated change: https://gist.github.com/jpienaar/7ffd2e9b0737134ba2ea2729b91c9572 Differential Revision: https://reviews.llvm.org/D113621
This commit is contained in:
parent
4a0c225616
commit
32b327e4ed
|
@ -566,20 +566,17 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
|||
Type<And<[containerPred,
|
||||
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
|
||||
etype.predicate>]>,
|
||||
descr # " of " # etype.summary # " values", cppClassName> {
|
||||
// The type of elements in the container.
|
||||
Type elementType = etype;
|
||||
|
||||
// Call to retrieve.
|
||||
code getElementTypeCall = elementTypeCall;
|
||||
}
|
||||
descr # " of " # etype.summary # " values", cppClassName>;
|
||||
|
||||
class ShapedContainerType<list<Type> allowedTypes,
|
||||
Pred containerPred, string descr,
|
||||
string cppClassName = "::mlir::Type"> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
|
||||
"$_self.cast<::mlir::ShapedType>().getElementType()", descr,
|
||||
cppClassName>;
|
||||
Type<And<[containerPred,
|
||||
Concat<"[](::mlir::Type elementType) { return ",
|
||||
SubstLeaves<"$_self", "elementType",
|
||||
AnyTypeOf<allowedTypes>.predicate>,
|
||||
"; }($_self.cast<::mlir::ShapedType>().getElementType())">]>,
|
||||
descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
|
||||
|
||||
// Whether a shaped type is ranked.
|
||||
def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;
|
||||
|
|
|
@ -25,11 +25,11 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
|
|||
// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
|
||||
|
||||
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
|
||||
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
|
||||
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) {
|
||||
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
|
||||
|
||||
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
|
||||
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
|
||||
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) {
|
||||
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
|
||||
|
||||
// CHECK-LABEL: OpA::verify
|
||||
|
|
Loading…
Reference in New Issue