forked from OSchip/llvm-project
[mlir][Shape] Canonicalize assume_all with one input and tensor_cast of constant_shape
This allows simplifying some more complicated shape expressions Differential Revision: https://reviews.llvm.org/D92843
This commit is contained in:
parent
febe75032f
commit
5844bc540c
|
@ -105,6 +105,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
|
|||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Shape_ConstSizeOp : Shape_Op<"const_size", [
|
||||
|
@ -630,6 +631,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]>
|
|||
let assemblyFormat = "$inputs attr-dict";
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
|
|
@ -271,6 +271,12 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
|
|||
//===----------------------------------------------------------------------===//
|
||||
// AssumingAllOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AssumingAllOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<AssumingAllOneOp>(context);
|
||||
}
|
||||
|
||||
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Iterate in reverse to first handle all constant operands. They are
|
||||
// guaranteed to be the tail of the inputs because this is commutative.
|
||||
|
@ -394,6 +400,11 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
|
|||
|
||||
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
|
||||
|
||||
void ConstShapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<TensorCastConstShape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CstrBroadcastableOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
|
||||
def AllInputShapesEq : Constraint<CPred< [{
|
||||
llvm::all_of($0, [&](mlir::Value val) {
|
||||
|
@ -6,8 +7,16 @@ def AllInputShapesEq : Constraint<CPred< [{
|
|||
})
|
||||
}]>>;
|
||||
|
||||
def HasSingleElement : Constraint<CPred< [{
|
||||
$0.size() == 1
|
||||
}]>>;
|
||||
|
||||
// Canonicalization patterns.
|
||||
|
||||
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
|
||||
(replaceWithValue $args),
|
||||
[(HasSingleElement $args)]>;
|
||||
|
||||
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
|
||||
(Shape_ConstWitnessOp ConstBoolAttrTrue)>;
|
||||
|
||||
|
@ -23,3 +32,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
|
|||
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
def TensorCastConstShape : Pat <
|
||||
(TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;
|
||||
|
|
|
@ -427,20 +427,23 @@ func @f() {
|
|||
|
||||
// -----
|
||||
|
||||
// assuming_all should not be removed if not all witnesses are statically passing.
|
||||
// assuming_all should not be removed if more than one witness is not
|
||||
// statically passing
|
||||
//
|
||||
// Additionally check that the attribute is moved to the end as this op is
|
||||
// commutative.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source"
|
||||
// CHECK-NEXT: shape.assuming_all %[[UNKNOWN]]
|
||||
// CHECK-NEXT: %[[UNKNOWN1:.*]] = "test.source"
|
||||
// CHECK-NEXT: %[[UNKNOWN2:.*]] = "test.source"
|
||||
// CHECK-NEXT: shape.assuming_all %[[UNKNOWN1]], %[[UNKNOWN2]]
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%0 = shape.const_witness true
|
||||
%1 = "test.source"() : () -> !shape.witness
|
||||
%2 = shape.assuming_all %0, %1
|
||||
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||
%2 = "test.source"() : () -> !shape.witness
|
||||
%3 = shape.assuming_all %0, %1, %2
|
||||
"consume.witness"(%3) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -854,3 +857,28 @@ func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex>
|
|||
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
|
||||
return %casted : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold assuming_all with a single input
|
||||
// CHECK-LABEL: @fold_assuming_all_single_element
|
||||
func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
|
||||
// CHECK-NOT: assuming_all
|
||||
%0 = "test.source"() : () -> (!shape.witness)
|
||||
%1 = shape.assuming_all %0
|
||||
"consume.witness"(%1) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold tensor_cast of a const_shape to const_shape
|
||||
// CHECK-LABEL: @fold_tensor_cast_of_const_shape
|
||||
func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) {
|
||||
// CHECK-NOT: tensor_cast
|
||||
%0 = shape.const_shape [2] : tensor<?xindex>
|
||||
%1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
|
||||
%2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex>
|
||||
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue