[mlir][StandardToSPIRV] Add support for lowering unary ops

Differential Revision: https://reviews.llvm.org/D76661
This commit is contained in:
Hanhan Wang 2020-03-24 09:15:54 -04:00 committed by Lei Zhang
parent 7caba33907
commit 58cdb8bff0
2 changed files with 55 additions and 21 deletions

View File

@ -107,16 +107,16 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
namespace {
/// Converts binary standard operations to SPIR-V operations.
/// Converts unary and binary standard operations to SPIR-V operations.
template <typename StdOp, typename SPIRVOp>
class BinaryOpPattern final : public SPIRVOpLowering<StdOp> {
class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
public:
using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.size() == 2);
assert(operands.size() <= 2);
auto dstType = this->typeConverter.convertType(operation.getType());
if (!dstType)
return failure();
@ -572,21 +572,31 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<
BinaryOpPattern<AddFOp, spirv::FAddOp>,
BinaryOpPattern<AddIOp, spirv::IAddOp>,
BinaryOpPattern<DivFOp, spirv::FDivOp>,
BinaryOpPattern<MulFOp, spirv::FMulOp>,
BinaryOpPattern<MulIOp, spirv::IMulOp>,
BinaryOpPattern<RemFOp, spirv::FRemOp>,
BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
BinaryOpPattern<SignedShiftRightOp, spirv::ShiftRightArithmeticOp>,
BinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
BinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
BinaryOpPattern<SubFOp, spirv::FSubOp>,
BinaryOpPattern<SubIOp, spirv::ISubOp>,
BinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
BinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
BinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
UnaryAndBinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
UnaryAndBinaryOpPattern<SignedShiftRightOp,
spirv::ShiftRightArithmeticOp>,
UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,

View File

@ -31,9 +31,33 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
return
}
// Check float operation conversions.
// CHECK-LABEL: @float32_scalar
func @float32_scalar(%lhs: f32, %rhs: f32) {
// Check float unary operation conversions.
// CHECK-LABEL: @float32_unary_scalar
func @float32_unary_scalar(%arg0: f32) {
// CHECK: spv.GLSL.FAbs %{{.*}}: f32
%0 = absf %arg0 : f32
// CHECK: spv.GLSL.Ceil %{{.*}}: f32
%1 = ceilf %arg0 : f32
// CHECK: spv.GLSL.Cos %{{.*}}: f32
%2 = cos %arg0 : f32
// CHECK: spv.GLSL.Exp %{{.*}}: f32
%3 = exp %arg0 : f32
// CHECK: spv.GLSL.Log %{{.*}}: f32
%4 = log %arg0 : f32
// CHECK: spv.FNegate %{{.*}}: f32
%5 = negf %arg0 : f32
// CHECK: spv.GLSL.InverseSqrt %{{.*}}: f32
%6 = rsqrt %arg0 : f32
// CHECK: spv.GLSL.Sqrt %{{.*}}: f32
%7 = sqrt %arg0 : f32
// CHECK: spv.GLSL.Tanh %{{.*}}: f32
%8 = tanh %arg0 : f32
return
}
// Check float binary operation conversions.
// CHECK-LABEL: @float32_binary_scalar
func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
// CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
%0 = addf %lhs, %rhs: f32
// CHECK: spv.FSub %{{.*}}, %{{.*}}: f32