forked from OSchip/llvm-project
[linalg] When removing noop linalg.generics, check that inserting a cast is valid
linalg.generic can also take scalars instead of tensors, which tensor.cast doesn't support. We don't have an easy way to cast between scalars and tensors so just keep the linalg.generic in those cases. Differential Revision: https://reviews.llvm.org/D122575
This commit is contained in:
parent
a8ebd85e46
commit
35dab904c0
|
@ -836,9 +836,13 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
|
|||
sparse_tensor::getSparseTensorEncoding(resultType))
|
||||
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
|
||||
genericOp.getLoc(), resultType, returnedArg);
|
||||
else
|
||||
else {
|
||||
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
|
||||
resultType))
|
||||
return failure();
|
||||
returnedArg = rewriter.create<tensor::CastOp>(
|
||||
genericOp.getLoc(), resultType, returnedArg);
|
||||
}
|
||||
}
|
||||
returnedArgs.push_back(returnedArg);
|
||||
}
|
||||
|
|
|
@ -175,6 +175,24 @@ func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
|
|||
|
||||
// -----
|
||||
|
||||
#map = affine_map<() -> ()>
|
||||
func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
|
||||
%out = linalg.init_tensor [] : tensor<f32>
|
||||
%g = linalg.generic {
|
||||
indexing_maps = [#map, #map],
|
||||
iterator_types = []
|
||||
} ins(%arg0 : f32)
|
||||
outs(%out : tensor<f32>) {
|
||||
^bb0(%arg2 : f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> (tensor<f32>)
|
||||
return %g : tensor<f32>
|
||||
}
|
||||
// CHECK-LABEL: func @cant_fold_to_tensor_cast
|
||||
// CHECK: linalg.generic
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue