[mlir] Add n-D vector lowering to LLVM for cast ops

The casting ops (sitofp, uitofp, fptosi, fptoui) lowering currently does
not handle n-D vectors. This patch fixes that.

Differential Revision: https://reviews.llvm.org/D103207
This commit is contained in:
harsh-nod 2021-05-26 15:18:32 -07:00 committed by thomasraoux
parent fd0a2f75ff
commit 94d67b51dd
2 changed files with 104 additions and 41 deletions

View File

@ -1717,6 +1717,10 @@ using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
using FPTruncOpLowering = VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
using Log10OpLowering =
@ -1729,6 +1733,7 @@ using NegFOpLowering = VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp>;
using OrOpLowering = VectorConvertToLLVMPattern<OrOp, LLVM::OrOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using RemFOpLowering = VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp>;
using SIToFPOpLowering = VectorConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp>;
using SelectOpLowering = VectorConvertToLLVMPattern<SelectOp, LLVM::SelectOp>;
using SignExtendIOpLowering =
VectorConvertToLLVMPattern<SignExtendIOp, LLVM::SExtOp>;
@ -1744,6 +1749,8 @@ using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
using TruncateIOpLowering = VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
using UnsignedDivIOpLowering =
VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
using UnsignedRemIOpLowering =
@ -3112,41 +3119,6 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
}
};
struct SIToFPLowering
: public OneToOneConvertToLLVMPattern<SIToFPOp, LLVM::SIToFPOp> {
using Super::Super;
};
struct UIToFPLowering
: public OneToOneConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp> {
using Super::Super;
};
struct FPExtLowering
: public OneToOneConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp> {
using Super::Super;
};
struct FPToSILowering
: public OneToOneConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp> {
using Super::Super;
};
struct FPToUILowering
: public OneToOneConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp> {
using Super::Super;
};
struct FPTruncLowering
: public OneToOneConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp> {
using Super::Super;
};
struct TruncateIOpLowering
: public OneToOneConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp> {
using Super::Super;
};
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
@ -3908,10 +3880,10 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
Log10OpLowering,
Log1pOpLowering,
Log2OpLowering,
FPExtLowering,
FPToSILowering,
FPToUILowering,
FPTruncLowering,
FPExtOpLowering,
FPToSIOpLowering,
FPToUIOpLowering,
FPTruncOpLowering,
IndexCastOpLowering,
MulFOpLowering,
MulIOpLowering,
@ -3922,7 +3894,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
RemFOpLowering,
ReturnOpLowering,
RsqrtOpLowering,
SIToFPLowering,
SIToFPOpLowering,
SelectOpLowering,
ShiftLeftOpLowering,
SignExtendIOpLowering,
@ -3936,7 +3908,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
SubFOpLowering,
SubIOpLowering,
TruncateIOpLowering,
UIToFPLowering,
UIToFPOpLowering,
UnsignedDivIOpLowering,
UnsignedRemIOpLowering,
UnsignedShiftRightOpLowering,

View File

@ -47,3 +47,94 @@ func @zexti_vector(%arg0 : vector<1x2x3xi32>, %arg1 : vector<1x2x3xi64>) {
%0 = zexti %arg0: vector<1x2x3xi32> to vector<1x2x3xi64>
return
}
// CHECK-LABEL: @sitofp
func @sitofp_vector(%arg0 : vector<1x2x3xi32>) -> vector<1x2x3xf32> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.sitofp %{{.*}} : vector<3xi32> to vector<3xf32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.sitofp %{{.*}} : vector<3xi32> to vector<3xf32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf32>>>
%0 = sitofp %arg0: vector<1x2x3xi32> to vector<1x2x3xf32>
return %0 : vector<1x2x3xf32>
}
// CHECK-LABEL: @uitofp
func @uitofp_vector(%arg0 : vector<1x2x3xi32>) -> vector<1x2x3xf32> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.uitofp %{{.*}} : vector<3xi32> to vector<3xf32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.uitofp %{{.*}} : vector<3xi32> to vector<3xf32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf32>>>
%0 = uitofp %arg0: vector<1x2x3xi32> to vector<1x2x3xf32>
return %0 : vector<1x2x3xf32>
}
// CHECK-LABEL: @fptosi
func @fptosi_vector(%arg0 : vector<1x2x3xf32>) -> vector<1x2x3xi32> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.fptosi %{{.*}} : vector<3xf32> to vector<3xi32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.fptosi %{{.*}} : vector<3xf32> to vector<3xi32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
%0 = fptosi %arg0: vector<1x2x3xf32> to vector<1x2x3xi32>
return %0 : vector<1x2x3xi32>
}
// CHECK-LABEL: @fptoui
func @fptoui_vector(%arg0 : vector<1x2x3xf32>) -> vector<1x2x3xi32> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.fptoui %{{.*}} : vector<3xf32> to vector<3xi32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi32>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf32>>>
// CHECK: llvm.fptoui %{{.*}} : vector<3xf32> to vector<3xi32>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi32>>>
%0 = fptoui %arg0: vector<1x2x3xf32> to vector<1x2x3xi32>
return %0 : vector<1x2x3xi32>
}
// CHECK-LABEL: @fpext
func @fpext_vector(%arg0 : vector<1x2x3xf16>) -> vector<1x2x3xf64> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xf64>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf16>>>
// CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf64>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf64>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf16>>>
// CHECK: llvm.fpext %{{.*}} : vector<3xf16> to vector<3xf64>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf64>>>
%0 = fpext %arg0: vector<1x2x3xf16> to vector<1x2x3xf64>
return %0 : vector<1x2x3xf64>
}
// CHECK-LABEL: @fptrunc
func @fptrunc_vector(%arg0 : vector<1x2x3xf64>) -> vector<1x2x3xf16> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xf16>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf64>>>
// CHECK: llvm.fptrunc %{{.*}} : vector<3xf64> to vector<3xf16>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xf16>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf64>>>
// CHECK: llvm.fptrunc %{{.*}} : vector<3xf64> to vector<3xf16>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xf16>>>
%0 = fptrunc %arg0: vector<1x2x3xf64> to vector<1x2x3xf16>
return %0 : vector<1x2x3xf16>
}
// CHECK-LABEL: @trunci
func @trunci_vector(%arg0 : vector<1x2x3xi64>) -> vector<1x2x3xi16> {
// CHECK: llvm.mlir.undef : !llvm.array<1 x array<2 x vector<3xi16>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi64>>>
// CHECK: llvm.trunc %{{.*}} : vector<3xi64> to vector<3xi16>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 0] : !llvm.array<1 x array<2 x vector<3xi16>>>
// CHECK: llvm.extractvalue %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi64>>>
// CHECK: llvm.trunc %{{.*}} : vector<3xi64> to vector<3xi16>
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0, 1] : !llvm.array<1 x array<2 x vector<3xi16>>>
%0 = trunci %arg0: vector<1x2x3xi64> to vector<1x2x3xi16>
return %0 : vector<1x2x3xi16>
}