[MLIR][arith] More float op folders

Fold `arith.fadd %x, -0.0 -> %x` and similarly for `fsub`, `fmul`, `fdiv`.

Fold `arith.fmin %x, %x -> %x`, `arith.fmin %x, +inf -> %x` and similarly for `fmax`.

Reviewed By: pifon2a, mehdi_amini, bondhugula

Differential Revision: https://reviews.llvm.org/D118244
This commit is contained in:
Christian Sigg 2022-01-31 14:07:25 +01:00
parent 23091f7d50
commit f278cf9cbc
6 changed files with 283 additions and 67 deletions

View File

@ -653,6 +653,7 @@ def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf", [Commutative]> {
%a = arith.maxf %b, %c : f64 %a = arith.maxf %b, %c : f64
``` ```
}]; }];
let hasFolder = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -696,6 +697,7 @@ def Arith_MinFOp : Arith_FloatBinaryOp<"minf", [Commutative]> {
%a = arith.minf %b, %c : f64 %a = arith.minf %b, %c : f64
``` ```
}]; }];
let hasFolder = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -91,6 +91,43 @@ struct constant_op_binder {
} }
}; };
/// The matcher that matches a constant scalar / vector splat / tensor splat
/// float operation and binds the constant float value.
struct constant_float_op_binder {
FloatAttr::ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
bool match(Operation *op) {
Attribute attr;
if (!constant_op_binder<Attribute>(&attr).match(op))
return false;
auto type = op->getResult(0).getType();
if (type.isa<FloatType>())
return attr_value_binder<FloatAttr>(bind_value).match(attr);
if (type.isa<VectorType, RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<FloatAttr>(bind_value)
.match(splatAttr.getSplatValue<Attribute>());
}
}
return false;
}
};
/// The matcher that matches a given target constant scalar / vector splat /
/// tensor splat float value that fulfills a predicate.
struct constant_float_predicate_matcher {
bool (*predicate)(const APFloat &);
bool match(Operation *op) {
APFloat value(APFloat::Bogus());
return constant_float_op_binder(&value).match(op) && predicate(value);
}
};
/// The matcher that matches a constant scalar / vector splat / tensor splat /// The matcher that matches a constant scalar / vector splat / tensor splat
/// integer operation and binds the constant integer value. /// integer operation and binds the constant integer value.
struct constant_int_op_binder { struct constant_int_op_binder {
@ -118,22 +155,13 @@ struct constant_int_op_binder {
}; };
/// The matcher that matches a given target constant scalar / vector splat / /// The matcher that matches a given target constant scalar / vector splat /
/// tensor splat integer value. /// tensor splat integer value that fulfills a predicate.
template <int64_t TargetValue> struct constant_int_predicate_matcher {
struct constant_int_value_matcher { bool (*predicate)(const APInt &);
bool match(Operation *op) {
APInt value;
return constant_int_op_binder(&value).match(op) && TargetValue == value;
}
};
/// The matcher that matches anything except the given target constant scalar /
/// vector splat / tensor splat integer value.
template <int64_t TargetNotValue>
struct constant_int_not_value_matcher {
bool match(Operation *op) { bool match(Operation *op) {
APInt value; APInt value;
return constant_int_op_binder(&value).match(op) && TargetNotValue != value; return constant_int_op_binder(&value).match(op) && predicate(value);
} }
}; };
@ -239,9 +267,59 @@ inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
return detail::constant_op_binder<AttrT>(bind_value); return detail::constant_op_binder<AttrT>(bind_value);
} }
/// Matches a constant scalar / vector splat / tensor splat float (both positive
/// and negative) zero.
inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
return {[](const APFloat &value) { return value.isZero(); }};
}
/// Matches a constant scalar / vector splat / tensor splat float positive zero.
inline detail::constant_float_predicate_matcher m_PosZeroFloat() {
return {[](const APFloat &value) { return value.isPosZero(); }};
}
/// Matches a constant scalar / vector splat / tensor splat float negative zero.
inline detail::constant_float_predicate_matcher m_NegZeroFloat() {
return {[](const APFloat &value) { return value.isNegZero(); }};
}
/// Matches a constant scalar / vector splat / tensor splat float ones.
inline detail::constant_float_predicate_matcher m_OneFloat() {
return {[](const APFloat &value) {
return APFloat(value.getSemantics(), 1) == value;
}};
}
/// Matches a constant scalar / vector splat / tensor splat float positive
/// infinity.
inline detail::constant_float_predicate_matcher m_PosInfFloat() {
return {[](const APFloat &value) {
return !value.isNegative() && value.isInfinity();
}};
}
/// Matches a constant scalar / vector splat / tensor splat float negative
/// infinity.
inline detail::constant_float_predicate_matcher m_NegInfFloat() {
return {[](const APFloat &value) {
return value.isNegative() && value.isInfinity();
}};
}
/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_predicate_matcher m_Zero() {
return {[](const APInt &value) { return 0 == value; }};
}
/// Matches a constant scalar / vector splat / tensor splat integer that is any
/// non-zero value.
inline detail::constant_int_predicate_matcher m_NonZero() {
return {[](const APInt &value) { return 0 != value; }};
}
/// Matches a constant scalar / vector splat / tensor splat integer one. /// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_value_matcher<1> m_One() { inline detail::constant_int_predicate_matcher m_One() {
return detail::constant_int_value_matcher<1>(); return {[](const APInt &value) { return 1 == value; }};
} }
/// Matches the given OpClass. /// Matches the given OpClass.
@ -250,17 +328,6 @@ inline detail::op_matcher<OpClass> m_Op() {
return detail::op_matcher<OpClass>(); return detail::op_matcher<OpClass>();
} }
/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_value_matcher<0> m_Zero() {
return detail::constant_int_value_matcher<0>();
}
/// Matches a constant scalar / vector splat / tensor splat integer that is any
/// non-zero value.
inline detail::constant_int_not_value_matcher<0> m_NonZero() {
return detail::constant_int_not_value_matcher<0>();
}
/// Entry point for matching a pattern over a Value. /// Entry point for matching a pattern over a Value.
template <typename Pattern> template <typename Pattern>
inline bool matchPattern(Value value, const Pattern &pattern) { inline bool matchPattern(Value value, const Pattern &pattern) {
@ -276,6 +343,13 @@ inline bool matchPattern(Operation *op, const Pattern &pattern) {
return const_cast<Pattern &>(pattern).match(op); return const_cast<Pattern &>(pattern).match(op);
} }
/// Matches a constant holding a scalar/vector/tensor float (splat) and
/// writes the float value to bind_value.
inline detail::constant_float_op_binder
m_ConstantFloat(FloatAttr::ValueType *bind_value) {
return detail::constant_float_op_binder(bind_value);
}
/// Matches a constant holding a scalar/vector/tensor integer (splat) and /// Matches a constant holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value. /// writes the integer value to bind_value.
inline detail::constant_int_op_binder inline detail::constant_int_op_binder

View File

@ -194,12 +194,12 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
if (matchPattern(getRhs(), m_Zero())) if (matchPattern(getRhs(), m_Zero()))
return getLhs(); return getLhs();
// add(sub(a, b), b) -> a // addi(subi(a, b), b) -> a
if (auto sub = getLhs().getDefiningOp<SubIOp>()) if (auto sub = getLhs().getDefiningOp<SubIOp>())
if (getRhs() == sub.getRhs()) if (getRhs() == sub.getRhs())
return sub.getLhs(); return sub.getLhs();
// add(b, sub(a, b)) -> a // addi(b, subi(a, b)) -> a
if (auto sub = getRhs().getDefiningOp<SubIOp>()) if (auto sub = getRhs().getDefiningOp<SubIOp>())
if (getLhs() == sub.getRhs()) if (getLhs() == sub.getRhs())
return sub.getLhs(); return sub.getLhs();
@ -576,6 +576,14 @@ void arith::XOrIOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
// addf(x, -0) -> x
if (matchPattern(getRhs(), m_NegZeroFloat()))
return getLhs();
// addf(-0, x) -> x
if (matchPattern(getLhs(), m_NegZeroFloat()))
return getRhs();
return constFoldBinaryOp<FloatAttr>( return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a + b; }); operands, [](const APFloat &a, const APFloat &b) { return a + b; });
} }
@ -585,10 +593,34 @@ OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
// subf(x, +0) -> x
if (matchPattern(getRhs(), m_PosZeroFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>( return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a - b; }); operands, [](const APFloat &a, const APFloat &b) { return a - b; });
} }
//===----------------------------------------------------------------------===//
// MaxFOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "maxf takes two operands");
// maxf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
// maxf(x, -inf) -> x
if (matchPattern(getRhs(), m_NegInfFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
operands,
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MaxSIOp // MaxSIOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -643,6 +675,26 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
}); });
} }
//===----------------------------------------------------------------------===//
// MinFOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "minf takes two operands");
// minf(x,x) -> x
if (getLhs() == getRhs())
return getRhs();
// minf(x, +inf) -> x
if (matchPattern(getRhs(), m_PosInfFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
operands,
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MinSIOp // MinSIOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -702,6 +754,15 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
APFloat floatValue(0.0f), inverseValue(0.0f);
// mulf(x, 1) -> x
if (matchPattern(getRhs(), m_OneFloat()))
return getLhs();
// mulf(1, x) -> x
if (matchPattern(getLhs(), m_OneFloat()))
return getRhs();
return constFoldBinaryOp<FloatAttr>( return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a * b; }); operands, [](const APFloat &a, const APFloat &b) { return a * b; });
} }
@ -711,6 +772,11 @@ OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
APFloat floatValue(0.0f), inverseValue(0.0f);
// divf(x, 1) -> x
if (matchPattern(getRhs(), m_OneFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>( return constFoldBinaryOp<FloatAttr>(
operands, [](const APFloat &a, const APFloat &b) { return a / b; }); operands, [](const APFloat &a, const APFloat &b) { return a / b; });
} }

