[mlir][linalg] Fix vectorization bug in vector transfer indexing map calculation

The current implementation had a bug as it was relying on the target vector
dimension sizes to calculate where to insert broadcast. If several dimensions
have the same size we may insert the broadcast on the wrong dimension. The
correct broadcast cannot be inferred from the type of the source and
destination vector.

Instead when we want to extend transfer ops we calculate an "inverse" map to the
projected permutation and insert broadcast in place of the projected dimensions.

Differential Revision: https://reviews.llvm.org/D101738
This commit is contained in:
thomasraoux 2021-05-03 11:56:11 -07:00
parent 456efbc0f1
commit 9621c1ef56
6 changed files with 102 additions and 42 deletions

View File

@ -1383,18 +1383,6 @@ def Vector_TransferReadOp :
"ArrayAttr":$inBounds)> "ArrayAttr":$inBounds)>
]; ];
let extraClassDeclaration = [{
/// Return a new `result` map with `0` inserted in the proper positions so
/// that vector.transfer_read `result` produces a vector of same element
/// type as `vt` and shape `targetShape.
/// Assume that `map` is a permutation map for a vector.transfer_read op,
/// `vt` the vector type produced by the vector.transfer_read and
/// `targetShape` is the desired `targetShape` for a broadcast version of
/// `vt`.
static AffineMap insertBroadcasts(AffineMap map, VectorType vt,
ArrayRef<int64_t> targetShape);
}];
let hasFolder = 1; let hasFolder = 1;
} }

View File

@ -404,6 +404,48 @@ AffineMap removeDuplicateExprs(AffineMap map);
/// ``` /// ```
AffineMap inversePermutation(AffineMap map); AffineMap inversePermutation(AffineMap map);
/// Return the reverse map of a projected permutation where the projected
/// dimensions are transformed into 0s.
///
/// Prerequisites: `map` must be a projected permuation.
///
/// Example 1:
///
/// ```mlir
/// affine_map<(d0, d1, d2, d3) -> (d2, d0)>
/// ```
///
/// returns:
///
/// ```mlir
/// affine_map<(d0, d1) -> (d1, 0, d0, 0)>
/// ```
///
/// Example 2:
///
/// ```mlir
/// affine_map<(d0, d1, d2, d3) -> (d0, d3)>
/// ```
///
/// returns:
///
/// ```mlir
/// affine_map<(d0, d1) -> (d0, 0, 0, d1)>
/// ```
///
/// Example 3:
///
/// ```mlir
/// affine_map<(d0, d1, d2, d3) -> (d2)>
/// ```
///
/// returns:
///
/// ```mlir
/// affine_map<(d0) -> (0, 0, d0, 0)>
/// ```
AffineMap inverseAndBroadcastProjectedPermuation(AffineMap map);
/// Concatenates a list of `maps` into a single AffineMap, stepping over /// Concatenates a list of `maps` into a single AffineMap, stepping over
/// potentially empty maps. Assumes each of the underlying map has 0 symbols. /// potentially empty maps. Assumes each of the underlying map has 0 symbols.
/// The resulting map has a number of dims equal to the max of `maps`' dims and /// The resulting map has a number of dims equal to the max of `maps`' dims and

View File

