forked from OSchip/llvm-project
[mlir][linalg] fix crash in vectorization of elementwise operations
The current vectorization logic implicitly expects "elementwise" linalg ops to have projected permutations for indexing maps, but the precondition logic misses this check. This can result in a crash when executing the generic vectorization transform on an op with a non-projected permutation input indexing map. This change fixes the logic and adds a test (which crashes without this fix). Differential Revision: https://reviews.llvm.org/D127000
This commit is contained in:
parent
f60875254b
commit
9f819f4c62
|
@ -432,6 +432,14 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Returns `true` if all indexing maps of the linalg op are projected
|
||||
/// permutations.
|
||||
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
|
||||
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
|
||||
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
|
||||
});
|
||||
}
|
||||
|
||||
// Return true if the op is an element-wise linalg op.
|
||||
static bool isElementwise(Operation *op) {
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
|
||||
|
@ -439,6 +447,10 @@ static bool isElementwise(Operation *op) {
|
|||
return false;
|
||||
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
|
||||
return false;
|
||||
|
||||
if (!allIndexingsAreProjectedPermutation(linalgOp))
|
||||
return false;
|
||||
|
||||
// TODO: relax the restrictions on indexing map.
|
||||
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
|
||||
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
|
||||
|
@ -564,17 +576,6 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Helper function to vectorize a `linalgOp` with contraction semantics in a
|
||||
/// generic fashion.
|
||||
/// This helper is needed atm because the truly generic implementation requires
|
||||
/// good vector.multi_reduce folding patterns that are currently NYI.
|
||||
// TODO: drop reliance on a specific pattern.
|
||||
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
|
||||
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
|
||||
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: probably need some extra checks for reduction followed by consumer
|
||||
// ops that may not commute (e.g. linear reduction + non-linear instructions).
|
||||
static LogicalResult reductionPreconditions(LinalgOp op) {
|
||||
|
|
|
@ -1077,3 +1077,27 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
|
|||
|
||||
return %2 : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// This test checks that vectorization does not occur when an input indexing map
|
||||
// is not a projected permutation. In the future, this can be converted to a
|
||||
// positive test when support is added.
|
||||
|
||||
// CHECK-LABEL: func @not_projected_permutation
|
||||
func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf32> {
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%init = linalg.init_tensor [6, 6, 3, 3] : tensor<6x6x3x3xf32>
|
||||
%fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<6x6x3x3xf32>) -> tensor<6x6x3x3xf32>
|
||||
// CHECK: linalg.generic
|
||||
%result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>,
|
||||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<8x8xf32>)
|
||||
outs(%fill : tensor<6x6x3x3xf32>) {
|
||||
^bb0(%arg7: f32, %arg9: f32):
|
||||
linalg.yield %arg7 : f32
|
||||
} -> tensor<6x6x3x3xf32>
|
||||
return %result : tensor<6x6x3x3xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue