[mlir][ArmSVE] Cleanup dialect registration

ArmSVE dialect is behind the recent changes in how the Vector dialect
interacts with backend vector dialects and the MLIR -> LLVM IR
translation module. This patch cleans up ArmSVE initialization within
Vector and removes the need for an LLVMArmSVE dialect.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D100171
This commit is contained in:
Javier Setoain 2021-04-16 15:51:17 +02:00 committed by Alex Zinenko
parent 1f8a6dcf12
commit b739bada9d
32 changed files with 287 additions and 372 deletions

View File

@ -1,24 +0,0 @@
//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_
namespace mlir {
class LLVMTypeConverter;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM.
void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_

View File

@ -14,6 +14,8 @@
#define ARMSVE_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td"
//===----------------------------------------------------------------------===//
// ArmSVE dialect definition
@ -93,35 +95,6 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
}];
}
//===----------------------------------------------------------------------===//
// ArmSVE type traits
//===----------------------------------------------------------------------===//
def IsScalableVectorTypePred :
CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
class ScalableVectorOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
"$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
"scalable vector">;
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
And<[IsScalableVectorTypePred,
Or<!foreach(allowedlength, allowedLengths, CPred<
[{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
# allowedlength>)>]>;
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
IsScalableVectorOfLengthPred<allowedLengths>,
" of length " # !interleave(allowedLengths, "/")>;
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[ScalableVectorOf<allowedTypes>.predicate,
ScalableVectorOfLength<allowedLengths>.predicate]>,
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary>;
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@ -129,6 +102,26 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
class ArmSVE_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ArmSVE_Dialect, mnemonic, traits> {}
class ArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
list<OpTrait> traits =[]> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/mnemonic,
/*string enumName=*/mnemonic,
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<OpTrait> traits = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
def SdotOp : ArmSVE_Op<"sdot",
[NoSideEffect,
AllTypesMatch<["src1", "src2"]>,
@ -273,4 +266,23 @@ def VectorScaleOp : ArmSVE_Op<"vector_scale",
"attr-dict `:` type($res)";
}
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
def UdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"udot">,
Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
def VectorScaleIntrOp:
ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
#endif // ARMSVE_OPS

View File

@ -0,0 +1,53 @@
//===-- ArmSVEOpBase.td - Base op definitions for ArmSVE ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the base operation definition file for ArmSVE scalable vector types.
//
//===----------------------------------------------------------------------===//
#ifndef ARMSVE_OP_BASE
#define ARMSVE_OP_BASE
//===----------------------------------------------------------------------===//
// ArmSVE scalable vector type constraints
//===----------------------------------------------------------------------===//
def IsScalableVectorTypePred :
CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">;
class ScalableVectorOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsScalableVectorTypePred,
"$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()",
"scalable vector">;
// Whether the number of elements of a scalable vector is from the given
// `allowedLengths` list
class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
And<[IsScalableVectorTypePred,
Or<!foreach(allowedlength, allowedLengths, CPred<
[{$_self.cast<::mlir::arm_sve::ScalableVectorType>().getNumElements() == }]
# allowedlength>)>]>;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list
class ScalableVectorOfLength<list<int> allowedLengths> : Type<
IsScalableVectorOfLengthPred<allowedLengths>,
" of length " # !interleave(allowedLengths, "/"),
"::mlir::arm_sve::ScalableVectorType">;
// Any scalable vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes` list
class ScalableVectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[ScalableVectorOf<allowedTypes>.predicate,
ScalableVectorOfLength<allowedLengths>.predicate]>,
ScalableVectorOf<allowedTypes>.summary #
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::arm_sve::ScalableVectorType">;
#endif // ARMSVE_OP_BASE

View File

@ -1,2 +1,6 @@
add_mlir_dialect(ArmSVE arm_sve ArmSVE)
add_mlir_doc(ArmSVE ArmSVE Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS ArmSVE.td)
mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSVEConversionsIncGen)

View File

@ -0,0 +1,30 @@
//===- Transforms.h - ArmSVE Dialect Transformation Entrypoints -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_H
#define MLIR_DIALECT_ARMSVE_TRANSFORMS_H
namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
/// Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM
/// intrinsics.
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
} // namespace mlir
#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_H

View File

@ -35,8 +35,3 @@ set(LLVM_TARGET_DEFINITIONS ROCDLOps.td)
mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRROCDLConversionsIncGen)
add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE)
add_mlir_doc(LLVMArmSVE LLVMArmSve Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td)
mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen)

View File

@ -1,70 +0,0 @@
//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the basic operations for the LLVMArmSVE dialect.
//
//===----------------------------------------------------------------------===//
#ifndef LLVMIR_ARMSVE_OPS
#define LLVMIR_ARMSVE_OPS
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
//===----------------------------------------------------------------------===//
// LLVMArmSVE dialect definition
//===----------------------------------------------------------------------===//
def LLVMArmSVE_Dialect : Dialect {
let name = "llvm_arm_sve";
let cppNamespace = "::mlir::LLVM";
}
//----------------------------------------------------------------------------//
// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system
//----------------------------------------------------------------------------//
class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<string mnemonic,
list<OpTrait> traits =[]> :
LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
/*string opName=*/mnemonic,
/*string enumName=*/mnemonic,
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
class LLVMArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<OpTrait> traits = []> :
LLVM_IntrOpBase</*Dialect dialect=*/LLVMArmSVE_Dialect,
/*string opName=*/mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/[0],
/*list<int> overloadedOperands=*/[], // defined by result overload
/*list<OpTrait> traits=*/traits,
/*int numResults=*/1>;
def LLVM_aarch64_arm_sve_ummla :
LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_arm_sve_smmla :
LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_arm_sve_sdot :
LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_aarch64_arm_sve_udot :
LLVMArmSVE_IntrBinaryOverloadedOp<"udot">,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_vector_scale :
LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
#endif // ARMSVE_OPS

View File

@ -1,24 +0,0 @@
//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Target dialect for LLVMArmSVE in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc"
#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_

View File

@ -22,7 +22,6 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@ -59,7 +58,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
DLTIDialect,
gpu::GPUDialect,
LLVM::LLVMDialect,
LLVM::LLVMArmSVEDialect,
linalg::LinalgDialect,
math::MathDialect,
memref::MemRefDialect,

View File

@ -16,7 +16,7 @@
#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
@ -31,7 +31,7 @@ class DialectRegistry;
static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
registerArmNeonDialectTranslation(registry);
registerAMXDialectTranslation(registry);
registerLLVMArmSVEDialectTranslation(registry);
registerArmSVEDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
registerNVVMDialectTranslation(registry);
registerOpenMPDialectTranslation(registry);

View File

@ -0,0 +1,31 @@
//=======- ArmSVEToLLVMIRTranslation.h - ArmSVE to LLVM IR --*- C++ -*-=======//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for ArmSVE dialect to LLVM IR translation.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSVE_ARMSVETOLLVMIRTRANSLATION_H
#define MLIR_TARGET_LLVMIR_DIALECT_ARMSVE_ARMSVETOLLVMIRTRANSLATION_H
namespace mlir {
class DialectRegistry;
class MLIRContext;
/// Register the ArmSVE dialect and the translation from it to the LLVM IR in
/// the given registry;
void registerArmSVEDialectTranslation(DialectRegistry &registry);
/// Register the ArmSVE dialect and the translation from it in the registry
/// associated with the given context.
void registerArmSVEDialectTranslation(MLIRContext &context);
} // namespace mlir
#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSVE_ARMSVETOLLVMIRTRANSLATION_H

View File

@ -1,32 +0,0 @@
//===- LLVMArmSVEToLLVMIRTranslation.h - LLVMArmSVE to LLVM IR --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for LLVMArmSVE dialect to LLVM IR
// translation.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMARMSVE_LLVMARMSVETOLLVMIRTRANSLATION_H
#define MLIR_TARGET_LLVMIR_DIALECT_LLVMARMSVE_LLVMARMSVETOLLVMIRTRANSLATION_H
namespace mlir {
class DialectRegistry;
class MLIRContext;
/// Register the LLVMArmSVE dialect and the translation from it to the LLVM IR
/// in the given registry;
void registerLLVMArmSVEDialectTranslation(DialectRegistry &registry);
/// Register the LLVMArmSVE dialect and the translation from it in the registry
/// associated with the given context.
void registerLLVMArmSVEDialectTranslation(MLIRContext &context);
} // namespace mlir
#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMARMSVE_LLVMARMSVETOLLVMIRTRANSLATION_H

View File

@ -1,19 +0,0 @@
add_mlir_conversion_library(MLIRArmSVEToLLVM
ArmSVEToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRArmSVE
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRStandardToLLVM
MLIRTransforms
)

View File

@ -22,7 +22,6 @@ add_subdirectory(StandardToSPIRV)
add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToSCF)
add_subdirectory(TosaToStandard)
add_subdirectory(ArmSVEToLLVM)
add_subdirectory(VectorToROCDL)
add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)

