[mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts

This is still fairly tricky code, but I tried to untangle it a bit.

Differential Revision: https://reviews.llvm.org/D96800
This commit is contained in:
Benjamin Kramer 2021-02-16 19:08:34 +01:00
parent 1e2d50936a
commit 63a35f35ec
4 changed files with 165 additions and 42 deletions

View File

@ -47,7 +47,7 @@ namespace util {
bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape);
/// Returns true if a broadcast between the 2 shapes is guaranteed to be
/// Returns true if a broadcast between n shapes is guaranteed to be
/// successful and not result in an error. False does not guarantee that the
/// shapes are not broadcastable; it might guarantee that they are not
/// broadcastable or it might mean that this function does not have enough
@ -59,6 +59,7 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
/// dimension, while this function will return false because it's possible for
/// both shapes to have a dimension greater than 1 and different which would
/// fail to broadcast.
bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes);
bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2);

View File

@ -490,38 +490,48 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
patterns.insert<CstrBroadcastableEqOps>(context);
}
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// TODO: Add folding for the nary case
if (operands.size() != 2)
return nullptr;
// Both operands are not needed if one is a scalar.
if (operands[0] &&
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
return BoolAttr::get(getContext(), true);
if (operands[1] &&
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
return BoolAttr::get(getContext(), true);
if (operands[0] && operands[1]) {
auto lhsShape = llvm::to_vector<6>(
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
return BoolAttr::get(getContext(), true);
// Return true if there is exactly one attribute not representing a scalar
// broadcast.
static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
bool nonScalarSeen = false;
for (Attribute a : attributes) {
if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
if (nonScalarSeen)
return false;
nonScalarSeen = true;
}
}
return true;
}
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// No broadcasting is needed if all operands but one are scalar.
if (hasAtMostSingleNonScalar(operands))
return BoolAttr::get(getContext(), true);
if ([&] {
SmallVector<SmallVector<int64_t, 6>, 6> extents;
for (const auto &operand : operands) {
if (!operand)
return false;
extents.push_back(llvm::to_vector<6>(
operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
return BoolAttr::get(getContext(), true);
// Lastly, see if folding can be completed based on what constraints are known
// on the input shapes.
SmallVector<int64_t, 6> lhsShape, rhsShape;
if (failed(getShapeVec(shapes()[0], lhsShape)))
return nullptr;
if (failed(getShapeVec(shapes()[1], rhsShape)))
return nullptr;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
if ([&] {
SmallVector<SmallVector<int64_t, 6>, 6> extents;
for (const auto &shape : shapes()) {
extents.emplace_back();
if (failed(getShapeVec(shape, extents.back())))
return false;
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
return BoolAttr::get(getContext(), true);
// Because a failing witness result here represents an eventual assertion

View File

@ -15,19 +15,45 @@ using namespace mlir;
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2) {
// Two dimensions are compatible when
// 1. they are defined and equal, or
// 2. one of them is 1
return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
[](auto dimensions) {
auto dim1 = std::get<0>(dimensions);
auto dim2 = std::get<1>(dimensions);
if (dim1 == 1 || dim2 == 1)
return true;
if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
return true;
return false;
});
SmallVector<SmallVector<int64_t, 6>, 2> extents;
extents.emplace_back(shape1.begin(), shape1.end());
extents.emplace_back(shape2.begin(), shape2.end());
return staticallyKnownBroadcastable(extents);
}
bool OpTrait::util::staticallyKnownBroadcastable(
ArrayRef<SmallVector<int64_t, 6>> shapes) {
assert(!shapes.empty() && "Expected at least one shape");
size_t maxRank = shapes[0].size();
for (size_t i = 1; i != shapes.size(); ++i)
maxRank = std::max(maxRank, shapes[i].size());
// We look backwards through every column of `shapes`.
for (size_t i = 0; i != maxRank; ++i) {
bool seenDynamic = false;
Optional<int64_t> nonOneDim;
for (ArrayRef<int64_t> extent : shapes) {
int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
if (dim == 1)
continue;
// Dimensions are compatible when
//. 1. One is dynamic, the rest are 1
if (ShapedType::isDynamic(dim)) {
if (seenDynamic || nonOneDim)
return false;
seenDynamic = true;
}
// 2. All are 1 or a specific constant.
if (nonOneDim && dim != *nonOneDim)
return false;
nonOneDim = dim;
}
}
return true;
}
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,

View File

@ -600,6 +600,92 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
return
}
// -----
// Fold ternary broadcastable
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, 8] : !shape.shape
%cs2 = shape.const_shape [1, 1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// Fold ternary broadcastable with dynamic ranks
// CHECK-LABEL: func @f
func @f() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, -1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs0, %cs1 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// One scalar and one non-scalar and one unknown cannot be broadcasted at compile time
// CHECK-LABEL: func @f
func @f() {
// CHECK: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, 8] : !shape.shape
%cs2 = shape.const_shape [1, -1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// One scalar and two unknowns cannot be broadcasted at compile time
// CHECK-LABEL: func @f
func @f() {
// CHECK: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, -1] : !shape.shape
%cs2 = shape.const_shape [1, -1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// Broadcastable with scalars and a non-scalar can be constant folded
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs0, %arg0 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// One scalar and one non-scalar and one unknown cannot be folded.
// CHECK-LABEL: func @f
func @f(%arg0 : !shape.shape) {
// CHECK: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [] : !shape.shape
%cs1 = shape.const_shape [2] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %arg0 : !shape.shape, !shape.shape, !shape.shape
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// Fold `rank` based on constant shape.