[mlir][ArmSVE][RFC] Add an ArmSVE dialect

This revision starts an Arm-specific ArmSVE dialect discussed in the discourse RFC thread:

https://llvm.discourse.group/t/rfc-vector-dialects-neon-and-sve/2284

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D92172
This commit is contained in:
Javier Setoain 2020-12-14 21:30:53 +00:00 committed by Mehdi Amini
parent 2e0e03c6a0
commit aece4e2793
28 changed files with 914 additions and 8 deletions

View File

@ -0,0 +1,23 @@
//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
namespace mlir {
class LLVMTypeConverter;
class OwningRewritePatternList;
/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_

View File

@ -396,8 +396,8 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
operations. The lowering pass provides several options to control
the kinds of optimizations that are allowed. It also provides options
that enable the use of one or more architectural-specific dialects
(AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral
vector dialect lowering.
(AVX512, ArmNeon, ArmSVE, etc.) in combination with the
architectural-neutral vector dialect lowering.
}];
let constructor = "mlir::createConvertVectorToLLVMPass()";
@ -418,7 +418,11 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
Option<"enableArmNeon", "enable-arm-neon",
"bool", /*default=*/"false",
"Enables the use of ArmNeon dialect while lowering the vector "
"dialect.">
"dialect.">,
Option<"enableArmSVE", "enable-arm-sve",
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">
];
}

View File

@ -23,7 +23,7 @@ class OperationPass;
struct LowerVectorToLLVMOptions {
LowerVectorToLLVMOptions()
: reassociateFPReductions(false), enableIndexOptimizations(true),
enableArmNeon(false), enableAVX512(false) {}
enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {}
LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
reassociateFPReductions = b;
@ -33,18 +33,23 @@ struct LowerVectorToLLVMOptions {
enableIndexOptimizations = b;
return *this;
}
LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
enableAVX512 = b;
return *this;
}
LowerVectorToLLVMOptions &setEnableArmNeon(bool b) {
enableArmNeon = b;
return *this;
}
LowerVectorToLLVMOptions &setEnableArmSVE(bool b) {
enableArmSVE = b;
return *this;
}
LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
enableAVX512 = b;
return *this;
}
bool reassociateFPReductions;
bool enableIndexOptimizations;
bool enableArmNeon;
bool enableArmSVE;
bool enableAVX512;
};

View File

@ -0,0 +1,276 @@
//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the basic operations for the ArmSVE dialect.
//
//===----------------------------------------------------------------------===//
#ifndef ARMSVE_OPS
#define ARMSVE_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// ArmSVE dialect definition
//===----------------------------------------------------------------------===//
def ArmSVE_Dialect : Dialect {
let name = "arm_sve";
let cppNamespace = "::mlir::arm_sve";
let summary = "Basic dialect to target Arm SVE architectures";
let description = [{
This dialect contains the definitions necessary to target Arm SVE scalable
vector operations, including a scalable vector type and intrinsics for
some Arm SVE instructions.
}];
}
//===----------------------------------------------------------------------===//
// ArmSVE type definitions
//===----------------------------------------------------------------------===//
def ArmSVE_ScalableVectorType : DialectType<ArmSVE_Dialect,
CPred<"$_self.isa<ScalableVectorType>()">,
"scalable vector type">,
BuildableType<"$_builder.getType<ScalableVectorType>()"> {
let typeDescription = [{
`arm_sve.vector` represents vectors that will be processed by a scalable
vector architecture.
}];
}
class ArmSVE_Type<string name> : TypeDef<ArmSVE_Dialect, name> { }
def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
let mnemonic = "vector";
let summary = "Scalable vector type";
let description = [{
A type representing scalable length SIMD vectors. Unlike fixed-length SIMD
vectors, whose size is constant and known at compile time, scalable
vectors' length is constant but determined by the specific hardware at
run time.
}];
let parameters = (ins
ArrayRefParameter<"int64_t", "Vector shape">:$shape,
"Type":$elementType
);
let printer = [{
$_printer << "vector<";
for (int64_t dim : getShape())
$_printer << dim << 'x';
$_printer << getElementType() << '>';
}];
let parser = [{
VectorType vector;
if ($_parser.parseType(vector))
return Type();
return get(ctxt, vector.getShape(), vector.getElementType());
}];
let extraClassDeclaration = [{
bool hasStaticShape() const {
return llvm::none_of(getShape(), ShapedType::isDynamic);
}
int64_t getNumElements() const {
assert(hasStaticShape() &&
"cannot get element count of dynamic shaped type");
ArrayRef<int64_t> shape = getShape();
int64_t num = 1;
for (auto dim : shape)
num *= dim;
return num;
}
}];
}
//===----------------------------------------------------------------------===//
// ArmSVE type traits
//===----------------------------------------------------------------------===//
def IsScalableVectorTypePred :
CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
class ScalableVectorOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
"$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
"scalable vector">;
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
And<[IsScalableVectorTypePred,
Or<!foreach(allowedlength, allowedLengths, CPred<
[{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
# allowedlength>)>]>;
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
IsScalableVectorOfLengthPred<allowedLengths>,
" of length " # StrJoinInt<allowedLengths, "/">.result>;
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[ScalableVectorOf<allowedTypes>.predicate,
ScalableVectorOfLength<allowedLengths>.predicate]>,
ScalableVectorOf<allowedTypes>.description #
ScalableVectorOfLength<allowedLengths>.description>;
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ArmSVE_Dialect, mnemonic, traits> {}
def SdotOp : ArmSVE_Op<"sdot",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
let summary = "Vector-vector dot product and accumulate op";
let description = [{
SDOT: Signed integer addition of dot product.
This function maps to the SDOT instruction, and it takes signless integer
operands that the operation interprets as signed. It partitions the second
and third vector inputs into groups of four elements. They calculate the dot
product of each group (without loss of precision) and then add each result
to the overlapping element of the first vector input.
Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports either:
// (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
// (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def SmmlaOp : ArmSVE_Op<"smmla",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
let summary = "Matrix-matrix mutiply and accumulate op";
let description = [{
SMMLA: Signed integer matrix multiply-accumulate.
This function maps to the SMMLA instruction, and it takes signless integer
operands that the operation interprets as signed. It partitions the inputs
into 128-bit quadwords, with the first input containing a row-by-row 2×2
matrix of 32-bit integers, the second input containing a row-by-row 2×8
matrix of 8-bit integers, and the third input containing a column-by-column
8×2 matrix of 8-bit integers. For each quadword, they multiply the second
input matrix by the third input matrix using natural arithmetic and then add
the result to the first input using modular arithmetic.
Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def UdotOp : ArmSVE_Op<"udot",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
let summary = "Vector-vector dot product and accumulate op";
let description = [{
UDOT: Unsigned integer addition of dot product.
This function maps to the UDOT instruction, and it takes signless integer
operands that the operation interprets as unsigned. It partitions the second
and third vector inputs into groups of four elements. They calculate the dot
product of each group (without loss of precision) and then add each result
to the overlapping element of the first vector input.
Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports either:
// (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
// (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def UmmlaOp : ArmSVE_Op<"ummla",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "dst"]>,
]> {
let summary = "Matrix-matrix mutiply and accumulate op";
let description = [{
UMMLA: Unsigned integer matrix multiply-accumulate.
This function maps to the UMMLA instruction, and it takes signless integer
operands that the operation interprets as unsigned. It partitions the inputs
into 128-bit quadwords, with the first input containing a row-by-row 2×2
matrix of 32-bit integers, the second input containing a row-by-row 2×8
matrix of 8-bit integers, and the third input containing a column-by-column
8×2 matrix of 8-bit integers. For each quadword, they multiply the second
input matrix by the third input matrix using natural arithmetic and then add
the result to the first input using modular arithmetic.
Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
def VectorScaleOp : ArmSVE_Op<"vector_scale",
[NoSideEffect]> {
let summary = "Load vector scale size";
let description = [{
The vector_scale op returns the scale of the scalable vectors, a positive
integer value that is constant at runtime but unknown at compile time.
The scale of the vector indicates the multiplicity of the vectors and
vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to
vector_scale consecutive vector<4xi32>; and an operation on an
!arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale
times, once on each <4xi32> segment of the scalable vector. The vector_scale
op can be used to calculate the step in vector-length agnostic (VLA) loops.
}];
let results = (outs Index:$res);
let assemblyFormat =
"attr-dict `:` type($res)";
}
#endif // ARMSVE_OPS

View File

@ -0,0 +1,29 @@
//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Target dialect for ArmSVE in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
#define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"
#endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H

View File

@ -0,0 +1 @@
add_mlir_dialect(ArmSVE arm_sve ArmSVE)

View File

@ -1,6 +1,7 @@
add_subdirectory(Affine)
add_subdirectory(Async)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)
add_subdirectory(AVX512)
add_subdirectory(GPU)
add_subdirectory(Linalg)

View File

@ -37,3 +37,9 @@ add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/)
set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td)
mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen)
add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE)
add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/)
set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td)
mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen)

View File

@ -0,0 +1,70 @@
//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the basic operations for the LLVMArmSVE dialect.
//
//===----------------------------------------------------------------------===//
#ifndef LLVMIR_ARMSVE_OPS
#define LLVMIR_ARMSVE_OPS
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
//===----------------------------------------------------------------------===//
// LLVMArmSVE dialect definition
//===----------------------------------------------------------------------===//
def LLVMArmSVE_Dialect : Dialect {
let name = "llvm_arm_sve";
let cppNamespace = "::mlir::LLVM";
}
//----------------------------------------------------------------------------//
// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system
//----------------------------------------------------------------------------//
class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
list<OpTrait> traits =[]> :
LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
/*string opName=*/mnemonic,
/*string enumName=*/mnemonic,
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
class LLVMArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<OpTrait> traits = []> :
LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
/*string opName=*/mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
def LLVM_aarch64_arm_sve_ummla :
LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_arm_sve_smmla :
LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_arm_sve_sdot :
LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_arm_sve_udot :
LLVMArmSVE_IntrBinaryOverloadedOp<"udot">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_vector_scale :
LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
#endif // ARMSVE_OPS

View File

@ -0,0 +1,24 @@
//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Target dialect for LLVMArmSVE in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc"
#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_

View File

@ -17,10 +17,12 @@
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@ -54,6 +56,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
LLVM::LLVMAVX512Dialect,
LLVM::LLVMDialect,
LLVM::LLVMArmNeonDialect,
LLVM::LLVMArmSVEDialect,
linalg::LinalgDialect,
scf::SCFDialect,
omp::OpenMPDialect,
@ -62,6 +65,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
quant::QuantizationDialect,
spirv::SPIRVDialect,
StandardOpsDialect,
arm_sve::ArmSVEDialect,
vector::VectorDialect,
NVVM::NVVMDialect,
ROCDL::ROCDLDialect,

View File

@ -24,6 +24,7 @@ void registerToNVVMIRTranslation();
void registerToROCDLIRTranslation();
void registerArmNeonToLLVMIRTranslation();
void registerAVX512ToLLVMIRTranslation();
void registerArmSVEToLLVMIRTranslation();
// This function should be called before creating any MLIRContext if one
// expects all the possible translations to be made available to the context
@ -38,6 +39,7 @@ inline void registerAllTranslations() {
registerToROCDLIRTranslation();
registerArmNeonToLLVMIRTranslation();
registerAVX512ToLLVMIRTranslation();
registerArmSVEToLLVMIRTranslation();
return true;
}();
(void)initOnce;

View File

@ -0,0 +1,75 @@
//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
using namespace mlir::vector;
using SdotOpLowering =
OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
using SmmlaOpLowering =
OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
using UdotOpLowering =
OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
using UmmlaOpLowering =
OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
// Extract an LLVM IR type from the LLVM IR dialect type.
static LLVM::LLVMType unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
if (!wrappedLLVMType)
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return wrappedLLVMType;
}
static Optional<Type>
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
LLVMTypeConverter &converter) {
auto elementType = unwrap(converter.convertType(svType.getElementType()));
if (!elementType)
return {};
auto sVectorType =
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
return sVectorType;
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVEToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
// clang-format off
patterns.insert<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
VectorScaleOpLowering>(converter);
// clang-format on
}

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRArmSVEToLLVM
ArmSVEToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRArmSVE
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRStandardToLLVM
MLIRTransforms
)

View File

@ -20,6 +20,7 @@ add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
add_subdirectory(StandardToSPIRV)
add_subdirectory(ArmSVEToLLVM)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)

