forked from OSchip/llvm-project
[mlir][spirv] Switch to kEmitAccessorPrefix_Predixed
Fixes https://github.com/llvm/llvm-project/issues/57887 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134580
This commit is contained in:
parent
cde3de5381
commit
90a1632d0b
|
@ -72,10 +72,6 @@ def SPIRV_Dialect : Dialect {
|
|||
void printAttribute(
|
||||
Attribute attr, DialectAsmPrinter &printer) const override;
|
||||
}];
|
||||
|
||||
// TODO(https://github.com/llvm/llvm-project/issues/57887): Switch to
|
||||
// _Prefixed accessors.
|
||||
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -65,7 +65,7 @@ def SPV_BranchOp : SPV_Op<"Branch", [
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the block arguments.
|
||||
operand_range getBlockArguments() { return targetOperands(); }
|
||||
operand_range getBlockArguments() { return getTargetOperands(); }
|
||||
}];
|
||||
|
||||
let autogenSerialization = 0;
|
||||
|
@ -161,22 +161,22 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
|
|||
|
||||
/// Returns the number of arguments to the true target block.
|
||||
unsigned getNumTrueBlockArguments() {
|
||||
return trueTargetOperands().size();
|
||||
return getTrueTargetOperands().size();
|
||||
}
|
||||
|
||||
/// Returns the number of arguments to the false target block.
|
||||
unsigned getNumFalseBlockArguments() {
|
||||
return falseTargetOperands().size();
|
||||
return getFalseTargetOperands().size();
|
||||
}
|
||||
|
||||
// Iterator and range support for true target block arguments.
|
||||
operand_range getTrueBlockArguments() {
|
||||
return trueTargetOperands();
|
||||
return getTrueTargetOperands();
|
||||
}
|
||||
|
||||
// Iterator and range support for false target block arguments.
|
||||
operand_range getFalseBlockArguments() {
|
||||
return falseTargetOperands();
|
||||
return getFalseTargetOperands();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -394,9 +394,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
|||
CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
|
||||
[{
|
||||
$_state.addAttribute("type", type);
|
||||
$_state.addAttribute(sym_nameAttrName($_state.name), sym_name);
|
||||
$_state.addAttribute(getSymNameAttrName($_state.name), sym_name);
|
||||
if (initializer)
|
||||
$_state.addAttribute(initializerAttrName($_state.name), initializer);
|
||||
$_state.addAttribute(getInitializerAttrName($_state.name), initializer);
|
||||
}]>,
|
||||
OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
|
||||
[{
|
||||
|
@ -412,9 +412,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
|||
CArg<"FlatSymbolRefAttr", "{}">:$initializer),
|
||||
[{
|
||||
$_state.addAttribute("type", TypeAttr::get(type));
|
||||
$_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name));
|
||||
$_state.addAttribute(getSymNameAttrName($_state.name), $_builder.getStringAttr(sym_name));
|
||||
if (initializer)
|
||||
$_state.addAttribute(initializerAttrName($_state.name), initializer);
|
||||
$_state.addAttribute(getInitializerAttrName($_state.name), initializer);
|
||||
}]>
|
||||
];
|
||||
|
||||
|
@ -424,7 +424,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::spirv::StorageClass storageClass() {
|
||||
return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
|
||||
return this->getType().cast<::mlir::spirv::PointerType>().getStorageClass();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -509,7 +509,7 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
|
||||
bool isOptionalSymbol() { return true; }
|
||||
|
||||
Optional<StringRef> getName() { return sym_name(); }
|
||||
Optional<StringRef> getName() { return getSymName(); }
|
||||
|
||||
static StringRef getVCETripleAttrName() { return "vce_triple"; }
|
||||
}];
|
||||
|
|
|
@ -69,12 +69,12 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
|
|||
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
||||
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
|
||||
auto lastDim = op->getOperand(op.getNumOperands() - 1);
|
||||
auto indices = llvm::to_vector<4>(op.indices());
|
||||
auto indices = llvm::to_vector<4>(op.getIndices());
|
||||
// There are two elements if this is a 1-D tensor.
|
||||
assert(indices.size() == 2);
|
||||
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
|
||||
Type t = typeConverter.convertType(op.component_ptr().getType());
|
||||
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
|
||||
Type t = typeConverter.convertType(op.getComponentPtr().getType());
|
||||
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
|
||||
}
|
||||
|
||||
/// Returns the shifted `targetBits`-bit value with the given offset.
|
||||
|
@ -371,7 +371,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
|||
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
|
||||
// still returns a linearized accessing. If the accessing is not linearized,
|
||||
// there will be offset issues.
|
||||
assert(accessChainOp.indices().size() == 2);
|
||||
assert(accessChainOp.getIndices().size() == 2);
|
||||
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
||||
srcBits, dstBits, rewriter);
|
||||
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
|
||||
|
@ -507,7 +507,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
|||
// 6) store 32-bit value back
|
||||
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
|
||||
// 4 to step 6 are done by AtomicOr as another atomic step.
|
||||
assert(accessChainOp.indices().size() == 2);
|
||||
assert(accessChainOp.getIndices().size() == 2);
|
||||
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
|
||||
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
|
|||
// Create the block for the header.
|
||||
auto *header = new Block();
|
||||
// Insert the header.
|
||||
loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
|
||||
loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header);
|
||||
|
||||
// Create the new induction variable to use.
|
||||
Value adapLowerBound = adaptor.getLowerBound();
|
||||
|
@ -197,13 +197,13 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
|
|||
|
||||
// Move the blocks from the forOp into the loopOp. This is the body of the
|
||||
// loopOp.
|
||||
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
|
||||
getBlockIt(loopOp.body(), 2));
|
||||
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
|
||||
getBlockIt(loopOp.getBody(), 2));
|
||||
|
||||
SmallVector<Value, 8> args(1, adaptor.getLowerBound());
|
||||
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
|
||||
// Branch into it from the entry.
|
||||
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
|
||||
rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
|
||||
rewriter.create<spirv::BranchOp>(loc, header, args);
|
||||
|
||||
// Generate the rest of the loop header.
|
||||
|
@ -252,12 +252,12 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
|
|||
auto selectionOp =
|
||||
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
|
||||
auto *mergeBlock =
|
||||
rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
|
||||
rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end());
|
||||
rewriter.create<spirv::MergeOp>(loc);
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
auto *selectionHeaderBlock =
|
||||
rewriter.createBlock(&selectionOp.body().front());
|
||||
rewriter.createBlock(&selectionOp.getBody().front());
|
||||
|
||||
// Inline `then` region before the merge block and branch to it.
|
||||
auto &thenRegion = ifOp.getThenRegion();
|
||||
|
@ -367,12 +367,12 @@ WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
|
|||
return failure();
|
||||
|
||||
// Move the while before block as the initial loop header block.
|
||||
rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
|
||||
getBlockIt(loopOp.body(), 1));
|
||||
rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
|
||||
getBlockIt(loopOp.getBody(), 1));
|
||||
|
||||
// Move the while after block as the initial loop body block.
|
||||
rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
|
||||
getBlockIt(loopOp.body(), 2));
|
||||
rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
|
||||
getBlockIt(loopOp.getBody(), 2));
|
||||
|
||||
// Jump from the loop entry block to the loop header block.
|
||||
rewriter.setInsertionPointToEnd(&entryBlock);
|
||||
|
|
|
@ -89,7 +89,7 @@ createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
|
|||
op->getAttrOfType<IntegerAttr>(descriptorSetName());
|
||||
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
|
||||
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
|
||||
kernelModuleName.str(), op.sym_name().str(),
|
||||
kernelModuleName.str(), op.getSymName().str(),
|
||||
std::to_string(descriptorSet.getInt()),
|
||||
std::to_string(binding.getInt()));
|
||||
}
|
||||
|
@ -126,14 +126,14 @@ static LogicalResult getKernelGlobalVariables(
|
|||
/// Encodes the SPIR-V module's symbolic name into the name of the entry point
|
||||
/// function.
|
||||
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
|
||||
StringRef spvModuleName = *module.sym_name();
|
||||
StringRef spvModuleName = *module.getSymName();
|
||||
// We already know that the module contains exactly one entry point function
|
||||
// based on `getKernelGlobalVariables()` call. Update this function's name
|
||||
// to:
|
||||
// {spv_module_name}_{function_name}
|
||||
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
|
||||
StringRef funcName = entryPoint.fn();
|
||||
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
|
||||
StringRef funcName = entryPoint.getFn();
|
||||
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
|
||||
StringAttr newFuncName =
|
||||
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
|
||||
|
@ -236,7 +236,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
|
|||
// LLVM dialect global variable.
|
||||
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
|
||||
auto pointeeType =
|
||||
spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
|
||||
spirvGlobal.getType().cast<spirv::PointerType>().getPointeeType();
|
||||
auto dstGlobalType = typeConverter->convertType(pointeeType);
|
||||
if (!dstGlobalType)
|
||||
return failure();
|
||||
|
|
|
@ -228,14 +228,14 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
|
|||
if (!dstType)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
|
||||
loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
|
||||
loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
|
||||
isVolatile, isNonTemporal);
|
||||
return success();
|
||||
}
|
||||
auto storeOp = cast<spirv::StoreOp>(op);
|
||||
spirv::StoreOpAdaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
|
||||
adaptor.ptr(), alignment,
|
||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
|
||||
adaptor.getPtr(), alignment,
|
||||
isVolatile, isNonTemporal);
|
||||
return success();
|
||||
}
|
||||
|
@ -305,19 +305,19 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(op.component_ptr().getType());
|
||||
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
|
||||
if (!dstType)
|
||||
return failure();
|
||||
// To use GEP we need to add a first 0 index to go through the pointer.
|
||||
auto indices = llvm::to_vector<4>(adaptor.indices());
|
||||
Type indexType = op.indices().front().getType();
|
||||
auto indices = llvm::to_vector<4>(adaptor.getIndices());
|
||||
Type indexType = op.getIndices().front().getType();
|
||||
auto llvmIndexType = typeConverter.convertType(indexType);
|
||||
if (!llvmIndexType)
|
||||
return failure();
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
|
||||
indices.insert(indices.begin(), zero);
|
||||
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
|
||||
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.getBasePtr(),
|
||||
indices);
|
||||
return success();
|
||||
}
|
||||
|
@ -330,10 +330,10 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstType = typeConverter.convertType(op.pointer().getType());
|
||||
auto dstType = typeConverter.convertType(op.getPointer().getType());
|
||||
if (!dstType)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
|
||||
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -353,9 +353,9 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
|
||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
|
||||
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
|
||||
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
|
||||
// Create a mask with bits set outside [Offset, Offset + Count - 1].
|
||||
|
@ -372,9 +372,9 @@ public:
|
|||
// Extract unchanged bits from the `Base` that are outside of
|
||||
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
|
||||
Value baseAndMask =
|
||||
rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
|
||||
rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
|
||||
Value insertShiftedByOffset =
|
||||
rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
|
||||
rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
|
||||
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
|
||||
insertShiftedByOffset);
|
||||
return success();
|
||||
|
@ -408,14 +408,14 @@ public:
|
|||
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
|
||||
|
||||
if (srcType.isa<VectorType>()) {
|
||||
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
|
||||
auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
|
||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
|
||||
constOp, dstType,
|
||||
dstElementsAttr.mapValues(
|
||||
signlessType, [&](const APInt &value) { return value; }));
|
||||
return success();
|
||||
}
|
||||
auto srcAttr = constOp.value().cast<IntegerAttr>();
|
||||
auto srcAttr = constOp.getValue().cast<IntegerAttr>();
|
||||
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
|
||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
|
||||
return success();
|
||||
|
@ -441,9 +441,9 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
|
||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
|
||||
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
|
||||
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
|
||||
// Create a constant that holds the size of the `Base`.
|
||||
|
@ -468,7 +468,7 @@ public:
|
|||
Value amountToShiftLeft =
|
||||
rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
|
||||
Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
|
||||
loc, dstType, op.base(), amountToShiftLeft);
|
||||
loc, dstType, op.getBase(), amountToShiftLeft);
|
||||
|
||||
// Shift the result right, filling the bits with the sign bit.
|
||||
Value amountToShiftRight =
|
||||
|
@ -494,9 +494,9 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
|
||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
|
||||
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
|
||||
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||
typeConverter, rewriter);
|
||||
|
||||
// Create a mask with bits set at [0, Count - 1].
|
||||
|
@ -508,7 +508,7 @@ public:
|
|||
|
||||
// Shift `Base` by `Offset` and apply the mask on it.
|
||||
Value shiftedBase =
|
||||
rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
|
||||
rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
|
||||
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
|
||||
return success();
|
||||
}
|
||||
|
@ -538,20 +538,20 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
// If branch weights exist, map them to 32-bit integer vector.
|
||||
ElementsAttr branchWeights = nullptr;
|
||||
if (auto weights = op.branch_weights()) {
|
||||
if (auto weights = op.getBranchWeights()) {
|
||||
VectorType weightType = VectorType::get(2, rewriter.getI32Type());
|
||||
branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||
op, op.condition(), op.getTrueBlockArguments(),
|
||||
op, op.getCondition(), op.getTrueBlockArguments(),
|
||||
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
|
||||
op.getFalseBlock());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
|
||||
/// Converts `spv.getCompositeExtract` to `llvm.extractvalue` if the container type
|
||||
/// is an aggregate type (struct or array). Otherwise, converts to
|
||||
/// `llvm.extractelement` that operates on vectors.
|
||||
class CompositeExtractPattern
|
||||
|
@ -566,23 +566,23 @@ public:
|
|||
if (!dstType)
|
||||
return failure();
|
||||
|
||||
Type containerType = op.composite().getType();
|
||||
Type containerType = op.getComposite().getType();
|
||||
if (containerType.isa<VectorType>()) {
|
||||
Location loc = op.getLoc();
|
||||
IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
|
||||
IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
|
||||
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
op, dstType, adaptor.composite(), index);
|
||||
op, dstType, adaptor.getComposite(), index);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
||||
op, adaptor.composite(), LLVM::convertArrayToIndices(op.indices()));
|
||||
op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
|
||||
/// Converts `spv.getCompositeInsert` to `llvm.insertvalue` if the container type
|
||||
/// is an aggregate type (struct or array). Otherwise, converts to
|
||||
/// `llvm.insertelement` that operates on vectors.
|
||||
class CompositeInsertPattern
|
||||
|
@ -597,19 +597,19 @@ public:
|
|||
if (!dstType)
|
||||
return failure();
|
||||
|
||||
Type containerType = op.composite().getType();
|
||||
Type containerType = op.getComposite().getType();
|
||||
if (containerType.isa<VectorType>()) {
|
||||
Location loc = op.getLoc();
|
||||
IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
|
||||
IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
|
||||
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
op, dstType, adaptor.composite(), adaptor.object(), index);
|
||||
op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
|
||||
op, adaptor.composite(), adaptor.object(),
|
||||
LLVM::convertArrayToIndices(op.indices()));
|
||||
op, adaptor.getComposite(), adaptor.getObject(),
|
||||
LLVM::convertArrayToIndices(op.getIndices()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -647,14 +647,14 @@ public:
|
|||
// this entry point's execution mode. We set it to be:
|
||||
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
|
||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr();
|
||||
spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
|
||||
std::string moduleName;
|
||||
if (module.getName().has_value())
|
||||
moduleName = "_" + module.getName().value().str();
|
||||
moduleName = "_" + module.getName()->str();
|
||||
else
|
||||
moduleName = "";
|
||||
std::string executionModeInfoName = llvm::formatv(
|
||||
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(),
|
||||
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
|
||||
static_cast<uint32_t>(executionModeAttr.getValue()));
|
||||
|
||||
MLIRContext *context = rewriter.getContext();
|
||||
|
@ -669,7 +669,7 @@ public:
|
|||
auto llvmI32Type = IntegerType::get(context, 32);
|
||||
SmallVector<Type, 2> fields;
|
||||
fields.push_back(llvmI32Type);
|
||||
ArrayAttr values = op.values();
|
||||
ArrayAttr values = op.getValues();
|
||||
if (!values.empty()) {
|
||||
auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
|
||||
fields.push_back(arrayType);
|
||||
|
@ -722,10 +722,10 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Currently, there is no support of initialization with a constant value in
|
||||
// SPIR-V dialect. Specialization constants are not considered as well.
|
||||
if (op.initializer())
|
||||
if (op.getInitializer())
|
||||
return failure();
|
||||
|
||||
auto srcType = op.type().cast<spirv::PointerType>();
|
||||
auto srcType = op.getType().cast<spirv::PointerType>();
|
||||
auto dstType = typeConverter.convertType(srcType.getPointeeType());
|
||||
if (!dstType)
|
||||
return failure();
|
||||
|
@ -759,12 +759,12 @@ public:
|
|||
? LLVM::Linkage::Private
|
||||
: LLVM::Linkage::External;
|
||||
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
||||
op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
|
||||
op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
|
||||
/*alignment=*/0);
|
||||
|
||||
// Attach location attribute if applicable
|
||||
if (op.locationAttr())
|
||||
newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
|
||||
if (op.getLocationAttr())
|
||||
newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -781,7 +781,7 @@ public:
|
|||
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Type fromType = operation.operand().getType();
|
||||
Type fromType = operation.getOperand().getType();
|
||||
Type toType = operation.getType();
|
||||
|
||||
auto dstType = this->typeConverter.convertType(toType);
|
||||
|
@ -839,8 +839,8 @@ public:
|
|||
return failure();
|
||||
|
||||
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
|
||||
operation, dstType, predicate, operation.operand1(),
|
||||
operation.operand2());
|
||||
operation, dstType, predicate, operation.getOperand1(),
|
||||
operation.getOperand2());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -860,8 +860,8 @@ public:
|
|||
return failure();
|
||||
|
||||
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
|
||||
operation, dstType, predicate, operation.operand1(),
|
||||
operation.operand2());
|
||||
operation, dstType, predicate, operation.getOperand1(),
|
||||
operation.getOperand2());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -881,7 +881,7 @@ public:
|
|||
|
||||
Location loc = op.getLoc();
|
||||
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
|
||||
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
|
||||
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
|
||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
|
||||
return success();
|
||||
}
|
||||
|
@ -896,20 +896,20 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.memory_access()) {
|
||||
if (!op.getMemoryAccess()) {
|
||||
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
||||
this->typeConverter, /*alignment=*/0,
|
||||
/*isVolatile=*/false,
|
||||
/*isNonTemporal=*/false);
|
||||
}
|
||||
auto memoryAccess = *op.memory_access();
|
||||
auto memoryAccess = *op.getMemoryAccess();
|
||||
switch (memoryAccess) {
|
||||
case spirv::MemoryAccess::Aligned:
|
||||
case spirv::MemoryAccess::None:
|
||||
case spirv::MemoryAccess::Nontemporal:
|
||||
case spirv::MemoryAccess::Volatile: {
|
||||
unsigned alignment =
|
||||
memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
|
||||
memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
|
||||
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
|
||||
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
|
||||
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
||||
|
@ -946,7 +946,7 @@ public:
|
|||
srcType.template cast<VectorType>(), minusOne))
|
||||
: rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
|
||||
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
|
||||
notOp.operand(), mask);
|
||||
notOp.getOperand(), mask);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1047,7 +1047,7 @@ public:
|
|||
matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// There is no support of loop control at the moment.
|
||||
if (loopOp.loop_control() != spirv::LoopControl::None)
|
||||
if (loopOp.getLoopControl() != spirv::LoopControl::None)
|
||||
return failure();
|
||||
|
||||
Location loc = loopOp.getLoc();
|
||||
|
@ -1077,7 +1077,7 @@ public:
|
|||
rewriter.setInsertionPointToEnd(mergeBlock);
|
||||
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
|
||||
|
||||
rewriter.inlineRegionBefore(loopOp.body(), endBlock);
|
||||
rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
|
||||
rewriter.replaceOp(loopOp, endBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
|
@ -1096,14 +1096,14 @@ public:
|
|||
// There is no support for `Flatten` or `DontFlatten` selection control at
|
||||
// the moment. This are just compiler hints and can be performed during the
|
||||
// optimization passes.
|
||||
if (op.selection_control() != spirv::SelectionControl::None)
|
||||
if (op.getSelectionControl() != spirv::SelectionControl::None)
|
||||
return failure();
|
||||
|
||||
// `spv.mlir.selection` should have at least two blocks: one selection
|
||||
// header block and one merge block. If no blocks are present, or control
|
||||
// flow branches straight to merge block (two blocks are present), the op is
|
||||
// redundant and it is erased.
|
||||
if (op.body().getBlocks().size() <= 2) {
|
||||
if (op.getBody().getBlocks().size() <= 2) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
@ -1140,11 +1140,11 @@ public:
|
|||
Block *trueBlock = condBrOp.getTrueBlock();
|
||||
Block *falseBlock = condBrOp.getFalseBlock();
|
||||
rewriter.setInsertionPointToEnd(currentBlock);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
|
||||
condBrOp.trueTargetOperands(), falseBlock,
|
||||
condBrOp.falseTargetOperands());
|
||||
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
|
||||
condBrOp.getTrueTargetOperands(), falseBlock,
|
||||
condBrOp.getFalseTargetOperands());
|
||||
|
||||
rewriter.inlineRegionBefore(op.body(), continueBlock);
|
||||
rewriter.inlineRegionBefore(op.getBody(), continueBlock);
|
||||
rewriter.replaceOp(op, continueBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
|
@ -1167,8 +1167,8 @@ public:
|
|||
if (!dstType)
|
||||
return failure();
|
||||
|
||||
Type op1Type = operation.operand1().getType();
|
||||
Type op2Type = operation.operand2().getType();
|
||||
Type op1Type = operation.getOperand1().getType();
|
||||
Type op2Type = operation.getOperand2().getType();
|
||||
|
||||
if (op1Type == op2Type) {
|
||||
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
|
||||
|
@ -1180,13 +1180,13 @@ public:
|
|||
Value extended;
|
||||
if (isUnsignedIntegerOrVector(op2Type)) {
|
||||
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
|
||||
adaptor.operand2());
|
||||
adaptor.getOperand2());
|
||||
} else {
|
||||
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
|
||||
adaptor.operand2());
|
||||
adaptor.getOperand2());
|
||||
}
|
||||
Value result = rewriter.template create<LLVMOp>(
|
||||
loc, dstType, adaptor.operand1(), extended);
|
||||
loc, dstType, adaptor.getOperand1(), extended);
|
||||
rewriter.replaceOp(operation, result);
|
||||
return success();
|
||||
}
|
||||
|
@ -1204,8 +1204,8 @@ public:
|
|||
return failure();
|
||||
|
||||
Location loc = tanOp.getLoc();
|
||||
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
|
||||
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
|
||||
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
|
||||
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
|
||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
|
||||
return success();
|
||||
}
|
||||
|
@ -1232,7 +1232,7 @@ public:
|
|||
Location loc = tanhOp.getLoc();
|
||||
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
|
||||
Value multiplied =
|
||||
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
|
||||
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
|
||||
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
|
||||
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
|
||||
Value numerator =
|
||||
|
@ -1255,7 +1255,7 @@ public:
|
|||
auto srcType = varOp.getType();
|
||||
// Initialization is supported for scalars and vectors only.
|
||||
auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
|
||||
auto init = varOp.initializer();
|
||||
auto init = varOp.getInitializer();
|
||||
if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
|
||||
return failure();
|
||||
|
||||
|
@ -1270,7 +1270,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
|
||||
rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
|
||||
rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
|
||||
rewriter.replaceOp(varOp, allocated);
|
||||
return success();
|
||||
}
|
||||
|
@ -1305,7 +1305,7 @@ public:
|
|||
|
||||
// Convert SPIR-V Function Control to equivalent LLVM function attribute
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
switch (funcOp.function_control()) {
|
||||
switch (funcOp.getFunctionControl()) {
|
||||
#define DISPATCH(functionControl, llvmAttr) \
|
||||
case functionControl: \
|
||||
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
|
||||
|
@ -1374,9 +1374,9 @@ public:
|
|||
matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto components = adaptor.components();
|
||||
auto vector1 = adaptor.vector1();
|
||||
auto vector2 = adaptor.vector2();
|
||||
auto components = adaptor.getComponents();
|
||||
auto vector1 = adaptor.getVector1();
|
||||
auto vector2 = adaptor.getVector2();
|
||||
int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
|
||||
int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
|
||||
if (vector1Size == vector2Size) {
|
||||
|
@ -1589,8 +1589,8 @@ void mlir::encodeBindAttribute(ModuleOp module) {
|
|||
// SPIR-V module has a name, add it at the beginning.
|
||||
auto moduleAndName =
|
||||
spvModule.getName().has_value()
|
||||
? spvModule.getName().value().str() + "_" + op.sym_name().str()
|
||||
: op.sym_name().str();
|
||||
? spvModule.getName()->str() + "_" + op.getSymName().str()
|
||||
: op.getSymName().str();
|
||||
std::string name =
|
||||
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
|
||||
std::to_string(descriptorSet.getInt()),
|
||||
|
|
|
@ -88,19 +88,19 @@ struct CombineChainedAccessChain
|
|||
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
||||
accessChainOp.base_ptr().getDefiningOp());
|
||||
accessChainOp.getBasePtr().getDefiningOp());
|
||||
|
||||
if (!parentAccessChainOp) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Combine indices.
|
||||
SmallVector<Value, 4> indices(parentAccessChainOp.indices());
|
||||
indices.append(accessChainOp.indices().begin(),
|
||||
accessChainOp.indices().end());
|
||||
SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
|
||||
indices.append(accessChainOp.getIndices().begin(),
|
||||
accessChainOp.getIndices().end());
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
accessChainOp, parentAccessChainOp.base_ptr(), indices);
|
||||
accessChainOp, parentAccessChainOp.getBasePtr(), indices);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -126,23 +126,24 @@ void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto insertOp = composite().getDefiningOp<spirv::CompositeInsertOp>()) {
|
||||
if (indices() == insertOp.indices())
|
||||
return insertOp.object();
|
||||
if (auto insertOp =
|
||||
getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
|
||||
if (getIndices() == insertOp.getIndices())
|
||||
return insertOp.getObject();
|
||||
}
|
||||
|
||||
if (auto constructOp =
|
||||
composite().getDefiningOp<spirv::CompositeConstructOp>()) {
|
||||
getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
|
||||
auto type = constructOp.getType().cast<spirv::CompositeType>();
|
||||
if (indices().size() == 1 &&
|
||||
constructOp.constituents().size() == type.getNumElements()) {
|
||||
auto i = indices().begin()->cast<IntegerAttr>();
|
||||
return constructOp.constituents()[i.getValue().getSExtValue()];
|
||||
if (getIndices().size() == 1 &&
|
||||
constructOp.getConstituents().size() == type.getNumElements()) {
|
||||
auto i = getIndices().begin()->cast<IntegerAttr>();
|
||||
return constructOp.getConstituents()[i.getValue().getSExtValue()];
|
||||
}
|
||||
}
|
||||
|
||||
auto indexVector =
|
||||
llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
|
||||
llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
|
||||
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
|
||||
}));
|
||||
return extractCompositeElement(operands[0], indexVector);
|
||||
|
@ -154,7 +155,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "spv.Constant has no operands");
|
||||
return value();
|
||||
return getValue();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -164,8 +165,8 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
|||
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "spv.IAdd expects two operands");
|
||||
// x + 0 = x
|
||||
if (matchPattern(operand2(), m_Zero()))
|
||||
return operand1();
|
||||
if (matchPattern(getOperand2(), m_Zero()))
|
||||
return getOperand1();
|
||||
|
||||
// According to the SPIR-V spec:
|
||||
//
|
||||
|
@ -183,11 +184,11 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
|||
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "spv.IMul expects two operands");
|
||||
// x * 0 == 0
|
||||
if (matchPattern(operand2(), m_Zero()))
|
||||
return operand2();
|
||||
if (matchPattern(getOperand2(), m_Zero()))
|
||||
return getOperand2();
|
||||
// x * 1 = x
|
||||
if (matchPattern(operand2(), m_One()))
|
||||
return operand1();
|
||||
if (matchPattern(getOperand2(), m_One()))
|
||||
return getOperand1();
|
||||
|
||||
// According to the SPIR-V spec:
|
||||
//
|
||||
|
@ -204,7 +205,7 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
||||
// x - x = 0
|
||||
if (operand1() == operand2())
|
||||
if (getOperand1() == getOperand2())
|
||||
return Builder(getContext()).getIntegerAttr(getType(), 0);
|
||||
|
||||
// According to the SPIR-V spec:
|
||||
|
@ -226,7 +227,7 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
|
||||
// x && true = x
|
||||
if (rhs.value())
|
||||
return operand1();
|
||||
return getOperand1();
|
||||
|
||||
// x && false = false
|
||||
if (!rhs.value())
|
||||
|
@ -262,7 +263,7 @@ OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
// x || false = x
|
||||
if (!rhs.value())
|
||||
return operand1();
|
||||
return getOperand1();
|
||||
}
|
||||
|
||||
return Attribute();
|
||||
|
@ -339,8 +340,8 @@ struct ConvertSelectionOpToSelect
|
|||
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
|
||||
|
||||
auto selectOp = rewriter.create<spirv::SelectOp>(
|
||||
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
|
||||
trueValue, falseValue);
|
||||
selectionOp.getLoc(), trueValue.getType(),
|
||||
brConditionalOp.getCondition(), trueValue, falseValue);
|
||||
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
|
||||
selectOp.getResult(), storeOpAttributes);
|
||||
|
||||
|
@ -371,13 +372,13 @@ private:
|
|||
// Returns a source value for the given block.
|
||||
Value getSrcValue(Block *block) const {
|
||||
auto storeOp = cast<spirv::StoreOp>(block->front());
|
||||
return storeOp.value();
|
||||
return storeOp.getValue();
|
||||
}
|
||||
|
||||
// Returns a destination value for the given block.
|
||||
Value getDstPtr(Block *block) const {
|
||||
auto storeOp = cast<spirv::StoreOp>(block->front());
|
||||
return storeOp.ptr();
|
||||
return storeOp.getPtr();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -406,14 +407,14 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
|||
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
|
||||
// Starting with version 1.4, Result Type can additionally be a composite type
|
||||
// other than a vector."
|
||||
bool isScalarOrVector = trueBrStoreOp.value()
|
||||
bool isScalarOrVector = trueBrStoreOp.getValue()
|
||||
.getType()
|
||||
.cast<spirv::SPIRVType>()
|
||||
.isScalarOrVector();
|
||||
|
||||
// Check that each `spv.Store` uses the same pointer, memory access
|
||||
// attributes and a valid type of the value.
|
||||
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
|
||||
if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
|
||||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -106,7 +106,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(valuesToRepl.size() == 1 &&
|
||||
"spv.ReturnValue expected to only handle one result");
|
||||
valuesToRepl.front().replaceAllUsesWith(retValOp.value());
|
||||
valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -94,16 +94,16 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
|
|||
return nullptr;
|
||||
|
||||
spirv::ModuleOp firstModule = inputModules.front();
|
||||
auto addressingModel = firstModule.addressing_model();
|
||||
auto memoryModel = firstModule.memory_model();
|
||||
auto vceTriple = firstModule.vce_triple();
|
||||
auto addressingModel = firstModule.getAddressingModel();
|
||||
auto memoryModel = firstModule.getMemoryModel();
|
||||
auto vceTriple = firstModule.getVceTriple();
|
||||
|
||||
// First check whether there are conflicts between addressing/memory model.
|
||||
// Return early if so.
|
||||
for (auto module : inputModules) {
|
||||
if (module.addressing_model() != addressingModel ||
|
||||
module.memory_model() != memoryModel ||
|
||||
module.vce_triple() != vceTriple) {
|
||||
if (module.getAddressingModel() != addressingModel ||
|
||||
module.getMemoryModel() != memoryModel ||
|
||||
module.getVceTriple() != vceTriple) {
|
||||
module.emitError("input modules differ in addressing model, memory "
|
||||
"model, and/or VCE triple");
|
||||
return nullptr;
|
||||
|
|
|
@ -40,7 +40,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<NamedAttribute, 4> globalVarAttrs;
|
||||
|
||||
auto ptrType = op.type().cast<spirv::PointerType>();
|
||||
auto ptrType = op.getType().cast<spirv::PointerType>();
|
||||
auto structType = VulkanLayoutUtils::decorateType(
|
||||
ptrType.getPointeeType().cast<spirv::StructType>());
|
||||
|
||||
|
@ -71,11 +71,11 @@ public:
|
|||
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
|
||||
auto varName = op.variableAttr();
|
||||
auto varName = op.getVariableAttr();
|
||||
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
|
||||
op, varOp.type(), SymbolRefAttr::get(varName.getAttr()));
|
||||
op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -121,12 +121,12 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
|
|||
target.addLegalOp<func::FuncOp>();
|
||||
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
|
||||
[](spirv::GlobalVariableOp op) {
|
||||
return VulkanLayoutUtils::isLegalType(op.type());
|
||||
return VulkanLayoutUtils::isLegalType(op.getType());
|
||||
});
|
||||
|
||||
// Change the type for the direct users.
|
||||
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
|
||||
return VulkanLayoutUtils::isLegalType(op.pointer().getType());
|
||||
return VulkanLayoutUtils::isLegalType(op.getPointer().getType());
|
||||
});
|
||||
|
||||
// Change the type for the indirect users.
|
||||
|
@ -134,7 +134,8 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
|
|||
spirv::StoreOp>([&](Operation *op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
|
||||
if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType()))
|
||||
if (addrOp &&
|
||||
!VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -88,13 +88,13 @@ getInterfaceVariables(spirv::FuncOp funcOp,
|
|||
// instructions in this function.
|
||||
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
|
||||
auto var =
|
||||
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
|
||||
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
|
||||
// TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
|
||||
// storage classes are limited to the Input and Output storage classes.
|
||||
// Starting with version 1.4, the interface’s storage classes are all
|
||||
// storage classes used in declaring all global variables referenced by the
|
||||
// entry point’s call tree." We should consider the target environment here.
|
||||
switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
|
||||
switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
|
||||
case spirv::StorageClass::Input:
|
||||
case spirv::StorageClass::Output:
|
||||
interfaceVarSet.insert(var.getOperation());
|
||||
|
@ -105,7 +105,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
|
|||
});
|
||||
for (auto &var : interfaceVarSet) {
|
||||
interfaceVars.push_back(SymbolRefAttr::get(
|
||||
funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
|
||||
funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
|
|||
auto zero =
|
||||
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
|
||||
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
|
||||
funcOp.getLoc(), replacement, zero.constant());
|
||||
funcOp.getLoc(), replacement, zero.getConstant());
|
||||
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
|
||||
}
|
||||
signatureConverter.remapInput(argType.index(), replacement);
|
||||
|
|
|
@ -63,7 +63,7 @@ void RewriteInsertsPass::runOnOperation() {
|
|||
SmallVector<Value, 4> operands;
|
||||
// Collect inserted objects.
|
||||
for (auto insertionOp : insertions)
|
||||
operands.push_back(insertionOp.object());
|
||||
operands.push_back(insertionOp.getObject());
|
||||
|
||||
OpBuilder builder(lastCompositeInsertOp);
|
||||
auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
|
||||
|
@ -84,11 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
|
|||
LogicalResult RewriteInsertsPass::collectInsertionChain(
|
||||
spirv::CompositeInsertOp op,
|
||||
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
|
||||
auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
|
||||
auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
|
||||
// TODO: handle nested composite object.
|
||||
if (indicesArrayAttr.size() == 1) {
|
||||
auto numElements =
|
||||
op.composite().getType().cast<spirv::CompositeType>().getNumElements();
|
||||
auto numElements = op.getComposite()
|
||||
.getType()
|
||||
.cast<spirv::CompositeType>()
|
||||
.getNumElements();
|
||||
|
||||
auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
|
||||
// Need a last index to collect a sequential chain.
|
||||
|
@ -102,12 +104,12 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
|
|||
if (index == 0)
|
||||
return success();
|
||||
|
||||
op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
|
||||
op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
|
||||
if (!op)
|
||||
return failure();
|
||||
|
||||
--index;
|
||||
indicesArrayAttr = op.indices().cast<ArrayAttr>();
|
||||
indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
|
||||
if ((indicesArrayAttr.size() != 1) ||
|
||||
(indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
|
||||
return failure();
|
||||
|
|
|
@ -642,7 +642,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
|
|||
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
|
||||
unsigned elementCount) {
|
||||
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
|
||||
auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
|
||||
auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
|
||||
if (!ptrType)
|
||||
continue;
|
||||
|
||||
|
@ -874,7 +874,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
|
|||
// Special treatment for global variables, whose type requirements are
|
||||
// conveyed by type attributes.
|
||||
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
||||
valueTypes.push_back(globalVar.type());
|
||||
valueTypes.push_back(globalVar.getType());
|
||||
|
||||
// Make sure the op's operands/results use types that are allowed by the
|
||||
// target environment.
|
||||
|
|
|
@ -51,8 +51,8 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
|
|||
AliasedResourceMap aliasedResources;
|
||||
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
|
||||
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
|
||||
Optional<uint32_t> set = varOp.descriptor_set();
|
||||
Optional<uint32_t> binding = varOp.binding();
|
||||
Optional<uint32_t> set = varOp.getDescriptorSet();
|
||||
Optional<uint32_t> binding = varOp.getBinding();
|
||||
if (set && binding)
|
||||
aliasedResources[{*set, *binding}].push_back(varOp);
|
||||
}
|
||||
|
@ -222,16 +222,16 @@ bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
|
|||
}
|
||||
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
|
||||
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
||||
auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
|
||||
auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
|
||||
return shouldUnify(varOp);
|
||||
}
|
||||
|
||||
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
|
||||
return shouldUnify(acOp.base_ptr().getDefiningOp());
|
||||
return shouldUnify(acOp.getBasePtr().getDefiningOp());
|
||||
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
|
||||
return shouldUnify(loadOp.ptr().getDefiningOp());
|
||||
return shouldUnify(loadOp.getPtr().getDefiningOp());
|
||||
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
|
||||
return shouldUnify(storeOp.ptr().getDefiningOp());
|
||||
return shouldUnify(storeOp.getPtr().getDefiningOp());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -265,7 +265,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
|
|||
// Collect the element types for all resources in the current set.
|
||||
SmallVector<spirv::SPIRVType> elementTypes;
|
||||
for (spirv::GlobalVariableOp resource : resources) {
|
||||
Type elementType = getRuntimeArrayElementType(resource.type());
|
||||
Type elementType = getRuntimeArrayElementType(resource.getType());
|
||||
if (!elementType)
|
||||
return; // Unexpected resource variable type.
|
||||
|
||||
|
@ -326,7 +326,7 @@ struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
|
|||
// Rewrite the AddressOf op to get the address of the canoncical resource.
|
||||
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
||||
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
|
||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
|
||||
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
|
||||
return success();
|
||||
|
@ -339,13 +339,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
|||
LogicalResult
|
||||
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
|
||||
auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
|
||||
if (!addressOp)
|
||||
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
|
||||
|
||||
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
|
||||
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
|
||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
|
||||
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
||||
|
||||
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
|
||||
|
@ -356,7 +356,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
|||
// We have the same bitwidth for source and destination element types.
|
||||
// Thie indices keep the same.
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
acOp, adaptor.base_ptr(), adaptor.indices());
|
||||
acOp, adaptor.getBasePtr(), adaptor.getIndices());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -375,7 +375,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
|||
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
||||
|
||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
||||
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
||||
Value oldIndex = indices.back();
|
||||
indices.back() =
|
||||
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
|
||||
|
@ -383,7 +383,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
|||
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
acOp, adaptor.base_ptr(), indices);
|
||||
acOp, adaptor.getBasePtr(), indices);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -399,13 +399,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
|||
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
||||
|
||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
||||
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
||||
Value oldIndex = indices.back();
|
||||
indices.back() =
|
||||
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
acOp, adaptor.base_ptr(), indices);
|
||||
acOp, adaptor.getBasePtr(), indices);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -420,13 +420,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
|||
LogicalResult
|
||||
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcPtrType = loadOp.ptr().getType().cast<spirv::PointerType>();
|
||||
auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
|
||||
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
||||
auto dstPtrType = adaptor.ptr().getType().cast<spirv::PointerType>();
|
||||
auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
|
||||
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
||||
|
||||
Location loc = loadOp.getLoc();
|
||||
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
|
||||
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
|
||||
if (srcElemType == dstElemType) {
|
||||
rewriter.replaceOp(loadOp, newLoadOp->getResults());
|
||||
return success();
|
||||
|
@ -434,7 +434,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
|||
|
||||
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
|
||||
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
|
||||
newLoadOp.value());
|
||||
newLoadOp.getValue());
|
||||
rewriter.replaceOp(loadOp, castOp->getResults());
|
||||
|
||||
return success();
|
||||
|
@ -457,19 +457,19 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
|||
components.reserve(ratio);
|
||||
components.push_back(newLoadOp);
|
||||
|
||||
auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
|
||||
auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
|
||||
if (!acOp)
|
||||
return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
|
||||
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
|
||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
||||
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
||||
for (int i = 1; i < ratio; ++i) {
|
||||
// Load all subsequent components belonging to this element.
|
||||
indices.back() = rewriter.create<spirv::IAddOp>(
|
||||
loc, i32Type, indices.back(), oneValue);
|
||||
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
|
||||
loc, acOp.base_ptr(), indices);
|
||||
loc, acOp.getBasePtr(), indices);
|
||||
// Assuming little endian, this reads lower-ordered bits of the number
|
||||
// to lower-numbered components of the vector.
|
||||
components.push_back(
|
||||
|
@ -504,19 +504,19 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
|
|||
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcElemType =
|
||||
storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
auto dstElemType =
|
||||
adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
|
||||
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
|
||||
return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
|
||||
|
||||
Location loc = storeOp.getLoc();
|
||||
Value value = adaptor.value();
|
||||
Value value = adaptor.getValue();
|
||||
if (srcElemType != dstElemType)
|
||||
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(), value,
|
||||
storeOp->getAttrs());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -151,7 +151,7 @@ void UpdateVCEPass::runOnOperation() {
|
|||
// Special treatment for global variables, whose type requirements are
|
||||
// conveyed by type attributes.
|
||||
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
||||
valueTypes.push_back(globalVar.type());
|
||||
valueTypes.push_back(globalVar.getType());
|
||||
|
||||
// Requirements from values' types
|
||||
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
|
||||
|
|
|
@ -46,20 +46,20 @@ Value spirv::Deserializer::getValue(uint32_t id) {
|
|||
}
|
||||
if (auto varOp = getGlobalVariable(id)) {
|
||||
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
|
||||
unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
|
||||
return addressOfOp.pointer();
|
||||
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
|
||||
return addressOfOp.getPointer();
|
||||
}
|
||||
if (auto constOp = getSpecConstant(id)) {
|
||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||
unknownLoc, constOp.default_value().getType(),
|
||||
unknownLoc, constOp.getDefaultValue().getType(),
|
||||
SymbolRefAttr::get(constOp.getOperation()));
|
||||
return referenceOfOp.reference();
|
||||
return referenceOfOp.getReference();
|
||||
}
|
||||
if (auto constCompositeOp = getSpecConstantComposite(id)) {
|
||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||
unknownLoc, constCompositeOp.type(),
|
||||
unknownLoc, constCompositeOp.getType(),
|
||||
SymbolRefAttr::get(constCompositeOp.getOperation()));
|
||||
return referenceOfOp.reference();
|
||||
return referenceOfOp.getReference();
|
||||
}
|
||||
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
|
||||
return materializeSpecConstantOperation(
|
||||
|
|
|
@ -1414,7 +1414,7 @@ Value spirv::Deserializer::materializeSpecConstantOperation(
|
|||
auto specConstOperationOp =
|
||||
opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
|
||||
|
||||
Region &body = specConstOperationOp.body();
|
||||
Region &body = specConstOperationOp.getBody();
|
||||
// Move the new block into SpecConstantOperation's body.
|
||||
body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
|
||||
Region::iterator(enclosedBlock));
|
||||
|
@ -1983,17 +1983,17 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
|
|||
assert((branchCondOp.getTrueBlock() == target ||
|
||||
branchCondOp.getFalseBlock() == target) &&
|
||||
"expected target to be either the true or false target");
|
||||
if (target == branchCondOp.trueTarget())
|
||||
if (target == branchCondOp.getTrueTarget())
|
||||
opBuilder.create<spirv::BranchConditionalOp>(
|
||||
branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
|
||||
branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
|
||||
branchCondOp.getFalseBlockArguments(),
|
||||
branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
|
||||
branchCondOp.falseTarget());
|
||||
branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
|
||||
branchCondOp.getFalseTarget());
|
||||
else
|
||||
opBuilder.create<spirv::BranchConditionalOp>(
|
||||
branchCondOp.getLoc(), branchCondOp.condition(),
|
||||
branchCondOp.getLoc(), branchCondOp.getCondition(),
|
||||
branchCondOp.getTrueBlockArguments(), blockArgs,
|
||||
branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
|
||||
branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
|
||||
branchCondOp.getFalseBlock());
|
||||
|
||||
branchCondOp.erase();
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace mlir {
|
|||
LogicalResult spirv::serialize(spirv::ModuleOp module,
|
||||
SmallVectorImpl<uint32_t> &binary,
|
||||
const SerializationOptions &options) {
|
||||
if (!module.vce_triple())
|
||||
if (!module.getVceTriple())
|
||||
return module.emitError(
|
||||
"module must have 'vce_triple' attribute to be serializeable");
|
||||
|
||||
|
|
|
@ -58,7 +58,8 @@ visitInPrettyBlockOrder(Block *headerBlock,
|
|||
namespace mlir {
|
||||
namespace spirv {
|
||||
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
||||
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
|
||||
if (auto resultID =
|
||||
prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
|
||||
valueIDMap[op.getResult()] = resultID;
|
||||
return success();
|
||||
}
|
||||
|
@ -66,7 +67,7 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
|||
}
|
||||
|
||||
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
||||
if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
|
||||
if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
|
||||
/*isSpec=*/true)) {
|
||||
// Emit the OpDecorate instruction for SpecId.
|
||||
if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
|
||||
|
@ -75,8 +76,8 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
|||
return failure();
|
||||
}
|
||||
|
||||
specConstIDMap[op.sym_name()] = resultID;
|
||||
return processName(resultID, op.sym_name());
|
||||
specConstIDMap[op.getSymName()] = resultID;
|
||||
return processName(resultID, op.getSymName());
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
@ -84,7 +85,7 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
|||
LogicalResult
|
||||
Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
||||
uint32_t typeID = 0;
|
||||
if (failed(processType(op.getLoc(), op.type(), typeID))) {
|
||||
if (failed(processType(op.getLoc(), op.getType(), typeID))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -94,7 +95,7 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
|||
operands.push_back(typeID);
|
||||
operands.push_back(resultID);
|
||||
|
||||
auto constituents = op.constituents();
|
||||
auto constituents = op.getConstituents();
|
||||
|
||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||
auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
|
||||
|
@ -112,9 +113,9 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
|||
|
||||
encodeInstructionInto(typesGlobalValues,
|
||||
spirv::Opcode::OpSpecConstantComposite, operands);
|
||||
specConstIDMap[op.sym_name()] = resultID;
|
||||
specConstIDMap[op.getSymName()] = resultID;
|
||||
|
||||
return processName(resultID, op.sym_name());
|
||||
return processName(resultID, op.getSymName());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
@ -199,7 +200,7 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
|
|||
operands.push_back(resTypeID);
|
||||
auto funcID = getOrCreateFunctionID(op.getName());
|
||||
operands.push_back(funcID);
|
||||
operands.push_back(static_cast<uint32_t>(op.function_control()));
|
||||
operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
|
||||
operands.push_back(fnTypeID);
|
||||
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
|
||||
|
||||
|
@ -310,7 +311,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
|||
// Get TypeID.
|
||||
uint32_t resultTypeID = 0;
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
|
||||
if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -320,7 +321,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
|||
auto resultID = getNextID();
|
||||
|
||||
// Encode the name.
|
||||
auto varName = varOp.sym_name();
|
||||
auto varName = varOp.getSymName();
|
||||
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
|
||||
if (failed(processName(resultID, varName))) {
|
||||
return failure();
|
||||
|
@ -332,7 +333,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
|||
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
|
||||
|
||||
// Encode initialization.
|
||||
if (auto initializer = varOp.initializer()) {
|
||||
if (auto initializer = varOp.getInitializer()) {
|
||||
auto initializerID = getVariableID(*initializer);
|
||||
if (!initializerID) {
|
||||
return emitError(varOp.getLoc(),
|
||||
|
@ -364,7 +365,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
|||
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
|
||||
// Assign <id>s to all blocks so that branches inside the SelectionOp can
|
||||
// resolve properly.
|
||||
auto &body = selectionOp.body();
|
||||
auto &body = selectionOp.getBody();
|
||||
for (Block &block : body)
|
||||
getOrCreateBlockID(&block);
|
||||
|
||||
|
@ -390,7 +391,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
|
|||
lastProcessedWasMergeInst = true;
|
||||
encodeInstructionInto(
|
||||
functionBody, spirv::Opcode::OpSelectionMerge,
|
||||
{mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
|
||||
{mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
|
||||
return success();
|
||||
};
|
||||
if (failed(
|
||||
|
@ -420,7 +421,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
|||
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve
|
||||
// properly. We don't need to assign for the entry block, which is just for
|
||||
// satisfying MLIR region's structural requirement.
|
||||
auto &body = loopOp.body();
|
||||
auto &body = loopOp.getBody();
|
||||
for (Block &block : llvm::drop_begin(body))
|
||||
getOrCreateBlockID(&block);
|
||||
|
||||
|
@ -452,7 +453,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
|||
lastProcessedWasMergeInst = true;
|
||||
encodeInstructionInto(
|
||||
functionBody, spirv::Opcode::OpLoopMerge,
|
||||
{mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
|
||||
{mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
|
||||
return success();
|
||||
};
|
||||
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
|
||||
|
@ -483,12 +484,12 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
|||
|
||||
LogicalResult Serializer::processBranchConditionalOp(
|
||||
spirv::BranchConditionalOp condBranchOp) {
|
||||
auto conditionID = getValueID(condBranchOp.condition());
|
||||
auto conditionID = getValueID(condBranchOp.getCondition());
|
||||
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
|
||||
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
|
||||
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
|
||||
|
||||
if (auto weights = condBranchOp.branch_weights()) {
|
||||
if (auto weights = condBranchOp.getBranchWeights()) {
|
||||
for (auto val : weights->getValue())
|
||||
arguments.push_back(val.cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
@ -509,26 +510,26 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
|
|||
}
|
||||
|
||||
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
|
||||
auto varName = addressOfOp.variable();
|
||||
auto varName = addressOfOp.getVariable();
|
||||
auto variableID = getVariableID(varName);
|
||||
if (!variableID) {
|
||||
return addressOfOp.emitError("unknown result <id> for variable ")
|
||||
<< varName;
|
||||
}
|
||||
valueIDMap[addressOfOp.pointer()] = variableID;
|
||||
valueIDMap[addressOfOp.getPointer()] = variableID;
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
|
||||
auto constName = referenceOfOp.spec_const();
|
||||
auto constName = referenceOfOp.getSpecConst();
|
||||
auto constID = getSpecConstID(constName);
|
||||
if (!constID) {
|
||||
return referenceOfOp.emitError(
|
||||
"unknown result <id> for specialization constant ")
|
||||
<< constName;
|
||||
}
|
||||
valueIDMap[referenceOfOp.reference()] = constID;
|
||||
valueIDMap[referenceOfOp.getReference()] = constID;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -537,21 +538,21 @@ LogicalResult
|
|||
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
|
||||
SmallVector<uint32_t, 4> operands;
|
||||
// Add the ExecutionModel.
|
||||
operands.push_back(static_cast<uint32_t>(op.execution_model()));
|
||||
operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
|
||||
// Add the function <id>.
|
||||
auto funcID = getFunctionID(op.fn());
|
||||
auto funcID = getFunctionID(op.getFn());
|
||||
if (!funcID) {
|
||||
return op.emitError("missing <id> for function ")
|
||||
<< op.fn()
|
||||
<< op.getFn()
|
||||
<< "; function needs to be defined before spv.EntryPoint is "
|
||||
"serialized";
|
||||
}
|
||||
operands.push_back(funcID);
|
||||
// Add the name of the function.
|
||||
spirv::encodeStringLiteralInto(operands, op.fn());
|
||||
spirv::encodeStringLiteralInto(operands, op.getFn());
|
||||
|
||||
// Add the interface values.
|
||||
if (auto interface = op.interface()) {
|
||||
if (auto interface = op.getInterface()) {
|
||||
for (auto var : interface.getValue()) {
|
||||
auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
|
||||
if (!id) {
|
||||
|
@ -571,19 +572,19 @@ LogicalResult
|
|||
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
|
||||
SmallVector<uint32_t, 4> operands;
|
||||
// Add the function <id>.
|
||||
auto funcID = getFunctionID(op.fn());
|
||||
auto funcID = getFunctionID(op.getFn());
|
||||
if (!funcID) {
|
||||
return op.emitError("missing <id> for function ")
|
||||
<< op.fn()
|
||||
<< op.getFn()
|
||||
<< "; function needs to be serialized before ExecutionModeOp is "
|
||||
"serialized";
|
||||
}
|
||||
operands.push_back(funcID);
|
||||
// Add the ExecutionMode.
|
||||
operands.push_back(static_cast<uint32_t>(op.execution_mode()));
|
||||
operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
|
||||
|
||||
// Serialize values if any.
|
||||
auto values = op.values();
|
||||
auto values = op.getValues();
|
||||
if (values) {
|
||||
for (auto &intVal : values.getValue()) {
|
||||
operands.push_back(static_cast<uint32_t>(
|
||||
|
@ -598,7 +599,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
|
|||
template <>
|
||||
LogicalResult
|
||||
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
||||
auto funcName = op.callee();
|
||||
auto funcName = op.getCallee();
|
||||
uint32_t resTypeID = 0;
|
||||
|
||||
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
|
||||
|
@ -609,7 +610,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
|||
auto funcCallID = getNextID();
|
||||
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
|
||||
|
||||
for (auto value : op.arguments()) {
|
||||
for (auto value : op.getArguments()) {
|
||||
auto valueID = getValueID(value);
|
||||
assert(valueID && "cannot find a value for spv.FunctionCall");
|
||||
operands.push_back(valueID);
|
||||
|
|
|
@ -119,7 +119,8 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
|
|||
binary.clear();
|
||||
binary.reserve(moduleSize);
|
||||
|
||||
spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
|
||||
spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
|
||||
nextID);
|
||||
binary.append(capabilities.begin(), capabilities.end());
|
||||
binary.append(extensions.begin(), extensions.end());
|
||||
binary.append(extendedSets.begin(), extendedSets.end());
|
||||
|
@ -166,7 +167,7 @@ uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
|
|||
}
|
||||
|
||||
void Serializer::processCapability() {
|
||||
for (auto cap : module.vce_triple()->getCapabilities())
|
||||
for (auto cap : module.getVceTriple()->getCapabilities())
|
||||
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
|
||||
{static_cast<uint32_t>(cap)});
|
||||
}
|
||||
|
@ -186,7 +187,7 @@ void Serializer::processDebugInfo() {
|
|||
|
||||
void Serializer::processExtension() {
|
||||
llvm::SmallVector<uint32_t, 16> extName;
|
||||
for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
|
||||
for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
|
||||
extName.clear();
|
||||
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
|
||||
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
|
||||
|
@ -1045,11 +1046,11 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
|
|||
} else if (auto branchCondOp =
|
||||
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
|
||||
Optional<OperandRange> blockOperands;
|
||||
if (branchCondOp.trueTarget() == block) {
|
||||
blockOperands = branchCondOp.trueTargetOperands();
|
||||
if (branchCondOp.getTrueTarget() == block) {
|
||||
blockOperands = branchCondOp.getTrueTargetOperands();
|
||||
} else {
|
||||
assert(branchCondOp.falseTarget() == block);
|
||||
blockOperands = branchCondOp.falseTargetOperands();
|
||||
assert(branchCondOp.getFalseTarget() == block);
|
||||
blockOperands = branchCondOp.getFalseTargetOperands();
|
||||
}
|
||||
|
||||
assert(!blockOperands->empty() &&
|
||||
|
|
|
@ -1360,7 +1360,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
|
|||
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
|
||||
"static_cast<{0}::{1}>(1 << i);\n",
|
||||
enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
|
||||
namedAttr.name);
|
||||
srcOp.getGetterName(namedAttr.name));
|
||||
os << formatv(
|
||||
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
|
||||
enumAttr.getUnderlyingType());
|
||||
|
@ -1368,7 +1368,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
|
|||
// For IntEnumAttr, we just need to query the value as a whole.
|
||||
os << " {\n";
|
||||
os << formatv(" auto tblgen_attrVal = this->{0}();\n",
|
||||
namedAttr.name);
|
||||
srcOp.getGetterName(namedAttr.name));
|
||||
}
|
||||
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
|
||||
enumAttr.getCppNamespace(), avail.getQueryFnName());
|
||||
|
|
Loading…
Reference in New Issue