forked from OSchip/llvm-project
[mlir][ArmSVE][RFC] Add an ArmSVE dialect
This revision starts an Arm-specific ArmSVE dialect discussed in the discourse RFC thread: https://llvm.discourse.group/t/rfc-vector-dialects-neon-and-sve/2284 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D92172
This commit is contained in:
parent
2e0e03c6a0
commit
aece4e2793
|
@ -0,0 +1,23 @@
|
|||
//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
|
||||
#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class LLVMTypeConverter;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
|
||||
void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
|
|
@ -396,8 +396,8 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
|
|||
operations. The lowering pass provides several options to control
|
||||
the kinds of optimizations that are allowed. It also provides options
|
||||
that enable the use of one or more architectural-specific dialects
|
||||
(AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral
|
||||
vector dialect lowering.
|
||||
(AVX512, ArmNeon, ArmSVE, etc.) in combination with the
|
||||
architectural-neutral vector dialect lowering.
|
||||
|
||||
}];
|
||||
let constructor = "mlir::createConvertVectorToLLVMPass()";
|
||||
|
@ -418,7 +418,11 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
|
|||
Option<"enableArmNeon", "enable-arm-neon",
|
||||
"bool", /*default=*/"false",
|
||||
"Enables the use of ArmNeon dialect while lowering the vector "
|
||||
"dialect.">
|
||||
"dialect.">,
|
||||
Option<"enableArmSVE", "enable-arm-sve",
|
||||
"bool", /*default=*/"false",
|
||||
"Enables the use of ArmSVE dialect while lowering the vector "
|
||||
"dialect.">
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class OperationPass;
|
|||
struct LowerVectorToLLVMOptions {
|
||||
LowerVectorToLLVMOptions()
|
||||
: reassociateFPReductions(false), enableIndexOptimizations(true),
|
||||
enableArmNeon(false), enableAVX512(false) {}
|
||||
enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {}
|
||||
|
||||
LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) {
|
||||
reassociateFPReductions = b;
|
||||
|
@ -33,18 +33,23 @@ struct LowerVectorToLLVMOptions {
|
|||
enableIndexOptimizations = b;
|
||||
return *this;
|
||||
}
|
||||
LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
|
||||
enableAVX512 = b;
|
||||
return *this;
|
||||
}
|
||||
LowerVectorToLLVMOptions &setEnableArmNeon(bool b) {
|
||||
enableArmNeon = b;
|
||||
return *this;
|
||||
}
|
||||
LowerVectorToLLVMOptions &setEnableArmSVE(bool b) {
|
||||
enableArmSVE = b;
|
||||
return *this;
|
||||
}
|
||||
LowerVectorToLLVMOptions &setEnableAVX512(bool b) {
|
||||
enableAVX512 = b;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool reassociateFPReductions;
|
||||
bool enableIndexOptimizations;
|
||||
bool enableArmNeon;
|
||||
bool enableArmSVE;
|
||||
bool enableAVX512;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the basic operations for the ArmSVE dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ARMSVE_OPS
|
||||
#define ARMSVE_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArmSVE dialect definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ArmSVE_Dialect : Dialect {
|
||||
let name = "arm_sve";
|
||||
let cppNamespace = "::mlir::arm_sve";
|
||||
let summary = "Basic dialect to target Arm SVE architectures";
|
||||
let description = [{
|
||||
This dialect contains the definitions necessary to target Arm SVE scalable
|
||||
vector operations, including a scalable vector type and intrinsics for
|
||||
some Arm SVE instructions.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArmSVE type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ArmSVE_ScalableVectorType : DialectType<ArmSVE_Dialect,
|
||||
CPred<"$_self.isa<ScalableVectorType>()">,
|
||||
"scalable vector type">,
|
||||
BuildableType<"$_builder.getType<ScalableVectorType>()"> {
|
||||
let typeDescription = [{
|
||||
`arm_sve.vector` represents vectors that will be processed by a scalable
|
||||
vector architecture.
|
||||
}];
|
||||
}
|
||||
|
||||
class ArmSVE_Type<string name> : TypeDef<ArmSVE_Dialect, name> { }
|
||||
|
||||
def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
|
||||
let mnemonic = "vector";
|
||||
|
||||
let summary = "Scalable vector type";
|
||||
|
||||
let description = [{
|
||||
A type representing scalable length SIMD vectors. Unlike fixed-length SIMD
|
||||
vectors, whose size is constant and known at compile time, scalable
|
||||
vectors' length is constant but determined by the specific hardware at
|
||||
run time.
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"int64_t", "Vector shape">:$shape,
|
||||
"Type":$elementType
|
||||
);
|
||||
|
||||
let printer = [{
|
||||
$_printer << "vector<";
|
||||
for (int64_t dim : getShape())
|
||||
$_printer << dim << 'x';
|
||||
$_printer << getElementType() << '>';
|
||||
}];
|
||||
|
||||
let parser = [{
|
||||
VectorType vector;
|
||||
if ($_parser.parseType(vector))
|
||||
return Type();
|
||||
return get(ctxt, vector.getShape(), vector.getElementType());
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool hasStaticShape() const {
|
||||
return llvm::none_of(getShape(), ShapedType::isDynamic);
|
||||
}
|
||||
int64_t getNumElements() const {
|
||||
assert(hasStaticShape() &&
|
||||
"cannot get element count of dynamic shaped type");
|
||||
ArrayRef<int64_t> shape = getShape();
|
||||
int64_t num = 1;
|
||||
for (auto dim : shape)
|
||||
num *= dim;
|
||||
return num;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArmSVE type traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def IsScalableVectorTypePred :
|
||||
CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
|
||||
|
||||
class ScalableVectorOf<list<Type> allowedTypes> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
|
||||
"$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
|
||||
"scalable vector">;
|
||||
|
||||
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
|
||||
And<[IsScalableVectorTypePred,
|
||||
Or<!foreach(allowedlength, allowedLengths, CPred<
|
||||
[{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
|
||||
# allowedlength>)>]>;
|
||||
|
||||
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
|
||||
IsScalableVectorOfLengthPred<allowedLengths>,
|
||||
" of length " # StrJoinInt<allowedLengths, "/">.result>;
|
||||
|
||||
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
|
||||
list<Type> allowedTypes> : Type<
|
||||
And<[ScalableVectorOf<allowedTypes>.predicate,
|
||||
ScalableVectorOfLength<allowedLengths>.predicate]>,
|
||||
ScalableVectorOf<allowedTypes>.description #
|
||||
ScalableVectorOfLength<allowedLengths>.description>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArmSVE op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<ArmSVE_Dialect, mnemonic, traits> {}
|
||||
|
||||
def SdotOp : ArmSVE_Op<"sdot",
|
||||
[NoSideEffect,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
let summary = "Vector-vector dot product and accumulate op";
|
||||
let description = [{
|
||||
SDOT: Signed integer addition of dot product.
|
||||
|
||||
This function maps to the SDOT instruction, and it takes signless integer
|
||||
operands that the operation interprets as signed. It partitions the second
|
||||
and third vector inputs into groups of four elements. They calculate the dot
|
||||
product of each group (without loss of precision) and then add each result
|
||||
to the overlapping element of the first vector input.
|
||||
|
||||
Source:
|
||||
https://developer.arm.com/documentation/100987/0000
|
||||
}];
|
||||
// Supports either:
|
||||
// (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
|
||||
// (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
|
||||
let arguments = (ins
|
||||
ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
|
||||
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
|
||||
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def SmmlaOp : ArmSVE_Op<"smmla",
|
||||
[NoSideEffect,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
let summary = "Matrix-matrix mutiply and accumulate op";
|
||||
let description = [{
|
||||
SMMLA: Signed integer matrix multiply-accumulate.
|
||||
|
||||
This function maps to the SMMLA instruction, and it takes signless integer
|
||||
operands that the operation interprets as signed. It partitions the inputs
|
||||
into 128-bit quadwords, with the first input containing a row-by-row 2×2
|
||||
matrix of 32-bit integers, the second input containing a row-by-row 2×8
|
||||
matrix of 8-bit integers, and the third input containing a column-by-column
|
||||
8×2 matrix of 8-bit integers. For each quadword, they multiply the second
|
||||
input matrix by the third input matrix using natural arithmetic and then add
|
||||
the result to the first input using modular arithmetic.
|
||||
|
||||
Source:
|
||||
https://developer.arm.com/documentation/100987/0000
|
||||
}];
|
||||
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
|
||||
let arguments = (ins
|
||||
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
|
||||
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
|
||||
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def UdotOp : ArmSVE_Op<"udot",
|
||||
[NoSideEffect,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
let summary = "Vector-vector dot product and accumulate op";
|
||||
let description = [{
|
||||
UDOT: Unsigned integer addition of dot product.
|
||||
|
||||
This function maps to the UDOT instruction, and it takes signless integer
|
||||
operands that the operation interprets as unsigned. It partitions the second
|
||||
and third vector inputs into groups of four elements. They calculate the dot
|
||||
product of each group (without loss of precision) and then add each result
|
||||
to the overlapping element of the first vector input.
|
||||
|
||||
Source:
|
||||
https://developer.arm.com/documentation/100987/0000
|
||||
}];
|
||||
// Supports either:
|
||||
// (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
|
||||
// (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>)
|
||||
let arguments = (ins
|
||||
ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc,
|
||||
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1,
|
||||
ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def UmmlaOp : ArmSVE_Op<"ummla",
|
||||
[NoSideEffect,
|
||||
AllTypesMatch<["src1", "src2"]>,
|
||||
AllTypesMatch<["acc", "dst"]>,
|
||||
]> {
|
||||
let summary = "Matrix-matrix mutiply and accumulate op";
|
||||
let description = [{
|
||||
UMMLA: Unsigned integer matrix multiply-accumulate.
|
||||
|
||||
This function maps to the UMMLA instruction, and it takes signless integer
|
||||
operands that the operation interprets as unsigned. It partitions the inputs
|
||||
into 128-bit quadwords, with the first input containing a row-by-row 2×2
|
||||
matrix of 32-bit integers, the second input containing a row-by-row 2×8
|
||||
matrix of 8-bit integers, and the third input containing a column-by-column
|
||||
8×2 matrix of 8-bit integers. For each quadword, they multiply the second
|
||||
input matrix by the third input matrix using natural arithmetic and then add
|
||||
the result to the first input using modular arithmetic.
|
||||
|
||||
Source:
|
||||
https://developer.arm.com/documentation/100987/0000
|
||||
}];
|
||||
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
|
||||
let arguments = (ins
|
||||
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
|
||||
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
|
||||
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
|
||||
);
|
||||
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
|
||||
}
|
||||
|
||||
def VectorScaleOp : ArmSVE_Op<"vector_scale",
|
||||
[NoSideEffect]> {
|
||||
let summary = "Load vector scale size";
|
||||
let description = [{
|
||||
The vector_scale op returns the scale of the scalable vectors, a positive
|
||||
integer value that is constant at runtime but unknown at compile time.
|
||||
The scale of the vector indicates the multiplicity of the vectors and
|
||||
vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to
|
||||
vector_scale consecutive vector<4xi32>; and an operation on an
|
||||
!arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale
|
||||
times, once on each <4xi32> segment of the scalable vector. The vector_scale
|
||||
op can be used to calculate the step in vector-length agnostic (VLA) loops.
|
||||
}];
|
||||
let results = (outs Index:$res);
|
||||
let assemblyFormat =
|
||||
"attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
#endif // ARMSVE_OPS
|
|
@ -0,0 +1,29 @@
|
|||
//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares the Target dialect for ArmSVE in MLIR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
|
||||
#define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H
|
|
@ -0,0 +1 @@
|
|||
add_mlir_dialect(ArmSVE arm_sve ArmSVE)
|
|
@ -1,6 +1,7 @@
|
|||
add_subdirectory(Affine)
|
||||
add_subdirectory(Async)
|
||||
add_subdirectory(ArmNeon)
|
||||
add_subdirectory(ArmSVE)
|
||||
add_subdirectory(AVX512)
|
||||
add_subdirectory(GPU)
|
||||
add_subdirectory(Linalg)
|
||||
|
|
|
@ -37,3 +37,9 @@ add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/)
|
|||
set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td)
|
||||
mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions)
|
||||
add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen)
|
||||
|
||||
add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE)
|
||||
add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/)
|
||||
set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td)
|
||||
mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions)
|
||||
add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen)
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the basic operations for the LLVMArmSVE dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVMIR_ARMSVE_OPS
|
||||
#define LLVMIR_ARMSVE_OPS
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LLVMArmSVE dialect definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LLVMArmSVE_Dialect : Dialect {
|
||||
let name = "llvm_arm_sve";
|
||||
let cppNamespace = "::mlir::LLVM";
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
|
||||
list<OpTrait> traits =[]> :
|
||||
LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
|
||||
/*string opName=*/mnemonic,
|
||||
/*string enumName=*/mnemonic,
|
||||
/*list<int> overloadedResults=*/[0],
|
||||
/*list<int> overloadedOperands=*/[], // defined by result overload
|
||||
/*list<OpTrait> traits=*/traits,
|
||||
/*int numResults=*/1>;
|
||||
|
||||
class LLVMArmSVE_IntrBinaryOverloadedOp<string mnemonic,
|
||||
list<OpTrait> traits = []> :
|
||||
LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
|
||||
/*string opName=*/mnemonic,
|
||||
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
|
||||
/*list<int> overloadedResults=*/[0],
|
||||
/*list<int> overloadedOperands=*/[], // defined by result overload
|
||||
/*list<OpTrait> traits=*/traits,
|
||||
/*int numResults=*/1>;
|
||||
|
||||
def LLVM_aarch64_arm_sve_ummla :
|
||||
LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_aarch64_arm_sve_smmla :
|
||||
LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_aarch64_arm_sve_sdot :
|
||||
LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_aarch64_arm_sve_udot :
|
||||
LLVMArmSVE_IntrBinaryOverloadedOp<"udot">,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_vector_scale :
|
||||
LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
|
||||
|
||||
#endif // ARMSVE_OPS
|
|
@ -0,0 +1,24 @@
|
|||
//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares the Target dialect for LLVMArmSVE in MLIR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
|
||||
#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
|
|
@ -17,10 +17,12 @@
|
|||
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
|
@ -54,6 +56,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
LLVM::LLVMAVX512Dialect,
|
||||
LLVM::LLVMDialect,
|
||||
LLVM::LLVMArmNeonDialect,
|
||||
LLVM::LLVMArmSVEDialect,
|
||||
linalg::LinalgDialect,
|
||||
scf::SCFDialect,
|
||||
omp::OpenMPDialect,
|
||||
|
@ -62,6 +65,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
|||
quant::QuantizationDialect,
|
||||
spirv::SPIRVDialect,
|
||||
StandardOpsDialect,
|
||||
arm_sve::ArmSVEDialect,
|
||||
vector::VectorDialect,
|
||||
NVVM::NVVMDialect,
|
||||
ROCDL::ROCDLDialect,
|
||||
|
|
|
@ -24,6 +24,7 @@ void registerToNVVMIRTranslation();
|
|||
void registerToROCDLIRTranslation();
|
||||
void registerArmNeonToLLVMIRTranslation();
|
||||
void registerAVX512ToLLVMIRTranslation();
|
||||
void registerArmSVEToLLVMIRTranslation();
|
||||
|
||||
// This function should be called before creating any MLIRContext if one
|
||||
// expects all the possible translations to be made available to the context
|
||||
|
@ -38,6 +39,7 @@ inline void registerAllTranslations() {
|
|||
registerToROCDLIRTranslation();
|
||||
registerArmNeonToLLVMIRTranslation();
|
||||
registerAVX512ToLLVMIRTranslation();
|
||||
registerArmSVEToLLVMIRTranslation();
|
||||
return true;
|
||||
}();
|
||||
(void)initOnce;
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::arm_sve;
|
||||
using namespace mlir::vector;
|
||||
|
||||
using SdotOpLowering =
|
||||
OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
|
||||
|
||||
using SmmlaOpLowering =
|
||||
OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
|
||||
|
||||
using UdotOpLowering =
|
||||
OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
|
||||
|
||||
using UmmlaOpLowering =
|
||||
OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
|
||||
|
||||
using VectorScaleOpLowering =
|
||||
OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
|
||||
|
||||
// Extract an LLVM IR type from the LLVM IR dialect type.
|
||||
static LLVM::LLVMType unwrap(Type type) {
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto *mlirContext = type.getContext();
|
||||
auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
|
||||
if (!wrappedLLVMType)
|
||||
emitError(UnknownLoc::get(mlirContext),
|
||||
"conversion resulted in a non-LLVM type");
|
||||
return wrappedLLVMType;
|
||||
}
|
||||
|
||||
static Optional<Type>
|
||||
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
|
||||
LLVMTypeConverter &converter) {
|
||||
auto elementType = unwrap(converter.convertType(svType.getElementType()));
|
||||
if (!elementType)
|
||||
return {};
|
||||
|
||||
auto sVectorType =
|
||||
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
|
||||
return sVectorType;
|
||||
}
|
||||
|
||||
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
|
||||
void mlir::populateArmSVEToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
converter.addConversion([&converter](ScalableVectorType svType) {
|
||||
return convertScalableVectorTypeToLLVM(svType, converter);
|
||||
});
|
||||
// clang-format off
|
||||
patterns.insert<SdotOpLowering,
|
||||
SmmlaOpLowering,
|
||||
UdotOpLowering,
|
||||
UmmlaOpLowering,
|
||||
VectorScaleOpLowering>(converter);
|
||||
// clang-format on
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
add_mlir_conversion_library(MLIRArmSVEToLLVM
|
||||
ArmSVEToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArmSVE
|
||||
MLIRLLVMArmSVE
|
||||
MLIRLLVMIR
|
||||
MLIRStandardToLLVM
|
||||
MLIRTransforms
|
||||
)
|
|
@ -20,6 +20,7 @@ add_subdirectory(ShapeToStandard)
|
|||
add_subdirectory(SPIRVToLLVM)
|
||||
add_subdirectory(StandardToLLVM)
|
||||
add_subdirectory(StandardToSPIRV)
|
||||
add_subdirectory(ArmSVEToLLVM)
|
||||
add_subdirectory(VectorToROCDL)
|
||||
add_subdirectory(VectorToLLVM)
|
||||
add_subdirectory(VectorToSCF)
|
||||
|
|
|
@ -26,6 +26,7 @@ class GPUModuleOp;
|
|||
|
||||
namespace LLVM {
|
||||
class LLVMArmNeonDialect;
|
||||
class LLVMArmSVEDialect;
|
||||
class LLVMAVX512Dialect;
|
||||
class LLVMDialect;
|
||||
} // end namespace LLVM
|
||||
|
|
|
@ -19,6 +19,9 @@ add_mlir_conversion_library(MLIRVectorToLLVM
|
|||
MLIRAVX512ToLLVM
|
||||
MLIRLLVMArmNeon
|
||||
MLIRLLVMAVX512
|
||||
MLIRArmSVE
|
||||
MLIRArmSVEToLLVM
|
||||
MLIRLLVMArmSVE
|
||||
MLIRLLVMIR
|
||||
MLIRStandardToLLVM
|
||||
MLIRTargetLLVMIRModuleTranslation
|
||||
|
|
|
@ -12,12 +12,15 @@
|
|||
|
||||
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
|
||||
#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
|
||||
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
|
||||
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
@ -32,6 +35,7 @@ struct LowerVectorToLLVMPass
|
|||
this->reassociateFPReductions = options.reassociateFPReductions;
|
||||
this->enableIndexOptimizations = options.enableIndexOptimizations;
|
||||
this->enableArmNeon = options.enableArmNeon;
|
||||
this->enableArmSVE = options.enableArmSVE;
|
||||
this->enableAVX512 = options.enableAVX512;
|
||||
}
|
||||
// Override explicitly to allow conditional dialect dependence.
|
||||
|
@ -39,6 +43,8 @@ struct LowerVectorToLLVMPass
|
|||
registry.insert<LLVM::LLVMDialect>();
|
||||
if (enableArmNeon)
|
||||
registry.insert<LLVM::LLVMArmNeonDialect>();
|
||||
if (enableArmSVE)
|
||||
registry.insert<LLVM::LLVMArmSVEDialect>();
|
||||
if (enableAVX512)
|
||||
registry.insert<LLVM::LLVMAVX512Dialect>();
|
||||
}
|
||||
|
@ -73,6 +79,11 @@ void LowerVectorToLLVMPass::runOnOperation() {
|
|||
target.addIllegalDialect<arm_neon::ArmNeonDialect>();
|
||||
populateArmNeonToLLVMConversionPatterns(converter, patterns);
|
||||
}
|
||||
if (enableArmSVE) {
|
||||
target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
|
||||
target.addIllegalDialect<arm_sve::ArmSVEDialect>();
|
||||
populateArmSVEToLLVMConversionPatterns(converter, patterns);
|
||||
}
|
||||
if (enableAVX512) {
|
||||
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
|
||||
target.addIllegalDialect<avx512::AVX512Dialect>();
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
add_mlir_dialect_library(MLIRArmSVE
|
||||
IR/ArmSVEDialect.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE
|
||||
|
||||
DEPENDS
|
||||
MLIRArmSVEIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
|
@ -0,0 +1,57 @@
|
|||
//===- ArmSVEDialect.cpp - MLIR ArmSVE dialect implementation -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the ArmSVE dialect and its operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void arm_sve::ArmSVEDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
|
||||
>();
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScalableVectorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const {
|
||||
llvm::SMLoc typeLoc = parser.getCurrentLocation();
|
||||
auto genType = generatedTypeParser(getContext(), parser, "vector");
|
||||
if (genType != Type())
|
||||
return genType;
|
||||
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
|
||||
return Type();
|
||||
}
|
||||
|
||||
void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
if (failed(generatedTypePrinter(type, os)))
|
||||
llvm_unreachable("unexpected 'arm_sve' type kind");
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
add_subdirectory(Affine)
|
||||
add_subdirectory(ArmNeon)
|
||||
add_subdirectory(ArmSVE)
|
||||
add_subdirectory(Async)
|
||||
add_subdirectory(AVX512)
|
||||
add_subdirectory(GPU)
|
||||
|
|
|
@ -70,6 +70,27 @@ add_mlir_dialect_library(MLIRLLVMArmNeon
|
|||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRLLVMArmSVE
|
||||
IR/LLVMArmSVEDialect.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
|
||||
|
||||
DEPENDS
|
||||
MLIRLLVMArmSVEIncGen
|
||||
MLIRLLVMArmSVEConversionsIncGen
|
||||
intrinsics_gen
|
||||
|
||||
LINK_COMPONENTS
|
||||
AsmParser
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRNVVMIR
|
||||
IR/NVVMDialect.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the LLVMArmSVE dialect and its operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/IR/IntrinsicsAArch64.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void LLVM::LLVMArmSVEDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
|
|
@ -74,6 +74,25 @@ add_mlir_translation_library(MLIRTargetArmNeon
|
|||
MLIRTargetLLVMIRModuleTranslation
|
||||
)
|
||||
|
||||
add_mlir_translation_library(MLIRTargetArmSVE
|
||||
LLVMIR/LLVMArmSVEIntr.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
|
||||
|
||||
DEPENDS
|
||||
MLIRLLVMArmSVEConversionsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLLVMArmSVE
|
||||
MLIRLLVMIR
|
||||
MLIRTargetLLVMIRModuleTranslation
|
||||
)
|
||||
|
||||
add_mlir_translation_library(MLIRTargetNVVMIR
|
||||
LLVMIR/ConvertToNVVMIR.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a translation between the MLIR LLVM and ArmSVE dialects
|
||||
// and LLVM IR with Arm SVE intrinsics.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||
#include "mlir/Translation.h"
|
||||
#include "llvm/IR/IntrinsicsAArch64.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation {
|
||||
friend LLVM::ModuleTranslation;
|
||||
|
||||
public:
|
||||
using LLVM::ModuleTranslation::ModuleTranslation;
|
||||
|
||||
protected:
|
||||
LogicalResult convertOperation(Operation &opInst,
|
||||
llvm::IRBuilder<> &builder) override {
|
||||
#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc"
|
||||
|
||||
return LLVM::ModuleTranslation::convertOperation(opInst, builder);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
static std::unique_ptr<llvm::Module>
|
||||
translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
|
||||
StringRef name) {
|
||||
return LLVM::ModuleTranslation::translateModule<LLVMArmSVEModuleTranslation>(
|
||||
m, llvmContext, name);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerArmSVEToLLVMIRTranslation() {
|
||||
TranslateFromMLIRRegistration reg(
|
||||
"arm-sve-mlir-to-llvmir",
|
||||
[](ModuleOp module, raw_ostream &output) {
|
||||
llvm::LLVMContext llvmContext;
|
||||
auto llvmModule = translateLLVMArmSVEModuleToLLVMIR(
|
||||
module, llvmContext, "LLVMDialectModule");
|
||||
if (!llvmModule)
|
||||
return failure();
|
||||
|
||||
llvmModule->print(output, nullptr);
|
||||
return success();
|
||||
},
|
||||
[](DialectRegistry ®istry) {
|
||||
registry.insert<LLVM::LLVMArmSVEDialect, LLVM::LLVMDialect>();
|
||||
});
|
||||
}
|
||||
} // namespace mlir
|
|
@ -0,0 +1,47 @@
|
|||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s
|
||||
|
||||
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>)
|
||||
-> !arm_sve.vector<4xi32> {
|
||||
// CHECK: llvm_arm_sve.sdot
|
||||
%0 = arm_sve.sdot %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>)
|
||||
-> !arm_sve.vector<4xi32> {
|
||||
// CHECK: llvm_arm_sve.smmla
|
||||
%0 = arm_sve.smmla %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>)
|
||||
-> !arm_sve.vector<4xi32> {
|
||||
// CHECK: llvm_arm_sve.udot
|
||||
%0 = arm_sve.udot %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>)
|
||||
-> !arm_sve.vector<4xi32> {
|
||||
// CHECK: llvm_arm_sve.ummla
|
||||
%0 = arm_sve.ummla %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @get_vector_scale() -> index {
|
||||
// CHECK: llvm_arm_sve.vscale
|
||||
%0 = arm_sve.vector_scale : index
|
||||
return %0 : index
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
|
||||
|
||||
func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
|
||||
%0 = arm_sve.sdot %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
|
||||
%0 = arm_sve.smmla %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32
|
||||
%0 = arm_sve.udot %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
|
||||
%b: !arm_sve.vector<16xi8>,
|
||||
%c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> {
|
||||
// CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3
|
||||
%0 = arm_sve.ummla %c, %a, %b :
|
||||
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
|
||||
return %0 : !arm_sve.vector<4xi32>
|
||||
}
|
||||
|
||||
func @get_vector_scale() -> index {
|
||||
// CHECK: arm_sve.vector_scale : index
|
||||
%0 = arm_sve.vector_scale : index
|
||||
return %0 : index
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_sdot
|
||||
llvm.func @arm_sve_sdot(%arg0: !llvm.vec<? x 16 x i8>,
|
||||
%arg1: !llvm.vec<? x 16 x i8>,
|
||||
%arg2: !llvm.vec<? x 4 x i32>)
|
||||
-> !llvm.vec<? x 4 x i32> {
|
||||
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
|
||||
%0 = "llvm_arm_sve.sdot"(%arg2, %arg0, %arg1) :
|
||||
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
|
||||
-> !llvm.vec<? x 4 x i32>
|
||||
llvm.return %0 : !llvm.vec<? x 4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_smmla
|
||||
llvm.func @arm_sve_smmla(%arg0: !llvm.vec<? x 16 x i8>,
|
||||
%arg1: !llvm.vec<? x 16 x i8>,
|
||||
%arg2: !llvm.vec<? x 4 x i32>)
|
||||
-> !llvm.vec<? x 4 x i32> {
|
||||
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
|
||||
%0 = "llvm_arm_sve.smmla"(%arg2, %arg0, %arg1) :
|
||||
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
|
||||
-> !llvm.vec<? x 4 x i32>
|
||||
llvm.return %0 : !llvm.vec<? x 4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_udot
|
||||
llvm.func @arm_sve_udot(%arg0: !llvm.vec<? x 16 x i8>,
|
||||
%arg1: !llvm.vec<? x 16 x i8>,
|
||||
%arg2: !llvm.vec<? x 4 x i32>)
|
||||
-> !llvm.vec<? x 4 x i32> {
|
||||
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
|
||||
%0 = "llvm_arm_sve.udot"(%arg2, %arg0, %arg1) :
|
||||
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
|
||||
-> !llvm.vec<? x 4 x i32>
|
||||
llvm.return %0 : !llvm.vec<? x 4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_ummla
|
||||
llvm.func @arm_sve_ummla(%arg0: !llvm.vec<? x 16 x i8>,
|
||||
%arg1: !llvm.vec<? x 16 x i8>,
|
||||
%arg2: !llvm.vec<? x 4 x i32>)
|
||||
-> !llvm.vec<? x 4 x i32> {
|
||||
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
|
||||
%0 = "llvm_arm_sve.ummla"(%arg2, %arg0, %arg1) :
|
||||
(!llvm.vec<? x 4 x i32>, !llvm.vec<? x 16 x i8>, !llvm.vec<? x 16 x i8>)
|
||||
-> !llvm.vec<? x 4 x i32>
|
||||
llvm.return %0 : !llvm.vec<? x 4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define i64 @get_vector_scale()
|
||||
llvm.func @get_vector_scale() -> !llvm.i64 {
|
||||
// CHECK: call i64 @llvm.vscale.i64()
|
||||
%0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64
|
||||
llvm.return %0 : !llvm.i64
|
||||
}
|
Loading…
Reference in New Issue