View File

@ -26,6 +26,7 @@ class GPUModuleOp;
namespace LLVM {
class LLVMArmNeonDialect;
class LLVMArmSVEDialect;
class LLVMAVX512Dialect;
class LLVMDialect;
} // end namespace LLVM

View File

@ -19,6 +19,9 @@ add_mlir_conversion_library(MLIRVectorToLLVM
MLIRAVX512ToLLVM
MLIRLLVMArmNeon
MLIRLLVMAVX512
MLIRArmSVE
MLIRArmSVEToLLVM
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRStandardToLLVM
MLIRTargetLLVMIRModuleTranslation

View File

@ -12,12 +12,15 @@
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -32,6 +35,7 @@ struct LowerVectorToLLVMPass
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
this->enableArmNeon = options.enableArmNeon;
this->enableArmSVE = options.enableArmSVE;
this->enableAVX512 = options.enableAVX512;
}
// Override explicitly to allow conditional dialect dependence.
@ -39,6 +43,8 @@ struct LowerVectorToLLVMPass
registry.insert<LLVM::LLVMDialect>();
if (enableArmNeon)
registry.insert<LLVM::LLVMArmNeonDialect>();
if (enableArmSVE)
registry.insert<LLVM::LLVMArmSVEDialect>();
if (enableAVX512)
registry.insert<LLVM::LLVMAVX512Dialect>();
}
@ -73,6 +79,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addIllegalDialect<arm_neon::ArmNeonDialect>();
populateArmNeonToLLVMConversionPatterns(converter, patterns);
}
if (enableArmSVE) {
target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
target.addIllegalDialect<arm_sve::ArmSVEDialect>();
populateArmSVEToLLVMConversionPatterns(converter, patterns);
}
if (enableAVX512) {
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
target.addIllegalDialect<avx512::AVX512Dialect>();

View File

@ -0,0 +1,13 @@
add_mlir_dialect_library(MLIRArmSVE
IR/ArmSVEDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE
DEPENDS
MLIRArmSVEIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRSideEffectInterfaces
)

View File

@ -0,0 +1,57 @@
//===- ArmSVEDialect.cpp - MLIR ArmSVE dialect implementation -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the ArmSVE dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
void arm_sve::ArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
>();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
//===----------------------------------------------------------------------===//
// ScalableVectorType
//===----------------------------------------------------------------------===//
Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
auto genType = generatedTypeParser(getContext(), parser, "vector");
if (genType != Type())
return genType;
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
return Type();
}
void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'arm_sve' type kind");
}

View File

@ -1,5 +1,6 @@
add_subdirectory(Affine)
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)
add_subdirectory(Async)
add_subdirectory(AVX512)
add_subdirectory(GPU)

View File

@ -70,6 +70,27 @@ add_mlir_dialect_library(MLIRLLVMArmNeon
MLIRSideEffectInterfaces
)
add_mlir_dialect_library(MLIRLLVMArmSVE
IR/LLVMArmSVEDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
DEPENDS
MLIRLLVMArmSVEIncGen
MLIRLLVMArmSVEConversionsIncGen
intrinsics_gen
LINK_COMPONENTS
AsmParser
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSideEffectInterfaces
)
add_mlir_dialect_library(MLIRNVVMIR
IR/NVVMDialect.cpp

View File

@ -0,0 +1,31 @@
//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the LLVMArmSVE dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "llvm/IR/IntrinsicsAArch64.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
void LLVM::LLVMArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
>();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"

View File

@ -74,6 +74,25 @@ add_mlir_translation_library(MLIRTargetArmNeon
MLIRTargetLLVMIRModuleTranslation
)
add_mlir_translation_library(MLIRTargetArmSVE
LLVMIR/LLVMArmSVEIntr.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
DEPENDS
MLIRLLVMArmSVEConversionsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRTargetLLVMIRModuleTranslation
)
add_mlir_translation_library(MLIRTargetNVVMIR
LLVMIR/ConvertToNVVMIR.cpp

View File

@ -0,0 +1,63 @@
//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR LLVM and ArmSVE dialects
// and LLVM IR with Arm SVE intrinsics.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Translation.h"
#include "llvm/IR/IntrinsicsAArch64.h"
using namespace mlir;
namespace {
class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation {
friend LLVM::ModuleTranslation;
public:
using LLVM::ModuleTranslation::ModuleTranslation;
protected:
LogicalResult convertOperation(Operation &opInst,
llvm::IRBuilder<> &builder) override {
#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc"
return LLVM::ModuleTranslation::convertOperation(opInst, builder);
}
};
} // end namespace
static std::unique_ptr<llvm::Module>
translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
StringRef name) {
return LLVM::ModuleTranslation::translateModule<LLVMArmSVEModuleTranslation>(
m, llvmContext, name);
}
namespace mlir {
void registerArmSVEToLLVMIRTranslation() {
TranslateFromMLIRRegistration reg(
"arm-sve-mlir-to-llvmir",
[](ModuleOp module, raw_ostream &output) {
llvm::LLVMContext llvmContext;
auto llvmModule = translateLLVMArmSVEModuleToLLVMIR(
module, llvmContext, "LLVMDialectModule");
if (!llvmModule)
return failure();
llvmModule->print(output, nullptr);
return success();
},
[](DialectRegistry &registry) {
registry.insert<LLVM::LLVMArmSVEDialect, LLVM::LLVMDialect>();
});
}
} // namespace mlir

View File

@ -0,0 +1,47 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.sdot
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.smmla
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.udot
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.ummla
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @get_vector_scale() -> index {
// CHECK: llvm_arm_sve.vscale
%0 = arm_sve.vector_scale : index
return %0 : index
}

View File

@ -0,0 +1,43 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
// CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @get_vector_scale() -> index {
// CHECK: arm_sve.vector_scale : index
%0 = arm_sve.vector_scale : index
return %0 : index
}

View File

@ -0,0 +1,56 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_sdot
llvm.func @arm_sve_sdot(%arg0: !llvm.vec<? x 16 x i8>,
%arg1: !llvm.vec<? x 16 x i8>,
%arg2: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.sdot"(%arg2, %arg0, %arg1) :
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
-> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_smmla
llvm.func @arm_sve_smmla(%arg0: !llvm.vec<? x 16 x i8>,
%arg1: !llvm.vec<? x 16 x i8>,
%arg2: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.smmla"(%arg2, %arg0, %arg1) :
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
-> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_udot
llvm.func @arm_sve_udot(%arg0: !llvm.vec<? x 16 x i8>,
%arg1: !llvm.vec<? x 16 x i8>,
%arg2: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.udot"(%arg2, %arg0, %arg1) :
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
-> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_ummla
llvm.func @arm_sve_ummla(%arg0: !llvm.vec<? x 16 x i8>,
%arg1: !llvm.vec<? x 16 x i8>,
%arg2: !llvm.vec<? x 4 x i32>)
-> !llvm.vec<? x 4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.ummla"(%arg2, %arg0, %arg1) :
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
-> !llvm.vec<? x 4 x i32>
llvm.return %0 : !llvm.vec<? x 4 x i32>
}
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> !llvm.i64 {
// CHECK: call i64 @llvm.vscale.i64()
%0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64
llvm.return %0 : !llvm.i64
}