llvm-project/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp

142 lines
5.4 KiB
C++

//===- LegalizeForLLVMExport.cpp - Prepare AVX512 for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AVX512/Transforms.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::avx512;
/// Extracts the "main" vector element type from the given AVX512 operation.
template <typename OpTy>
static Type getSrcVectorElementType(OpTy op) {
return op.src().getType().template cast<VectorType>().getElementType();
}
template <>
Type getSrcVectorElementType(Vp2IntersectOp op) {
return op.a().getType().template cast<VectorType>().getElementType();
}
namespace {
/// Base conversion for AVX512 ops that can be lowered to one of the two
/// intrinsics based on the bitwidth of their "main" vector element type. This
/// relies on the to-LLVM-dialect conversion helpers to correctly pack the
/// results of multi-result intrinsic ops.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
explicit LowerToIntrinsic(LLVMTypeConverter &converter)
: OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
LLVMTypeConverter &getTypeConverter() const {
return *static_cast<LLVMTypeConverter *>(
OpConversionPattern<OpTy>::getTypeConverter());
}
LogicalResult
matchAndRewrite(OpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type elementType = getSrcVectorElementType<OpTy>(op);
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
if (bitwidth == 32)
return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
operands, getTypeConverter(),
rewriter);
if (bitwidth == 64)
return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
operands, getTypeConverter(),
rewriter);
return rewriter.notifyMatchFailure(
op, "expected 'src' to be either f32 or f64");
}
};
struct MaskCompressOpConversion
: public ConvertOpToLLVMPattern<MaskCompressOp> {
using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MaskCompressOp::Adaptor adaptor(operands);
auto opType = adaptor.a().getType();
Value src;
if (op.src()) {
src = adaptor.src();
} else if (op.constant_src()) {
src = rewriter.create<ConstantOp>(op.getLoc(), opType,
op.constant_srcAttr());
} else {
Attribute zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
}
rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
src, adaptor.k());
return success();
}
};
/// An entry associating the "main" AVX512 op with its instantiations for
/// vectors of 32-bit and 64-bit elements.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct RegEntry {
using MainOp = OpTy;
using Intr32Op = Intr32OpTy;
using Intr64Op = Intr64OpTy;
};
/// A container for op association entries facilitating the configuration of
/// dialect conversion.
template <typename... Args>
struct RegistryImpl {
/// Registers the patterns specializing the "main" op to one of the
/// "intrinsic" ops depending on elemental type.
static void registerPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns
.add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
typename Args::Intr64Op>...>(converter);
}
/// Configures the conversion target to lower out "main" ops.
static void configureTarget(LLVMConversionTarget &target) {
target.addIllegalOp<typename Args::MainOp...>();
target.addLegalOp<typename Args::Intr32Op...>();
target.addLegalOp<typename Args::Intr64Op...>();
}
};
using Registry = RegistryImpl<
RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
} // namespace
/// Populate the given list with patterns that convert from AVX512 to LLVM.
void mlir::populateAVX512LegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
patterns.add<MaskCompressOpConversion>(converter);
}
void mlir::configureAVX512LegalizeForExportTarget(
LLVMConversionTarget &target) {
Registry::configureTarget(target);
target.addLegalOp<MaskCompressIntrOp>();
target.addIllegalOp<MaskCompressOp>();
}