forked from OSchip/llvm-project
[mlir][Linalg] Add bounded recursion declaration to FMAOp -> LLVM conversion.
FMAOp -> LLVM conversion is done progressively by peeling off 1 dimension from FMAOp at each pattern iteration. Add the recursively bounded property declaration to the pattern so that the rewriter can apply it multiple times. Without this, FMAOps with 3+D do not lower to LLVM. Differential Revision: https://reviews.llvm.org/D113886
This commit is contained in:
parent
9b1d90e8ac
commit
ee80ffbf9a
|
@ -752,6 +752,12 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
|
|||
public:
|
||||
using OpRewritePattern<FMAOp>::OpRewritePattern;
|
||||
|
||||
void initialize() {
|
||||
// This pattern recursively unpacks one dimension at a time. The recursion
|
||||
// bounded as the rank is strictly decreasing.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(FMAOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vType = op.getVectorType();
|
||||
|
|
|
@ -941,10 +941,11 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
|
|||
|
||||
// -----
|
||||
|
||||
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
|
||||
func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) {
|
||||
// CHECK-LABEL: @vector_fma
|
||||
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
|
||||
// CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[C:.*]]: vector<1x1x1xf32>
|
||||
// CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
|
||||
// CHECK: "llvm.intr.fmuladd"
|
||||
// CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
|
||||
|
@ -964,7 +965,11 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
|
|||
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
|
||||
%1 = vector.fma %b, %b, %b : vector<2x4xf32>
|
||||
|
||||
return %0, %1: vector<8xf32>, vector<2x4xf32>
|
||||
// CHECK: %[[C0:.*]] = "llvm.intr.fmuladd"
|
||||
// CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
|
||||
%2 = vector.fma %c, %c, %c : vector<1x1x1xf32>
|
||||
|
||||
return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue