forked from OSchip/llvm-project
[mlir] Use dynamic_tensor_from_elements in shape.broadcast conversion
Now, convert-shape-to-std doesn't internally create memrefs, which was previously a bit of a layering violation. The conversion to memrefs should logically happen as part of bufferization. Differential Revision: https://reviews.llvm.org/D89669
This commit is contained in:
parent
7885bf8b78
commit
57211fd239
|
@ -110,47 +110,48 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
|
|||
Value greaterRankOperand =
|
||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
|
||||
|
||||
// Allocate stack memory for the broadcasted extent tensor.
|
||||
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
|
||||
|
||||
// Copy extents from greater operand that are not challenged.
|
||||
Value rankDiff =
|
||||
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
|
||||
rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
||||
Value extent = b.create<ExtractElementOp>(
|
||||
loc, greaterRankOperand, ValueRange{iv});
|
||||
b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
|
||||
// Determine remaining broadcasted extents.
|
||||
rewriter.create<ForOp>(
|
||||
loc, rankDiff, greaterRank, one, llvm::None,
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
||||
Value greaterOperandExtent =
|
||||
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
|
||||
Value greaterOperandExtentIsOne =
|
||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
|
||||
rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
|
||||
op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value outputDimension = args[0];
|
||||
Value isUnchallengedDimension = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::ult, outputDimension, rankDiff);
|
||||
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
||||
loc, greaterRankOperand, outputDimension);
|
||||
// The initial dimensions of the greater-rank operand are unchallenged,
|
||||
// so we can take them as-is. Otherwise, we need to do a comparison.
|
||||
// We need an actual branch here (instead of a select) because the
|
||||
// lesser-rank operand might be rank 0, so any extract_element would be
|
||||
// invalid.
|
||||
auto ifOp = b.create<IfOp>(
|
||||
loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
|
||||
loc, TypeRange{indexTy}, isUnchallengedDimension,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||
loc, lesserRankOperand, ValueRange{ivShifted});
|
||||
b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
|
||||
b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(loc, greaterOperandExtent);
|
||||
// The broadcasting logic is:
|
||||
// - if one extent (here we arbitrariliy choose the extent from
|
||||
// the greater-rank operand) is equal to 1, then take the extent
|
||||
// from the other operand
|
||||
// - otherwise, take the extent as-is.
|
||||
// Note that this logic remains correct in the presence of
|
||||
// dimensions of zero extent.
|
||||
Value lesserRankOperandDimension =
|
||||
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||
loc, lesserRankOperand,
|
||||
ValueRange{lesserRankOperandDimension});
|
||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
||||
Value broadcastedExtent = b.create<SelectOp>(
|
||||
loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
|
||||
greaterRankOperandExtent);
|
||||
b.create<scf::YieldOp>(loc, broadcastedExtent);
|
||||
});
|
||||
Value extent = ifOp.getResult(0);
|
||||
b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
|
||||
b.create<scf::YieldOp>(loc);
|
||||
b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
|
||||
});
|
||||
|
||||
// Load broadcasted shape as an extent tensor.
|
||||
rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -305,39 +305,39 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast_unknown_extents
|
||||
// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
|
||||
// CHECK-LABEL: func @broadcast_unknown_extents(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) {
|
||||
func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
|
||||
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
|
||||
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
|
||||
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
|
||||
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
|
||||
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
|
||||
// CHECK: } : tensor<?xindex>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
%0 = shape.broadcast %a, %b
|
||||
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
|
||||
return
|
||||
|
@ -345,39 +345,39 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast_known_different_extents
|
||||
// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>)
|
||||
// CHECK-LABEL: func @broadcast_known_different_extents(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) {
|
||||
func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
|
||||
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
|
||||
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
|
||||
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
|
||||
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
|
||||
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
|
||||
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
|
||||
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
|
||||
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
|
||||
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
|
||||
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
|
||||
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
|
||||
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
|
||||
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
|
||||
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
|
||||
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
|
||||
// CHECK: }
|
||||
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
|
||||
// CHECK: } : tensor<?xindex>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
%0 = shape.broadcast %a, %b
|
||||
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue