forked from OSchip/llvm-project
[mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.
We need to make sure that the types used in the body are valid element types for VectorType. Differential Revision: https://reviews.llvm.org/D128336
This commit is contained in:
parent
719658d078
commit
991547703a
|
@ -552,6 +552,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
|
|||
}
|
||||
|
||||
static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
|
||||
// All types in the body should be a supported element type for VectorType.
|
||||
for (Operation &innerOp : op->getRegion(0).front()) {
|
||||
if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
|
||||
return !VectorType::isValidElementType(type);
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
|
||||
return !VectorType::isValidElementType(type);
|
||||
})) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (isElementwise(op))
|
||||
return success();
|
||||
// TODO: isaConvolutionOpInterface that can also infer from generic features.
|
||||
|
|
|
@ -207,6 +207,23 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
|
||||
func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex<f32>>, %arg0 : complex<f32>) {
|
||||
// CHECK-NOT: vector.broadcast
|
||||
// CHECK-NOT: vector.transfer_write
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : complex<f32>)
|
||||
outs(%A: memref<8x16xcomplex<f32>>) {
|
||||
^bb(%0: complex<f32>, %1: complex<f32>) :
|
||||
linalg.yield %0 : complex<f32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_vectorize_fill
|
||||
func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
|
||||
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
|
||||
|
|
Loading…
Reference in New Issue