forked from OSchip/llvm-project
[mlir][spirv] Add math.fma lowering to spirv
Differential Revision: https://reviews.llvm.org/D117704
This commit is contained in:
parent
4060b81e76
commit
d9edc1a585
|
@ -790,25 +790,25 @@ void mlir::arith::populateArithmeticToSPIRVPatterns(
|
|||
patterns.add<
|
||||
ConstantCompositeOpPattern,
|
||||
ConstantScalarOpPattern,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
|
||||
spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
|
||||
spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
|
||||
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
|
||||
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
|
||||
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
|
||||
RemSIOpGLSLPattern, RemSIOpOCLPattern,
|
||||
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
|
||||
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
|
||||
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>,
|
||||
spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
|
||||
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
|
||||
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
|
||||
spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
|
||||
spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
|
||||
spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
|
||||
spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
|
||||
spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
|
||||
TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
|
||||
TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
|
||||
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
|
||||
|
|
|
@ -64,35 +64,36 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
RewritePatternSet &patterns) {
|
||||
|
||||
// GLSL patterns
|
||||
patterns.add<
|
||||
Log1pOpPattern<spirv::GLSLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns
|
||||
.add<Log1pOpPattern<spirv::GLSLLogOp>,
|
||||
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
|
||||
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
|
||||
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
|
||||
spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
|
||||
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
|
||||
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
|
||||
spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
|
||||
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
|
||||
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
|
||||
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
|
||||
// OpenCL patterns
|
||||
patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::ErfOp, spirv::OCLErfOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
|
||||
spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
|
||||
spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
|
||||
spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
|
||||
spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
|
||||
spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
|
||||
spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
|
||||
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
|
||||
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
|
||||
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
|
||||
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
|
||||
spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
|
|
|
@ -15,16 +15,17 @@
|
|||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
/// Converts unary and binary standard operations to SPIR-V operations.
|
||||
/// Converts elementwise unary, binary and ternary standard operations to SPIR-V
|
||||
/// operations.
|
||||
template <typename Op, typename SPIRVOp>
|
||||
class UnaryAndBinaryOpPattern final : public OpConversionPattern<Op> {
|
||||
class ElementwiseOpPattern final : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
assert(adaptor.getOperands().size() <= 2);
|
||||
assert(adaptor.getOperands().size() <= 3);
|
||||
auto dstType = this->getTypeConverter()->convertType(op.getType());
|
||||
if (!dstType)
|
||||
return failure();
|
||||
|
|
|
@ -230,12 +230,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
|||
|
||||
patterns.add<
|
||||
// Unary and binary patterns
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
|
||||
spirv::UnaryAndBinaryOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
|
||||
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
|
||||
|
||||
ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
|
||||
CondBranchOpPattern>(typeConverter, context);
|
||||
|
|
|
@ -68,4 +68,19 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_ternary_scalar
|
||||
func @float32_ternary_scalar(%a: f32, %b: f32, %c: f32) {
|
||||
// CHECK: spv.GLSL.Fma %{{.*}}: f32
|
||||
%0 = math.fma %a, %b, %c : f32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @float32_ternary_vector
|
||||
func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
|
||||
%c: vector<4xf32>) {
|
||||
// CHECK: spv.GLSL.Fma %{{.*}}: vector<4xf32>
|
||||
%0 = math.fma %a, %b, %c : vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
} // end module
|
||||
|
|
Loading…
Reference in New Issue