2020-06-06 00:53:41 +08:00
|
|
|
//===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file implements a pass to generate ROCDLIR operations for higher-level
|
|
|
|
// Vector operations.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
|
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
|
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
|
|
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::vector;
|
|
|
|
|
|
|
|
static LogicalResult replaceTransferOpWithMubuf(
|
|
|
|
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
|
|
|
|
LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
|
2021-01-05 23:22:53 +08:00
|
|
|
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
|
|
|
|
Value &glc, Value &slc) {
|
2020-06-06 00:53:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
|
|
|
|
xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
static LogicalResult replaceTransferOpWithMubuf(
|
|
|
|
ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
|
|
|
|
LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
|
2021-01-05 23:22:53 +08:00
|
|
|
Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
|
|
|
|
Value &glc, Value &slc) {
|
2020-06-15 21:01:31 +08:00
|
|
|
auto adaptor = TransferWriteOpAdaptor(operands);
|
2020-06-06 00:53:41 +08:00
|
|
|
rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
|
|
|
|
dwordConfig, vindex,
|
|
|
|
offsetSizeInBytes, glc, slc);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
/// Conversion pattern that converts a 1-D vector transfer read/write.
|
|
|
|
/// Note that this conversion pass only converts vector x2 or x4 f32
|
|
|
|
/// types. For unsupported cases, they will fall back to the vector to
|
|
|
|
/// llvm conversion pattern.
|
|
|
|
template <typename ConcreteOp>
|
2020-12-10 10:18:35 +08:00
|
|
|
class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
|
2020-06-06 00:53:41 +08:00
|
|
|
public:
|
2020-12-10 10:18:35 +08:00
|
|
|
using ConvertOpToLLVMPattern<ConcreteOp>::ConvertOpToLLVMPattern;
|
2020-06-06 00:53:41 +08:00
|
|
|
|
|
|
|
LogicalResult
|
2020-12-10 10:18:35 +08:00
|
|
|
matchAndRewrite(ConcreteOp xferOp, ArrayRef<Value> operands,
|
2020-06-06 00:53:41 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2020-06-15 21:01:31 +08:00
|
|
|
typename ConcreteOp::Adaptor adaptor(operands);
|
2020-06-06 00:53:41 +08:00
|
|
|
|
|
|
|
if (xferOp.getVectorType().getRank() > 1 ||
|
|
|
|
llvm::size(xferOp.indices()) == 0)
|
|
|
|
return failure();
|
|
|
|
|
2020-07-11 04:47:51 +08:00
|
|
|
if (!xferOp.permutation_map().isMinorIdentity())
|
2020-06-06 00:53:41 +08:00
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Have it handled in vector->llvm conversion pass.
|
|
|
|
if (!xferOp.isMaskedDim(0))
|
|
|
|
return failure();
|
|
|
|
|
2020-12-10 10:18:35 +08:00
|
|
|
auto toLLVMTy = [&](Type t) {
|
|
|
|
return this->getTypeConverter()->convertType(t);
|
|
|
|
};
|
[mlir] use built-in vector types instead of LLVM dialect types when possible
Continue the convergence between LLVM dialect and built-in types by using the
built-in vector type whenever possible, that is for fixed vectors of built-in
integers and built-in floats. LLVM dialect vector type is still in use for
pointers, less frequent floating point types that do not have a built-in
equivalent, and scalable vectors. However, the top-level `LLVMVectorType` class
has been removed in favor of free functions capable of inspecting both built-in
and LLVM dialect vector types: `LLVM::getVectorElementType`,
`LLVM::getNumVectorElements` and `LLVM::getFixedVectorType`. Additional work is
necessary to design an implemented the extensions to built-in types so as to
remove the `LLVMFixedVectorType` entirely.
Note that the default output format for the built-in vectors does not have
whitespace around the `x` separator, e.g., `vector<4xf32>` as opposed to the
LLVM dialect vector type format that does, e.g., `!llvm.vec<4 x fp128>`. This
required changing the FileCheck patterns in several tests.
Reviewed By: mehdi_amini, silvas
Differential Revision: https://reviews.llvm.org/D94405
2021-01-11 20:58:05 +08:00
|
|
|
auto vecTy = toLLVMTy(xferOp.getVectorType());
|
|
|
|
unsigned vecWidth = LLVM::getVectorNumElements(vecTy).getFixedValue();
|
2020-12-10 10:18:35 +08:00
|
|
|
Location loc = xferOp->getLoc();
|
2020-06-06 00:53:41 +08:00
|
|
|
|
|
|
|
// The backend result vector scalarization have trouble scalarize
|
|
|
|
// <1 x ty> result, exclude the x1 width from the lowering.
|
|
|
|
if (vecWidth != 2 && vecWidth != 4)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Obtain dataPtr and elementType from the memref.
|
2020-12-18 08:26:07 +08:00
|
|
|
auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
|
|
|
|
if (!memRefType)
|
|
|
|
return failure();
|
2020-06-06 00:53:41 +08:00
|
|
|
// MUBUF instruction operate only on addresspace 0(unified) or 1(global)
|
|
|
|
// In case of 3(LDS): fall back to vector->llvm pass
|
|
|
|
// In case of 5(VGPR): wrong
|
|
|
|
if ((memRefType.getMemorySpace() != 0) &&
|
|
|
|
(memRefType.getMemorySpace() != 1))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// Note that the dataPtr starts at the offset address specified by
|
2020-10-29 03:03:15 +08:00
|
|
|
// indices, so no need to calculate offset size in bytes again in
|
2020-06-06 00:53:41 +08:00
|
|
|
// the MUBUF instruction.
|
2020-12-10 10:18:35 +08:00
|
|
|
Value dataPtr = this->getStridedElementPtr(
|
2020-12-18 08:26:07 +08:00
|
|
|
loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
|
2020-06-06 00:53:41 +08:00
|
|
|
|
|
|
|
// 1. Create and fill a <4 x i32> dwordConfig with:
|
|
|
|
// 1st two elements holding the address of dataPtr.
|
|
|
|
// 3rd element: -1.
|
|
|
|
// 4th element: 0x27000.
|
|
|
|
SmallVector<int32_t, 4> constConfigAttr{0, 0, -1, 0x27000};
|
|
|
|
Type i32Ty = rewriter.getIntegerType(32);
|
|
|
|
VectorType i32Vecx4 = VectorType::get(4, i32Ty);
|
|
|
|
Value constConfig = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
loc, toLLVMTy(i32Vecx4),
|
|
|
|
DenseElementsAttr::get(i32Vecx4, ArrayRef<int32_t>(constConfigAttr)));
|
|
|
|
|
|
|
|
// Treat first two element of <4 x i32> as i64, and save the dataPtr
|
|
|
|
// to it.
|
|
|
|
Type i64Ty = rewriter.getIntegerType(64);
|
|
|
|
Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
|
[mlir] use built-in vector types instead of LLVM dialect types when possible
Continue the convergence between LLVM dialect and built-in types by using the
built-in vector type whenever possible, that is for fixed vectors of built-in
integers and built-in floats. LLVM dialect vector type is still in use for
pointers, less frequent floating point types that do not have a built-in
equivalent, and scalable vectors. However, the top-level `LLVMVectorType` class
has been removed in favor of free functions capable of inspecting both built-in
and LLVM dialect vector types: `LLVM::getVectorElementType`,
`LLVM::getNumVectorElements` and `LLVM::getFixedVectorType`. Additional work is
necessary to design an implemented the extensions to built-in types so as to
remove the `LLVMFixedVectorType` entirely.
Note that the default output format for the built-in vectors does not have
whitespace around the `x` separator, e.g., `vector<4xf32>` as opposed to the
LLVM dialect vector type format that does, e.g., `!llvm.vec<4 x fp128>`. This
required changing the FileCheck patterns in several tests.
Reviewed By: mehdi_amini, silvas
Differential Revision: https://reviews.llvm.org/D94405
2021-01-11 20:58:05 +08:00
|
|
|
loc, LLVM::getFixedVectorType(toLLVMTy(i64Ty), 2), constConfig);
|
2020-06-06 00:53:41 +08:00
|
|
|
Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
|
2021-01-05 23:22:53 +08:00
|
|
|
loc, toLLVMTy(i64Ty).template cast<Type>(), dataPtr);
|
2020-12-10 10:18:35 +08:00
|
|
|
Value zero = this->createIndexConstant(rewriter, loc, 0);
|
2020-06-06 00:53:41 +08:00
|
|
|
Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
|
[mlir] use built-in vector types instead of LLVM dialect types when possible
Continue the convergence between LLVM dialect and built-in types by using the
built-in vector type whenever possible, that is for fixed vectors of built-in
integers and built-in floats. LLVM dialect vector type is still in use for
pointers, less frequent floating point types that do not have a built-in
equivalent, and scalable vectors. However, the top-level `LLVMVectorType` class
has been removed in favor of free functions capable of inspecting both built-in
and LLVM dialect vector types: `LLVM::getVectorElementType`,
`LLVM::getNumVectorElements` and `LLVM::getFixedVectorType`. Additional work is
necessary to design an implemented the extensions to built-in types so as to
remove the `LLVMFixedVectorType` entirely.
Note that the default output format for the built-in vectors does not have
whitespace around the `x` separator, e.g., `vector<4xf32>` as opposed to the
LLVM dialect vector type format that does, e.g., `!llvm.vec<4 x fp128>`. This
required changing the FileCheck patterns in several tests.
Reviewed By: mehdi_amini, silvas
Differential Revision: https://reviews.llvm.org/D94405
2021-01-11 20:58:05 +08:00
|
|
|
loc, LLVM::getFixedVectorType(toLLVMTy(i64Ty), 2), i64x2Ty,
|
|
|
|
dataPtrAsI64, zero);
|
2020-06-06 00:53:41 +08:00
|
|
|
dwordConfig =
|
|
|
|
rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);
|
|
|
|
|
|
|
|
// 2. Rewrite op as a buffer read or write.
|
|
|
|
Value int1False = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
loc, toLLVMTy(rewriter.getIntegerType(1)),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
|
|
|
Value int32Zero = rewriter.create<LLVM::ConstantOp>(
|
|
|
|
loc, toLLVMTy(i32Ty),
|
|
|
|
rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
|
2020-12-03 19:34:26 +08:00
|
|
|
return replaceTransferOpWithMubuf(
|
2020-12-10 10:18:35 +08:00
|
|
|
rewriter, operands, *this->getTypeConverter(), loc, xferOp, vecTy,
|
2020-12-03 19:34:26 +08:00
|
|
|
dwordConfig, int32Zero, int32Zero, int1False, int1False);
|
2020-06-06 00:53:41 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
|
|
|
|
void mlir::populateVectorToROCDLConversionPatterns(
|
|
|
|
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
|
|
|
patterns.insert<VectorTransferConversion<TransferReadOp>,
|
2020-12-10 10:18:35 +08:00
|
|
|
VectorTransferConversion<TransferWriteOp>>(converter);
|
2020-06-06 00:53:41 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
struct LowerVectorToROCDLPass
|
|
|
|
: public ConvertVectorToROCDLBase<LowerVectorToROCDLPass> {
|
|
|
|
void runOnOperation() override;
|
|
|
|
};
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void LowerVectorToROCDLPass::runOnOperation() {
|
|
|
|
LLVMTypeConverter converter(&getContext());
|
|
|
|
OwningRewritePatternList patterns;
|
|
|
|
|
|
|
|
populateVectorToROCDLConversionPatterns(converter, patterns);
|
|
|
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
|
|
|
|
|
|
|
LLVMConversionTarget target(getContext());
|
|
|
|
target.addLegalDialect<ROCDL::ROCDLDialect>();
|
|
|
|
|
2020-10-27 08:25:01 +08:00
|
|
|
if (failed(
|
|
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
2020-06-06 00:53:41 +08:00
|
|
|
signalPassFailure();
|
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
|
|
mlir::createConvertVectorToROCDLPass() {
|
|
|
|
return std::make_unique<LowerVectorToROCDLPass>();
|
|
|
|
}
|