[mlir][Shape] Canonicalize assume_all with one input and tensor_cast of constant_shape

This allows simplifying some more complicated shape expressions

Differential Revision: https://reviews.llvm.org/D92843
This commit is contained in:
Benjamin Kramer 2020-12-08 15:37:32 +01:00
parent febe75032f
commit 5844bc540c
4 changed files with 57 additions and 5 deletions

View File

@ -105,6 +105,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Shape_ConstSizeOp : Shape_Op<"const_size", [
@ -630,6 +631,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]>
let assemblyFormat = "$inputs attr-dict";
let hasFolder = 1;
let hasCanonicalizer = 1;
let verifier = [{ return ::verify(*this); }];
}

View File

@ -271,6 +271,12 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
void AssumingAllOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<AssumingAllOneOp>(context);
}
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
// Iterate in reverse to first handle all constant operands. They are
// guaranteed to be the tail of the inputs because this is commutative.
@ -394,6 +400,11 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
void ConstShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<TensorCastConstShape>(context);
}
//===----------------------------------------------------------------------===//
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,5 @@
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
@ -6,8 +7,16 @@ def AllInputShapesEq : Constraint<CPred< [{
})
}]>>;
def HasSingleElement : Constraint<CPred< [{
$0.size() == 1
}]>>;
// Canonicalization patterns.
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
(replaceWithValue $args),
[(HasSingleElement $args)]>;
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
(Shape_ConstWitnessOp ConstBoolAttrTrue)>;
@ -23,3 +32,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
def TensorCastConstShape : Pat <
(TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;

View File

@ -427,20 +427,23 @@ func @f() {
// -----
// assuming_all should not be removed if not all witnesses are statically passing.
// assuming_all should not be removed if more than one witness is not
// statically passing
//
// Additionally check that the attribute is moved to the end as this op is
// commutative.
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source"
// CHECK-NEXT: shape.assuming_all %[[UNKNOWN]]
// CHECK-NEXT: %[[UNKNOWN1:.*]] = "test.source"
// CHECK-NEXT: %[[UNKNOWN2:.*]] = "test.source"
// CHECK-NEXT: shape.assuming_all %[[UNKNOWN1]], %[[UNKNOWN2]]
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.const_witness true
%1 = "test.source"() : () -> !shape.witness
%2 = shape.assuming_all %0, %1
"consume.witness"(%2) : (!shape.witness) -> ()
%2 = "test.source"() : () -> !shape.witness
%3 = shape.assuming_all %0, %1, %2
"consume.witness"(%3) : (!shape.witness) -> ()
return
}
@ -854,3 +857,28 @@ func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex>
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
return %casted : tensor<?xindex>
}
// -----
// Fold assuming_all with a single input
// CHECK-LABEL: @fold_assuming_all_single_element
func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
// CHECK-NOT: assuming_all
%0 = "test.source"() : () -> (!shape.witness)
%1 = shape.assuming_all %0
"consume.witness"(%1) : (!shape.witness) -> ()
return
}
// -----
// Fold tensor_cast of a const_shape to const_shape
// CHECK-LABEL: @fold_tensor_cast_of_const_shape
func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) {
// CHECK-NOT: tensor_cast
%0 = shape.const_shape [2] : tensor<?xindex>
%1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
%2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex>
"consume.witness"(%2) : (!shape.witness) -> ()
return
}