Extract common code to deal with multidimensional vectors.

Summary: Also replace dyn_cast_or_null with dyn_cast when possible.

Differential Revision: https://reviews.llvm.org/D75733
This commit is contained in:
Adrian Kuegel 2020-03-06 13:04:37 +01:00
parent 6ef953c2d6
commit 86306df7dd
1 changed files with 62 additions and 57 deletions

View File

@ -24,6 +24,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@ -32,6 +33,7 @@
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include <functional>
using namespace mlir;
@ -1165,6 +1167,36 @@ void ValidateOpCount() {
OpCountValidator<SourceOp, OpCount>();
}
static LogicalResult HandleMultidimensionalVectors(
Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter) {
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
if (!vectorType)
return failure();
auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
return failure();
auto loc = op->getLoc();
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, 4> extractedOperands;
for (auto operand : operands)
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operand, position));
Value newVal = createOperand(llvmVectorTy, extractedOperands);
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
position);
});
rewriter.replaceOp(op, desc);
return success();
}
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
// Ops for N-ary ops with one result. This supports higher-dimensional vector
// types.
@ -1192,7 +1224,6 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
return this->matchFailure();
}
auto loc = op->getLoc();
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
if (!llvmArrayTy.isArrayTy()) {
@ -1202,31 +1233,15 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
return this->matchSuccess();
}
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
if (!vectorType)
return this->matchFailure();
auto vectorTypeInfo =
extractNDVectorTypeInfo(vectorType, this->typeConverter);
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
return this->matchFailure();
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
SmallVector<Value, OpCount> extractedOperands;
for (unsigned i = 0; i < OpCount; ++i) {
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[i], position));
}
Value newVal = rewriter.create<TargetOp>(
loc, llvmVectorTy, extractedOperands, op->getAttrs());
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
newVal, position);
});
rewriter.replaceOp(op, desc);
return this->matchSuccess();
if (succeeded(HandleMultidimensionalVectors(
op, operands, this->typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
operands, op->getAttrs());
},
rewriter)))
return this->matchSuccess();
return this->matchFailure();
}
};
@ -1673,7 +1688,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<RsqrtOp> transformed(operands);
auto operandType =
transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
@ -1694,41 +1709,31 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
return matchSuccess();
return this->matchSuccess();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return this->matchFailure();
auto vectorTypeInfo =
extractNDVectorTypeInfo(vectorType, this->typeConverter);
auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy)
return this->matchFailure();
Value desc = rewriter.create<LLVM::UndefOp>(loc, operandType);
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
// For this unrolled `position` corresponding to the `linearIndex`^th
// element, extract operand vectors
auto extractedOperand = rewriter.create<LLVM::ExtractValueOp>(
loc, llvmVectorTy, operands[0], position);
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
auto sqrt =
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, extractedOperand);
auto div = rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
desc = rewriter.create<LLVM::InsertValueOp>(loc, operandType, desc, div,
position);
});
rewriter.replaceOp(op, desc);
return matchSuccess();
if (succeeded(HandleMultidimensionalVectors(
op, operands, typeConverter,
[&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
->getVectorNumElements()},
floatType),
floatOne);
auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
splatAttr);
auto sqrt =
rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
sqrt);
},
rewriter)))
return this->matchSuccess();
return this->matchFailure();
}
};
@ -1745,7 +1750,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
OperandAdaptor<TanhOp> transformed(operands);
LLVMTypeT operandType =
transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
if (!operandType)
return matchFailure();