[mlir][complex] Convert complex.abs to libm

Convert complex.abs to libm library

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D127476
This commit is contained in:
lewuathe 2022-07-08 09:24:34 +09:00
parent 85768677f8
commit eaba6e0b5c
2 changed files with 58 additions and 14 deletions

View File

@ -16,14 +16,43 @@
using namespace mlir;
namespace {
// Functor to resolve the function name corresponding to the given complex
// result type.
struct ComplexTypeResolver {
llvm::Optional<bool> operator()(Type type) const {
auto complexType = type.cast<ComplexType>();
auto elementType = complexType.getElementType();
if (!elementType.isa<Float32Type, Float64Type>())
return {};
return elementType.getIntOrFloatBitWidth() == 64;
}
};
// Functor to resolve the function name corresponding to the given float result
// type.
struct FloatTypeResolver {
llvm::Optional<bool> operator()(Type type) const {
auto elementType = type.cast<FloatType>();
if (!elementType.isa<Float32Type, Float64Type>())
return {};
return elementType.getIntOrFloatBitWidth() == 64;
}
};
// Pattern to convert scalar complex operations to calls to libm functions.
// Additionally the libm function signatures are declared.
template <typename Op>
// TypeResolver is a functor returning the libm function name according to the
// expected type double or float.
template <typename Op, typename TypeResolver = ComplexTypeResolver>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
StringRef doubleFunc, PatternBenefit benefit)
ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
StringRef floatFunc,
StringRef doubleFunc,
PatternBenefit benefit)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
@ -34,18 +63,16 @@ private:
};
} // namespace
template <typename Op>
LogicalResult
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
PatternRewriter &rewriter) const {
template <typename Op, typename TypeResolver>
LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
Op op, PatternRewriter &rewriter) const {
auto module = SymbolTable::getNearestSymbolTable(op);
auto type = op.getType().template cast<ComplexType>();
Type elementType = type.getElementType();
if (!elementType.isa<Float32Type, Float64Type>())
auto isDouble = TypeResolver()(op.getType());
if (!isDouble.hasValue())
return failure();
auto name =
elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
auto name = isDouble.value() ? doubleFunc : floatFunc;
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
SymbolTable::lookupSymbolIn(module, name));
// Forward declare function if it hasn't already been
@ -60,7 +87,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
}
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
op->getOperands());
return success();
}
@ -79,6 +107,8 @@ void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
"csinf", "csin", benefit);
patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
"conjf", "conj", benefit);
patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
patterns.getContext(), "cabsf", "cabs", benefit);
}
namespace {
@ -96,7 +126,8 @@ void ConvertComplexToLibmPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
complex::AbsOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}

View File

@ -9,6 +9,7 @@
// CHECK-DAG: @ccos(complex<f64>) -> complex<f64>
// CHECK-DAG: @csin(complex<f64>) -> complex<f64>
// CHECK-DAG: @conj(complex<f64>) -> complex<f64>
// CHECK-DAG: @cabs(complex<f64>) -> f64
// CHECK-LABEL: func @cpow_caller
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
@ -80,4 +81,16 @@ func.func @conj_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<
%double_result = complex.conj %double : complex<f64>
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : complex<f32>, complex<f64>
}
// CHECK-LABEL: func @cabs_caller
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
func.func @cabs_caller(%float: complex<f32>, %double: complex<f64>) -> (f32, f64) {
// CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]])
%float_result = complex.abs %float : complex<f32>
// CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]])
%double_result = complex.abs %double : complex<f64>
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
return %float_result, %double_result : f32, f64
}