[mlir][shape] Generalize broadcast to a variadic number of shapes

Previously broadcast was a binary op. Now it can support more inputs.
This has been changed in such a way that for now, this is an NFC for
all broadcast operations that were previously legal.

Differential Revision: https://reviews.llvm.org/D95777
This commit is contained in:
Tres Popp 2021-02-01 09:49:54 +01:00
parent 7e75f6fc1d
commit f30f347da1
4 changed files with 162 additions and 159 deletions

View File

@ -50,12 +50,13 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
}
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
let summary = "Returns the broadcasted output shape of two inputs";
let summary = "Returns the broadcasted output shape of two or more inputs";
let description = [{
Returns the broadcasted shape for two input shapes or extent tensors. Both
operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
type `shape.shape` and, if both operands are tensors, may be of type
`tensor<?xindex>`.
Returns the broadcasted shape for input shapes or extent tensors. The rest
of this description is simplified for the 2 input case but can be extended
to more inputs. Both operands can be of type `shape.shape` or
`tensor<?xindex>`. The result is of type `shape.shape` and, if both
operands are tensors, may be of type `tensor<?xindex>`.
If the two operand shapes are of different rank the smaller one is padded
with 1's from the left. The resulting broadcasted shape is then defined as
@ -72,19 +73,26 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
attribute can be used to describe the error case.
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
Shape_ShapeOrExtentTensorType:$rhs,
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
$shapes attr-dict `:` type($shapes) `->` type($result)
}];
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasFolder = 1;
let builders = [OpBuilderDAG<(ins "::mlir::Type":$result,
"::mlir::Value":$lhs, "::mlir::Value":$rhs,
"/*optional*/ ::mlir::StringAttr":$error), [{
build($_builder, $_state, result, ::llvm::makeArrayRef({lhs, rhs}), error);
}]>
];
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasFolder = 1;
let verifier = [{
return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) &&
getNumOperands() >= 2);
}];
}
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {

View File

@ -14,7 +14,9 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::shape;
@ -73,6 +75,48 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
// Get the resulting extent in a given dimension. This is computed with any
// number of extent tensors and shifted offsets into them.
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
ValueRange rankDiffs, Value outputDimension) {
Value one = lb.create<ConstantIndexOp>(1);
Value broadcastedDim = one;
for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
Value shape = std::get<0>(tup);
Value rankDiff = std::get<1>(tup);
Value outOfBounds =
lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
Type indexTy = lb.getIndexType();
broadcastedDim =
lb.create<IfOp>(
TypeRange{indexTy}, outOfBounds,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, broadcastedDim);
},
[&](OpBuilder &b, Location loc) {
// The broadcasting logic is:
// - if one extent (here we arbitrarily choose the
// extent from the greater-rank operand) is equal to 1,
// then take the extent from the other operand
// - otherwise, take the extent as-is.
// Note that this logic remains correct in the presence
// of dimensions of zero extent.
Value lesserRankOperandDimension =
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, shape, ValueRange{lesserRankOperandDimension});
Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
lesserRankOperandExtent, one);
Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
lesserRankOperandExtent);
b.create<scf::YieldOp>(loc, dim);
})
.getResult(0);
}
return broadcastedDim;
}
} // namespace
LogicalResult BroadcastOpConverter::matchAndRewrite(
@ -83,76 +127,44 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
if (op.getType().isa<ShapeType>())
return failure();
assert(!op.lhs().getType().isa<ShapeType>() &&
!op.rhs().getType().isa<ShapeType>());
auto loc = op.getLoc();
ImplicitLocOpBuilder lb(loc, rewriter);
BroadcastOp::Adaptor transformed(operands);
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Value lesserRank =
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
Value greaterRank =
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
auto erasedRankType =
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
Value rankErasedLhs =
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
Value rankErasedRhs =
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
Value lesserRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
Value zero = lb.create<ConstantIndexOp>(0);
Type indexTy = lb.getIndexType();
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputDimension = args[0];
Value isUnchallengedDimension = b.create<CmpIOp>(
loc, CmpIPredicate::ult, outputDimension, rankDiff);
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
loc, greaterRankOperand, outputDimension);
// The initial dimensions of the greater-rank operand are unchallenged,
// so we can take them as-is. Otherwise, we need to do a comparison.
// We need an actual branch here (instead of a select) because the
// lesser-rank operand might be rank 0, so any tensor.extract would be
// invalid.
auto ifOp = b.create<IfOp>(
loc, TypeRange{indexTy}, isUnchallengedDimension,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
},
[&](OpBuilder &b, Location loc) {
// The broadcasting logic is:
// - if one extent (here we arbitrarily choose the extent from
// the greater-rank operand) is equal to 1, then take the extent
// from the other operand
// - otherwise, take the extent as-is.
// Note that this logic remains correct in the presence of
// dimensions of zero extent.
Value lesserRankOperandDimension =
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
loc, lesserRankOperand,
ValueRange{lesserRankOperandDimension});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value broadcastedExtent = b.create<SelectOp>(
loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
greaterRankOperandExtent);
b.create<scf::YieldOp>(loc, broadcastedExtent);
});
b.create<tensor::YieldOp>(loc, ifOp.getResult(0));
});
// Save all the ranks for bounds checking. Because this is a tensor
// representing the shape extents, the rank is the extent of the only
// dimension in the tensor.
SmallVector<Value> ranks, rankDiffs;
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
return lb.create<DimOp>(v, zero);
}));
// Find the maximum rank
Value maxRank = ranks.front();
for (Value v : llvm::drop_begin(ranks, 1)) {
Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
}
// Calculate the difference of ranks and the maximum rank for later offsets.
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
return lb.create<SubIOp>(indexTy, maxRank, v);
}));
rewriter.replaceOp(
op, lb.create<tensor::GenerateOp>(
getExtentTensorType(lb.getContext()), ValueRange{maxRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value broadcastedDim = getBroadcastedDim(
ImplicitLocOpBuilder(loc, b), transformed.shapes(),
rankDiffs, args[0]);
b.create<tensor::YieldOp>(loc, broadcastedDim);
})
->getResults());
return success();
}