View File

@ -619,6 +619,113 @@ func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
// ----- // -----
// CHECK-LABEL: @test_minf(
func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-NEXT: %[[X:.+]] = arith.minf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
%inf = arith.constant 0x7F800000 : f32
%0 = arith.minf %c0, %arg0 : f32
%1 = arith.minf %arg0, %arg0 : f32
%2 = arith.minf %inf, %arg0 : f32
return %0, %1, %2 : f32, f32, f32
}
// -----
// CHECK-LABEL: @test_maxf(
func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant
// CHECK-NEXT: %[[X:.+]] = arith.maxf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0
%c0 = arith.constant 0.0 : f32
%-inf = arith.constant 0xFF800000 : f32
%0 = arith.maxf %c0, %arg0 : f32
%1 = arith.maxf %arg0, %arg0 : f32
%2 = arith.maxf %-inf, %arg0 : f32
return %0, %1, %2 : f32, f32, f32
}
// -----
// CHECK-LABEL: @test_addf(
func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2.0
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
// CHECK-NEXT: %[[X:.+]] = arith.addf %arg0, %[[C0]]
// CHECK-NEXT: return %[[X]], %arg0, %arg0, %[[C2]]
%c0 = arith.constant 0.0 : f32
%c-0 = arith.constant -0.0 : f32
%c1 = arith.constant 1.0 : f32
%0 = arith.addf %arg0, %c0 : f32
%1 = arith.addf %arg0, %c-0 : f32
%2 = arith.addf %c-0, %arg0 : f32
%3 = arith.addf %c1, %c1 : f32
return %0, %1, %2, %3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: @test_subf(
func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
// CHECK-DAG: %[[C1:.+]] = arith.constant -1.0
// CHECK-DAG: %[[C0:.+]] = arith.constant -0.0
// CHECK-NEXT: %[[X:.+]] = arith.subf %arg0, %[[C0]]
// CHECK-NEXT: return %arg0, %[[X]], %[[C1]]
%c0 = arith.constant 0.0 : f16
%c-0 = arith.constant -0.0 : f16
%c1 = arith.constant 1.0 : f16
%0 = arith.subf %arg0, %c0 : f16
%1 = arith.subf %arg0, %c-0 : f16
%2 = arith.subf %c0, %c1 : f16
return %0, %1, %2 : f16, f16, f16
}
// -----
// CHECK-LABEL: @test_mulf(
func @test_mulf(%arg0 : f32) -> (f32, f32, f32) {
// CHECK-NEXT: %[[C4:.+]] = arith.constant 4.0
// CHECK-NEXT: return %arg0, %arg0, %[[C4]]
%c1 = arith.constant 1.0 : f32
%c2 = arith.constant 2.0 : f32
%0 = arith.mulf %arg0, %c1 : f32
%1 = arith.mulf %c1, %arg0 : f32
%2 = arith.mulf %c2, %c2 : f32
return %0, %1, %2 : f32, f32, f32
}
// -----
// CHECK-LABEL: @test_divf(
func @test_divf(%arg0 : f64) -> (f64, f64) {
// CHECK-NEXT: %[[C5:.+]] = arith.constant 5.000000e-01
// CHECK-NEXT: return %arg0, %[[C5]]
%c1 = arith.constant 1.0 : f64
%c2 = arith.constant 2.0 : f64
%0 = arith.divf %arg0, %c1 : f64
%1 = arith.divf %c1, %c2 : f64
return %0, %1 : f64, f64
}
// -----
// CHECK-LABEL: @test_cmpf(
func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
%nan = arith.constant 0x7fffffff : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
%3 = arith.cmpf ugt, %arg0, %nan : f32
return %0, %1, %2, %3 : i1, i1, i1, i1
}
// -----
// CHECK-LABEL: @constant_FPtoUI( // CHECK-LABEL: @constant_FPtoUI(
func @constant_FPtoUI() -> i32 { func @constant_FPtoUI() -> i32 {
// CHECK: %[[C0:.+]] = arith.constant 2 : i32 // CHECK: %[[C0:.+]] = arith.constant 2 : i32
@ -678,30 +785,3 @@ func @constant_UItoFP() -> f32 {
%res = arith.sitofp %c0 : i32 to f32 %res = arith.sitofp %c0 : i32 to f32
return %res : f32 return %res : f32
} }
// -----
// CHECK-LABEL: @constant_MinMax(
func @constant_MinMax(%arg0 : f32) -> f32 {
// CHECK: %[[const:.+]] = arith.constant
// CHECK: %[[min:.+]] = arith.minf %arg0, %[[const]] : f32
// CHECK: %[[res:.+]] = arith.maxf %[[min]], %[[const]] : f32
// CHECK: return %[[res]]
%const = arith.constant 0.0 : f32
%min = arith.minf %const, %arg0 : f32
%res = arith.maxf %const, %min : f32
return %res : f32
}
// -----
// CHECK-LABEL: @cmpf_nan(
func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true
// CHECK-DAG: %[[F:.*]] = arith.constant false
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
%nan = arith.constant 0x7fffffff : f32
%0 = arith.cmpf olt, %nan, %arg0 : f32
%1 = arith.cmpf olt, %arg0, %nan : f32
%2 = arith.cmpf ugt, %nan, %arg0 : f32
%3 = arith.cmpf ugt, %arg0, %nan : f32
return %0, %1, %2, %3 : i1, i1, i1, i1
}

