forked from OSchip/llvm-project
[mlir][Linalg] Disable const -> linalg.generic when fused op is illegal.
Fusing a constant with a linalg.generic operation can result in the fused operation being illegal since the loop bound computation fails. Avoid such fusions. Differential Revision: https://reviews.llvm.org/D100272
This commit is contained in:
parent
15689f3af0
commit
b0fc712b14
|
@ -1103,6 +1103,12 @@ public:
|
|||
linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
|
||||
fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
|
||||
|
||||
// Check if the operation shapes to loops map is computable.
|
||||
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
linalgOp, "fused op loop bound computation failed");
|
||||
}
|
||||
|
||||
// The operands list is same as the linalgOp with the argument for
|
||||
// constant index dropped.
|
||||
SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
|
||||
|
|
|
@ -678,3 +678,26 @@ func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8
|
|||
} -> tensor<1x8xindex>
|
||||
return %1 : tensor<1x8xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @no_fuse_constant_with_reduction
|
||||
func @no_fuse_constant_with_reduction() -> tensor<3xf32>
|
||||
{
|
||||
// CHECK: %[[CONST:.+]] = constant {{.+}} : tensor<3x2xf32>
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[CONST]] : tensor<3x2xf32>)
|
||||
// CHECK: return %[[RESULT]]
|
||||
%three = constant dense<3.0> : tensor<3x2xf32>
|
||||
%init = linalg.init_tensor [3] : tensor<3xf32>
|
||||
%result = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%three : tensor<3x2xf32>) outs(%init : tensor<3xf32>) {
|
||||
^bb0(%arg0 : f32, %arg1 : f32):
|
||||
%0 = addf %arg0, %arg1 : f32
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<3xf32>
|
||||
return %result : tensor<3xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue