forked from OSchip/llvm-project
[mlir][linalg] Fix vectorization bug in vector transfer indexing map calculation
The current implementation had a bug as it was relying on the target vector dimension sizes to calculate where to insert broadcast. If several dimensions have the same size we may insert the broadcast on the wrong dimension. The correct broadcast cannot be inferred from the type of the source and destination vector. Instead when we want to extend transfer ops we calculate an "inverse" map to the projected permutation and insert broadcast in place of the projected dimensions. Differential Revision: https://reviews.llvm.org/D101738
This commit is contained in:
parent
456efbc0f1
commit
9621c1ef56
|
@ -1383,18 +1383,6 @@ def Vector_TransferReadOp :
|
|||
"ArrayAttr":$inBounds)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return a new `result` map with `0` inserted in the proper positions so
|
||||
/// that vector.transfer_read `result` produces a vector of same element
|
||||
/// type as `vt` and shape `targetShape.
|
||||
/// Assume that `map` is a permutation map for a vector.transfer_read op,
|
||||
/// `vt` the vector type produced by the vector.transfer_read and
|
||||
/// `targetShape` is the desired `targetShape` for a broadcast version of
|
||||
/// `vt`.
|
||||
static AffineMap insertBroadcasts(AffineMap map, VectorType vt,
|
||||
ArrayRef<int64_t> targetShape);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -404,6 +404,48 @@ AffineMap removeDuplicateExprs(AffineMap map);
|
|||
/// ```
|
||||
AffineMap inversePermutation(AffineMap map);
|
||||
|
||||
/// Return the reverse map of a projected permutation where the projected
|
||||
/// dimensions are transformed into 0s.
|
||||
///
|
||||
/// Prerequisites: `map` must be a projected permuation.
|
||||
///
|
||||
/// Example 1:
|
||||
///
|
||||
/// ```mlir
|
||||
/// affine_map<(d0, d1, d2, d3) -> (d2, d0)>
|
||||
/// ```
|
||||
///
|
||||
/// returns:
|
||||
///
|
||||
/// ```mlir
|
||||
/// affine_map<(d0, d1) -> (d1, 0, d0, 0)>
|
||||
/// ```
|
||||
///
|
||||
/// Example 2:
|
||||
///
|
||||
/// ```mlir
|
||||
/// affine_map<(d0, d1, d2, d3) -> (d0, d3)>
|
||||
/// ```
|
||||
///
|
||||
/// returns:
|
||||
///
|
||||
/// ```mlir
|
||||
/// affine_map<(d0, d1) -> (d0, 0, 0, d1)>
|
||||
/// ```
|
||||
///
|
||||
/// Example 3:
|
||||
///
|
||||
/// ```mlir
|
||||
/// affine_map<(d0, d1, d2, d3) -> (d2)>
|
||||
/// ```
|
||||
///
|
||||
/// returns:
|
||||
///
|
||||
/// ```mlir
|
||||
/// affine_map<(d0) -> (0, 0, d0, 0)>
|
||||
/// ```
|
||||
AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map);
|
||||
|
||||
/// Concatenates a list of `maps` into a single AffineMap, stepping over
|
||||
/// potentially empty maps. Assumes each of the underlying map has 0 symbols.
|
||||
/// The resulting map has a number of dims equal to the max of `maps`' dims and
|
||||
|
|
|
@ -493,15 +493,18 @@ LogicalResult vectorizeAsLinalgGeneric(
|
|||
bvm.map(shapedArg, loaded);
|
||||
continue;
|
||||
}
|
||||
AffineMap map = inversePermutation(
|
||||
reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
|
||||
VectorType vectorType = VectorType::get(map.compose(shapedType.getShape()),
|
||||
shapedType.getElementType());
|
||||
AffineMap map;
|
||||
VectorType vectorType;
|
||||
if (broadcastToMaximalCommonShape) {
|
||||
map = vector::TransferReadOp::insertBroadcasts(map, vectorType,
|
||||
commonVectorShape);
|
||||
map = inverseAndBroadcastProjectedPermuation(
|
||||
linalgOp.getIndexingMap(bbarg.getArgNumber()));
|
||||
vectorType =
|
||||
VectorType::get(commonVectorShape, vectorType.getElementType());
|
||||
VectorType::get(commonVectorShape, shapedType.getElementType());
|
||||
} else {
|
||||
map = inversePermutation(
|
||||
reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
|
||||
vectorType = VectorType::get(map.compose(shapedType.getShape()),
|
||||
shapedType.getElementType());
|
||||
}
|
||||
Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map);
|
||||
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
|
||||
|
|
|
@ -2253,29 +2253,6 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
|
|||
// TransferReadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AffineMap TransferReadOp::insertBroadcasts(AffineMap map, VectorType vt,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
unsigned targetRank = targetShape.size();
|
||||
assert(vt.getShape().size() <= targetRank && "mismatching ranks");
|
||||
if (vt.getShape().size() == targetRank)
|
||||
return map;
|
||||
MLIRContext *ctx = map.getContext();
|
||||
SmallVector<AffineExpr> exprs;
|
||||
exprs.reserve(targetRank);
|
||||
for (unsigned idx = 0, vtidx = 0; idx < targetRank; ++idx) {
|
||||
// If shapes match, just keep the existing indexing and advance ranks.
|
||||
if (vtidx < vt.getShape().size() &&
|
||||
vt.getShape()[vtidx] == targetShape[idx]) {
|
||||
exprs.push_back(map.getResult(vtidx));
|
||||
++vtidx;
|
||||
continue;
|
||||
}
|
||||
// Otherwise insert a broadcast.
|
||||
exprs.push_back(getAffineConstantExpr(0, ctx));
|
||||
}
|
||||
return AffineMap::get(map.getNumDims(), /*numSymbols=*/0, exprs, ctx);
|
||||
}
|
||||
|
||||
template <typename EmitFun>
|
||||
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
|
||||
EmitFun emitOpError) {
|
||||
|
|
|
@ -664,6 +664,19 @@ AffineMap mlir::inversePermutation(AffineMap map) {
|
|||
return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
|
||||
}
|
||||
|
||||
AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) {
|
||||
assert(map.isProjectedPermutation());
|
||||
MLIRContext *context = map.getContext();
|
||||
AffineExpr zero = mlir::getAffineConstantExpr(0, context);
|
||||
// Start with all the results as 0.
|
||||
SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
|
||||
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
|
||||
// Reverse each dimension existing in the oringal map result.
|
||||
exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
|
||||
}
|
||||
return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
|
||||
}
|
||||
|
||||
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
|
||||
unsigned numResults = 0, numDims = 0, numSymbols = 0;
|
||||
for (auto m : maps)
|
||||
|
|
|
@ -381,6 +381,43 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)>
|
||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0, 0, 0, 0)>
|
||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, 0, d0, 0)>
|
||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1, 0, d0, 0)>
|
||||
// CHECK: func @generic_vectorize_broadcast_transpose
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[CF:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP0]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
|
||||
// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP1]]} : memref<4xf32>, vector<4x4x4x4xf32>
|
||||
// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP2]]} : memref<4xf32>, vector<4x4x4x4xf32>
|
||||
// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP3]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
|
||||
// CHECK: %[[SUB:.*]] = subf %[[V0]], %[[V1]] : vector<4x4x4x4xf32>
|
||||
// CHECK: %[[ADD0:.*]] = addf %[[V2]], %[[SUB]] : vector<4x4x4x4xf32>
|
||||
// CHECK: %[[ADD1:.*]] = addf %[[V3]], %[[ADD0]] : vector<4x4x4x4xf32>
|
||||
// CHECK: vector.transfer_write %[[ADD1]], {{.*}} : vector<4x4x4x4xf32>, memref<4x4x4x4xf32>
|
||||
func @generic_vectorize_broadcast_transpose(
|
||||
%A: memref<4xf32>, %B: memref<4x4xf32>, %C: memref<4x4x4x4xf32>) {
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d2)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d2, d0)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%B, %A, %A, %B: memref<4x4xf32>, memref<4xf32>, memref<4xf32>, memref<4x4xf32>)
|
||||
outs(%C : memref<4x4x4x4xf32>) {
|
||||
^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
|
||||
%s = subf %arg0, %arg1 : f32
|
||||
%a = addf %arg2, %s : f32
|
||||
%b = addf %arg3, %a : f32
|
||||
linalg.yield %b : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test different input maps.
|
||||
#matmul_trait = {
|
||||
indexing_maps = [
|
||||
|
|
Loading…
Reference in New Issue