[mlir][Linalg] Add bounded recursion declaration to FMAOp -> LLVM conversion.

FMAOp -> LLVM conversion is done progressively by peeling off 1 dimension from FMAOp at each pattern iteration. Add the recursively bounded property declaration to the pattern so that the rewriter can apply it multiple times.

Without this, FMAOps with 3+D do not lower to LLVM.

Differential Revision: https://reviews.llvm.org/D113886
This commit is contained in:
Nicolas Vasilache 2021-11-15 12:19:00 +00:00
parent 9b1d90e8ac
commit ee80ffbf9a
2 changed files with 13 additions and 2 deletions

View File

@ -752,6 +752,12 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
public:
using OpRewritePattern<FMAOp>::OpRewritePattern;
void initialize() {
// This pattern recursively unpacks one dimension at a time. The recursion
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(FMAOp op,
PatternRewriter &rewriter) const override {
auto vType = op.getVectorType();

View File

@ -941,10 +941,11 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// -----
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) {
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
// CHECK-SAME: %[[C:.*]]: vector<1x1x1xf32>
// CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: "llvm.intr.fmuladd"
// CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
@ -964,7 +965,11 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
%1 = vector.fma %b, %b, %b : vector<2x4xf32>
return %0, %1: vector<8xf32>, vector<2x4xf32>
// CHECK: %[[C0:.*]] = "llvm.intr.fmuladd"
// CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
%2 = vector.fma %c, %c, %c : vector<1x1x1xf32>
return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>
}
// -----