[mlir][vector] Untangle TransferWriteDistribution and avoid crashing in the 0-D case.

This revision avoids a crash in the 0-D case of distributing vector.transfer ops out of
vector.warp_execute_on_lane_0.
Due to the code complexity and lack of documentation, it took untangling the implementation
before realizing that the simple fix was to fail in the 0-D case.
The rewrite is still very useful to understand this code better.

Differential Revision: https://reviews.llvm.org/D128793
This commit is contained in:
Nicolas Vasilache 2022-06-29 01:59:33 -07:00
parent 9b994593cc
commit 6a57d8fba5
2 changed files with 65 additions and 20 deletions

View File

@ -262,6 +262,28 @@ private:
const WarpExecuteOnLane0LoweringOptions &options;
};
/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
/// op with the proper return type.
/// The new write op is updated to write the result of the new warp execute op.
/// The old `writeOp` is deleted.
static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
WarpExecuteOnLane0Op warpOp,
vector::TransferWriteOp writeOp,
VectorType targetType) {
assert(writeOp->getParentOp() == warpOp &&
"write must be nested immediately under warp");
OpBuilder::InsertionGuard g(rewriter);
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
TypeRange{targetType});
rewriter.setInsertionPointAfter(newWarpOp);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
return newWriteOp;
}
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
@ -290,11 +312,21 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
LogicalResult tryDistributeOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
VectorType writtenVectorType = writeOp.getVectorType();
// 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
// to separate it from the rest.
if (writtenVectorType.getRank() == 0)
return failure();
// 2. Compute the distribution map.
AffineMap map = distributionMapFn(writeOp);
SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
writeOp.getVectorType().getShape().end());
assert(map.getNumResults() == 1 &&
"multi-dim distribution not implemented yet");
if (map.getNumResults() != 1)
return writeOp->emitError("multi-dim distribution not implemented yet");
// 3. Compute the targetType using the distribution map.
SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
writtenVectorType.getShape().end());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
if (targetShape[position] % warpOp.getWarpSize() != 0)
@ -302,20 +334,16 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
targetShape[position] = targetShape[position] / warpOp.getWarpSize();
}
VectorType targetType =
VectorType::get(targetShape, writeOp.getVectorType().getElementType());
VectorType::get(targetShape, writtenVectorType.getElementType());
SmallVector<Value> yieldValues = {writeOp.getVector()};
SmallVector<Type> retTypes = {targetType};
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes);
rewriter.setInsertionPointAfter(newWarpOp);
// Move op outside of region: Insert clone at the insertion point and delete
// the old op.
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
// 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
// the rest.
vector::TransferWriteOp newWriteOp =
cloneWriteOp(rewriter, warpOp, writeOp, targetType);
// 5. Reindex the write using the distribution map.
auto newWarpOp =
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
rewriter.setInsertionPoint(newWriteOp);
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
@ -329,13 +357,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale =
getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
indices[indexPos] =
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
{indices[indexPos], newWarpOp.getLaneid()});
}
newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
newWriteOp.getIndicesMutable().assign(indices);
return success();
@ -634,7 +660,6 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
Value broadcasted = rewriter.create<vector::BroadcastOp>(
loc, destVecType, newWarpOp->getResults().back());
newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
return success();
}
};

View File

@ -491,3 +491,23 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
}
return %r : f32
}
// -----
func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
%c0 = arith.constant 0: index
%f0 = arith.constant 0.0: f32
// CHECK-D: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
// CHECK-D: vector.transfer_write %[[R]], %{{.*}}[] : vector<f32>, memref<f32>
vector.warp_execute_on_lane_0(%laneid)[32] {
%0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32>
%1 = vector.transfer_read %m1[], %f0 : memref<f32>, vector<f32>
%2 = vector.extractelement %1[] : vector<f32>
%3 = vector.reduction <add>, %0 : vector<32xf32> into f32
%4 = arith.addf %3, %2 : f32
%5 = vector.broadcast %4 : f32 to vector<f32>
vector.transfer_write %5, %m1[] : vector<f32>, memref<f32>
}
return
}