forked from OSchip/llvm-project
[mlir][ArmSVE] Add basic arithmetic operations
While we figure out how to best add Standard support for scalable vectors, these instructions provide a workaround for basic arithmetic between scalable vectors. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D100837
This commit is contained in:
parent
e510860656
commit
001d601ac4
|
@ -122,6 +122,42 @@ class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
|
|||
/*list<OpTrait> traits=*/traits,
|
||||
/*int numResults=*/1>;
|
||||
|
||||
class ScalableFOp<string mnemonic, string op_description,
|
||||
list<OpTrait> traits = []> :
|
||||
ArmSVE_Op<mnemonic, !listconcat(traits,
|
||||
[AllTypesMatch<["src1", "src2", "dst"]>])> {
|
||||
let summary = op_description # " for scalable vectors of floats";
|
||||
let description = [{
|
||||
The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and
|
||||
returns one scalable vector with the result of the }] # op_description # [{.
|
||||
}];
|
||||
let arguments = (ins
|
||||
ScalableVectorOf<[AnyFloat]>:$src1,
|
||||
ScalableVectorOf<[AnyFloat]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOf<[AnyFloat]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$src1 `,` $src2 attr-dict `:` type($src1)";
|
||||
}
|
||||
|
||||
class ScalableIOp<string mnemonic, string op_description,
|
||||
list<OpTrait> traits = []> :
|
||||
ArmSVE_Op<mnemonic, !listconcat(traits,
|
||||
[AllTypesMatch<["src1", "src2", "dst"]>])> {
|
||||
let summary = op_description # " for scalable vectors of integers";
|
||||
let description = [{
|
||||
The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and
|
||||
returns one scalable vector with the result of the }] # op_description # [{.
|
||||
}];
|
||||
let arguments = (ins
|
||||
ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
|
||||
ScalableVectorOf<[I8, I16, I32, I64]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$src1 `,` $src2 attr-dict `:` type($src1)";
|
||||
}
|
||||
|
||||
def SdotOp : ArmSVE_Op<"sdot",
|
||||
[NoSideEffect,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
|
@ -266,6 +302,25 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
|
|||
"attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
|
||||
def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
|
||||
|
||||
def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>;
|
||||
|
||||
def ScalableSubIOp : ScalableIOp<"subi", "subtraction">;
|
||||
|
||||
def ScalableSubFOp : ScalableFOp<"subf", "subtraction">;
|
||||
|
||||
def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>;
|
||||
|
||||
def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>;
|
||||
|
||||
def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">;
|
||||
|
||||
def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
|
||||
|
||||
def ScalableDivFOp : ScalableFOp<"divf", "division">;
|
||||
|
||||
def UmmlaIntrOp :
|
||||
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
|
||||
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
|
||||
|
|
|
@ -84,6 +84,38 @@ using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
|
|||
using VectorScaleOpLowering =
|
||||
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
|
||||
|
||||
static void
|
||||
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns) {
|
||||
// clang-format off
|
||||
patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
|
||||
OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
|
||||
>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
static void
|
||||
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
|
||||
// clang-format off
|
||||
target.addIllegalOp<ScalableAddIOp,
|
||||
ScalableAddFOp,
|
||||
ScalableSubIOp,
|
||||
ScalableSubFOp,
|
||||
ScalableMulIOp,
|
||||
ScalableMulFOp,
|
||||
ScalableSDivIOp,
|
||||
ScalableUDivIOp,
|
||||
ScalableDivFOp>();
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
|
||||
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
|
@ -106,20 +138,14 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
|
|||
UmmlaOpLowering,
|
||||
VectorScaleOpLowering>(converter);
|
||||
// clang-format on
|
||||
populateBasicSVEArithmeticExportPatterns(converter, patterns);
|
||||
}
|
||||
|
||||
void mlir::configureArmSVELegalizeForExportTarget(
|
||||
LLVMConversionTarget &target) {
|
||||
target.addLegalOp<SdotIntrOp>();
|
||||
target.addIllegalOp<SdotOp>();
|
||||
target.addLegalOp<SmmlaIntrOp>();
|
||||
target.addIllegalOp<SmmlaOp>();
|
||||
target.addLegalOp<UdotIntrOp>();
|
||||
target.addIllegalOp<UdotOp>();
|
||||
target.addLegalOp<UmmlaIntrOp>();
|
||||
target.addIllegalOp<UmmlaOp>();
|
||||
target.addLegalOp<VectorScaleIntrOp>();
|
||||
target.addIllegalOp<VectorScaleOp>();
|
||||
target.addLegalOp<SdotIntrOp, SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp,
|
||||
VectorScaleIntrOp>();
|
||||
target.addIllegalOp<SdotOp, SmmlaOp, UdotOp, UmmlaOp, VectorScaleOp>();
|
||||
auto hasScalableVectorType = [](TypeRange types) {
|
||||
for (Type type : types)
|
||||
if (type.isa<arm_sve::ScalableVectorType>())
|
||||
|
@ -135,4 +161,5 @@ void mlir::configureArmSVELegalizeForExportTarget(
|
|||
return !hasScalableVectorType(op->getOperandTypes()) &&
|
||||
!hasScalableVectorType(op->getResultTypes());
|
||||
});
|
||||
configureBasicSVEArithmeticLegalizations(target);
|
||||
}
|
||||
|
|
|
@ -40,6 +40,40 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
|
|||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
|
||||
%b: !arm_sve.vector<4xi32>,
|
||||
%c: !arm_sve.vector<4xi32>,
|
||||
%d: !arm_sve.vector<4xi32>,
|
||||
%e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: llvm.mul {{.*}}: !llvm.vec<? x 4 x i32>
|
||||
%0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
|
||||
// CHECK: llvm.add {{.*}}: !llvm.vec<? x 4 x i32>
|
||||
%1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
|
||||
// CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
|
||||
%2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32>
|
||||
// CHECK: llvm.sdiv {{.*}}: !llvm.vec<? x 4 x i32>
|
||||
%3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
|
||||
// CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
|
||||
%4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
|
||||
return %3 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
|
||||
%b: !arm_sve.vector<4xf32>,
|
||||
%c: !arm_sve.vector<4xf32>,
|
||||
%d: !arm_sve.vector<4xf32>,
|
||||
%e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
|
||||
// CHECK: llvm.fmul {{.*}}: !llvm.vec<? x 4 x f32>
|
||||
%0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
|
||||
// CHECK: llvm.fadd {{.*}}: !llvm.vec<? x 4 x f32>
|
||||
%1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
|
||||
// CHECK: llvm.fsub {{.*}}: !llvm.vec<? x 4 x f32>
|
||||
%2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32>
|
||||
// CHECK: llvm.fdiv {{.*}}: !llvm.vec<? x 4 x f32>
|
||||
%3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32>
|
||||
return %3 : !arm_sve.vector<4xf32>
|
||||
}
|
||||
|
||||
func @get_vector_scale() -> index {
|
||||
// CHECK: arm_sve.vscale
|
||||
%0 = arm_sve.vector_scale : index
|
||||
|
|
|
@ -36,6 +36,26 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
|
|||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
|
||||
%b: !arm_sve.vector<4xi32>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32>
|
||||
%0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
|
||||
// CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32>
|
||||
%1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
|
||||
return %1 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
|
||||
%b: !arm_sve.vector<4xf32>,
|
||||
%c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
|
||||
// CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32>
|
||||
%0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
|
||||
// CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32>
|
||||
%1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
|
||||
return %1 : !arm_sve.vector<4xf32>
|
||||
}
|
||||
|
||||
func @get_vector_scale() -> index {
|
||||
// CHECK: arm_sve.vector_scale : index
|
||||
%0 = arm_sve.vector_scale : index
|
||||
|
|
|
@ -48,6 +48,30 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
|
|||
llvm.return %0 : !llvm.vec<?x4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
|
||||
llvm.func @arm_sve_arithi(%arg0: !llvm.vec<? x 4 x i32>,
|
||||
%arg1: !llvm.vec<? x 4 x i32>,
|
||||
%arg2: !llvm.vec<? x 4 x i32>)
|
||||
-> !llvm.vec<? x 4 x i32> {
|
||||
// CHECK: mul <vscale x 4 x i32>
|
||||
%0 = llvm.mul %arg0, %arg1 : !llvm.vec<? x 4 x i32>
|
||||
// CHECK: add <vscale x 4 x i32>
|
||||
%1 = llvm.add %0, %arg2 : !llvm.vec<? x 4 x i32>
|
||||
llvm.return %1 : !llvm.vec<? x 4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf
|
||||
llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
|
||||
%arg1: !llvm.vec<? x 4 x f32>,
|
||||
%arg2: !llvm.vec<? x 4 x f32>)
|
||||
-> !llvm.vec<? x 4 x f32> {
|
||||
// CHECK: fmul <vscale x 4 x float>
|
||||
%0 = llvm.fmul %arg0, %arg1 : !llvm.vec<? x 4 x f32>
|
||||
// CHECK: fadd <vscale x 4 x float>
|
||||
%1 = llvm.fadd %0, %arg2 : !llvm.vec<? x 4 x f32>
|
||||
llvm.return %1 : !llvm.vec<? x 4 x f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define i64 @get_vector_scale()
|
||||
llvm.func @get_vector_scale() -> i64 {
|
||||
// CHECK: call i64 @llvm.vscale.i64()
|
||||
|
|
Loading…
Reference in New Issue