[linalg] When removing noop linalg.generics, check that inserting a cast is valid

linalg.generic can also take scalars instead of tensors, which
tensor.cast doesn't support. We don't have an easy way to cast between
scalars and tensors so just keep the linalg.generic in those cases.

Differential Revision: https://reviews.llvm.org/D122575
This commit is contained in:
Benjamin Kramer 2022-03-28 14:10:26 +02:00
parent a8ebd85e46
commit 35dab904c0
2 changed files with 23 additions and 1 deletions

View File

@ -836,9 +836,13 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
sparse_tensor::getSparseTensorEncoding(resultType))
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
genericOp.getLoc(), resultType, returnedArg);
else
else {
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
resultType))
return failure();
returnedArg = rewriter.create<tensor::CastOp>(
genericOp.getLoc(), resultType, returnedArg);
}
}
returnedArgs.push_back(returnedArg);
}

View File

@ -175,6 +175,24 @@ func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
// -----
#map = affine_map<() -> ()>
func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
%out = linalg.init_tensor [] : tensor<f32>
%g = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = []
} ins(%arg0 : f32)
outs(%out : tensor<f32>) {
^bb0(%arg2 : f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> (tensor<f32>)
return %g : tensor<f32>
}
// CHECK-LABEL: func @cant_fold_to_tensor_cast
// CHECK: linalg.generic
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index