[mlir][ArmSVE] Add basic arithmetic operations

While we figure out how to best add Standard support for scalable
vectors, these instructions provide a workaround for basic arithmetic
between scalable vectors.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D100837
This commit is contained in:
Javier Setoain 2021-05-05 09:38:50 +02:00 committed by Alex Zinenko
parent e510860656
commit 001d601ac4
5 changed files with 170 additions and 10 deletions

View File

@ -122,6 +122,42 @@ class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
class ScalableFOp<string mnemonic, string op_description,
list<OpTrait> traits = []> :
ArmSVE_Op<mnemonic, !listconcat(traits,
[AllTypesMatch<["src1", "src2", "dst"]>])> {
let summary = op_description # " for scalable vectors of floats";
let description = [{
The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and
returns one scalable vector with the result of the }] # op_description # [{.
}];
let arguments = (ins
ScalableVectorOf<[AnyFloat]>:$src1,
ScalableVectorOf<[AnyFloat]>:$src2
);
let results = (outs ScalableVectorOf<[AnyFloat]>:$dst);
let assemblyFormat =
"$src1 `,` $src2 attr-dict `:` type($src1)";
}
class ScalableIOp<string mnemonic, string op_description,
list<OpTrait> traits = []> :
ArmSVE_Op<mnemonic, !listconcat(traits,
[AllTypesMatch<["src1", "src2", "dst"]>])> {
let summary = op_description # " for scalable vectors of integers";
let description = [{
The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and
returns one scalable vector with the result of the }] # op_description # [{.
}];
let arguments = (ins
ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
ScalableVectorOf<[I8, I16, I32, I64]>:$src2
);
let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst);
let assemblyFormat =
"$src1 `,` $src2 attr-dict `:` type($src1)";
}
def SdotOp : ArmSVE_Op<"sdot",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
@ -266,6 +302,25 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
"attr-dict `:` type($res)";
}
def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>;
def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>;
def ScalableSubIOp : ScalableIOp<"subi", "subtraction">;
def ScalableSubFOp : ScalableFOp<"subf", "subtraction">;
def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>;
def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>;
def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">;
def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
def ScalableDivFOp : ScalableFOp<"divf", "division">;
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;

View File

@ -84,6 +84,38 @@ using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
static void
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
>(converter);
// clang-format on
}
static void
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableAddIOp,
ScalableAddFOp,
ScalableSubIOp,
ScalableSubFOp,
ScalableMulIOp,
ScalableMulFOp,
ScalableSDivIOp,
ScalableUDivIOp,
ScalableDivFOp>();
// clang-format on
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
@ -106,20 +138,14 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
UmmlaOpLowering,
VectorScaleOpLowering>(converter);
// clang-format on
populateBasicSVEArithmeticExportPatterns(converter, patterns);
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addLegalOp<SdotIntrOp>();
target.addIllegalOp<SdotOp>();
target.addLegalOp<SmmlaIntrOp>();
target.addIllegalOp<SmmlaOp>();
target.addLegalOp<UdotIntrOp>();
target.addIllegalOp<UdotOp>();
target.addLegalOp<UmmlaIntrOp>();
target.addIllegalOp<UmmlaOp>();
target.addLegalOp<VectorScaleIntrOp>();
target.addIllegalOp<VectorScaleOp>();
target.addLegalOp<SdotIntrOp, SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp,
VectorScaleIntrOp>();
target.addIllegalOp<SdotOp, SmmlaOp, UdotOp, UmmlaOp, VectorScaleOp>();
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
@ -135,4 +161,5 @@ void mlir::configureArmSVELegalizeForExportTarget(
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
configureBasicSVEArithmeticLegalizations(target);
}

View File

@ -40,6 +40,40 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>,
%c: !arm_sve.vector<4xi32>,
%d: !arm_sve.vector<4xi32>,
%e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: llvm.mul {{.*}}: !llvm.vec<? x 4 x i32>
%0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
// CHECK: llvm.add {{.*}}: !llvm.vec<? x 4 x i32>
%1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
// CHECK: llvm.sub {{.*}}: !llvm.vec<? x 4 x i32>
%2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32>
// CHECK: llvm.sdiv {{.*}}: !llvm.vec<? x 4 x i32>
%3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32>
// CHECK: llvm.udiv {{.*}}: !llvm.vec<? x 4 x i32>
%4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32>
return %3 : !arm_sve.vector<4xi32>
}
func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>,
%c: !arm_sve.vector<4xf32>,
%d: !arm_sve.vector<4xf32>,
%e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
// CHECK: llvm.fmul {{.*}}: !llvm.vec<? x 4 x f32>
%0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
// CHECK: llvm.fadd {{.*}}: !llvm.vec<? x 4 x f32>
%1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
// CHECK: llvm.fsub {{.*}}: !llvm.vec<? x 4 x f32>
%2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32>
// CHECK: llvm.fdiv {{.*}}: !llvm.vec<? x 4 x f32>
%3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32>
return %3 : !arm_sve.vector<4xf32>
}
func @get_vector_scale() -> index {
// CHECK: arm_sve.vscale
%0 = arm_sve.vector_scale : index

View File

@ -36,6 +36,26 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>,
%b: !arm_sve.vector<4xi32>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32>
%0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32>
// CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32>
%1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32>
return %1 : !arm_sve.vector<4xi32>
}
func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>,
%b: !arm_sve.vector<4xf32>,
%c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> {
// CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32>
%0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32>
// CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32>
%1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32>
return %1 : !arm_sve.vector<4xf32>
}
func @get_vector_scale() -> index {
// CHECK: arm_sve.vector_scale : index
%0 = arm_sve.vector_scale : index

View File

@ -48,6 +48,30 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
llvm.return %0 : !llvm.vec<?x4 x i32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: !llvm.vec<? x 4 x i32>,
%arg1: !llvm.vec<? x 4 x i32>,
%arg2: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
// CHECK: mul <vscale x 4 x i32>
%0 = llvm.mul %arg0, %arg1 : !llvm.vec<? x 4 x i32>
// CHECK: add <vscale x 4 x i32>
%1 = llvm.add %0, %arg2 : !llvm.vec<? x 4 x i32>
llvm.return %1 : !llvm.vec<? x 4 x i32>
}
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_arithf
llvm.func @arm_sve_arithf(%arg0: !llvm.vec<? x 4 x f32>,
%arg1: !llvm.vec<? x 4 x f32>,
%arg2: !llvm.vec<? x 4 x f32>)
-> !llvm.vec<? x 4 x f32> {
// CHECK: fmul <vscale x 4 x float>
%0 = llvm.fmul %arg0, %arg1 : !llvm.vec<? x 4 x f32>
// CHECK: fadd <vscale x 4 x float>
%1 = llvm.fadd %0, %arg2 : !llvm.vec<? x 4 x f32>
llvm.return %1 : !llvm.vec<? x 4 x f32>
}
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> i64 {
// CHECK: call i64 @llvm.vscale.i64()