From 95861216ac6558dc0dbcf638902feb9072c84661 Mon Sep 17 00:00:00 2001 From: Javier Setoain Date: Mon, 19 Apr 2021 15:37:29 +0100 Subject: [PATCH] [mlir][ArmSVE] Add masked arithmetic operations These instructions map to SVE-specific instrinsics that accept a predicate operand to support control flow in vector code. Differential Revision: https://reviews.llvm.org/D100982 --- mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td | 134 +++++++++++++++++- mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp | 15 ++ .../Transforms/LegalizeForLLVMExport.cpp | 71 +++++++++- .../Dialect/ArmSVE/legalize-for-llvm.mlir | 49 ++++++- mlir/test/Dialect/ArmSVE/roundtrip.mlir | 47 ++++++ mlir/test/Target/LLVMIR/arm-sve.mlir | 67 +++++++++ 6 files changed, 374 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td index 33c60ba7c8a5..e34177bb5094 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -95,6 +95,13 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { }]; } +//===----------------------------------------------------------------------===// +// Additional LLVM type constraints +//===----------------------------------------------------------------------===// +def LLVMScalableVectorType : + Type()">, + "LLVM dialect scalable vector type">; + //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -158,6 +165,52 @@ class ScalableIOp traits = []> : + ArmSVE_Op, + TypesMatchWith< + "mask has i1 element type and same shape as operands", + "src1", "mask", "getI1SameShape($_self)">])> { + let summary = "masked " # op_description # " for scalable vectors of floats"; + let description = [{ + The `arm_sve.}] # mnemonic # [{` operation takes one scalable vector mask + and two scalable vector operands, and perform floating point }] # + op_description # [{ on active lanes. Inactive lanes will keep the value of + the first operand.}]; + let arguments = (ins + ScalableVectorOf<[I1]>:$mask, + ScalableVectorOf<[AnyFloat]>:$src1, + ScalableVectorOf<[AnyFloat]>:$src2 + ); + let results = (outs ScalableVectorOf<[AnyFloat]>:$res); + let assemblyFormat = + "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)"; +} + +class ScalableMaskedIOp traits = []> : + ArmSVE_Op, + TypesMatchWith< + "mask has i1 element type and same shape as operands", + "src1", "mask", "getI1SameShape($_self)">])> { + let summary = "masked " # op_description # " for scalable vectors of integers"; + let description = [{ + The `arm_sve.}] # mnemonic # [{` operation takes one scalable vector mask + and two scalable vector operands, and perform integer }] # + op_description # [{ on active lanes. Inactive lanes will keep the value of + the first operand.}]; + let arguments = (ins + ScalableVectorOf<[I1]>:$mask, + ScalableVectorOf<[I8, I16, I32, I64]>:$src1, + ScalableVectorOf<[I8, I16, I32, I64]>:$src2 + ); + let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res); + let assemblyFormat = + "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)"; +} + def SdotOp : ArmSVE_Op<"sdot", [NoSideEffect, AllTypesMatch<["src1", "src2"]>, @@ -321,21 +374,94 @@ def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">; def ScalableDivFOp : ScalableFOp<"divf", "division">; +def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition", + [Commutative]>; + +def ScalableMaskedAddFOp : ScalableMaskedFOp<"masked.addf", "addition", + [Commutative]>; + +def ScalableMaskedSubIOp : ScalableMaskedIOp<"masked.subi", "subtraction">; + +def ScalableMaskedSubFOp : ScalableMaskedFOp<"masked.subf", "subtraction">; + +def ScalableMaskedMulIOp : ScalableMaskedIOp<"masked.muli", "multiplication", + [Commutative]>; + +def ScalableMaskedMulFOp : ScalableMaskedFOp<"masked.mulf", "multiplication", + [Commutative]>; + +def ScalableMaskedSDivIOp : ScalableMaskedIOp<"masked.divi_signed", + "signed division">; + +def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned", + "unsigned division">; + +def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">; + def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, - Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, - Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, - Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; def UdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udot">, - Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedAddIIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"add">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedAddFIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"fadd">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedMulIIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"mul">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedMulFIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"fmul">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedSubIIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"sub">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedSubFIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"fsub">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedSDivIIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"sdiv">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedUDivIIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"udiv">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; + +def ScalableMaskedDivFIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"fdiv">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, + LLVMScalableVectorType)>; def VectorScaleIntrOp: ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp index 6091626c011c..b86ba14303f8 100644 --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -21,6 +21,8 @@ using namespace mlir; +static Type getI1SameShape(Type type); + #define GET_OP_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" @@ -59,3 +61,16 @@ void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { if (failed(generatedTypePrinter(type, os))) llvm_unreachable("unexpected 'arm_sve' type kind"); } + +//===----------------------------------------------------------------------===// +// ScalableVector versions of general helpers for comparison ops +//===----------------------------------------------------------------------===// + +// Return the scalable vector of the same shape and containing i1. +static Type getI1SameShape(Type type) { + auto i1Type = IntegerType::get(type.getContext(), 1); + if (auto sVectorType = type.dyn_cast()) + return arm_sve::ScalableVectorType::get(type.getContext(), + sVectorType.getShape(), i1Type); + return nullptr; +} diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index b258f2ad9315..845f407fba3f 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -83,6 +83,33 @@ using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; using VectorScaleOpLowering = OneToOneConvertToLLVMPattern; +using ScalableMaskedAddIOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedAddFOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedSubIOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedSubFOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedMulIOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedMulFOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedSDivIOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedUDivIOpLowering = + OneToOneConvertToLLVMPattern; +using ScalableMaskedDivFOpLowering = + OneToOneConvertToLLVMPattern; static void populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, @@ -136,16 +163,52 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns( SmmlaOpLowering, UdotOpLowering, UmmlaOpLowering, - VectorScaleOpLowering>(converter); + VectorScaleOpLowering, + ScalableMaskedAddIOpLowering, + ScalableMaskedAddFOpLowering, + ScalableMaskedSubIOpLowering, + ScalableMaskedSubFOpLowering, + ScalableMaskedMulIOpLowering, + ScalableMaskedMulFOpLowering, + ScalableMaskedSDivIOpLowering, + ScalableMaskedUDivIOpLowering, + ScalableMaskedDivFOpLowering>(converter); // clang-format on populateBasicSVEArithmeticExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); - target.addIllegalOp(); + // clang-format off + target.addLegalOp(); + target.addIllegalOp(); + // clang-format on auto hasScalableVectorType = [](TypeRange types) { for (Type type : types) if (type.isa()) diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir index f81196f4928f..2b2eda0bf32e 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -55,7 +55,7 @@ func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32> // CHECK: llvm.udiv {{.*}}: !llvm.vec %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32> - return %3 : !arm_sve.vector<4xi32> + return %4 : !arm_sve.vector<4xi32> } func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, @@ -74,6 +74,53 @@ func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, return %3 : !arm_sve.vector<4xf32> } +func @arm_sve_arithi_masked(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>, + %d: !arm_sve.vector<4xi32>, + %e: !arm_sve.vector<4xi32>, + %mask: !arm_sve.vector<4xi1> + ) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.intr.add{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %0 = arm_sve.masked.addi %mask, %a, %b : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.intr.sub{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %1 = arm_sve.masked.subi %mask, %0, %c : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.intr.mul{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %2 = arm_sve.masked.muli %mask, %1, %d : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.intr.sdiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.intr.udiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + return %4 : !arm_sve.vector<4xi32> +} + +func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>, + %d: !arm_sve.vector<4xf32>, + %e: !arm_sve.vector<4xf32>, + %mask: !arm_sve.vector<4xi1> + ) -> !arm_sve.vector<4xf32> { + // CHECK: arm_sve.intr.fadd{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %0 = arm_sve.masked.addf %mask, %a, %b : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + // CHECK: arm_sve.intr.fsub{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %1 = arm_sve.masked.subf %mask, %0, %c : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + // CHECK: arm_sve.intr.fmul{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %2 = arm_sve.masked.mulf %mask, %1, %d : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + // CHECK: arm_sve.intr.fdiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec + %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + return %3 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vscale %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir index 44cc2fa12217..4666d16f33f2 100644 --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -56,6 +56,53 @@ func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, return %1 : !arm_sve.vector<4xf32> } +func @arm_sve_masked_arithi(%a: !arm_sve.vector<4xi32>, + %b: !arm_sve.vector<4xi32>, + %c: !arm_sve.vector<4xi32>, + %d: !arm_sve.vector<4xi32>, + %e: !arm_sve.vector<4xi32>, + %mask: !arm_sve.vector<4xi1>) + -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.masked.muli {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %0 = arm_sve.masked.muli %mask, %a, %b : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.masked.addi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %1 = arm_sve.masked.addi %mask, %0, %c : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.masked.subi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %2 = arm_sve.masked.subi %mask, %1, %d : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.masked.divi_signed + %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + // CHECK: arm_sve.masked.divi_unsigned + %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xi32> + return %2 : !arm_sve.vector<4xi32> +} + +func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>, + %b: !arm_sve.vector<4xf32>, + %c: !arm_sve.vector<4xf32>, + %d: !arm_sve.vector<4xf32>, + %e: !arm_sve.vector<4xf32>, + %mask: !arm_sve.vector<4xi1>) + -> !arm_sve.vector<4xf32> { + // CHECK: arm_sve.masked.mulf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %0 = arm_sve.masked.mulf %mask, %a, %b : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + // CHECK: arm_sve.masked.addf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %1 = arm_sve.masked.addf %mask, %0, %c : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + // CHECK: arm_sve.masked.subf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %2 = arm_sve.masked.subf %mask, %1, %d : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + // CHECK: arm_sve.masked.divf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector + %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>, + !arm_sve.vector<4xf32> + return %3 : !arm_sve.vector<4xf32> +} + func @get_vector_scale() -> index { // CHECK: arm_sve.vector_scale : index %0 = arm_sve.vector_scale : index diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir index 71d4b0aee9b4..cf367904f899 100644 --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -72,6 +72,73 @@ llvm.func @arm_sve_arithf(%arg0: !llvm.vec, llvm.return %1 : !llvm.vec } +// CHECK-LABEL: define @arm_sve_arithi_masked +llvm.func @arm_sve_arithi_masked(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec, + %arg3: !llvm.vec, + %arg4: !llvm.vec, + %arg5: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.add.nxv4i32 + %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 + %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.mul.nxv4i32 + %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.sdiv.nxv4i32 + %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.udiv.nxv4i32 + %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + llvm.return %4 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_arithf_masked +llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec, + %arg3: !llvm.vec, + %arg4: !llvm.vec, + %arg5: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.fadd.nxv4f32 + %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.fsub.nxv4f32 + %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.fmul.nxv4f32 + %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + // CHECK: call @llvm.aarch64.sve.fdiv.nxv4f32 + %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (!llvm.vec, + !llvm.vec, + !llvm.vec) + -> !llvm.vec + llvm.return %3 : !llvm.vec +} + // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64()