diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h index 2b4fb6d77855..aecceaac0a42 100644 --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -47,6 +47,21 @@ namespace util { bool getBroadcastedShape(ArrayRef shape1, ArrayRef shape2, SmallVectorImpl &resultShape); +/// Returns true if a broadcast between the 2 shapes is guaranteed to be +/// successful and not result in an error. False does not guarantee that the +/// shapes are not broadcastable; it might guarantee that they are not +/// broadcastable or it might mean that this function does not have enough +/// information to know. +/// +/// Conceptually, this returns true if getBroadcastedShape would have returned +/// true and vice versa, with one exception. If a dimension is unknown in both +/// shapes, getBroadcastedShape would return true and have a result with unknown +/// dimension, while this function will return false because it's possible for +/// both shapes to have a dimension greater than 1 and different which would +/// fail to broadcast. +bool staticallyKnownBroadcastable(ArrayRef shape1, + ArrayRef shape2); + /// Returns the result broadcast composition type from the two given types by /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index e251a3887cd4..0a0608bbcda4 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -317,21 +317,101 @@ OpFoldResult ConstShapeOp::fold(ArrayRef) { return shapeAttr(); } // CstrBroadcastableOp //===----------------------------------------------------------------------===// +namespace { +// Given an input shape Value, try to obtain the shape's values. +LogicalResult getShapeVec(Value input, SmallVectorImpl &shapeValues) { + if (auto inputOp = input.getDefiningOp()) { + auto type = inputOp.arg().getType().dyn_cast(); + if (!type.hasRank()) + return failure(); + shapeValues = llvm::to_vector<6>(type.getShape()); + return success(); + } else if (auto inputOp = input.getDefiningOp()) { + shapeValues = llvm::to_vector<6>(inputOp.shape().getValues()); + return success(); + } else { + return failure(); + } +} + +// For shapes that were created by some operations, we can obtain partial +// information on the shapes and sometimes determine if they will be +// broadcastable with that. +struct CstrBroadcastablePartialInfo + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + SmallVector lhsShape, rhsShape; + if (failed(getShapeVec(op.lhs(), lhsShape))) + return failure(); + if (failed(getShapeVec(op.rhs(), rhsShape))) + return failure(); + if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) + return failure(); + + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +// Scalars are always broadcastable. +struct CstrBroadcastableScalar : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + SmallVector shape; + if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0) + return failure(); + if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0) + return failure(); + + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +} // namespace + void CstrBroadcastableOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - // If inputs are equal, return passing witness - patterns.insert(context); + // Canonicalization patterns have overlap with the considerations during + // folding in case additional shape information is inferred at some point that + // does not result in folding. + patterns.insert(context); } OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { - if (!operands[0] || !operands[1]) + // Both operands are not needed if one is a scalar. + if (operands[0] && + operands[0].cast().getNumElements() == 0) + return BoolAttr::get(true, getContext()); + if (operands[1] && + operands[1].cast().getNumElements() == 0) + return BoolAttr::get(true, getContext()); + + if (operands[0] && operands[1]) { + auto lhsShape = llvm::to_vector<6>( + operands[0].cast().getValues()); + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); + SmallVector resultShape; + if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) + return BoolAttr::get(true, getContext()); + } + + // Lastly, see if folding can be completed based on what constraints are known + // on the input shapes. + SmallVector lhsShape, rhsShape; + if (failed(getShapeVec(lhs(), lhsShape))) return nullptr; - auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - SmallVector resultShape; - if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) + if (failed(getShapeVec(rhs(), rhsShape))) + return nullptr; + + if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) return BoolAttr::get(true, getContext()); // Because a failing witness result here represents an eventual assertion diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index c974e2fc097b..2a557c489e0b 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -13,6 +13,23 @@ using namespace mlir; +bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef shape1, + ArrayRef shape2) { + // Two dimensions are compatible when + // 1. they are defined and equal, or + // 2. one of them is 1 + return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)), + [](auto dimensions) { + auto dim1 = std::get<0>(dimensions); + auto dim2 = std::get<1>(dimensions); + if (dim1 == 1 || dim2 == 1) + return true; + if (dim1 == dim2 && !ShapedType::isDynamic(dim1)) + return true; + return false; + }); +} + bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, ArrayRef shape2, SmallVectorImpl &resultShape) { diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 1b9f3924b8b0..1665ef73f3e3 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -403,8 +403,8 @@ func @f() { // ----- // Broadcastable with non-broadcastable constant shapes is always false -// CHECK-LABEL: func @f -func @f() { +// CHECK-LABEL: func @static_non_broadcastable +func @static_non_broadcastable() { // CHECK-NEXT: shape.const_shape // CHECK-NEXT: shape.const_shape // CHECK-NEXT: shape.cstr_broadcastable @@ -515,3 +515,49 @@ func @size_to_index_to_size(%size : !shape.size) -> !shape.size { return %result : !shape.size } +// ----- + +// Canonicalize scalar cstr_broadcastable checks +// CHECK-LABEL: @cstr_broadcastable_scalar +func @cstr_broadcastable_scalar(%arg0 : tensor) { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.const_shape [] + %1 = shape.shape_of %arg0 : tensor + %2 = shape.cstr_broadcastable %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +} + +// ----- + +// Do not canonicalize cstr_broadcastable checks with 2 unknowns +// CHECK-LABEL: @cstr_broadcastable_unknown +func @cstr_broadcastable_unknown(%arg0 : tensor, %arg1 : tensor) { + // CHECK-NEXT: shape.shape_of %arg0 + // CHECK-NEXT: shape.shape_of %arg1 + // CHECK-NEXT: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.shape_of %arg0 : tensor + %1 = shape.shape_of %arg1 : tensor + %2 = shape.cstr_broadcastable %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +} + +// ----- + +// Scalars are safe to broadcast to unranked sizes. +// CHECK-LABEL: @cstr_broadcastable_scalar_unranked +func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor) { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %0 = shape.shape_of %arg1 : tensor + %1 = shape.shape_of %arg0 : tensor<*xf32> + %2 = shape.cstr_broadcastable %0, %1 + "consume.witness"(%2) : (!shape.witness) -> () + return +}