Generalize the vector transfer flattening patterns (dyn shapes).

Differential Revision: https://reviews.llvm.org/D130284
This commit is contained in:
Benoit Jacob 2022-07-25 15:21:33 +00:00
parent 953a98ef8d
commit f4ac950957
2 changed files with 158 additions and 52 deletions

View File

@ -339,23 +339,71 @@ class TransferWriteDropUnitDimsPattern
}
};
/// Creates a memref.collapse_shape collapsing all of the dimensions of the
/// input into a 1D shape.
// TODO: move helper function
static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
Value rankReducedInput =
rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
ShapedType rankReducedInputType =
rankReducedInput.getType().cast<ShapedType>();
if (rankReducedInputType.getRank() == 1)
return rankReducedInput;
ReassociationIndices indices;
for (int i = 0; i < rankReducedInputType.getRank(); ++i)
indices.push_back(i);
return rewriter.create<memref::CollapseShapeOp>(
loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
/// Returns the position of the first inner dimension that has contiguous layout
/// with at least `requiredContiguousSize` contiguous elements.
/// When such a dimension is found, the return value satisfies:
/// 0 <= return_value <= memrefType.getRank() - 1.
/// When no such dimension is found, the return value is memrefType.getRank().
static int64_t getContiguousInnerDim(MemRefType memrefType,
int64_t requiredContiguousSize) {
auto shape = memrefType.getShape();
SmallVector<int64_t> strides;
int64_t offset;
int64_t innerDim = shape.size();
if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
int64_t innerSize = 1;
while (true) {
if (innerDim == 0)
break;
const int64_t nextDim = innerDim - 1;
if (shape[nextDim] == ShapedType::kDynamicSize)
break;
if (strides[nextDim] != innerSize)
break;
innerSize *= shape[nextDim];
innerDim = nextDim;
if (innerSize >= requiredContiguousSize)
break;
}
}
return innerDim;
}
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
/// input starting at `firstDimToCollapse`.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
Value input, int64_t firstDimToCollapse) {
ShapedType inputType = input.getType().cast<ShapedType>();
if (inputType.getRank() == 1)
return input;
SmallVector<ReassociationIndices> reassociation;
for (int64_t i = 0; i < firstDimToCollapse; ++i)
reassociation.push_back(ReassociationIndices{i});
ReassociationIndices collapsedIndices;
for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
collapsedIndices.push_back(i);
reassociation.push_back(collapsedIndices);
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
}
/// Checks that the indices corresponding to dimensions starting at
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
SmallVector<Value> &outIndices) {
int64_t rank = indices.size();
if (firstDimToCollapse >= rank)
return failure();
for (int64_t i = firstDimToCollapse; i < rank; ++i) {
arith::ConstantIndexOp cst =
indices[i].getDefiningOp<arith::ConstantIndexOp>();
if (!cst || cst.value() != 0)
return failure();
}
outIndices = indices;
outIndices.resize(firstDimToCollapse + 1);
return success();
}
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
@ -379,12 +427,9 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!isStaticShapeAndContiguousRowMajor(sourceType))
return failure();
if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
// This pattern requires the source to already be rank-reduced.
return failure();
if (sourceType.getNumElements() != vectorType.getNumElements())
int64_t firstContiguousInnerDim =
getContiguousInnerDim(sourceType, vectorType.getNumElements());
if (firstContiguousInnerDim >= sourceType.getRank() - 1)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
@ -393,19 +438,28 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
if (transferReadOp.getMask())
return failure();
if (llvm::any_of(transferReadOp.getIndices(),
[](Value v) { return !isZero(v); }))
SmallVector<Value> collapsedIndices;
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstContiguousInnerDim,
collapsedIndices)))
return failure();
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
sourceType.getElementType());
Value source1d =
collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
Value read1d = rewriter.create<vector::TransferReadOp>(
loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
collapsedSource.getType().dyn_cast<MemRefType>();
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, vector.getType().cast<VectorType>(), read1d);
transferReadOp, vector.getType().cast<VectorType>(), flatRead);
return success();
}
};
@ -431,12 +485,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!isStaticShapeAndContiguousRowMajor(sourceType))
return failure();
if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
// This pattern requires the source to already be rank-reduced.
return failure();
if (sourceType.getNumElements() != vectorType.getNumElements())
int64_t firstContiguousInnerDim =
getContiguousInnerDim(sourceType, vectorType.getNumElements());
if (firstContiguousInnerDim >= sourceType.getRank() - 1)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
@ -445,19 +496,29 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (transferWriteOp.getMask())
return failure();
if (llvm::any_of(transferWriteOp.getIndices(),
[](Value v) { return !isZero(v); }))
SmallVector<Value> collapsedIndices;
if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
firstContiguousInnerDim,
collapsedIndices)))
return failure();
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
sourceType.getElementType());
Value source1d =
collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
Value vector1d =
rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
ValueRange{c0}, identityMap1D);
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
collapsedSource.getType().cast<MemRefType>();
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
vector::TransferWriteOp flatWrite =
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.eraseOp(transferWriteOp);
return success();
}

View File

@ -59,3 +59,48 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
// CHECK: return %[[READ]]
// -----
#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
%c0_i8 = arith.constant 0 : i8
%c0 = arith.constant 0 : index
%result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, #map0>, vector<8x4xi8>
return %result : vector<8x4xi8>
}
// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
// CHECK-SAME: {in_bounds = [true]}
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
// -----
#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, #map0>
return
}
// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
// CHECK-SAME: {in_bounds = [true]}
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>