[mlir][linalg] Insert a cast for identity linalg.generics when the types don't match

This can happen when the result has different dynamic dimensions than
the input.

Differential Revision: https://reviews.llvm.org/D117498
This commit is contained in:
Benjamin Kramer 2022-01-17 17:38:07 +01:00
parent 7294d7dae7
commit f100bedb03
2 changed files with 31 additions and 3 deletions

View File

@ -857,12 +857,19 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
returnedArgs.push_back(genericOp->getOperand(argumentNumber));
Value returnedArg = genericOp->getOperand(argumentNumber);
Type resultType = genericOp->getResult(yieldVal.index()).getType();
// The input can have a different type than the result, e.g. a dynamic
// input dimension can be turned into a static output dimension.
if (returnedArg.getType() != resultType)
returnedArg = rewriter.create<tensor::CastOp>(genericOp.getLoc(),
resultType, returnedArg);
returnedArgs.push_back(returnedArg);
}
if (returnedArgs.size() != genericOp->getNumResults())
return failure();

View File

@ -179,6 +179,27 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
// -----
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
-> tensor<1x2x3xf32> {
%out = linalg.init_tensor [1, 2, 3] : tensor<1x2x3xf32>
%g = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%arg0 : tensor<?x?x?xf32>)
outs(%out : tensor<1x2x3xf32>) {
^bb0(%arg2 : f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> (tensor<1x2x3xf32>)
return %g : tensor<1x2x3xf32>
}
// CHECK-LABEL: func @remove_no_op_mismatched_types
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
// CHECK: return %[[CAST]]
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index