forked from OSchip/llvm-project
Generalize the vector transfer flattening patterns (dyn shapes).
Differential Revision: https://reviews.llvm.org/D130284
This commit is contained in:
parent
953a98ef8d
commit
f4ac950957
|
@ -339,23 +339,71 @@ class TransferWriteDropUnitDimsPattern
|
|||
}
|
||||
};
|
||||
|
||||
/// Creates a memref.collapse_shape collapsing all of the dimensions of the
|
||||
/// input into a 1D shape.
|
||||
// TODO: move helper function
|
||||
static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter,
|
||||
mlir::Location loc,
|
||||
Value input) {
|
||||
Value rankReducedInput =
|
||||
rankReducingSubviewDroppingUnitDims(rewriter, loc, input);
|
||||
ShapedType rankReducedInputType =
|
||||
rankReducedInput.getType().cast<ShapedType>();
|
||||
if (rankReducedInputType.getRank() == 1)
|
||||
return rankReducedInput;
|
||||
ReassociationIndices indices;
|
||||
for (int i = 0; i < rankReducedInputType.getRank(); ++i)
|
||||
indices.push_back(i);
|
||||
return rewriter.create<memref::CollapseShapeOp>(
|
||||
loc, rankReducedInput, std::array<ReassociationIndices, 1>{indices});
|
||||
/// Returns the position of the first inner dimension that has contiguous layout
|
||||
/// with at least `requiredContiguousSize` contiguous elements.
|
||||
/// When such a dimension is found, the return value satisfies:
|
||||
/// 0 <= return_value <= memrefType.getRank() - 1.
|
||||
/// When no such dimension is found, the return value is memrefType.getRank().
|
||||
static int64_t getContiguousInnerDim(MemRefType memrefType,
|
||||
int64_t requiredContiguousSize) {
|
||||
auto shape = memrefType.getShape();
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset;
|
||||
int64_t innerDim = shape.size();
|
||||
if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
|
||||
int64_t innerSize = 1;
|
||||
while (true) {
|
||||
if (innerDim == 0)
|
||||
break;
|
||||
const int64_t nextDim = innerDim - 1;
|
||||
if (shape[nextDim] == ShapedType::kDynamicSize)
|
||||
break;
|
||||
if (strides[nextDim] != innerSize)
|
||||
break;
|
||||
innerSize *= shape[nextDim];
|
||||
innerDim = nextDim;
|
||||
if (innerSize >= requiredContiguousSize)
|
||||
break;
|
||||
}
|
||||
}
|
||||
return innerDim;
|
||||
}
|
||||
|
||||
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
|
||||
/// input starting at `firstDimToCollapse`.
|
||||
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
|
||||
Value input, int64_t firstDimToCollapse) {
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
if (inputType.getRank() == 1)
|
||||
return input;
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
for (int64_t i = 0; i < firstDimToCollapse; ++i)
|
||||
reassociation.push_back(ReassociationIndices{i});
|
||||
ReassociationIndices collapsedIndices;
|
||||
for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
|
||||
collapsedIndices.push_back(i);
|
||||
reassociation.push_back(collapsedIndices);
|
||||
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
|
||||
}
|
||||
|
||||
/// Checks that the indices corresponding to dimensions starting at
|
||||
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
|
||||
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
|
||||
static LogicalResult
|
||||
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
|
||||
SmallVector<Value> &outIndices) {
|
||||
int64_t rank = indices.size();
|
||||
if (firstDimToCollapse >= rank)
|
||||
return failure();
|
||||
for (int64_t i = firstDimToCollapse; i < rank; ++i) {
|
||||
arith::ConstantIndexOp cst =
|
||||
indices[i].getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!cst || cst.value() != 0)
|
||||
return failure();
|
||||
}
|
||||
outIndices = indices;
|
||||
outIndices.resize(firstDimToCollapse + 1);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
|
||||
|
@ -379,12 +427,9 @@ class FlattenContiguousRowMajorTransferReadPattern
|
|||
if (vectorType.getRank() <= 1)
|
||||
// Already 0D/1D, nothing to do.
|
||||
return failure();
|
||||
if (!isStaticShapeAndContiguousRowMajor(sourceType))
|
||||
return failure();
|
||||
if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
|
||||
// This pattern requires the source to already be rank-reduced.
|
||||
return failure();
|
||||
if (sourceType.getNumElements() != vectorType.getNumElements())
|
||||
int64_t firstContiguousInnerDim =
|
||||
getContiguousInnerDim(sourceType, vectorType.getNumElements());
|
||||
if (firstContiguousInnerDim >= sourceType.getRank() - 1)
|
||||
return failure();
|
||||
// TODO: generalize this pattern, relax the requirements here.
|
||||
if (transferReadOp.hasOutOfBoundsDim())
|
||||
|
@ -393,19 +438,28 @@ class FlattenContiguousRowMajorTransferReadPattern
|
|||
return failure();
|
||||
if (transferReadOp.getMask())
|
||||
return failure();
|
||||
if (llvm::any_of(transferReadOp.getIndices(),
|
||||
[](Value v) { return !isZero(v); }))
|
||||
SmallVector<Value> collapsedIndices;
|
||||
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
|
||||
firstContiguousInnerDim,
|
||||
collapsedIndices)))
|
||||
return failure();
|
||||
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
|
||||
VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
|
||||
sourceType.getElementType());
|
||||
Value source1d =
|
||||
collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
|
||||
Value read1d = rewriter.create<vector::TransferReadOp>(
|
||||
loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D);
|
||||
Value collapsedSource =
|
||||
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
|
||||
MemRefType collapsedSourceType =
|
||||
collapsedSource.getType().dyn_cast<MemRefType>();
|
||||
int64_t collapsedRank = collapsedSourceType.getRank();
|
||||
assert(collapsedRank == firstContiguousInnerDim + 1);
|
||||
SmallVector<AffineExpr, 1> dimExprs{
|
||||
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
|
||||
auto collapsedMap =
|
||||
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
|
||||
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
|
||||
vectorType.getElementType());
|
||||
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
|
||||
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
|
||||
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
|
||||
transferReadOp, vector.getType().cast<VectorType>(), read1d);
|
||||
transferReadOp, vector.getType().cast<VectorType>(), flatRead);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -431,12 +485,9 @@ class FlattenContiguousRowMajorTransferWritePattern
|
|||
if (vectorType.getRank() <= 1)
|
||||
// Already 0D/1D, nothing to do.
|
||||
return failure();
|
||||
if (!isStaticShapeAndContiguousRowMajor(sourceType))
|
||||
return failure();
|
||||
if (getReducedRank(sourceType.getShape()) != sourceType.getRank())
|
||||
// This pattern requires the source to already be rank-reduced.
|
||||
return failure();
|
||||
if (sourceType.getNumElements() != vectorType.getNumElements())
|
||||
int64_t firstContiguousInnerDim =
|
||||
getContiguousInnerDim(sourceType, vectorType.getNumElements());
|
||||
if (firstContiguousInnerDim >= sourceType.getRank() - 1)
|
||||
return failure();
|
||||
// TODO: generalize this pattern, relax the requirements here.
|
||||
if (transferWriteOp.hasOutOfBoundsDim())
|
||||
|
@ -445,19 +496,29 @@ class FlattenContiguousRowMajorTransferWritePattern
|
|||
return failure();
|
||||
if (transferWriteOp.getMask())
|
||||
return failure();
|
||||
if (llvm::any_of(transferWriteOp.getIndices(),
|
||||
[](Value v) { return !isZero(v); }))
|
||||
SmallVector<Value> collapsedIndices;
|
||||
if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
|
||||
firstContiguousInnerDim,
|
||||
collapsedIndices)))
|
||||
return failure();
|
||||
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
auto identityMap1D = rewriter.getMultiDimIdentityMap(1);
|
||||
VectorType vectorType1d = VectorType::get({sourceType.getNumElements()},
|
||||
sourceType.getElementType());
|
||||
Value source1d =
|
||||
collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source);
|
||||
Value vector1d =
|
||||
rewriter.create<vector::ShapeCastOp>(loc, vectorType1d, vector);
|
||||
rewriter.create<vector::TransferWriteOp>(loc, vector1d, source1d,
|
||||
ValueRange{c0}, identityMap1D);
|
||||
Value collapsedSource =
|
||||
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
|
||||
MemRefType collapsedSourceType =
|
||||
collapsedSource.getType().cast<MemRefType>();
|
||||
int64_t collapsedRank = collapsedSourceType.getRank();
|
||||
assert(collapsedRank == firstContiguousInnerDim + 1);
|
||||
SmallVector<AffineExpr, 1> dimExprs{
|
||||
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
|
||||
auto collapsedMap =
|
||||
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
|
||||
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
|
||||
vectorType.getElementType());
|
||||
Value flatVector =
|
||||
rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
|
||||
vector::TransferWriteOp flatWrite =
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
|
||||
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
|
||||
rewriter.eraseOp(transferWriteOp);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -59,3 +59,48 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
|
|||
// CHECK: %[[CST:.+]] = arith.constant 0 : i8
|
||||
// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
|
||||
// CHECK: return %[[READ]]
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
|
||||
|
||||
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
|
||||
%c0_i8 = arith.constant 0 : i8
|
||||
%c0 = arith.constant 0 : index
|
||||
%result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, #map0>, vector<8x4xi8>
|
||||
return %result : vector<8x4xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
|
||||
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
|
||||
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
|
||||
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
|
||||
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
|
||||
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
|
||||
// CHECK-SAME: {in_bounds = [true]}
|
||||
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
|
||||
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
|
||||
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)>
|
||||
|
||||
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, #map0>, %arg1 : index, %arg2 : index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, #map0>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
|
||||
// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
|
||||
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
|
||||
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
|
||||
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
|
||||
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
|
||||
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
|
||||
// CHECK-SAME: {in_bounds = [true]}
|
||||
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
|
||||
|
|
Loading…
Reference in New Issue