forked from OSchip/llvm-project
[mlir][linalg] Insert a cast for identity linalg.generics when the types don't match
This can happen when the result has different dynamic dimensions than the input. Differential Revision: https://reviews.llvm.org/D117498
This commit is contained in:
parent
7294d7dae7
commit
f100bedb03
|
@ -857,12 +857,19 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
|
|||
// Get the argument number of the returned values. That is the operand
|
||||
// number to use for replacing uses of this operation.
|
||||
SmallVector<Value> returnedArgs;
|
||||
for (Value yieldVal : yieldOp.values()) {
|
||||
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
|
||||
for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) {
|
||||
auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
|
||||
if (!yieldArg || yieldArg.getOwner() != &body)
|
||||
return failure();
|
||||
unsigned argumentNumber = yieldArg.getArgNumber();
|
||||
returnedArgs.push_back(genericOp->getOperand(argumentNumber));
|
||||
Value returnedArg = genericOp->getOperand(argumentNumber);
|
||||
Type resultType = genericOp->getResult(yieldVal.index()).getType();
|
||||
// The input can have a different type than the result, e.g. a dynamic
|
||||
// input dimension can be turned into a static output dimension.
|
||||
if (returnedArg.getType() != resultType)
|
||||
returnedArg = rewriter.create<tensor::CastOp>(genericOp.getLoc(),
|
||||
resultType, returnedArg);
|
||||
returnedArgs.push_back(returnedArg);
|
||||
}
|
||||
if (returnedArgs.size() != genericOp->getNumResults())
|
||||
return failure();
|
||||
|
|
|
@ -179,6 +179,27 @@ func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
|
|||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
|
||||
-> tensor<1x2x3xf32> {
|
||||
%out = linalg.init_tensor [1, 2, 3] : tensor<1x2x3xf32>
|
||||
%g = linalg.generic {
|
||||
indexing_maps = [#map, #map],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]
|
||||
} ins(%arg0 : tensor<?x?x?xf32>)
|
||||
outs(%out : tensor<1x2x3xf32>) {
|
||||
^bb0(%arg2 : f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> (tensor<1x2x3xf32>)
|
||||
return %g : tensor<1x2x3xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @remove_no_op_mismatched_types
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
|
||||
// CHECK: return %[[CAST]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
|
|
Loading…
Reference in New Issue