View File

@ -357,10 +357,14 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (!operands[1])
return nullptr;
// TODO: Support folding with more than 2 input shapes
if (operands.size() > 2 && !operands[2].isa<StringAttr>())
return nullptr;
auto rhsShape = llvm::to_vector<6>(
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
if (rhsShape.empty())
return lhs();
return shapes()[0];
if (!operands[0])
return nullptr;
@ -368,7 +372,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto lhsShape = llvm::to_vector<6>(
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
if (lhsShape.empty())
return rhs();
return shapes()[1];
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.

View File

@ -305,86 +305,6 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
// -----
// CHECK-LABEL: func @broadcast_unknown_extents(
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) {
func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
// CHECK: }
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
// CHECK: } : tensor<?xindex>
// CHECK: return
// CHECK: }
%0 = shape.broadcast %a, %b
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return
}
// -----
// CHECK-LABEL: func @broadcast_known_different_extents(
// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>,
// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) {
func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
// CHECK: }
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
// CHECK: } : tensor<?xindex>
// CHECK: return
// CHECK: }
%0 = shape.broadcast %a, %b
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return
}
// -----
func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
%0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
return %0 : i1
@ -459,3 +379,62 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
// CHECK: return %[[RESULT]] : !shape.witness
// CHECK: }
// -----
func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
%b : tensor<3xindex>,
%c : tensor<2xindex>) {
// CHECK-LABEL: func @broadcast_3_shapes_different_extents(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
// CHECK: %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]] {
// CHECK: ^bb0(%[[IDX:.*]]: index):
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
// CHECK: scf.yield %[[C1]] : index
// CHECK: } else {
// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index
// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index
// CHECK: }
// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
// CHECK: scf.yield %[[DIM0]] : index
// CHECK: } else {
// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index
// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
// CHECK: }
// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
// CHECK: scf.yield %[[DIM1]] : index
// CHECK: } else {
// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index
// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
// CHECK: }
// CHECK: tensor.yield %[[DIM2]] : index
// CHECK: } : tensor<?xindex>
// CHECK: return
// CHECK: }
%0 = shape.broadcast %a, %b, %c
: tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
return
}