View File

@ -29,7 +29,6 @@ class GPUModuleOp;
} // end namespace gpu
namespace LLVM {
class LLVMArmSVEDialect;
class LLVMDialect;
} // end namespace LLVM

View File

@ -14,11 +14,10 @@ add_mlir_conversion_library(MLIRVectorToLLVM
LINK_LIBS PUBLIC
MLIRArmNeon
MLIRArmSVE
MLIRArmSVETransforms
MLIRAMX
MLIRAMXTransforms
MLIRArmSVE
MLIRArmSVEToLLVM
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRMemRef
MLIRStandardToLLVM

View File

@ -10,14 +10,13 @@
#include "../PassDetail.h"
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -47,7 +46,7 @@ struct LowerVectorToLLVMPass
if (enableArmNeon)
registry.insert<arm_neon::ArmNeonDialect>();
if (enableArmSVE)
registry.insert<LLVM::LLVMArmSVEDialect>();
registry.insert<arm_sve::ArmSVEDialect>();
if (enableAMX)
registry.insert<amx::AMXDialect>();
if (enableX86Vector)
@ -90,26 +89,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addLegalDialect<arm_neon::ArmNeonDialect>();
}
if (enableArmSVE) {
target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
target.addIllegalDialect<arm_sve::ArmSVEDialect>();
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, converter);
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
populateArmSVEToLLVMConversionPatterns(converter, patterns);
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (enableAMX) {
configureAMXLegalizeForExportTarget(target);

View File

@ -1,13 +1,2 @@
add_mlir_dialect_library(MLIRArmSVE
IR/ArmSVEDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE
DEPENDS
MLIRArmSVEIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRSideEffectInterfaces
)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"

View File

@ -0,0 +1,14 @@
add_mlir_dialect_library(MLIRArmSVE
ArmSVEDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE
DEPENDS
MLIRArmSVEIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSideEffectInterfaces
)

View File

@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
LINK_LIBS PUBLIC
MLIRArmSVE
MLIRIR
MLIRLLVMIR
MLIRStandardToLLVM
)

