forked from OSchip/llvm-project
[MLIR][arith] More float op folders
Fold `arith.fadd %x, -0.0 -> %x` and similarly for `fsub`, `fmul`, `fdiv`. Fold `arith.fmin %x, %x -> %x`, `arith.fmin %x, +inf -> %x` and similarly for `fmax`. Reviewed By: pifon2a, mehdi_amini, bondhugula Differential Revision: https://reviews.llvm.org/D118244
This commit is contained in:
parent
23091f7d50
commit
f278cf9cbc
|
@ -653,6 +653,7 @@ def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf", [Commutative]> {
|
|||
%a = arith.maxf %b, %c : f64
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -696,6 +697,7 @@ def Arith_MinFOp : Arith_FloatBinaryOp<"minf", [Commutative]> {
|
|||
%a = arith.minf %b, %c : f64
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -91,6 +91,43 @@ struct constant_op_binder {
|
|||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a constant scalar / vector splat / tensor splat
|
||||
/// float operation and binds the constant float value.
|
||||
struct constant_float_op_binder {
|
||||
FloatAttr::ValueType *bind_value;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bv if match succeeds.
|
||||
constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
Attribute attr;
|
||||
if (!constant_op_binder<Attribute>(&attr).match(op))
|
||||
return false;
|
||||
auto type = op->getResult(0).getType();
|
||||
|
||||
if (type.isa<FloatType>())
|
||||
return attr_value_binder<FloatAttr>(bind_value).match(attr);
|
||||
if (type.isa<VectorType, RankedTensorType>()) {
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
return attr_value_binder<FloatAttr>(bind_value)
|
||||
.match(splatAttr.getSplatValue<Attribute>());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a given target constant scalar / vector splat /
|
||||
/// tensor splat float value that fulfills a predicate.
|
||||
struct constant_float_predicate_matcher {
|
||||
bool (*predicate)(const APFloat &);
|
||||
|
||||
bool match(Operation *op) {
|
||||
APFloat value(APFloat::Bogus());
|
||||
return constant_float_op_binder(&value).match(op) && predicate(value);
|
||||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a constant scalar / vector splat / tensor splat
|
||||
/// integer operation and binds the constant integer value.
|
||||
struct constant_int_op_binder {
|
||||
|
@ -118,22 +155,13 @@ struct constant_int_op_binder {
|
|||
};
|
||||
|
||||
/// The matcher that matches a given target constant scalar / vector splat /
|
||||
/// tensor splat integer value.
|
||||
template <int64_t TargetValue>
|
||||
struct constant_int_value_matcher {
|
||||
bool match(Operation *op) {
|
||||
APInt value;
|
||||
return constant_int_op_binder(&value).match(op) && TargetValue == value;
|
||||
}
|
||||
};
|
||||
/// tensor splat integer value that fulfills a predicate.
|
||||
struct constant_int_predicate_matcher {
|
||||
bool (*predicate)(const APInt &);
|
||||
|
||||
/// The matcher that matches anything except the given target constant scalar /
|
||||
/// vector splat / tensor splat integer value.
|
||||
template <int64_t TargetNotValue>
|
||||
struct constant_int_not_value_matcher {
|
||||
bool match(Operation *op) {
|
||||
APInt value;
|
||||
return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
|
||||
return constant_int_op_binder(&value).match(op) && predicate(value);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -239,9 +267,59 @@ inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
|
|||
return detail::constant_op_binder<AttrT>(bind_value);
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat float (both positive
|
||||
/// and negative) zero.
|
||||
inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
|
||||
return {[](const APFloat &value) { return value.isZero(); }};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat float positive zero.
|
||||
inline detail::constant_float_predicate_matcher m_PosZeroFloat() {
|
||||
return {[](const APFloat &value) { return value.isPosZero(); }};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat float negative zero.
|
||||
inline detail::constant_float_predicate_matcher m_NegZeroFloat() {
|
||||
return {[](const APFloat &value) { return value.isNegZero(); }};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat float ones.
|
||||
inline detail::constant_float_predicate_matcher m_OneFloat() {
|
||||
return {[](const APFloat &value) {
|
||||
return APFloat(value.getSemantics(), 1) == value;
|
||||
}};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat float positive
|
||||
/// infinity.
|
||||
inline detail::constant_float_predicate_matcher m_PosInfFloat() {
|
||||
return {[](const APFloat &value) {
|
||||
return !value.isNegative() && value.isInfinity();
|
||||
}};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat float negative
|
||||
/// infinity.
|
||||
inline detail::constant_float_predicate_matcher m_NegInfFloat() {
|
||||
return {[](const APFloat &value) {
|
||||
return value.isNegative() && value.isInfinity();
|
||||
}};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer zero.
|
||||
inline detail::constant_int_predicate_matcher m_Zero() {
|
||||
return {[](const APInt &value) { return 0 == value; }};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer that is any
|
||||
/// non-zero value.
|
||||
inline detail::constant_int_predicate_matcher m_NonZero() {
|
||||
return {[](const APInt &value) { return 0 != value; }};
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer one.
|
||||
inline detail::constant_int_value_matcher<1> m_One() {
|
||||
return detail::constant_int_value_matcher<1>();
|
||||
inline detail::constant_int_predicate_matcher m_One() {
|
||||
return {[](const APInt &value) { return 1 == value; }};
|
||||
}
|
||||
|
||||
/// Matches the given OpClass.
|
||||
|
@ -250,17 +328,6 @@ inline detail::op_matcher<OpClass> m_Op() {
|
|||
return detail::op_matcher<OpClass>();
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer zero.
|
||||
inline detail::constant_int_value_matcher<0> m_Zero() {
|
||||
return detail::constant_int_value_matcher<0>();
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer that is any
|
||||
/// non-zero value.
|
||||
inline detail::constant_int_not_value_matcher<0> m_NonZero() {
|
||||
return detail::constant_int_not_value_matcher<0>();
|
||||
}
|
||||
|
||||
/// Entry point for matching a pattern over a Value.
|
||||
template <typename Pattern>
|
||||
inline bool matchPattern(Value value, const Pattern &pattern) {
|
||||
|
@ -276,6 +343,13 @@ inline bool matchPattern(Operation *op, const Pattern &pattern) {
|
|||
return const_cast<Pattern &>(pattern).match(op);
|
||||
}
|
||||
|
||||
/// Matches a constant holding a scalar/vector/tensor float (splat) and
|
||||
/// writes the float value to bind_value.
|
||||
inline detail::constant_float_op_binder
|
||||
m_ConstantFloat(FloatAttr::ValueType *bind_value) {
|
||||
return detail::constant_float_op_binder(bind_value);
|
||||
}
|
||||
|
||||
/// Matches a constant holding a scalar/vector/tensor integer (splat) and
|
||||
/// writes the integer value to bind_value.
|
||||
inline detail::constant_int_op_binder
|
||||
|
|
|
@ -194,12 +194,12 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (matchPattern(getRhs(), m_Zero()))
|
||||
return getLhs();
|
||||
|
||||
// add(sub(a, b), b) -> a
|
||||
// addi(subi(a, b), b) -> a
|
||||
if (auto sub = getLhs().getDefiningOp<SubIOp>())
|
||||
if (getRhs() == sub.getRhs())
|
||||
return sub.getLhs();
|
||||
|
||||
// add(b, sub(a, b)) -> a
|
||||
// addi(b, subi(a, b)) -> a
|
||||
if (auto sub = getRhs().getDefiningOp<SubIOp>())
|
||||
if (getLhs() == sub.getRhs())
|
||||
return sub.getLhs();
|
||||
|
@ -576,6 +576,14 @@ void arith::XOrIOp::getCanonicalizationPatterns(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
|
||||
// addf(x, -0) -> x
|
||||
if (matchPattern(getRhs(), m_NegZeroFloat()))
|
||||
return getLhs();
|
||||
|
||||
// addf(-0, x) -> x
|
||||
if (matchPattern(getLhs(), m_NegZeroFloat()))
|
||||
return getRhs();
|
||||
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](const APFloat &a, const APFloat &b) { return a + b; });
|
||||
}
|
||||
|
@ -585,10 +593,34 @@ OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
|
||||
// subf(x, +0) -> x
|
||||
if (matchPattern(getRhs(), m_PosZeroFloat()))
|
||||
return getLhs();
|
||||
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](const APFloat &a, const APFloat &b) { return a - b; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "maxf takes two operands");
|
||||
|
||||
// maxf(x,x) -> x
|
||||
if (getLhs() == getRhs())
|
||||
return getRhs();
|
||||
|
||||
// maxf(x, -inf) -> x
|
||||
if (matchPattern(getRhs(), m_NegInfFloat()))
|
||||
return getLhs();
|
||||
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands,
|
||||
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -643,6 +675,26 @@ OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "minf takes two operands");
|
||||
|
||||
// minf(x,x) -> x
|
||||
if (getLhs() == getRhs())
|
||||
return getRhs();
|
||||
|
||||
// minf(x, +inf) -> x
|
||||
if (matchPattern(getRhs(), m_PosInfFloat()))
|
||||
return getLhs();
|
||||
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands,
|
||||
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -702,6 +754,15 @@ OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
|
||||
APFloat floatValue(0.0f), inverseValue(0.0f);
|
||||
// mulf(x, 1) -> x
|
||||
if (matchPattern(getRhs(), m_OneFloat()))
|
||||
return getLhs();
|
||||
|
||||
// mulf(1, x) -> x
|
||||
if (matchPattern(getLhs(), m_OneFloat()))
|
||||
return getRhs();
|
||||
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](const APFloat &a, const APFloat &b) { return a * b; });
|
||||
}
|
||||
|
@ -711,6 +772,11 @@ OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
|
||||
APFloat floatValue(0.0f), inverseValue(0.0f);
|
||||
// divf(x, 1) -> x
|
||||
if (matchPattern(getRhs(), m_OneFloat()))
|
||||
return getLhs();
|
||||
|
||||
return constFoldBinaryOp<FloatAttr>(
|
||||
operands, [](const APFloat &a, const APFloat &b) { return a / b; });
|
||||
}
|
||||
|
|
|
@ -619,6 +619,113 @@ func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_minf(
|
||||
func @test_minf(%arg0 : f32) -> (f32, f32, f32) {
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
|
||||
// CHECK-NEXT: %[[X:.+]] = arith.minf %arg0, %[[C0]]
|
||||
// CHECK-NEXT: return %[[X]], %arg0, %arg0
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%inf = arith.constant 0x7F800000 : f32
|
||||
%0 = arith.minf %c0, %arg0 : f32
|
||||
%1 = arith.minf %arg0, %arg0 : f32
|
||||
%2 = arith.minf %inf, %arg0 : f32
|
||||
return %0, %1, %2 : f32, f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_maxf(
|
||||
func @test_maxf(%arg0 : f32) -> (f32, f32, f32) {
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant
|
||||
// CHECK-NEXT: %[[X:.+]] = arith.maxf %arg0, %[[C0]]
|
||||
// CHECK-NEXT: return %[[X]], %arg0, %arg0
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%-inf = arith.constant 0xFF800000 : f32
|
||||
%0 = arith.maxf %c0, %arg0 : f32
|
||||
%1 = arith.maxf %arg0, %arg0 : f32
|
||||
%2 = arith.maxf %-inf, %arg0 : f32
|
||||
return %0, %1, %2 : f32, f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_addf(
|
||||
func @test_addf(%arg0 : f32) -> (f32, f32, f32, f32) {
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2.0
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.0
|
||||
// CHECK-NEXT: %[[X:.+]] = arith.addf %arg0, %[[C0]]
|
||||
// CHECK-NEXT: return %[[X]], %arg0, %arg0, %[[C2]]
|
||||
%c0 = arith.constant 0.0 : f32
|
||||
%c-0 = arith.constant -0.0 : f32
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%0 = arith.addf %arg0, %c0 : f32
|
||||
%1 = arith.addf %arg0, %c-0 : f32
|
||||
%2 = arith.addf %c-0, %arg0 : f32
|
||||
%3 = arith.addf %c1, %c1 : f32
|
||||
return %0, %1, %2, %3 : f32, f32, f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_subf(
|
||||
func @test_subf(%arg0 : f16) -> (f16, f16, f16) {
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant -1.0
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant -0.0
|
||||
// CHECK-NEXT: %[[X:.+]] = arith.subf %arg0, %[[C0]]
|
||||
// CHECK-NEXT: return %arg0, %[[X]], %[[C1]]
|
||||
%c0 = arith.constant 0.0 : f16
|
||||
%c-0 = arith.constant -0.0 : f16
|
||||
%c1 = arith.constant 1.0 : f16
|
||||
%0 = arith.subf %arg0, %c0 : f16
|
||||
%1 = arith.subf %arg0, %c-0 : f16
|
||||
%2 = arith.subf %c0, %c1 : f16
|
||||
return %0, %1, %2 : f16, f16, f16
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_mulf(
|
||||
func @test_mulf(%arg0 : f32) -> (f32, f32, f32) {
|
||||
// CHECK-NEXT: %[[C4:.+]] = arith.constant 4.0
|
||||
// CHECK-NEXT: return %arg0, %arg0, %[[C4]]
|
||||
%c1 = arith.constant 1.0 : f32
|
||||
%c2 = arith.constant 2.0 : f32
|
||||
%0 = arith.mulf %arg0, %c1 : f32
|
||||
%1 = arith.mulf %c1, %arg0 : f32
|
||||
%2 = arith.mulf %c2, %c2 : f32
|
||||
return %0, %1, %2 : f32, f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_divf(
|
||||
func @test_divf(%arg0 : f64) -> (f64, f64) {
|
||||
// CHECK-NEXT: %[[C5:.+]] = arith.constant 5.000000e-01
|
||||
// CHECK-NEXT: return %arg0, %[[C5]]
|
||||
%c1 = arith.constant 1.0 : f64
|
||||
%c2 = arith.constant 2.0 : f64
|
||||
%0 = arith.divf %arg0, %c1 : f64
|
||||
%1 = arith.divf %c1, %c2 : f64
|
||||
return %0, %1 : f64, f64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_cmpf(
|
||||
func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
|
||||
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
||||
// CHECK-DAG: %[[F:.*]] = arith.constant false
|
||||
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
|
||||
%nan = arith.constant 0x7fffffff : f32
|
||||
%0 = arith.cmpf olt, %nan, %arg0 : f32
|
||||
%1 = arith.cmpf olt, %arg0, %nan : f32
|
||||
%2 = arith.cmpf ugt, %nan, %arg0 : f32
|
||||
%3 = arith.cmpf ugt, %arg0, %nan : f32
|
||||
return %0, %1, %2, %3 : i1, i1, i1, i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @constant_FPtoUI(
|
||||
func @constant_FPtoUI() -> i32 {
|
||||
// CHECK: %[[C0:.+]] = arith.constant 2 : i32
|
||||
|
@ -678,30 +785,3 @@ func @constant_UItoFP() -> f32 {
|
|||
%res = arith.sitofp %c0 : i32 to f32
|
||||
return %res : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @constant_MinMax(
|
||||
func @constant_MinMax(%arg0 : f32) -> f32 {
|
||||
// CHECK: %[[const:.+]] = arith.constant
|
||||
// CHECK: %[[min:.+]] = arith.minf %arg0, %[[const]] : f32
|
||||
// CHECK: %[[res:.+]] = arith.maxf %[[min]], %[[const]] : f32
|
||||
// CHECK: return %[[res]]
|
||||
%const = arith.constant 0.0 : f32
|
||||
%min = arith.minf %const, %arg0 : f32
|
||||
%res = arith.maxf %const, %min : f32
|
||||
return %res : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: @cmpf_nan(
|
||||
func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) {
|
||||
// CHECK-DAG: %[[T:.*]] = arith.constant true
|
||||
// CHECK-DAG: %[[F:.*]] = arith.constant false
|
||||
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
|
||||
%nan = arith.constant 0x7fffffff : f32
|
||||
%0 = arith.cmpf olt, %nan, %arg0 : f32
|
||||
%1 = arith.cmpf olt, %arg0, %nan : f32
|
||||
%2 = arith.cmpf ugt, %nan, %arg0 : f32
|
||||
%3 = arith.cmpf ugt, %arg0, %nan : f32
|
||||
return %0, %1, %2, %3 : i1, i1, i1, i1
|
||||
}
|
||||
|
|
|
@ -878,7 +878,6 @@ func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
|
|||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
|
||||
// CHECK: vector.multi_reduction <mul>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
|
||||
// CHECK: mulf {{.*}} : vector<4xf32>
|
||||
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
|
||||
%ident = arith.constant 1.0 : f32
|
||||
%init = linalg.init_tensor [4] : tensor<4xf32>
|
||||
|
|
|
@ -224,17 +224,14 @@ func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
|
|||
// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
|
||||
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
|
||||
// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
|
||||
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADDARG]] : f32
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32
|
||||
// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
|
||||
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
|
||||
// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
|
||||
// CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
|
||||
// CHECK-NEXT: }
|
||||
// Epilogue:
|
||||
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[R]]#1 : f32
|
||||
// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32
|
||||
// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[CSTF]], %[[ADD2]] : f32
|
||||
// CHECK-NEXT: return %[[MUL2]] : f32
|
||||
// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
|
||||
// CHECK-NEXT: return %[[ADD2]] : f32
|
||||
func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
|
@ -264,15 +261,13 @@ func @backedge_different_stage(%A: memref<?xf32>) -> f32 {
|
|||
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
|
||||
// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
|
||||
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32
|
||||
// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[CSTF]], %[[ADD0]] : f32
|
||||
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
|
||||
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
|
||||
// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32
|
||||
// CHECK-NEXT: scf.yield %[[ADD0]], %[[L2]] : f32, f32
|
||||
// CHECK-NEXT: }
|
||||
// Epilogue:
|
||||
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
|
||||
// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[CSTF]], %[[ADD1]] : f32
|
||||
// CHECK-NEXT: return %[[MUL1]] : f32
|
||||
// CHECK-NEXT: return %[[ADD1]] : f32
|
||||
func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
|
|
Loading…
Reference in New Issue