forked from OSchip/llvm-project
[mlir][StandardToSPIRV] Add support for lowering unary ops
Differential Revision: https://reviews.llvm.org/D76661
This commit is contained in:
parent
7caba33907
commit
58cdb8bff0
|
@ -107,16 +107,16 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
|
|||
|
||||
namespace {
|
||||
|
||||
/// Converts binary standard operations to SPIR-V operations.
|
||||
/// Converts unary and binary standard operations to SPIR-V operations.
|
||||
template <typename StdOp, typename SPIRVOp>
|
||||
class BinaryOpPattern final : public SPIRVOpLowering<StdOp> {
|
||||
class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
assert(operands.size() == 2);
|
||||
assert(operands.size() <= 2);
|
||||
auto dstType = this->typeConverter.convertType(operation.getType());
|
||||
if (!dstType)
|
||||
return failure();
|
||||
|
@ -572,21 +572,31 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
|
|||
SPIRVTypeConverter &typeConverter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
patterns.insert<
|
||||
BinaryOpPattern<AddFOp, spirv::FAddOp>,
|
||||
BinaryOpPattern<AddIOp, spirv::IAddOp>,
|
||||
BinaryOpPattern<DivFOp, spirv::FDivOp>,
|
||||
BinaryOpPattern<MulFOp, spirv::FMulOp>,
|
||||
BinaryOpPattern<MulIOp, spirv::IMulOp>,
|
||||
BinaryOpPattern<RemFOp, spirv::FRemOp>,
|
||||
BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
|
||||
BinaryOpPattern<SignedShiftRightOp, spirv::ShiftRightArithmeticOp>,
|
||||
BinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
|
||||
BinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
|
||||
BinaryOpPattern<SubFOp, spirv::FSubOp>,
|
||||
BinaryOpPattern<SubIOp, spirv::ISubOp>,
|
||||
BinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
|
||||
BinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
|
||||
BinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
|
||||
UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
|
||||
UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
|
||||
UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
|
||||
UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
|
||||
UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
|
||||
UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
|
||||
UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
|
||||
UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
|
||||
UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
|
||||
UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
|
||||
UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
|
||||
UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
|
||||
UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
|
||||
UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
|
||||
UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
|
||||
UnaryAndBinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
|
||||
UnaryAndBinaryOpPattern<SignedShiftRightOp,
|
||||
spirv::ShiftRightArithmeticOp>,
|
||||
UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
|
||||
UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
|
||||
UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
|
||||
UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
|
||||
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
|
||||
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
|
||||
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
|
||||
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
|
||||
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
|
||||
ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,
|
||||
|
|
|
@ -31,9 +31,33 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
|
|||
return
|
||||
}
|
||||
|
||||
// Check float operation conversions.
|
||||
// CHECK-LABEL: @float32_scalar
|
||||
func @float32_scalar(%lhs: f32, %rhs: f32) {
|
||||
// Check float unary operation conversions.
|
||||
// CHECK-LABEL: @float32_unary_scalar
|
||||
func @float32_unary_scalar(%arg0: f32) {
|
||||
// CHECK: spv.GLSL.FAbs %{{.*}}: f32
|
||||
%0 = absf %arg0 : f32
|
||||
// CHECK: spv.GLSL.Ceil %{{.*}}: f32
|
||||
%1 = ceilf %arg0 : f32
|
||||
// CHECK: spv.GLSL.Cos %{{.*}}: f32
|
||||
%2 = cos %arg0 : f32
|
||||
// CHECK: spv.GLSL.Exp %{{.*}}: f32
|
||||
%3 = exp %arg0 : f32
|
||||
// CHECK: spv.GLSL.Log %{{.*}}: f32
|
||||
%4 = log %arg0 : f32
|
||||
// CHECK: spv.FNegate %{{.*}}: f32
|
||||
%5 = negf %arg0 : f32
|
||||
// CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
|
||||
%6 = rsqrt %arg0 : f32
|
||||
// CHECK: spv.GLSL.Sqrt %{{.*}}: f32
|
||||
%7 = sqrt %arg0 : f32
|
||||
// CHECK: spv.GLSL.Tanh %{{.*}}: f32
|
||||
%8 = tanh %arg0 : f32
|
||||
return
|
||||
}
|
||||
|
||||
// Check float binary operation conversions.
|
||||
// CHECK-LABEL: @float32_binary_scalar
|
||||
func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
|
||||
// CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
|
||||
%0 = addf %lhs, %rhs: f32
|
||||
// CHECK: spv.FSub %{{.*}}, %{{.*}}: f32
|
||||
|
|
Loading…
Reference in New Issue