View File

@ -1,4 +1,4 @@
//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===//
//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,34 +6,16 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
using namespace mlir::vector;
using SdotOpLowering =
OneToOneConvertToLLVMPattern<SdotOp, LLVM::aarch64_arm_sve_sdot>;
using SmmlaOpLowering =
OneToOneConvertToLLVMPattern<SmmlaOp, LLVM::aarch64_arm_sve_smmla>;
using UdotOpLowering =
OneToOneConvertToLLVMPattern<UdotOp, LLVM::aarch64_arm_sve_udot>;
using UmmlaOpLowering =
OneToOneConvertToLLVMPattern<UmmlaOp, LLVM::aarch64_arm_sve_ummla>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
// Extract an LLVM IR type from the LLVM IR dialect type.
static Type unwrap(Type type) {
@ -95,9 +77,19 @@ static Optional<Value> addUnrealizedCast(OpBuilder &builder,
.getResult(0);
}
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// Populate conversion patterns
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, converter);
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
@ -115,3 +107,32 @@ void mlir::populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter,
VectorScaleOpLowering>(converter);
// clang-format on
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addLegalOp<SdotIntrOp>();
target.addIllegalOp<SdotOp>();
target.addLegalOp<SmmlaIntrOp>();
target.addIllegalOp<SmmlaOp>();
target.addLegalOp<UdotIntrOp>();
target.addIllegalOp<UdotOp>();
target.addLegalOp<UmmlaIntrOp>();
target.addIllegalOp<UmmlaOp>();
target.addLegalOp<VectorScaleIntrOp>();
target.addIllegalOp<VectorScaleOp>();
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
}

