diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index d642a09f6a50..445d391e65b7 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -121,6 +121,15 @@ OpFoldResult AddOp::fold(ArrayRef operands) { if (getLhs() == sub.getRhs()) return sub.getLhs(); + // complex.add(a, complex.constant<0.0, 0.0>) -> a + if (auto constantOp = getRhs().getDefiningOp()) { + auto arrayAttr = constantOp.getValue(); + if (arrayAttr[0].cast().getValue().isZero() && + arrayAttr[1].cast().getValue().isZero()) { + return getLhs(); + } + } + return {}; } diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir index 21ad95af3998..d57dc693d160 100644 --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -123,4 +123,14 @@ func.func @complex_conj_conj() -> complex { %conj1 = complex.conj %complex1 : complex %conj2 = complex.conj %conj1 : complex return %conj2 : complex +} + +// CHECK-LABEL: func @complex_add_zero +func.func @complex_add_zero() -> complex { + %complex1 = complex.constant [1.0 : f32, 0.0 : f32] : complex + %complex2 = complex.constant [0.0 : f32, 0.0 : f32] : complex + // CHECK: %[[CPLX:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK-NEXT: return %[[CPLX:.*]] : complex + %add = complex.add %complex1, %complex2 : complex + return %add : complex } \ No newline at end of file