forked from OSchip/llvm-project
[mlir][shape] Generalize broadcast to a variadic number of shapes
Previously broadcast was a binary op. Now it can support more inputs. This has been changed in such a way that for now, this is an NFC for all broadcast operations that were previously legal. Differential Revision: https://reviews.llvm.org/D95777
This commit is contained in:
parent
7e75f6fc1d
commit
f30f347da1
|
@ -50,12 +50,13 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
|
|||
}
|
||||
|
||||
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
|
||||
let summary = "Returns the broadcasted output shape of two inputs";
|
||||
let summary = "Returns the broadcasted output shape of two or more inputs";
|
||||
let description = [{
|
||||
Returns the broadcasted shape for two input shapes or extent tensors. Both
|
||||
operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
|
||||
type `shape.shape` and, if both operands are tensors, may be of type
|
||||
`tensor<?xindex>`.
|
||||
Returns the broadcasted shape for input shapes or extent tensors. The rest
|
||||
of this description is simplified for the 2 input case but can be extended
|
||||
to more inputs. Both operands can be of type `shape.shape` or
|
||||
`tensor<?xindex>`. The result is of type `shape.shape` and, if both
|
||||
operands are tensors, may be of type `tensor<?xindex>`.
|
||||
|
||||
If the two operand shapes are of different rank the smaller one is padded
|
||||
with 1's from the left. The resulting broadcasted shape is then defined as
|
||||
|
@ -72,19 +73,26 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
|
|||
attribute can be used to describe the error case.
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
|
||||
Shape_ShapeOrExtentTensorType:$rhs,
|
||||
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes,
|
||||
OptionalAttr<StrAttr>:$error);
|
||||
let results = (outs Shape_ShapeOrExtentTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||
$shapes attr-dict `:` type($shapes) `->` type($result)
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
|
||||
let hasFolder = 1;
|
||||
let builders = [OpBuilderDAG<(ins "::mlir::Type":$result,
|
||||
"::mlir::Value":$lhs, "::mlir::Value":$rhs,
|
||||
"/*optional*/ ::mlir::StringAttr":$error), [{
|
||||
build($_builder, $_state, result, ::llvm::makeArrayRef({lhs, rhs}), error);
|
||||
}]>
|
||||
];
|
||||
|
||||
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
|
||||
let hasFolder = 1;
|
||||
let verifier = [{
|
||||
return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) &&
|
||||
getNumOperands() >= 2);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
@ -73,6 +75,48 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
|
|||
matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
// Get the resulting extent in a given dimension. This is computed with any
|
||||
// number of extent tensors and shifted offsets into them.
|
||||
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
|
||||
ValueRange rankDiffs, Value outputDimension) {
|
||||
Value one = lb.create<ConstantIndexOp>(1);
|
||||
Value broadcastedDim = one;
|
||||
for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
|
||||
Value shape = std::get<0>(tup);
|
||||
Value rankDiff = std::get<1>(tup);
|
||||
Value outOfBounds =
|
||||
lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
|
||||
Type indexTy = lb.getIndexType();
|
||||
broadcastedDim =
|
||||
lb.create<IfOp>(
|
||||
TypeRange{indexTy}, outOfBounds,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(loc, broadcastedDim);
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
// The broadcasting logic is:
|
||||
// - if one extent (here we arbitrarily choose the
|
||||
// extent from the greater-rank operand) is equal to 1,
|
||||
// then take the extent from the other operand
|
||||
// - otherwise, take the extent as-is.
|
||||
// Note that this logic remains correct in the presence
|
||||
// of dimensions of zero extent.
|
||||
Value lesserRankOperandDimension =
|
||||
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, shape, ValueRange{lesserRankOperandDimension});
|
||||
|
||||
Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||
lesserRankOperandExtent, one);
|
||||
Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
|
||||
lesserRankOperandExtent);
|
||||
b.create<scf::YieldOp>(loc, dim);
|
||||
})
|
||||
.getResult(0);
|
||||
}
|
||||
return broadcastedDim;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
LogicalResult BroadcastOpConverter::matchAndRewrite(
|
||||
|
@ -83,76 +127,44 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
|||
if (op.getType().isa<ShapeType>())
|
||||
return failure();
|
||||
|
||||
assert(!op.lhs().getType().isa<ShapeType>() &&
|
||||
!op.rhs().getType().isa<ShapeType>());
|
||||
auto loc = op.getLoc();
|
||||
ImplicitLocOpBuilder lb(loc, rewriter);
|
||||
BroadcastOp::Adaptor transformed(operands);
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
|
||||
// Find smaller and greater rank and extent tensor.
|
||||
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
|
||||
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
|
||||
Value lhsRankULE =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value lesserRank =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
|
||||
Value greaterRank =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
|
||||
auto erasedRankType =
|
||||
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
Value rankErasedLhs =
|
||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
|
||||
Value rankErasedRhs =
|
||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
|
||||
Value lesserRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
|
||||
Value greaterRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
|
||||
Value zero = lb.create<ConstantIndexOp>(0);
|
||||
Type indexTy = lb.getIndexType();
|
||||
|
||||
Value rankDiff =
|
||||
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
|
||||
rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
|
||||
op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value outputDimension = args[0];
|
||||
Value isUnchallengedDimension = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::ult, outputDimension, rankDiff);
|
||||
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, greaterRankOperand, outputDimension);
|
||||
// The initial dimensions of the greater-rank operand are unchallenged,
|
||||
// so we can take them as-is. Otherwise, we need to do a comparison.
|
||||
// We need an actual branch here (instead of a select) because the
|
||||
// lesser-rank operand might be rank 0, so any tensor.extract would be
|
||||
// invalid.
|
||||
auto ifOp = b.create<IfOp>(
|
||||
loc, TypeRange{indexTy}, isUnchallengedDimension,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
// The broadcasting logic is:
|
||||
// - if one extent (here we arbitrarily choose the extent from
|
||||
// the greater-rank operand) is equal to 1, then take the extent
|
||||
// from the other operand
|
||||
// - otherwise, take the extent as-is.
|
||||
// Note that this logic remains correct in the presence of
|
||||
// dimensions of zero extent.
|
||||
Value lesserRankOperandDimension =
|
||||
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
|
||||
loc, lesserRankOperand,
|
||||
ValueRange{lesserRankOperandDimension});
|
||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
||||
Value broadcastedExtent = b.create<SelectOp>(
|
||||
loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
|
||||
greaterRankOperandExtent);
|
||||
b.create<scf::YieldOp>(loc, broadcastedExtent);
|
||||
});
|
||||
b.create<tensor::YieldOp>(loc, ifOp.getResult(0));
|
||||
});
|
||||
// Save all the ranks for bounds checking. Because this is a tensor
|
||||
// representing the shape extents, the rank is the extent of the only
|
||||
// dimension in the tensor.
|
||||
SmallVector<Value> ranks, rankDiffs;
|
||||
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
|
||||
return lb.create<DimOp>(v, zero);
|
||||
}));
|
||||
|
||||
// Find the maximum rank
|
||||
Value maxRank = ranks.front();
|
||||
for (Value v : llvm::drop_begin(ranks, 1)) {
|
||||
Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
|
||||
maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
|
||||
}
|
||||
|
||||
// Calculate the difference of ranks and the maximum rank for later offsets.
|
||||
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
|
||||
return lb.create<SubIOp>(indexTy, maxRank, v);
|
||||
}));
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, lb.create<tensor::GenerateOp>(
|
||||
getExtentTensorType(lb.getContext()), ValueRange{maxRank},
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value broadcastedDim = getBroadcastedDim(
|
||||
ImplicitLocOpBuilder(loc, b), transformed.shapes(),
|
||||
rankDiffs, args[0]);
|
||||
|
||||
b.create<tensor::YieldOp>(loc, broadcastedDim);
|
||||
})
|
||||
->getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -357,10 +357,14 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (!operands[1])
|
||||
return nullptr;
|
||||
|
||||
// TODO: Support folding with more than 2 input shapes
|
||||
if (operands.size() > 2 && !operands[2].isa<StringAttr>())
|
||||
return nullptr;
|
||||
|
||||
auto rhsShape = llvm::to_vector<6>(
|
||||
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
if (rhsShape.empty())
|
||||
return lhs();
|
||||
return shapes()[0];
|
||||
|
||||
if (!operands[0])
|
||||
return nullptr;
|
||||
|
@ -368,7 +372,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
|||
auto lhsShape = llvm::to_vector<6>(
|
||||
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
if (lhsShape.empty())
|
||||
return rhs();
|
||||
return shapes()[1];
|
||||
|
||||
SmallVector<int64_t, 6> resultShape;
|
||||
// If the shapes are not compatible, we can't fold it.
|
||||
|
|
|
@ -305,86 +305,6 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast_unknown_extents(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) {
|
||||
func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
|
||||
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
|
||||
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
|
||||
// CHECK: } : tensor<?xindex>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
%0 = shape.broadcast %a, %b
|
||||
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast_known_different_extents(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) {
|
||||
func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
|
||||
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
|
||||
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
|
||||
// CHECK: } : tensor<?xindex>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
%0 = shape.broadcast %a, %b
|
||||
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
|
||||
%0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
|
||||
return %0 : i1
|
||||
|
@ -459,3 +379,62 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
|
|||
// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
|
||||
// CHECK: return %[[RESULT]] : !shape.witness
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
|
||||
%b : tensor<3xindex>,
|
||||
%c : tensor<2xindex>) {
|
||||
// CHECK-LABEL: func @broadcast_3_shapes_different_extents(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
|
||||
// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
|
||||
// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
|
||||
// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
|
||||
// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
|
||||
// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
|
||||
// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
|
||||
// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
|
||||
// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
|
||||
// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
|
||||
// CHECK: %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]] {
|
||||
// CHECK: ^bb0(%[[IDX:.*]]: index):
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
|
||||
// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
|
||||
// CHECK: scf.yield %[[C1]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
|
||||
// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
|
||||
// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index
|
||||
// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
|
||||
// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
|
||||
// CHECK: scf.yield %[[DIM0]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
|
||||
// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
|
||||
// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index
|
||||
// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
|
||||
// CHECK: }
|
||||
// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
|
||||
// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
|
||||
// CHECK: scf.yield %[[DIM1]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
|
||||
// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
|
||||
// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index
|
||||
// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
|
||||
// CHECK: }
|
||||
// CHECK: tensor.yield %[[DIM2]] : index
|
||||
// CHECK: } : tensor<?xindex>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
%0 = shape.broadcast %a, %b, %c
|
||||
: tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue