LLVM IR Dialect conversion: use builder arguments instead of named attributes

The first version of TableGen-defined LLVM IR Dialect did not include the
mandatory or optional attributes of the operations due to the missing support
for some of the relevant attribute types.  This support has been recently
introduced, along with named attributes as arguments in the TableGen operation
definitions.  With these changes, LLVM IR Dialect operations now have factory
functions accepting (unnamed) attributes and attaching their canonical names.
Use these factories instead of manually constructing named attributes in the
dialect convreter to avoid hardcoded attribute names in unexpected places.

PiperOrigin-RevId: 237237769
This commit is contained in:
Alex Zinenko 2019-03-07 06:41:17 -08:00 committed by jpienaar
parent b9724e98c2
commit 6621f39d19
1 changed files with 39 additions and 58 deletions

View File

@ -375,37 +375,32 @@ public:
Value *createIndexConstant(FuncBuilder &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
auto namedAttr = builder.getNamedAttr("value", attr);
return builder.create<LLVM::ConstantOp>(
loc, getIndexType(), ArrayRef<Value *>{},
ArrayRef<NamedAttribute>{namedAttr});
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
}
// Get the array attribute named "position" containing the given list of
// integers as integer attribute elements.
static NamedAttribute getPositionAttribute(FuncBuilder &builder,
ArrayRef<int64_t> positions) {
SmallVector<Attribute, 4> attrPositions;
attrPositions.reserve(positions.size());
for (int64_t pos : positions)
attrPositions.push_back(
builder.getIntegerAttr(builder.getIndexType(), pos));
return builder.getNamedAttr("position",
builder.getArrayAttr(attrPositions));
static ArrayAttr getIntegerArrayAttr(FuncBuilder &builder,
ArrayRef<int64_t> values) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(values.size());
for (int64_t pos : values)
attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
return builder.getArrayAttr(attrs);
}
// Extract raw data pointer value from a value representing a memref.
static Value *extractMemRefElementPtr(FuncBuilder &builder, Location loc,
Value *convertedMemRefValue,
Type elementTypePtr,
bool statically_shaped) {
bool hasStaticShape) {
Value *buffer;
if (statically_shaped)
if (hasStaticShape)
return convertedMemRefValue;
else
return builder.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, ArrayRef<Value *>{convertedMemRefValue},
getPositionAttribute(builder, 0));
loc, elementTypePtr, convertedMemRefValue,
getIntegerArrayAttr(builder, 0));
return buffer;
}
@ -461,13 +456,11 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
auto positionNamedAttr = this->getPositionAttribute(rewriter, i);
auto type = TypeConverter::convert(op->getResult(i)->getType(),
this->dialect.getLLVMModule());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), type,
ArrayRef<Value *>(newOp->getInstruction()->getResult(0)),
llvm::makeArrayRef(positionNamedAttr)));
op->getLoc(), type, newOp->getInstruction()->getResult(0),
this->getIntegerArrayAttr(rewriter, i)));
}
return results;
}
@ -608,13 +601,11 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
auto mallocNamedAttr =
rewriter.getNamedAttr("callee", rewriter.getFunctionAttr(mallocFunc));
Value *allocated =
rewriter
.create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
ArrayRef<Value *>(cumulativeSize),
llvm::makeArrayRef(mallocNamedAttr))
rewriter.getFunctionAttr(mallocFunc),
cumulativeSize)
->getResult(0);
auto structElementType = TypeConverter::convert(elementType, getModule());
auto elementPtrType = LLVM::LLVMType::get(
@ -634,21 +625,16 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
auto namedPositionAttr = getPositionAttribute(rewriter, 0);
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType,
ArrayRef<Value *>{memRefDescriptor, allocated},
llvm::makeArrayRef(namedPositionAttr));
op->getLoc(), structType, memRefDescriptor, allocated,
getIntegerArrayAttr(rewriter, 0));
// Store dynamically allocated sizes in the descriptor. Dynamic sizes are
// passed in as operands.
for (auto indexedSize : llvm::enumerate(operands)) {
auto positionAttr =
getPositionAttribute(rewriter, 1 + indexedSize.index());
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType,
ArrayRef<Value *>{memRefDescriptor, indexedSize.value()},
llvm::makeArrayRef(positionAttr));
op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
}
// Return the final value of the descriptor.
@ -677,20 +663,18 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
auto *type =
operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType();
auto statically_shaped = type->isPointerTy();
auto hasStaticShape = type->isPointerTy();
Type elementPtrType =
(statically_shaped)
(hasStaticShape)
? rewriter.getType<LLVM::LLVMType>(type)
: rewriter.getType<LLVM::LLVMType>(
cast<llvm::StructType>(type)->getStructElementType(0));
Value *bufferPtr = extractMemRefElementPtr(
rewriter, op->getLoc(), operands[0], elementPtrType, statically_shaped);
rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
auto freeNamedAttr =
rewriter.getNamedAttr("callee", rewriter.getFunctionAttr(freeFunc));
rewriter.create<LLVM::CallOp>(op->getLoc(), casted,
llvm::makeArrayRef(freeNamedAttr));
rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(),
rewriter.getFunctionAttr(freeFunc), casted);
return {};
}
};
@ -734,8 +718,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
op->getLoc(), structType, ArrayRef<Value *>{});
// Otherwise target type is dynamic memref, so create a proper descriptor.
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, buffer},
getPositionAttribute(rewriter, 0));
op->getLoc(), structType, newDescriptor, buffer,
getIntegerArrayAttr(rewriter, 0));
// Fill in the dynamic sizes of the new descriptor. If the size was
// dynamic, copy it from the old descriptor. If the size was static, insert
@ -757,12 +741,12 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
ArrayRef<Value *>{operands[0]}, // NB: dynamic memref
getPositionAttribute(rewriter, sourceDynamicDimIdx++))
operands[0], // NB: dynamic memref
getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, size},
getPositionAttribute(rewriter, targetDynamicDimIdx++));
op->getLoc(), structType, newDescriptor, size,
getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
}
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
"source dynamic dimensions were not processed");
@ -807,8 +791,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
++position;
}
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(), operands,
getPositionAttribute(rewriter, position)));
op->getLoc(), getIndexType(), operands[0],
getIntegerArrayAttr(rewriter, position)));
} else {
results.push_back(
createIndexConstant(rewriter, op->getLoc(), shape[index]));
@ -876,9 +860,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
for (int64_t s : shape) {
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
loc, this->getIndexType(), ArrayRef<Value *>{memRefDescriptor},
llvm::makeArrayRef(
this->getPositionAttribute(rewriter, dynamicSizeIdx++)));
loc, this->getIndexType(), memRefDescriptor,
this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
sizes.push_back(size);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
@ -890,8 +873,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, elementTypePtr, ArrayRef<Value *>{memRefDescriptor},
llvm::makeArrayRef(this->getPositionAttribute(rewriter, 0)));
loc, elementTypePtr, memRefDescriptor,
this->getIntegerArrayAttr(rewriter, 0));
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
ArrayRef<Value *>{dataPtr, subscript},
ArrayRef<NamedAttribute>{});
@ -1018,11 +1001,9 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
auto positionNamedAttr = getPositionAttribute(rewriter, i);
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType,
llvm::ArrayRef<Value *>{packed, operands[i]},
llvm::makeArrayRef(positionNamedAttr));
op->getLoc(), packedType, packed, operands[i],
getIntegerArrayAttr(rewriter, i));
}
rewriter.create<LLVM::ReturnOp>(
op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),