[mlir][spirv] Add math.fma lowering to spirv

Differential Revision: https://reviews.llvm.org/D117704
This commit is contained in:
Thomas Raoux 2022-01-19 10:32:16 -08:00
parent 4060b81e76
commit d9edc1a585
5 changed files with 67 additions and 50 deletions

View File

@ -790,25 +790,25 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>,
spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>,
spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
RemSIOpGLSLPattern, RemSIOpOCLPattern,
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>,
spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>,
spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>,
spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>,
spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>,
spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>,
spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,

View File

@ -64,35 +64,36 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// GLSL patterns
patterns.add<
Log1pOpPattern<spirv::GLSLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
typeConverter, patterns.getContext());
patterns
.add<Log1pOpPattern<spirv::GLSLLogOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
spirv::UnaryAndBinaryOpPattern<math::ErfOp, spirv::OCLErfOp>,
spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
typeConverter, patterns.getContext());
}

View File

@ -15,16 +15,17 @@
namespace mlir {
namespace spirv {
/// Converts unary and binary standard operations to SPIR-V operations.
/// Converts elementwise unary, binary and ternary standard operations to SPIR-V
/// operations.
template <typename Op, typename SPIRVOp>
class UnaryAndBinaryOpPattern final : public OpConversionPattern<Op> {
class ElementwiseOpPattern final : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 2);
assert(adaptor.getOperands().size() <= 3);
auto dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();

View File

@ -230,12 +230,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<
// Unary and binary patterns
spirv::UnaryAndBinaryOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
spirv::UnaryAndBinaryOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
spirv::UnaryAndBinaryOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
spirv::UnaryAndBinaryOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
spirv::UnaryAndBinaryOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
spirv::UnaryAndBinaryOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
CondBranchOpPattern>(typeConverter, context);

View File

@ -68,4 +68,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
return
}
// CHECK-LABEL: @float32_ternary_scalar
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
// CHECK: spv.GLSL.Fma %{{.*}}: f32
%0 = math.fma %a, %b, %c : f32
return
}
// CHECK-LABEL: @float32_ternary_vector
func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
%c: vector<4xf32>) {
// CHECK: spv.GLSL.Fma %{{.*}}: vector<4xf32>
%0 = math.fma %a, %b, %c : vector<4xf32>
return
}
} // end module