forked from OSchip/llvm-project
[mlir] Canonicalization and folding of shape.cstr_broadcastable
This allows replacing of this op with a true witness in the case of both inputs being const_shapes and being found to be broadcastable. Differential Revision: https://reviews.llvm.org/D80304
This commit is contained in:
parent
4a255bbd29
commit
6aab709459
|
@ -531,7 +531,7 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
|
|||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
}
|
||||
|
||||
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
|
||||
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
|
||||
let summary = "Determines if 2 shapes can be successfully broadcasted";
|
||||
let description = [{
|
||||
Given 2 input shapes, return a witness specifying if they are broadcastable.
|
||||
|
@ -550,6 +550,9 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
|
|||
let results = (outs Shape_WitnessType:$result);
|
||||
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict";
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
set(LLVM_TARGET_DEFINITIONS IR/ShapeCanonicalization.td)
|
||||
mlir_tablegen(IR/ShapeCanonicalization.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRShapeCanonicalizationIncGen)
|
||||
|
||||
add_mlir_dialect_library(MLIRShape
|
||||
IR/Shape.cpp
|
||||
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace {
|
||||
#include "IR/ShapeCanonicalization.inc"
|
||||
}
|
||||
|
||||
ShapeDialect::ShapeDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
|
@ -260,6 +264,32 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
|
|||
|
||||
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CstrBroadcastableOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void CstrBroadcastableOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
// If inputs are equal, return passing witness
|
||||
patterns.insert<CstrBroadcastableEqOps>(context);
|
||||
}
|
||||
|
||||
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1])
|
||||
return nullptr;
|
||||
auto lhsShape = llvm::to_vector<6>(
|
||||
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
auto rhsShape = llvm::to_vector<6>(
|
||||
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
SmallVector<int64_t, 6> resultShape;
|
||||
if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
|
||||
return BoolAttr::get(true, getContext());
|
||||
|
||||
// Because a failing witness result here represents an eventual assertion
|
||||
// failure, we do not replace it with a constant witness.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstSizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
include "mlir/Dialect/Shape/IR/ShapeOps.td"
|
||||
|
||||
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
|
||||
|
||||
// Canonicalization patterns.
|
||||
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $lhs, $rhs),
|
||||
(Shape_ConstWitnessOp ConstBoolAttrTrue),
|
||||
[(EqualBinaryOperands $lhs, $rhs)]>;
|
|
@ -267,3 +267,59 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
|
|||
%1 = shape.any %arg0, %arg1
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
||||
// -----
|
||||
// Broadcastable with broadcastable constant shapes can be removed.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [3, 1]
|
||||
%cs1 = shape.const_shape [1, 5]
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Broadcastable with non-broadcastable constant shapes is always false
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK-NEXT: shape.const_shape
|
||||
// CHECK-NEXT: shape.const_shape
|
||||
// CHECK-NEXT: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [1, 3]
|
||||
%cs1 = shape.const_shape [1, 5]
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Broadcastable without guaranteed broadcastable shapes cannot be removed.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) {
|
||||
// CHECK-NEXT: shape.const_shape
|
||||
// CHECK-NEXT: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [1,3]
|
||||
%0 = shape.cstr_broadcastable %arg0, %cs0
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Broadcastable with non-constant but known equal shapes can be removed.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) {
|
||||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%0 = shape.cstr_broadcastable %arg0, %arg0
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue