forked from OSchip/llvm-project
Extract common code to deal with multidimensional vectors.
Summary: Also replace dyn_cast_or_null with dyn_cast when possible. Differential Revision: https://reviews.llvm.org/D75733
This commit is contained in:
parent
6ef953c2d6
commit
86306df7dd
|
@ -24,6 +24,7 @@
|
|||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
@ -32,6 +33,7 @@
|
|||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <functional>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -1165,6 +1167,36 @@ void ValidateOpCount() {
|
|||
OpCountValidator<SourceOp, OpCount>();
|
||||
}
|
||||
|
||||
static LogicalResult HandleMultidimensionalVectors(
|
||||
Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
|
||||
std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return failure();
|
||||
auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
|
||||
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
|
||||
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
|
||||
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
|
||||
return failure();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
|
||||
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
SmallVector<Value, 4> extractedOperands;
|
||||
for (auto operand : operands)
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operand, position));
|
||||
Value newVal = createOperand(llvmVectorTy, extractedOperands);
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
|
||||
position);
|
||||
});
|
||||
rewriter.replaceOp(op, desc);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
|
||||
// Ops for N-ary ops with one result. This supports higher-dimensional vector
|
||||
// types.
|
||||
|
@ -1192,7 +1224,6 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
|
||||
|
||||
if (!llvmArrayTy.isArrayTy()) {
|
||||
|
@ -1202,31 +1233,15 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return this->matchFailure();
|
||||
auto vectorTypeInfo =
|
||||
extractNDVectorTypeInfo(vectorType, this->typeConverter);
|
||||
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
|
||||
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
|
||||
return this->matchFailure();
|
||||
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
|
||||
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
SmallVector<Value, OpCount> extractedOperands;
|
||||
for (unsigned i = 0; i < OpCount; ++i) {
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[i], position));
|
||||
}
|
||||
Value newVal = rewriter.create<TargetOp>(
|
||||
loc, llvmVectorTy, extractedOperands, op->getAttrs());
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
|
||||
newVal, position);
|
||||
});
|
||||
rewriter.replaceOp(op, desc);
|
||||
return this->matchSuccess();
|
||||
if (succeeded(HandleMultidimensionalVectors(
|
||||
op, operands, this->typeConverter,
|
||||
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
|
||||
return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
|
||||
operands, op->getAttrs());
|
||||
},
|
||||
rewriter)))
|
||||
return this->matchSuccess();
|
||||
return this->matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1673,7 +1688,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
OperandAdaptor<RsqrtOp> transformed(operands);
|
||||
auto operandType =
|
||||
transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
|
@ -1694,41 +1709,31 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
|
|||
}
|
||||
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
|
||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
|
||||
return matchSuccess();
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
auto vectorType = resultType.dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return this->matchFailure();
|
||||
|
||||
auto vectorTypeInfo =
|
||||
extractNDVectorTypeInfo(vectorType, this->typeConverter);
|
||||
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
|
||||
if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy)
|
||||
return this->matchFailure();
|
||||
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, operandType);
|
||||
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
auto extractedOperand = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[0], position);
|
||||
auto splatAttr = SplatElementsAttr::get(
|
||||
mlir::VectorType::get(
|
||||
{llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
|
||||
floatType),
|
||||
floatOne);
|
||||
auto one =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
|
||||
auto sqrt =
|
||||
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, extractedOperand);
|
||||
auto div = rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, operandType, desc, div,
|
||||
position);
|
||||
});
|
||||
rewriter.replaceOp(op, desc);
|
||||
|
||||
return matchSuccess();
|
||||
if (succeeded(HandleMultidimensionalVectors(
|
||||
op, operands, typeConverter,
|
||||
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
|
||||
auto splatAttr = SplatElementsAttr::get(
|
||||
mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
|
||||
->getVectorNumElements()},
|
||||
floatType),
|
||||
floatOne);
|
||||
auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
|
||||
splatAttr);
|
||||
auto sqrt =
|
||||
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
|
||||
return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
|
||||
sqrt);
|
||||
},
|
||||
rewriter)))
|
||||
return this->matchSuccess();
|
||||
return this->matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1745,7 +1750,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
|
|||
|
||||
OperandAdaptor<TanhOp> transformed(operands);
|
||||
LLVMTypeT operandType =
|
||||
transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
|
|
Loading…
Reference in New Issue