forked from OSchip/llvm-project
[mlir][vector] Untangle TransferWriteDistribution and avoid crashing in the 0-D case.
This revision avoids a crash in the 0-D case of distributing vector.transfer ops out of vector.warp_execute_on_lane_0. Due to the code complexity and lack of documentation, it took untangling the implementation before realizing that the simple fix was to fail in the 0-D case. The rewrite is still very useful to understand this code better. Differential Revision: https://reviews.llvm.org/D128793
This commit is contained in:
parent
9b994593cc
commit
6a57d8fba5
|
@ -262,6 +262,28 @@ private:
|
|||
const WarpExecuteOnLane0LoweringOptions &options;
|
||||
};
|
||||
|
||||
/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
|
||||
/// op with the proper return type.
|
||||
/// The new write op is updated to write the result of the new warp execute op.
|
||||
/// The old `writeOp` is deleted.
|
||||
static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
|
||||
WarpExecuteOnLane0Op warpOp,
|
||||
vector::TransferWriteOp writeOp,
|
||||
VectorType targetType) {
|
||||
assert(writeOp->getParentOp() == warpOp &&
|
||||
"write must be nested immediately under warp");
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
|
||||
TypeRange{targetType});
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
auto newWriteOp =
|
||||
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
|
||||
rewriter.eraseOp(writeOp);
|
||||
newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
|
||||
return newWriteOp;
|
||||
}
|
||||
|
||||
/// Distribute transfer_write ops based on the affine map returned by
|
||||
/// `distributionMapFn`.
|
||||
/// Example:
|
||||
|
@ -290,11 +312,21 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
LogicalResult tryDistributeOp(RewriterBase &rewriter,
|
||||
vector::TransferWriteOp writeOp,
|
||||
WarpExecuteOnLane0Op warpOp) const {
|
||||
VectorType writtenVectorType = writeOp.getVectorType();
|
||||
|
||||
// 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
|
||||
// to separate it from the rest.
|
||||
if (writtenVectorType.getRank() == 0)
|
||||
return failure();
|
||||
|
||||
// 2. Compute the distribution map.
|
||||
AffineMap map = distributionMapFn(writeOp);
|
||||
SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
|
||||
writeOp.getVectorType().getShape().end());
|
||||
assert(map.getNumResults() == 1 &&
|
||||
"multi-dim distribution not implemented yet");
|
||||
if (map.getNumResults() != 1)
|
||||
return writeOp->emitError("multi-dim distribution not implemented yet");
|
||||
|
||||
// 3. Compute the targetType using the distribution map.
|
||||
SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
|
||||
writtenVectorType.getShape().end());
|
||||
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
|
||||
unsigned position = map.getDimPosition(i);
|
||||
if (targetShape[position] % warpOp.getWarpSize() != 0)
|
||||
|
@ -302,20 +334,16 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
targetShape[position] = targetShape[position] / warpOp.getWarpSize();
|
||||
}
|
||||
VectorType targetType =
|
||||
VectorType::get(targetShape, writeOp.getVectorType().getElementType());
|
||||
VectorType::get(targetShape, writtenVectorType.getElementType());
|
||||
|
||||
SmallVector<Value> yieldValues = {writeOp.getVector()};
|
||||
SmallVector<Type> retTypes = {targetType};
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, yieldValues, retTypes);
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
|
||||
// Move op outside of region: Insert clone at the insertion point and delete
|
||||
// the old op.
|
||||
auto newWriteOp =
|
||||
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
|
||||
rewriter.eraseOp(writeOp);
|
||||
// 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
|
||||
// the rest.
|
||||
vector::TransferWriteOp newWriteOp =
|
||||
cloneWriteOp(rewriter, warpOp, writeOp, targetType);
|
||||
|
||||
// 5. Reindex the write using the distribution map.
|
||||
auto newWarpOp =
|
||||
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
|
||||
rewriter.setInsertionPoint(newWriteOp);
|
||||
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
|
||||
Location loc = newWriteOp.getLoc();
|
||||
|
@ -329,13 +357,11 @@ struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
continue;
|
||||
unsigned indexPos = indexExpr.getPosition();
|
||||
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
|
||||
auto scale =
|
||||
getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
|
||||
auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
|
||||
indices[indexPos] =
|
||||
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
|
||||
{indices[indexPos], newWarpOp.getLaneid()});
|
||||
}
|
||||
newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
|
||||
newWriteOp.getIndicesMutable().assign(indices);
|
||||
|
||||
return success();
|
||||
|
@ -634,7 +660,6 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
|||
Value broadcasted = rewriter.create<vector::BroadcastOp>(
|
||||
loc, destVecType, newWarpOp->getResults().back());
|
||||
newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -491,3 +491,23 @@ func.func @vector_reduction(%laneid: index) -> (f32) {
|
|||
}
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref<f32>) {
|
||||
%c0 = arith.constant 0: index
|
||||
%f0 = arith.constant 0.0: f32
|
||||
// CHECK-D: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
|
||||
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
|
||||
// CHECK-D: vector.transfer_write %[[R]], %{{.*}}[] : vector<f32>, memref<f32>
|
||||
vector.warp_execute_on_lane_0(%laneid)[32] {
|
||||
%0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32>
|
||||
%1 = vector.transfer_read %m1[], %f0 : memref<f32>, vector<f32>
|
||||
%2 = vector.extractelement %1[] : vector<f32>
|
||||
%3 = vector.reduction <add>, %0 : vector<32xf32> into f32
|
||||
%4 = arith.addf %3, %2 : f32
|
||||
%5 = vector.broadcast %4 : f32 to vector<f32>
|
||||
vector.transfer_write %5, %m1[] : vector<f32>, memref<f32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue