forked from OSchip/llvm-project
[mlir][Complex]: Add lowerings for AddOp and SubOp from Complex dialect to
Standard. Differential Revision: https://reviews.llvm.org/D106429
This commit is contained in:
parent
80e0bd1496
commit
fb978f092c
|
@ -79,6 +79,35 @@ struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Default conversion which applies the BinaryStandardOp separately on the real
|
||||
// and imaginary parts. Can for example be used for complex::AddOp and
|
||||
// complex::SubOp.
|
||||
template <typename BinaryComplexOp, typename BinaryStandardOp>
|
||||
struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
|
||||
using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(BinaryComplexOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
typename BinaryComplexOp::Adaptor transformed(operands);
|
||||
auto type = transformed.lhs().getType().template cast<ComplexType>();
|
||||
auto elementType = type.getElementType().template cast<FloatType>();
|
||||
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
Value realLhs = b.create<complex::ReOp>(elementType, transformed.lhs());
|
||||
Value realRhs = b.create<complex::ReOp>(elementType, transformed.rhs());
|
||||
Value resultReal =
|
||||
b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
|
||||
Value imagLhs = b.create<complex::ImOp>(elementType, transformed.lhs());
|
||||
Value imagRhs = b.create<complex::ImOp>(elementType, transformed.rhs());
|
||||
Value resultImag =
|
||||
b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
|
||||
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
||||
resultImag);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
|
||||
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
|
||||
|
||||
|
@ -554,6 +583,8 @@ void mlir::populateComplexToStandardConversionPatterns(
|
|||
AbsOpConversion,
|
||||
ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
|
||||
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
|
||||
BinaryComplexOpConversion<complex::AddOp, AddFOp>,
|
||||
BinaryComplexOpConversion<complex::SubOp, SubFOp>,
|
||||
DivOpConversion,
|
||||
ExpOpConversion,
|
||||
LogOpConversion,
|
||||
|
@ -578,12 +609,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
|
|||
populateComplexToStandardConversionPatterns(patterns);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
|
||||
complex::ComplexDialect>();
|
||||
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
|
||||
complex::ExpOp, complex::LogOp, complex::Log1pOp,
|
||||
complex::MulOp, complex::NegOp, complex::NotEqualOp,
|
||||
complex::SignOp>();
|
||||
target.addLegalDialect<StandardOpsDialect, math::MathDialect>();
|
||||
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
|
||||
if (failed(applyPartialConversion(function, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -14,6 +14,21 @@ func @complex_abs(%arg: complex<f32>) -> f32 {
|
|||
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
|
||||
// CHECK: return %[[NORM]] : f32
|
||||
|
||||
// CHECK-LABEL: func @complex_add
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
%add = complex.add %lhs, %rhs: complex<f32>
|
||||
return %add : complex<f32>
|
||||
}
|
||||
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
|
||||
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
|
||||
// CHECK: %[[RESULT_REAL:.*]] = addf %[[REAL_LHS]], %[[REAL_RHS]] : f32
|
||||
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
|
||||
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
|
||||
// CHECK: %[[RESULT_IMAG:.*]] = addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
|
||||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// CHECK-LABEL: func @complex_div
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
|
@ -366,3 +381,18 @@ func @complex_sign(%arg: complex<f32>) -> complex<f32> {
|
|||
// CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex<f32>
|
||||
// CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
||||
// CHECK-LABEL: func @complex_sub
|
||||
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
|
||||
func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
|
||||
%sub = complex.sub %lhs, %rhs: complex<f32>
|
||||
return %sub : complex<f32>
|
||||
}
|
||||
// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex<f32>
|
||||
// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex<f32>
|
||||
// CHECK: %[[RESULT_REAL:.*]] = subf %[[REAL_LHS]], %[[REAL_RHS]] : f32
|
||||
// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex<f32>
|
||||
// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex<f32>
|
||||
// CHECK: %[[RESULT_IMAG:.*]] = subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
|
||||
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
|
||||
// CHECK: return %[[RESULT]] : complex<f32>
|
||||
|
|
Loading…
Reference in New Issue