@ -493,15 +493,18 @@ LogicalResult vectorizeAsLinalgGeneric(
bvm.map(shapedArg, loaded); bvm.map(shapedArg, loaded);
continue; continue;
} }
AffineMap map = inversePermutation( AffineMap map;
reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber()))); VectorType vectorType;
VectorType vectorType = VectorType::get(map.compose(shapedType.getShape()),
shapedType.getElementType());
if (broadcastToMaximalCommonShape) { if (broadcastToMaximalCommonShape) {
map = vector::TransferReadOp::insertBroadcasts(map, vectorType, map = inverseAndBroadcastProjectedPermuation(
commonVectorShape); linalgOp.getIndexingMap(bbarg.getArgNumber()));
vectorType = vectorType =
VectorType::get(commonVectorShape, vectorType.getElementType()); VectorType::get(commonVectorShape, shapedType.getElementType());
} else {
map = inversePermutation(
reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
vectorType = VectorType::get(map.compose(shapedType.getShape()),
shapedType.getElementType());
} }
Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map); Value vectorRead = buildVectorRead(builder, shapedArg, vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("

View File

@ -2253,29 +2253,6 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
// TransferReadOp // TransferReadOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
AffineMap TransferReadOp::insertBroadcasts(AffineMap map, VectorType vt,
ArrayRef<int64_t> targetShape) {
unsigned targetRank = targetShape.size();
assert(vt.getShape().size() <= targetRank && "mismatching ranks");
if (vt.getShape().size() == targetRank)
return map;
MLIRContext *ctx = map.getContext();
SmallVector<AffineExpr> exprs;
exprs.reserve(targetRank);
for (unsigned idx = 0, vtidx = 0; idx < targetRank; ++idx) {
// If shapes match, just keep the existing indexing and advance ranks.
if (vtidx < vt.getShape().size() &&
vt.getShape()[vtidx] == targetShape[idx]) {
exprs.push_back(map.getResult(vtidx));
++vtidx;
continue;
}
// Otherwise insert a broadcast.
exprs.push_back(getAffineConstantExpr(0, ctx));
}
return AffineMap::get(map.getNumDims(), /*numSymbols=*/0, exprs, ctx);
}
template <typename EmitFun> template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap, static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) { EmitFun emitOpError) {

View File

@ -664,6 +664,19 @@ AffineMap mlir::inversePermutation(AffineMap map) {
return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext()); return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
} }
AffineMap mlir::inverseAndBroadcastProjectedPermuation(AffineMap map) {
assert(map.isProjectedPermutation());
MLIRContext *context = map.getContext();
AffineExpr zero = mlir::getAffineConstantExpr(0, context);
// Start with all the results as 0.
SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
// Reverse each dimension existing in the oringal map result.
exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
}
return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
}
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) { AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
unsigned numResults = 0, numDims = 0, numSymbols = 0; unsigned numResults = 0, numDims = 0, numSymbols = 0;
for (auto m : maps) for (auto m : maps)

View File

@ -381,6 +381,43 @@ func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
// ----- // -----
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0, 0, 0, 0)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, 0, d0, 0)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1, 0, d0, 0)>
// CHECK: func @generic_vectorize_broadcast_transpose
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[CF:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP0]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP1]]} : memref<4xf32>, vector<4x4x4x4xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {permutation_map = #[[$MAP2]]} : memref<4xf32>, vector<4x4x4x4xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {permutation_map = #[[$MAP3]]} : memref<4x4xf32>, vector<4x4x4x4xf32>
// CHECK: %[[SUB:.*]] = subf %[[V0]], %[[V1]] : vector<4x4x4x4xf32>
// CHECK: %[[ADD0:.*]] = addf %[[V2]], %[[SUB]] : vector<4x4x4x4xf32>
// CHECK: %[[ADD1:.*]] = addf %[[V3]], %[[ADD0]] : vector<4x4x4x4xf32>
// CHECK: vector.transfer_write %[[ADD1]], {{.*}} : vector<4x4x4x4xf32>, memref<4x4x4x4xf32>
func @generic_vectorize_broadcast_transpose(
%A: memref<4xf32>, %B: memref<4x4xf32>, %C: memref<4x4x4x4xf32>) {
linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0)>,
affine_map<(d0, d1, d2, d3) -> (d2)>,
affine_map<(d0, d1, d2, d3) -> (d2, d0)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%B, %A, %A, %B: memref<4x4xf32>, memref<4xf32>, memref<4xf32>, memref<4x4xf32>)
outs(%C : memref<4x4x4x4xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%s = subf %arg0, %arg1 : f32
%a = addf %arg2, %s : f32
%b = addf %arg3, %a : f32
linalg.yield %b : f32
}
return
}
// -----
// Test different input maps. // Test different input maps.
#matmul_trait = { #matmul_trait = {
indexing_maps = [ indexing_maps = [