View File

@ -878,7 +878,6 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
// CHECK: mulf {{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%ident = arith.constant 1.0 : f32 %ident = arith.constant 1.0 : f32
%init = linalg.init_tensor [4] : tensor<4xf32> %init = linalg.init_tensor [4] : tensor<4xf32>

View File

@ -224,17 +224,14 @@ func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]] // CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) { // CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADDARG]] : f32 // CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32
// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index // CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32> // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32 // CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
// CHECK-NEXT: } // CHECK-NEXT: }
// Epilogue: // Epilogue:
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[R]]#1 : f32 // CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32 // CHECK-NEXT: return %[[ADD2]] : f32
// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[CSTF]], %[[ADD2]] : f32
// CHECK-NEXT: return %[[MUL2]] : f32
func @backedge_different_stage(%A: memref<?xf32>) -> f32 { func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index
@ -264,15 +261,13 @@ func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]], // CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) { // CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32 // CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADD0]] : f32
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index // CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32> // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32 // CHECK-NEXT: scf.yield %[[ADD0]], %[[L2]] : f32, f32
// CHECK-NEXT: } // CHECK-NEXT: }
// Epilogue: // Epilogue:
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32 // CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[ADD1]] : f32 // CHECK-NEXT: return %[[ADD1]] : f32
// CHECK-NEXT: return %[[MUL1]] : f32
func @backedge_same_stage(%A: memref<?xf32>) -> f32 { func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index %c1 = arith.constant 1 : index