forked from OSchip/llvm-project
[mlir][complex] Convert complex.abs to libm
Convert complex.abs to libm library Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D127476
This commit is contained in:
parent
85768677f8
commit
eaba6e0b5c
|
@ -16,14 +16,43 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// Functor to resolve the function name corresponding to the given complex
|
||||
// result type.
|
||||
struct ComplexTypeResolver {
|
||||
llvm::Optional<bool> operator()(Type type) const {
|
||||
auto complexType = type.cast<ComplexType>();
|
||||
auto elementType = complexType.getElementType();
|
||||
if (!elementType.isa<Float32Type, Float64Type>())
|
||||
return {};
|
||||
|
||||
return elementType.getIntOrFloatBitWidth() == 64;
|
||||
}
|
||||
};
|
||||
|
||||
// Functor to resolve the function name corresponding to the given float result
|
||||
// type.
|
||||
struct FloatTypeResolver {
|
||||
llvm::Optional<bool> operator()(Type type) const {
|
||||
auto elementType = type.cast<FloatType>();
|
||||
if (!elementType.isa<Float32Type, Float64Type>())
|
||||
return {};
|
||||
|
||||
return elementType.getIntOrFloatBitWidth() == 64;
|
||||
}
|
||||
};
|
||||
|
||||
// Pattern to convert scalar complex operations to calls to libm functions.
|
||||
// Additionally the libm function signatures are declared.
|
||||
template <typename Op>
|
||||
// TypeResolver is a functor returning the libm function name according to the
|
||||
// expected type double or float.
|
||||
template <typename Op, typename TypeResolver = ComplexTypeResolver>
|
||||
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
|
||||
public:
|
||||
using OpRewritePattern<Op>::OpRewritePattern;
|
||||
ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
|
||||
StringRef doubleFunc, PatternBenefit benefit)
|
||||
ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
|
||||
StringRef floatFunc,
|
||||
StringRef doubleFunc,
|
||||
PatternBenefit benefit)
|
||||
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
|
||||
doubleFunc(doubleFunc){};
|
||||
|
||||
|
@ -34,18 +63,16 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
template <typename Op>
|
||||
LogicalResult
|
||||
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
|
||||
PatternRewriter &rewriter) const {
|
||||
template <typename Op, typename TypeResolver>
|
||||
LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
|
||||
Op op, PatternRewriter &rewriter) const {
|
||||
auto module = SymbolTable::getNearestSymbolTable(op);
|
||||
auto type = op.getType().template cast<ComplexType>();
|
||||
Type elementType = type.getElementType();
|
||||
if (!elementType.isa<Float32Type, Float64Type>())
|
||||
auto isDouble = TypeResolver()(op.getType());
|
||||
if (!isDouble.hasValue())
|
||||
return failure();
|
||||
|
||||
auto name =
|
||||
elementType.getIntOrFloatBitWidth() == 64 ? doubleFunc : floatFunc;
|
||||
auto name = isDouble.value() ? doubleFunc : floatFunc;
|
||||
|
||||
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
|
||||
SymbolTable::lookupSymbolIn(module, name));
|
||||
// Forward declare function if it hasn't already been
|
||||
|
@ -60,7 +87,8 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
|
|||
}
|
||||
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
|
||||
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, type, op->getOperands());
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
|
||||
op->getOperands());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -79,6 +107,8 @@ void mlir::populateComplexToLibmConversionPatterns(RewritePatternSet &patterns,
|
|||
"csinf", "csin", benefit);
|
||||
patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
|
||||
"conjf", "conj", benefit);
|
||||
patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
|
||||
patterns.getContext(), "cabsf", "cabs", benefit);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -96,7 +126,8 @@ void ConvertComplexToLibmPass::runOnOperation() {
|
|||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp>();
|
||||
target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
|
||||
complex::AbsOp>();
|
||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
// CHECK-DAG: @ccos(complex<f64>) -> complex<f64>
|
||||
// CHECK-DAG: @csin(complex<f64>) -> complex<f64>
|
||||
// CHECK-DAG: @conj(complex<f64>) -> complex<f64>
|
||||
// CHECK-DAG: @cabs(complex<f64>) -> f64
|
||||
|
||||
// CHECK-LABEL: func @cpow_caller
|
||||
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
|
||||
|
@ -80,4 +81,16 @@ func.func @conj_caller(%float: complex<f32>, %double: complex<f64>) -> (complex<
|
|||
%double_result = complex.conj %double : complex<f64>
|
||||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
|
||||
return %float_result, %double_result : complex<f32>, complex<f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cabs_caller
|
||||
// CHECK-SAME: %[[FLOAT:.*]]: complex<f32>
|
||||
// CHECK-SAME: %[[DOUBLE:.*]]: complex<f64>
|
||||
func.func @cabs_caller(%float: complex<f32>, %double: complex<f64>) -> (f32, f64) {
|
||||
// CHECK: %[[FLOAT_RESULT:.*]] = call @cabsf(%[[FLOAT]])
|
||||
%float_result = complex.abs %float : complex<f32>
|
||||
// CHECK: %[[DOUBLE_RESULT:.*]] = call @cabs(%[[DOUBLE]])
|
||||
%double_result = complex.abs %double : complex<f64>
|
||||
// CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
|
||||
return %float_result, %double_result : f32, f64
|
||||
}
|
Loading…
Reference in New Issue