[mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.

We need to make sure that the types used in the body are valid element types
for VectorType.

Differential Revision: https://reviews.llvm.org/D128336
This commit is contained in:
Adrian Kuegel 2022-06-22 13:50:30 +02:00
parent 719658d078
commit 991547703a
2 changed files with 30 additions and 0 deletions

View File

@ -552,6 +552,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
}
static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
// All types in the body should be a supported element type for VectorType.
for (Operation &innerOp : op->getRegion(0).front()) {
if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
return failure();
}
if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
return failure();
}
}
if (isElementwise(op))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.

View File

@ -207,6 +207,23 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
// -----
// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex<f32>>, %arg0 : complex<f32>) {
// CHECK-NOT: vector.broadcast
// CHECK-NOT: vector.transfer_write
linalg.generic {
indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : complex<f32>)
outs(%A: memref<8x16xcomplex<f32>>) {
^bb(%0: complex<f32>, %1: complex<f32>) :
linalg.yield %0 : complex<f32>
}
return
}
// -----
// CHECK-LABEL: func @test_vectorize_fill
func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>