forked from OSchip/llvm-project
Add support for LLVM lowering of binary ops on n-D vector types
This CL allows binary operations on n-D vector types to be lowered to LLVMIR by performing an (n-1)-D extractvalue, 1-D vector operation and an (n-1)-D insertvalue. PiperOrigin-RevId: 264339118
This commit is contained in:
parent
07ecb011a7
commit
f55ac5c076
|
@ -67,21 +67,26 @@ public:
|
|||
/// Array type utilities.
|
||||
LLVMType getArrayElementType();
|
||||
unsigned getArrayNumElements();
|
||||
bool isArrayTy();
|
||||
|
||||
/// Vector type utilities.
|
||||
LLVMType getVectorElementType();
|
||||
bool isVectorTy();
|
||||
|
||||
/// Function type utilities.
|
||||
LLVMType getFunctionParamType(unsigned argIdx);
|
||||
unsigned getFunctionNumParams();
|
||||
LLVMType getFunctionResultType();
|
||||
bool isFunctionTy();
|
||||
|
||||
/// Pointer type utilities.
|
||||
LLVMType getPointerTo(unsigned addrSpace = 0);
|
||||
LLVMType getPointerElementTy();
|
||||
bool isPointerTy();
|
||||
|
||||
/// Struct type utilities.
|
||||
LLVMType getStructElementType(unsigned i);
|
||||
bool isStructTy();
|
||||
|
||||
/// Utilities used to generate floating point types.
|
||||
static LLVMType getDoubleTy(LLVMDialect *dialect);
|
||||
|
|
|
@ -137,6 +137,7 @@ public:
|
|||
ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
|
||||
ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
|
||||
ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
|
||||
ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values);
|
||||
ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
|
||||
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
|
||||
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
|
||||
|
|
|
@ -346,58 +346,156 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
auto type = this->lowering.convertType(op->getResult(i)->getType());
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), type, newOp.getOperation()->getResult(0),
|
||||
this->getIntegerArrayAttr(rewriter, i)));
|
||||
rewriter.getIndexArrayAttr(i)));
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Express `linearIndex` in terms of coordinates of `basis`.
|
||||
// Returns the empty vector when linearIndex is out of the range [0, P] where
|
||||
// P is the product of all the basis coordinates.
|
||||
//
|
||||
// Prerequisites:
|
||||
// Basis is an array of nonnegative integers (signed type inherited from
|
||||
// vector shape type).
|
||||
static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
|
||||
unsigned linearIndex) {
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(basis.size());
|
||||
for (unsigned basisElement : llvm::reverse(basis)) {
|
||||
res.push_back(linearIndex % basisElement);
|
||||
linearIndex = linearIndex / basisElement;
|
||||
}
|
||||
if (linearIndex > 0)
|
||||
return {};
|
||||
std::reverse(res.begin(), res.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
|
||||
// Ops for binary ops with one result. This supports higher-dimensional vector
|
||||
// types.
|
||||
template <typename SourceOp, typename TargetOp>
|
||||
struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
||||
using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
|
||||
using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
|
||||
|
||||
// Convert the type of the result to an LLVM type, pass operands as is,
|
||||
// preserve attributes.
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
static_assert(
|
||||
std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
|
||||
"expected binary op");
|
||||
static_assert(
|
||||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
|
||||
SourceOp>::value,
|
||||
"expected single result op");
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
|
||||
|
||||
if (!llvmArrayTy.isArrayTy()) {
|
||||
auto newOp = rewriter.create<TargetOp>(
|
||||
op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
|
||||
rewriter.replaceOp(op, newOp.getResult());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
// Unroll iterated array type until we hit a non-array type.
|
||||
auto llvmTy = llvmArrayTy;
|
||||
SmallVector<int64_t, 4> arraySizes;
|
||||
while (llvmTy.isArrayTy()) {
|
||||
arraySizes.push_back(llvmTy.getArrayNumElements());
|
||||
llvmTy = llvmTy.getArrayElementType();
|
||||
}
|
||||
assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
|
||||
auto llvmVectorTy = llvmTy;
|
||||
|
||||
// Iteratively extract a position coordinates with basis `arraySize` from a
|
||||
// `linearIndex` that is incremented at each step. This terminates when
|
||||
// `linearIndex` exceeds the range specified by `arraySize`.
|
||||
// This has the effect of fully unrolling the dimensions of the n-D array
|
||||
// type, getting to the underlying vector element.
|
||||
Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
|
||||
unsigned ub = 1;
|
||||
for (auto s : arraySizes)
|
||||
ub *= s;
|
||||
for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
|
||||
auto coords = getCoordinates(arraySizes, linearIndex);
|
||||
// Linear index is out of bounds, we are done.
|
||||
if (coords.empty())
|
||||
break;
|
||||
|
||||
auto position = rewriter.getIndexArrayAttr(coords);
|
||||
|
||||
// For this unrolled `position` corresponding to the `linearIndex`^th
|
||||
// element, extract operand vectors
|
||||
Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[0], position);
|
||||
Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, llvmVectorTy, operands[1], position);
|
||||
Value *newVal = rewriter.create<TargetOp>(
|
||||
loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
|
||||
op->getAttrs());
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
|
||||
newVal, position);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Specific lowerings.
|
||||
// FIXME: this should be tablegen'ed.
|
||||
struct AddIOpLowering : public OneToOneLLVMOpLowering<AddIOp, LLVM::AddOp> {
|
||||
struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct SubIOpLowering : public OneToOneLLVMOpLowering<SubIOp, LLVM::SubOp> {
|
||||
struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct MulIOpLowering : public OneToOneLLVMOpLowering<MulIOp, LLVM::MulOp> {
|
||||
struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct DivISOpLowering : public OneToOneLLVMOpLowering<DivISOp, LLVM::SDivOp> {
|
||||
struct DivISOpLowering : public BinaryOpLLVMOpLowering<DivISOp, LLVM::SDivOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct DivIUOpLowering : public OneToOneLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
|
||||
struct DivIUOpLowering : public BinaryOpLLVMOpLowering<DivIUOp, LLVM::UDivOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct RemISOpLowering : public OneToOneLLVMOpLowering<RemISOp, LLVM::SRemOp> {
|
||||
struct RemISOpLowering : public BinaryOpLLVMOpLowering<RemISOp, LLVM::SRemOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct RemIUOpLowering : public OneToOneLLVMOpLowering<RemIUOp, LLVM::URemOp> {
|
||||
struct RemIUOpLowering : public BinaryOpLLVMOpLowering<RemIUOp, LLVM::URemOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct AndOpLowering : public OneToOneLLVMOpLowering<AndOp, LLVM::AndOp> {
|
||||
struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct OrOpLowering : public OneToOneLLVMOpLowering<OrOp, LLVM::OrOp> {
|
||||
struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct XOrOpLowering : public OneToOneLLVMOpLowering<XOrOp, LLVM::XOrOp> {
|
||||
struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct AddFOpLowering : public OneToOneLLVMOpLowering<AddFOp, LLVM::FAddOp> {
|
||||
struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct SubFOpLowering : public OneToOneLLVMOpLowering<SubFOp, LLVM::FSubOp> {
|
||||
struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct MulFOpLowering : public OneToOneLLVMOpLowering<MulFOp, LLVM::FMulOp> {
|
||||
struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct DivFOpLowering : public OneToOneLLVMOpLowering<DivFOp, LLVM::FDivOp> {
|
||||
struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct RemFOpLowering : public OneToOneLLVMOpLowering<RemFOp, LLVM::FRemOp> {
|
||||
struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
|
||||
using Super::Super;
|
||||
};
|
||||
struct SelectOpLowering
|
||||
|
@ -516,14 +614,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
|
||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), structType, memRefDescriptor, allocated,
|
||||
getIntegerArrayAttr(rewriter, 0));
|
||||
rewriter.getIndexArrayAttr(0));
|
||||
|
||||
// Store dynamically allocated sizes in the descriptor. Dynamic sizes are
|
||||
// passed in as operands.
|
||||
for (auto indexedSize : llvm::enumerate(operands)) {
|
||||
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
|
||||
getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
|
||||
rewriter.getIndexArrayAttr(1 + indexedSize.index()));
|
||||
}
|
||||
|
||||
// Return the final value of the descriptor.
|
||||
|
@ -553,7 +651,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
|
|||
}
|
||||
|
||||
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
|
||||
auto hasStaticShape = type.getUnderlyingType()->isPointerTy();
|
||||
auto hasStaticShape = type.isPointerTy();
|
||||
Type elementPtrType = hasStaticShape ? type : type.getStructElementType(0);
|
||||
Value *bufferPtr =
|
||||
extractMemRefElementPtr(rewriter, op->getLoc(), transformed.memref(),
|
||||
|
@ -603,7 +701,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
// Otherwise target type is dynamic memref, so create a proper descriptor.
|
||||
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), structType, newDescriptor, buffer,
|
||||
getIntegerArrayAttr(rewriter, 0));
|
||||
rewriter.getIndexArrayAttr(0));
|
||||
|
||||
// Fill in the dynamic sizes of the new descriptor. If the size was
|
||||
// dynamic, copy it from the old descriptor. If the size was static, insert
|
||||
|
@ -626,11 +724,11 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
? rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), getIndexType(),
|
||||
transformed.source(), // NB: dynamic memref
|
||||
getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
|
||||
rewriter.getIndexArrayAttr(sourceDynamicDimIdx++))
|
||||
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
|
||||
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), structType, newDescriptor, size,
|
||||
getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
|
||||
rewriter.getIndexArrayAttr(targetDynamicDimIdx++));
|
||||
}
|
||||
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
|
||||
"source dynamic dimensions were not processed");
|
||||
|
@ -673,7 +771,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
|||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
||||
op, getIndexType(), transformed.memrefOrTensor(),
|
||||
getIntegerArrayAttr(rewriter, position));
|
||||
rewriter.getIndexArrayAttr(position));
|
||||
} else {
|
||||
rewriter.replaceOp(
|
||||
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
|
||||
|
@ -739,7 +837,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
if (s == -1) {
|
||||
Value *size = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, this->getIndexType(), memRefDescriptor,
|
||||
this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
|
||||
rewriter.getIndexArrayAttr(dynamicSizeIdx++));
|
||||
sizes.push_back(size);
|
||||
} else {
|
||||
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
|
||||
|
@ -751,8 +849,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
|
|||
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
|
||||
|
||||
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, elementTypePtr, memRefDescriptor,
|
||||
this->getIntegerArrayAttr(rewriter, 0));
|
||||
loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(0));
|
||||
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
|
||||
ArrayRef<Value *>{dataPtr, subscript},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
|
@ -970,7 +1067,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
|
|||
for (unsigned i = 0; i < numArguments; ++i) {
|
||||
packed = rewriter.create<LLVM::InsertValueOp>(
|
||||
op->getLoc(), packedType, packed, operands[i],
|
||||
getIntegerArrayAttr(rewriter, i));
|
||||
rewriter.getIndexArrayAttr(i));
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
|
||||
op, llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
|
||||
|
|
|
@ -1281,11 +1281,13 @@ LLVMType LLVMType::getArrayElementType() {
|
|||
unsigned LLVMType::getArrayNumElements() {
|
||||
return getUnderlyingType()->getArrayNumElements();
|
||||
}
|
||||
bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
|
||||
|
||||
/// Vector type utilities.
|
||||
LLVMType LLVMType::getVectorElementType() {
|
||||
return get(getContext(), getUnderlyingType()->getVectorElementType());
|
||||
}
|
||||
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
|
||||
|
||||
/// Function type utilities.
|
||||
LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
|
||||
|
@ -1299,6 +1301,7 @@ LLVMType LLVMType::getFunctionResultType() {
|
|||
getContext(),
|
||||
llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
|
||||
}
|
||||
bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); }
|
||||
|
||||
/// Pointer type utilities.
|
||||
LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
|
||||
|
@ -1310,11 +1313,13 @@ LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
|
|||
LLVMType LLVMType::getPointerElementTy() {
|
||||
return get(getContext(), getUnderlyingType()->getPointerElementType());
|
||||
}
|
||||
bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
|
||||
|
||||
/// Struct type utilities.
|
||||
LLVMType LLVMType::getStructElementType(unsigned i) {
|
||||
return get(getContext(), getUnderlyingType()->getStructElementType(i));
|
||||
}
|
||||
bool LLVMType::isStructTy() { return getUnderlyingType()->isStructTy(); }
|
||||
|
||||
/// Utilities used to generate floating point types.
|
||||
LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
|
||||
|
|
|
@ -218,6 +218,15 @@ ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
|
|||
return getArrayAttr(attrs);
|
||||
}
|
||||
|
||||
ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
|
||||
auto attrs = functional::map(
|
||||
[this](int64_t v) -> Attribute {
|
||||
return getIntegerAttr(IndexType::get(getContext()), v);
|
||||
},
|
||||
values);
|
||||
return getArrayAttr(attrs);
|
||||
}
|
||||
|
||||
ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
|
||||
auto attrs = functional::map(
|
||||
[this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
|
||||
|
|
|
@ -513,3 +513,28 @@ func @fcmp(f32, f32) -> () {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @vec_bin
|
||||
func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
|
||||
%0 = addf %arg0, %arg0 : vector<2x2x2xf32>
|
||||
return %0 : vector<2x2x2xf32>
|
||||
|
||||
// CHECK-NEXT: llvm.undef : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
|
||||
// This block appears 2x2 times
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK-NEXT: llvm.fadd %{{.*}} : !llvm<"<2 x float>">
|
||||
// CHECK-NEXT: llvm.insertvalue %{{.*}}[0 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
|
||||
// We check the proper indexing of extract/insert in the remaining 3 positions.
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK: llvm.insertvalue %{{.*}}[0 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK: llvm.insertvalue %{{.*}}[1 : index, 0 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
// CHECK: llvm.insertvalue %{{.*}}[1 : index, 1 : index] : !llvm<"[2 x [2 x <2 x float>]]">
|
||||
|
||||
// And we're done
|
||||
// CHECK-NEXT: return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue