forked from OSchip/llvm-project
[mlir][Shape] Generalize cstr_broadcastable folding for n-ary broadcasts
This is still fairly tricky code, but I tried to untangle it a bit. Differential Revision: https://reviews.llvm.org/D96800
This commit is contained in:
parent
1e2d50936a
commit
63a35f35ec
|
@ -47,7 +47,7 @@ namespace util {
|
|||
bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
|
||||
SmallVectorImpl<int64_t> &resultShape);
|
||||
|
||||
/// Returns true if a broadcast between the 2 shapes is guaranteed to be
|
||||
/// Returns true if a broadcast between n shapes is guaranteed to be
|
||||
/// successful and not result in an error. False does not guarantee that the
|
||||
/// shapes are not broadcastable; it might guarantee that they are not
|
||||
/// broadcastable or it might mean that this function does not have enough
|
||||
|
@ -59,6 +59,7 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
|
|||
/// dimension, while this function will return false because it's possible for
|
||||
/// both shapes to have a dimension greater than 1 and different which would
|
||||
/// fail to broadcast.
|
||||
bool staticallyKnownBroadcastable(ArrayRef<SmallVector<int64_t, 6>> shapes);
|
||||
bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
|
||||
ArrayRef<int64_t> shape2);
|
||||
|
||||
|
|
|
@ -490,38 +490,48 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
|
|||
patterns.insert<CstrBroadcastableEqOps>(context);
|
||||
}
|
||||
|
||||
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
||||
// TODO: Add folding for the nary case
|
||||
if (operands.size() != 2)
|
||||
return nullptr;
|
||||
|
||||
// Both operands are not needed if one is a scalar.
|
||||
if (operands[0] &&
|
||||
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
|
||||
return BoolAttr::get(getContext(), true);
|
||||
if (operands[1] &&
|
||||
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
|
||||
return BoolAttr::get(getContext(), true);
|
||||
|
||||
if (operands[0] && operands[1]) {
|
||||
auto lhsShape = llvm::to_vector<6>(
|
||||
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
auto rhsShape = llvm::to_vector<6>(
|
||||
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
SmallVector<int64_t, 6> resultShape;
|
||||
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
|
||||
return BoolAttr::get(getContext(), true);
|
||||
// Return true if there is exactly one attribute not representing a scalar
|
||||
// broadcast.
|
||||
static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {
|
||||
bool nonScalarSeen = false;
|
||||
for (Attribute a : attributes) {
|
||||
if (!a || a.cast<DenseIntElementsAttr>().getNumElements() != 0) {
|
||||
if (nonScalarSeen)
|
||||
return false;
|
||||
nonScalarSeen = true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
||||
// No broadcasting is needed if all operands but one are scalar.
|
||||
if (hasAtMostSingleNonScalar(operands))
|
||||
return BoolAttr::get(getContext(), true);
|
||||
|
||||
if ([&] {
|
||||
SmallVector<SmallVector<int64_t, 6>, 6> extents;
|
||||
for (const auto &operand : operands) {
|
||||
if (!operand)
|
||||
return false;
|
||||
extents.push_back(llvm::to_vector<6>(
|
||||
operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
|
||||
}
|
||||
return OpTrait::util::staticallyKnownBroadcastable(extents);
|
||||
}())
|
||||
return BoolAttr::get(getContext(), true);
|
||||
|
||||
// Lastly, see if folding can be completed based on what constraints are known
|
||||
// on the input shapes.
|
||||
SmallVector<int64_t, 6> lhsShape, rhsShape;
|
||||
if (failed(getShapeVec(shapes()[0], lhsShape)))
|
||||
return nullptr;
|
||||
if (failed(getShapeVec(shapes()[1], rhsShape)))
|
||||
return nullptr;
|
||||
|
||||
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
|
||||
if ([&] {
|
||||
SmallVector<SmallVector<int64_t, 6>, 6> extents;
|
||||
for (const auto &shape : shapes()) {
|
||||
extents.emplace_back();
|
||||
if (failed(getShapeVec(shape, extents.back())))
|
||||
return false;
|
||||
}
|
||||
return OpTrait::util::staticallyKnownBroadcastable(extents);
|
||||
}())
|
||||
return BoolAttr::get(getContext(), true);
|
||||
|
||||
// Because a failing witness result here represents an eventual assertion
|
||||
|
|
|
@ -15,19 +15,45 @@ using namespace mlir;
|
|||
|
||||
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
|
||||
ArrayRef<int64_t> shape2) {
|
||||
// Two dimensions are compatible when
|
||||
// 1. they are defined and equal, or
|
||||
// 2. one of them is 1
|
||||
return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)),
|
||||
[](auto dimensions) {
|
||||
auto dim1 = std::get<0>(dimensions);
|
||||
auto dim2 = std::get<1>(dimensions);
|
||||
if (dim1 == 1 || dim2 == 1)
|
||||
return true;
|
||||
if (dim1 == dim2 && !ShapedType::isDynamic(dim1))
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
SmallVector<SmallVector<int64_t, 6>, 2> extents;
|
||||
extents.emplace_back(shape1.begin(), shape1.end());
|
||||
extents.emplace_back(shape2.begin(), shape2.end());
|
||||
return staticallyKnownBroadcastable(extents);
|
||||
}
|
||||
|
||||
bool OpTrait::util::staticallyKnownBroadcastable(
|
||||
ArrayRef<SmallVector<int64_t, 6>> shapes) {
|
||||
assert(!shapes.empty() && "Expected at least one shape");
|
||||
size_t maxRank = shapes[0].size();
|
||||
for (size_t i = 1; i != shapes.size(); ++i)
|
||||
maxRank = std::max(maxRank, shapes[i].size());
|
||||
|
||||
// We look backwards through every column of `shapes`.
|
||||
for (size_t i = 0; i != maxRank; ++i) {
|
||||
bool seenDynamic = false;
|
||||
Optional<int64_t> nonOneDim;
|
||||
for (ArrayRef<int64_t> extent : shapes) {
|
||||
int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
|
||||
|
||||
if (dim == 1)
|
||||
continue;
|
||||
|
||||
// Dimensions are compatible when
|
||||
//. 1. One is dynamic, the rest are 1
|
||||
if (ShapedType::isDynamic(dim)) {
|
||||
if (seenDynamic || nonOneDim)
|
||||
return false;
|
||||
seenDynamic = true;
|
||||
}
|
||||
|
||||
// 2. All are 1 or a specific constant.
|
||||
if (nonOneDim && dim != *nonOneDim)
|
||||
return false;
|
||||
|
||||
nonOneDim = dim;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
||||
|
|
|
@ -600,6 +600,92 @@ func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Fold ternary broadcastable
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [8, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, 8] : !shape.shape
|
||||
%cs2 = shape.const_shape [1, 1] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Fold ternary broadcastable with dynamic ranks
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [8, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, -1] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs0, %cs1 : !shape.shape, !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// One scalar and one non-scalar and one unknown cannot be broadcasted at compile time
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [8, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, 8] : !shape.shape
|
||||
%cs2 = shape.const_shape [1, -1] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// One scalar and two unknowns cannot be broadcasted at compile time
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
// CHECK: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [8, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, -1] : !shape.shape
|
||||
%cs2 = shape.const_shape [1, -1] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// Broadcastable with scalars and a non-scalar can be constant folded
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) {
|
||||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs0, %arg0 : !shape.shape, !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// One scalar and one non-scalar and one unknown cannot be folded.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg0 : !shape.shape) {
|
||||
// CHECK: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [] : !shape.shape
|
||||
%cs1 = shape.const_shape [2] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1, %arg0 : !shape.shape, !shape.shape, !shape.shape
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold `rank` based on constant shape.
|
||||
|
|
Loading…
Reference in New Issue