[mlir] Use dynamic_tensor_from_elements in shape.broadcast conversion

Now, convert-shape-to-std doesn't internally create memrefs, which was
previously a bit of a layering violation. The conversion to memrefs
should logically happen as part of bufferization.

Differential Revision: https://reviews.llvm.org/D89669
This commit is contained in:
Sean Silva 2020-10-18 18:34:05 -07:00
parent 7885bf8b78
commit 57211fd239
2 changed files with 98 additions and 97 deletions

View File

@ -110,47 +110,48 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
Value greaterRankOperand =
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
// Allocate stack memory for the broadcasted extent tensor.
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
// Copy extents from greater operand that are not challenged.
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value extent = b.create<ExtractElementOp>(
loc, greaterRankOperand, ValueRange{iv});
b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
b.create<scf::YieldOp>(loc);
});
// Determine remaining broadcasted extents.
rewriter.create<ForOp>(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterOperandExtent =
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
Value greaterOperandExtentIsOne =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
rewriter.replaceOpWithNewOp<DynamicTensorFromElementsOp>(
op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputDimension = args[0];
Value isUnchallengedDimension = b.create<CmpIOp>(
loc, CmpIPredicate::ult, outputDimension, rankDiff);
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
loc, greaterRankOperand, outputDimension);
// The initial dimensions of the greater-rank operand are unchallenged,
// so we can take them as-is. Otherwise, we need to do a comparison.
// We need an actual branch here (instead of a select) because the
// lesser-rank operand might be rank 0, so any extract_element would be
// invalid.
auto ifOp = b.create<IfOp>(
loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
loc, TypeRange{indexTy}, isUnchallengedDimension,
[&](OpBuilder &b, Location loc) {
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
},
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, greaterOperandExtent);
// The broadcasting logic is:
// - if one extent (here we arbitrariliy choose the extent from
// the greater-rank operand) is equal to 1, then take the extent
// from the other operand
// - otherwise, take the extent as-is.
// Note that this logic remains correct in the presence of
// dimensions of zero extent.
Value lesserRankOperandDimension =
b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand,
ValueRange{lesserRankOperandDimension});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value broadcastedExtent = b.create<SelectOp>(
loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
greaterRankOperandExtent);
b.create<scf::YieldOp>(loc, broadcastedExtent);
});
Value extent = ifOp.getResult(0);
b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
b.create<scf::YieldOp>(loc);
b.create<mlir::YieldOp>(loc, ifOp.getResult(0));
});
// Load broadcasted shape as an extent tensor.
rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
return success();
}

View File

@ -305,39 +305,39 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
// -----
// CHECK-LABEL: @broadcast_unknown_extents
// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
// CHECK-LABEL: func @broadcast_unknown_extents(
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) {
func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: }
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
// CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
// CHECK: }
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
// CHECK: } : tensor<?xindex>
// CHECK: return
// CHECK: }
%0 = shape.broadcast %a, %b
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return
@ -345,39 +345,39 @@ func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// -----
// CHECK-LABEL: @broadcast_known_different_extents
// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>)
// CHECK-LABEL: func @broadcast_known_different_extents(
// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>,
// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) {
func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
// CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
// CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: }
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
// CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
// CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
// CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] {
// CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
// CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
// CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
// CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
// CHECK: }
// CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
// CHECK: } : tensor<?xindex>
// CHECK: return
// CHECK: }
%0 = shape.broadcast %a, %b
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return