View File

@ -29,27 +29,6 @@ add_mlir_dialect_library(MLIRLLVMIR
MLIRSupport
)
add_mlir_dialect_library(MLIRLLVMArmSVE
IR/LLVMArmSVEDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
DEPENDS
MLIRLLVMArmSVEIncGen
MLIRLLVMArmSVEConversionsIncGen
intrinsics_gen
LINK_COMPONENTS
AsmParser
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMIR
MLIRSideEffectInterfaces
)
add_mlir_dialect_library(MLIRNVVMIR
IR/NVVMDialect.cpp

View File

@ -1,31 +0,0 @@
//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the LLVMArmSVE dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "llvm/IR/IntrinsicsAArch64.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
void LLVM::LLVMArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"
>();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc"

View File

@ -37,9 +37,9 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
LINK_LIBS PUBLIC
MLIRArmNeonToLLVMIRTranslation
MLIRArmSVEToLLVMIRTranslation
MLIRAMXToLLVMIRTranslation
MLIRX86VectorToLLVMIRTranslation
MLIRLLVMArmSVEToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
MLIRNVVMToLLVMIRTranslation
MLIROpenMPToLLVMIRTranslation

View File

@ -1,4 +1,4 @@
//===- LLVMArmSVEToLLVMIRTranslation.cpp - Translate LLVMArmSVE to LLVM IR-===//
//======- ArmSVEToLLVMIRTranslation.cpp - Translate ArmSVE to LLVM IR -=======//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,13 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR LLVMArmSVE dialect and
// LLVM IR.
// This file implements a translation between the ArmSVE dialect and LLVM IR.
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
@ -24,8 +23,8 @@ using namespace mlir::LLVM;
namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the LLVMArmSVE dialect to LLVM IR.
class LLVMArmSVEDialectLLVMIRTranslationInterface
/// to the ArmSVE dialect to LLVM IR.
class ArmSVEDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
@ -36,21 +35,21 @@ public:
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const final {
Operation &opInst = *op;
#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc"
#include "mlir/Dialect/ArmSVE/ArmSVEConversions.inc"
return failure();
}
};
} // end namespace
void mlir::registerLLVMArmSVEDialectTranslation(DialectRegistry &registry) {
registry.insert<LLVM::LLVMArmSVEDialect>();
registry.addDialectInterface<LLVM::LLVMArmSVEDialect,
LLVMArmSVEDialectLLVMIRTranslationInterface>();
void mlir::registerArmSVEDialectTranslation(DialectRegistry &registry) {
registry.insert<arm_sve::ArmSVEDialect>();
registry.addDialectInterface<arm_sve::ArmSVEDialect,
ArmSVEDialectLLVMIRTranslationInterface>();
}
void mlir::registerLLVMArmSVEDialectTranslation(MLIRContext &context) {
void mlir::registerArmSVEDialectTranslation(MLIRContext &context) {
DialectRegistry registry;
registerLLVMArmSVEDialectTranslation(registry);
registerArmSVEDialectTranslation(registry);
context.appendDialectRegistry(registry);
}

View File

@ -0,0 +1,16 @@
add_mlir_translation_library(MLIRArmSVEToLLVMIRTranslation
ArmSVEToLLVMIRTranslation.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRArmSVE
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
)

View File

@ -1,6 +1,6 @@
add_subdirectory(ArmNeon)
add_subdirectory(ArmSVE)
add_subdirectory(AMX)
add_subdirectory(LLVMArmSVE)
add_subdirectory(LLVMIR)
add_subdirectory(NVVM)
add_subdirectory(OpenMP)

View File

@ -1,16 +0,0 @@
add_mlir_translation_library(MLIRLLVMArmSVEToLLVMIRTranslation
LLVMArmSVEToLLVMIRTranslation.cpp
DEPENDS
MLIRLLVMArmSVEConversionsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMArmSVE
MLIRLLVMIR
MLIRSupport
MLIRTargetLLVMIRExport
)

View File

@ -4,7 +4,7 @@ func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.sdot
// CHECK: arm_sve.intr.sdot
%0 = arm_sve.sdot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@ -14,7 +14,7 @@ func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.smmla
// CHECK: arm_sve.intr.smmla
%0 = arm_sve.smmla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@ -24,7 +24,7 @@ func @arm_sve_udot(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.udot
// CHECK: arm_sve.intr.udot
%0 = arm_sve.udot %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
@ -34,14 +34,14 @@ func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>,
%b: !arm_sve.vector<16xi8>,
%c: !arm_sve.vector<4xi32>)
-> !arm_sve.vector<4xi32> {
// CHECK: llvm_arm_sve.ummla
// CHECK: arm_sve.intr.ummla
%0 = arm_sve.ummla %c, %a, %b :
!arm_sve.vector<16xi8> to !arm_sve.vector<4xi32>
return %0 : !arm_sve.vector<4xi32>
}
func @get_vector_scale() -> index {
// CHECK: llvm_arm_sve.vscale
// CHECK: arm_sve.vscale
%0 = arm_sve.vector_scale : index
return %0 : index
}

View File

@ -6,7 +6,7 @@ llvm.func @arm_sve_sdot(%arg0: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.sdot"(%arg2, %arg0, %arg1) :
%0 = "arm_sve.intr.sdot"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
@ -18,7 +18,7 @@ llvm.func @arm_sve_smmla(%arg0: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.smmla.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.smmla"(%arg2, %arg0, %arg1) :
%0 = "arm_sve.intr.smmla"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
@ -30,7 +30,7 @@ llvm.func @arm_sve_udot(%arg0: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.udot"(%arg2, %arg0, %arg1) :
%0 = "arm_sve.intr.udot"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
@ -42,7 +42,7 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
%arg2: !llvm.vec<?x4 x i32>)
-> !llvm.vec<?x4 x i32> {
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.ummla.nxv4i32(<vscale x 4
%0 = "llvm_arm_sve.ummla"(%arg2, %arg0, %arg1) :
%0 = "arm_sve.intr.ummla"(%arg2, %arg0, %arg1) :
(!llvm.vec<?x4 x i32>, !llvm.vec<?x16 x i8>, !llvm.vec<?x16 x i8>)
-> !llvm.vec<?x4 x i32>
llvm.return %0 : !llvm.vec<?x4 x i32>
@ -51,6 +51,6 @@ llvm.func @arm_sve_ummla(%arg0: !llvm.vec<?x16 x i8>,
// CHECK-LABEL: define i64 @get_vector_scale()
llvm.func @get_vector_scale() -> i64 {
// CHECK: call i64 @llvm.vscale.i64()
%0 = "llvm_arm_sve.vscale"() : () -> i64
%0 = "arm_sve.vscale"() : () -> i64
llvm.return %0 : i64
}

View File

@ -11,7 +11,6 @@
// CHECK-NEXT: gpu
// CHECK-NEXT: linalg
// CHECK-NEXT: llvm
// CHECK-NEXT: llvm_arm_sve
// CHECK-NEXT: math
// CHECK-NEXT: memref
// CHECK-NEXT: nvvm