[mlir][spirv] Switch to kEmitAccessorPrefix_Predixed

Fixes https://github.com/llvm/llvm-project/issues/57887

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134580
This commit is contained in:
Jakub Kuderski 2022-09-24 00:36:53 -04:00
parent cde3de5381
commit 90a1632d0b
23 changed files with 466 additions and 457 deletions

View File

@ -72,10 +72,6 @@ def SPIRV_Dialect : Dialect {
void printAttribute( void printAttribute(
Attribute attr, DialectAsmPrinter &printer) const override; Attribute attr, DialectAsmPrinter &printer) const override;
}]; }];
// TODO(https://github.com/llvm/llvm-project/issues/57887): Switch to
// _Prefixed accessors.
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -65,7 +65,7 @@ def SPV_BranchOp : SPV_Op<"Branch", [
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Returns the block arguments. /// Returns the block arguments.
operand_range getBlockArguments() { return targetOperands(); } operand_range getBlockArguments() { return getTargetOperands(); }
}]; }];
let autogenSerialization = 0; let autogenSerialization = 0;
@ -161,22 +161,22 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
/// Returns the number of arguments to the true target block. /// Returns the number of arguments to the true target block.
unsigned getNumTrueBlockArguments() { unsigned getNumTrueBlockArguments() {
return trueTargetOperands().size(); return getTrueTargetOperands().size();
} }
/// Returns the number of arguments to the false target block. /// Returns the number of arguments to the false target block.
unsigned getNumFalseBlockArguments() { unsigned getNumFalseBlockArguments() {
return falseTargetOperands().size(); return getFalseTargetOperands().size();
} }
// Iterator and range support for true target block arguments. // Iterator and range support for true target block arguments.
operand_range getTrueBlockArguments() { operand_range getTrueBlockArguments() {
return trueTargetOperands(); return getTrueTargetOperands();
} }
// Iterator and range support for false target block arguments. // Iterator and range support for false target block arguments.
operand_range getFalseBlockArguments() { operand_range getFalseBlockArguments() {
return falseTargetOperands(); return getFalseTargetOperands();
} }
private: private:

View File

@ -394,9 +394,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
CArg<"FlatSymbolRefAttr", "nullptr">:$initializer), CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
[{ [{
$_state.addAttribute("type", type); $_state.addAttribute("type", type);
$_state.addAttribute(sym_nameAttrName($_state.name), sym_name); $_state.addAttribute(getSymNameAttrName($_state.name), sym_name);
if (initializer) if (initializer)
$_state.addAttribute(initializerAttrName($_state.name), initializer); $_state.addAttribute(getInitializerAttrName($_state.name), initializer);
}]>, }]>,
OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs), OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
[{ [{
@ -412,9 +412,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
CArg<"FlatSymbolRefAttr", "{}">:$initializer), CArg<"FlatSymbolRefAttr", "{}">:$initializer),
[{ [{
$_state.addAttribute("type", TypeAttr::get(type)); $_state.addAttribute("type", TypeAttr::get(type));
$_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name)); $_state.addAttribute(getSymNameAttrName($_state.name), $_builder.getStringAttr(sym_name));
if (initializer) if (initializer)
$_state.addAttribute(initializerAttrName($_state.name), initializer); $_state.addAttribute(getInitializerAttrName($_state.name), initializer);
}]> }]>
]; ];
@ -424,7 +424,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
::mlir::spirv::StorageClass storageClass() { ::mlir::spirv::StorageClass storageClass() {
return this->type().cast<::mlir::spirv::PointerType>().getStorageClass(); return this->getType().cast<::mlir::spirv::PointerType>().getStorageClass();
} }
}]; }];
} }
@ -509,7 +509,7 @@ def SPV_ModuleOp : SPV_Op<"module",
bool isOptionalSymbol() { return true; } bool isOptionalSymbol() { return true; }
Optional<StringRef> getName() { return sym_name(); } Optional<StringRef> getName() { return getSymName(); }
static StringRef getVCETripleAttrName() { return "vce_triple"; } static StringRef getVCETripleAttrName() { return "vce_triple"; }
}]; }];

View File

@ -69,12 +69,12 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
builder.getIntegerAttr(targetType, targetBits / sourceBits); builder.getIntegerAttr(targetType, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr); auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
auto lastDim = op->getOperand(op.getNumOperands() - 1); auto lastDim = op->getOperand(op.getNumOperands() - 1);
auto indices = llvm::to_vector<4>(op.indices()); auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor. // There are two elements if this is a 1-D tensor.
assert(indices.size() == 2); assert(indices.size() == 2);
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx); indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.component_ptr().getType()); Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices); return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
} }
/// Returns the shifted `targetBits`-bit value with the given offset. /// Returns the shifted `targetBits`-bit value with the given offset.
@ -371,7 +371,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// Assume that getElementPtr() works linearizely. If it's a scalar, the method // Assume that getElementPtr() works linearizely. If it's a scalar, the method
// still returns a linearized accessing. If the accessing is not linearized, // still returns a linearized accessing. If the accessing is not linearized,
// there will be offset issues. // there will be offset issues.
assert(accessChainOp.indices().size() == 2); assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter); srcBits, dstBits, rewriter);
Value spvLoadOp = rewriter.create<spirv::LoadOp>( Value spvLoadOp = rewriter.create<spirv::LoadOp>(
@ -507,7 +507,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
// 6) store 32-bit value back // 6) store 32-bit value back
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
// 4 to step 6 are done by AtomicOr as another atomic step. // 4 to step 6 are done by AtomicOr as another atomic step.
assert(accessChainOp.indices().size() == 2); assert(accessChainOp.getIndices().size() == 2);
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);

View File

@ -174,7 +174,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
// Create the block for the header. // Create the block for the header.
auto *header = new Block(); auto *header = new Block();
// Insert the header. // Insert the header.
loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header); loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header);
// Create the new induction variable to use. // Create the new induction variable to use.
Value adapLowerBound = adaptor.getLowerBound(); Value adapLowerBound = adaptor.getLowerBound();
@ -197,13 +197,13 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
// Move the blocks from the forOp into the loopOp. This is the body of the // Move the blocks from the forOp into the loopOp. This is the body of the
// loopOp. // loopOp.
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
getBlockIt(loopOp.body(), 2)); getBlockIt(loopOp.getBody(), 2));
SmallVector<Value, 8> args(1, adaptor.getLowerBound()); SmallVector<Value, 8> args(1, adaptor.getLowerBound());
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
// Branch into it from the entry. // Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.body().front())); rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
rewriter.create<spirv::BranchOp>(loc, header, args); rewriter.create<spirv::BranchOp>(loc, header, args);
// Generate the rest of the loop header. // Generate the rest of the loop header.
@ -252,12 +252,12 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
auto selectionOp = auto selectionOp =
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
auto *mergeBlock = auto *mergeBlock =
rewriter.createBlock(&selectionOp.body(), selectionOp.body().end()); rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end());
rewriter.create<spirv::MergeOp>(loc); rewriter.create<spirv::MergeOp>(loc);
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock = auto *selectionHeaderBlock =
rewriter.createBlock(&selectionOp.body().front()); rewriter.createBlock(&selectionOp.getBody().front());
// Inline `then` region before the merge block and branch to it. // Inline `then` region before the merge block and branch to it.
auto &thenRegion = ifOp.getThenRegion(); auto &thenRegion = ifOp.getThenRegion();
@ -367,12 +367,12 @@ WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
return failure(); return failure();
// Move the while before block as the initial loop header block. // Move the while before block as the initial loop header block.
rewriter.inlineRegionBefore(beforeRegion, loopOp.body(), rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
getBlockIt(loopOp.body(), 1)); getBlockIt(loopOp.getBody(), 1));
// Move the while after block as the initial loop body block. // Move the while after block as the initial loop body block.
rewriter.inlineRegionBefore(afterRegion, loopOp.body(), rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
getBlockIt(loopOp.body(), 2)); getBlockIt(loopOp.getBody(), 2));
// Jump from the loop entry block to the loop header block. // Jump from the loop entry block to the loop header block.
rewriter.setInsertionPointToEnd(&entryBlock); rewriter.setInsertionPointToEnd(&entryBlock);

View File

@ -89,7 +89,7 @@ createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
op->getAttrOfType<IntegerAttr>(descriptorSetName()); op->getAttrOfType<IntegerAttr>(descriptorSetName());
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName()); IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}", return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
kernelModuleName.str(), op.sym_name().str(), kernelModuleName.str(), op.getSymName().str(),
std::to_string(descriptorSet.getInt()), std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt())); std::to_string(binding.getInt()));
} }
@ -126,14 +126,14 @@ static LogicalResult getKernelGlobalVariables(
/// Encodes the SPIR-V module's symbolic name into the name of the entry point /// Encodes the SPIR-V module's symbolic name into the name of the entry point
/// function. /// function.
static LogicalResult encodeKernelName(spirv::ModuleOp module) { static LogicalResult encodeKernelName(spirv::ModuleOp module) {
StringRef spvModuleName = *module.sym_name(); StringRef spvModuleName = *module.getSymName();
// We already know that the module contains exactly one entry point function // We already know that the module contains exactly one entry point function
// based on `getKernelGlobalVariables()` call. Update this function's name // based on `getKernelGlobalVariables()` call. Update this function's name
// to: // to:
// {spv_module_name}_{function_name} // {spv_module_name}_{function_name}
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin(); auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
StringRef funcName = entryPoint.fn(); StringRef funcName = entryPoint.getFn();
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr()); auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
StringAttr newFuncName = StringAttr newFuncName =
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
@ -236,7 +236,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// LLVM dialect global variable. // LLVM dialect global variable.
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()]; spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
auto pointeeType = auto pointeeType =
spirvGlobal.type().cast<spirv::PointerType>().getPointeeType(); spirvGlobal.getType().cast<spirv::PointerType>().getPointeeType();
auto dstGlobalType = typeConverter->convertType(pointeeType); auto dstGlobalType = typeConverter->convertType(pointeeType);
if (!dstGlobalType) if (!dstGlobalType)
return failure(); return failure();

View File

@ -228,14 +228,14 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
if (!dstType) if (!dstType)
return failure(); return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>( rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment, loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
isVolatile, isNonTemporal); isVolatile, isNonTemporal);
return success(); return success();
} }
auto storeOp = cast<spirv::StoreOp>(op); auto storeOp = cast<spirv::StoreOp>(op);
spirv::StoreOpAdaptor adaptor(operands); spirv::StoreOpAdaptor adaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(), rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
adaptor.ptr(), alignment, adaptor.getPtr(), alignment,
isVolatile, isNonTemporal); isVolatile, isNonTemporal);
return success(); return success();
} }
@ -305,19 +305,19 @@ public:
LogicalResult LogicalResult
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.component_ptr().getType()); auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
if (!dstType) if (!dstType)
return failure(); return failure();
// To use GEP we need to add a first 0 index to go through the pointer. // To use GEP we need to add a first 0 index to go through the pointer.
auto indices = llvm::to_vector<4>(adaptor.indices()); auto indices = llvm::to_vector<4>(adaptor.getIndices());
Type indexType = op.indices().front().getType(); Type indexType = op.getIndices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType); auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType) if (!llvmIndexType)
return failure(); return failure();
Value zero = rewriter.create<LLVM::ConstantOp>( Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero); indices.insert(indices.begin(), zero);
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(), rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.getBasePtr(),
indices); indices);
return success(); return success();
} }
@ -330,10 +330,10 @@ public:
LogicalResult LogicalResult
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto dstType = typeConverter.convertType(op.pointer().getType()); auto dstType = typeConverter.convertType(op.getPointer().getType());
if (!dstType) if (!dstType)
return failure(); return failure();
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable()); rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
return success(); return success();
} }
}; };
@ -353,9 +353,9 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed. // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter); typeConverter, rewriter);
Value count = processCountOrOffset(loc, op.count(), srcType, dstType, Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter); typeConverter, rewriter);
// Create a mask with bits set outside [Offset, Offset + Count - 1]. // Create a mask with bits set outside [Offset, Offset + Count - 1].
@ -372,9 +372,9 @@ public:
// Extract unchanged bits from the `Base` that are outside of // Extract unchanged bits from the `Base` that are outside of
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
Value baseAndMask = Value baseAndMask =
rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask); rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
Value insertShiftedByOffset = Value insertShiftedByOffset =
rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset); rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
insertShiftedByOffset); insertShiftedByOffset);
return success(); return success();
@ -408,14 +408,14 @@ public:
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
if (srcType.isa<VectorType>()) { if (srcType.isa<VectorType>()) {
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>(); auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
constOp, dstType, constOp, dstType,
dstElementsAttr.mapValues( dstElementsAttr.mapValues(
signlessType, [&](const APInt &value) { return value; })); signlessType, [&](const APInt &value) { return value; }));
return success(); return success();
} }
auto srcAttr = constOp.value().cast<IntegerAttr>(); auto srcAttr = constOp.getValue().cast<IntegerAttr>();
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
return success(); return success();
@ -441,9 +441,9 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed. // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter); typeConverter, rewriter);
Value count = processCountOrOffset(loc, op.count(), srcType, dstType, Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter); typeConverter, rewriter);
// Create a constant that holds the size of the `Base`. // Create a constant that holds the size of the `Base`.
@ -468,7 +468,7 @@ public:
Value amountToShiftLeft = Value amountToShiftLeft =
rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
loc, dstType, op.base(), amountToShiftLeft); loc, dstType, op.getBase(), amountToShiftLeft);
// Shift the result right, filling the bits with the sign bit. // Shift the result right, filling the bits with the sign bit.
Value amountToShiftRight = Value amountToShiftRight =
@ -494,9 +494,9 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed. // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter); typeConverter, rewriter);
Value count = processCountOrOffset(loc, op.count(), srcType, dstType, Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter); typeConverter, rewriter);
// Create a mask with bits set at [0, Count - 1]. // Create a mask with bits set at [0, Count - 1].
@ -508,7 +508,7 @@ public:
// Shift `Base` by `Offset` and apply the mask on it. // Shift `Base` by `Offset` and apply the mask on it.
Value shiftedBase = Value shiftedBase =
rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset); rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
return success(); return success();
} }
@ -538,20 +538,20 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// If branch weights exist, map them to 32-bit integer vector. // If branch weights exist, map them to 32-bit integer vector.
ElementsAttr branchWeights = nullptr; ElementsAttr branchWeights = nullptr;
if (auto weights = op.branch_weights()) { if (auto weights = op.getBranchWeights()) {
VectorType weightType = VectorType::get(2, rewriter.getI32Type()); VectorType weightType = VectorType::get(2, rewriter.getI32Type());
branchWeights = DenseElementsAttr::get(weightType, weights->getValue()); branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
} }
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, op.condition(), op.getTrueBlockArguments(), op, op.getCondition(), op.getTrueBlockArguments(),
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
op.getFalseBlock()); op.getFalseBlock());
return success(); return success();
} }
}; };
/// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type /// Converts `spv.getCompositeExtract` to `llvm.extractvalue` if the container type
/// is an aggregate type (struct or array). Otherwise, converts to /// is an aggregate type (struct or array). Otherwise, converts to
/// `llvm.extractelement` that operates on vectors. /// `llvm.extractelement` that operates on vectors.
class CompositeExtractPattern class CompositeExtractPattern
@ -566,23 +566,23 @@ public:
if (!dstType) if (!dstType)
return failure(); return failure();
Type containerType = op.composite().getType(); Type containerType = op.getComposite().getType();
if (containerType.isa<VectorType>()) { if (containerType.isa<VectorType>()) {
Location loc = op.getLoc(); Location loc = op.getLoc();
IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
Value index = createI32ConstantOf(loc, rewriter, value.getInt()); Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, dstType, adaptor.composite(), index); op, dstType, adaptor.getComposite(), index);
return success(); return success();
} }
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
op, adaptor.composite(), LLVM::convertArrayToIndices(op.indices())); op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
return success(); return success();
} }
}; };
/// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type /// Converts `spv.getCompositeInsert` to `llvm.insertvalue` if the container type
/// is an aggregate type (struct or array). Otherwise, converts to /// is an aggregate type (struct or array). Otherwise, converts to
/// `llvm.insertelement` that operates on vectors. /// `llvm.insertelement` that operates on vectors.
class CompositeInsertPattern class CompositeInsertPattern
@ -597,19 +597,19 @@ public:
if (!dstType) if (!dstType)
return failure(); return failure();
Type containerType = op.composite().getType(); Type containerType = op.getComposite().getType();
if (containerType.isa<VectorType>()) { if (containerType.isa<VectorType>()) {
Location loc = op.getLoc(); Location loc = op.getLoc();
IntegerAttr value = op.indices()[0].cast<IntegerAttr>(); IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
Value index = createI32ConstantOf(loc, rewriter, value.getInt()); Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, dstType, adaptor.composite(), adaptor.object(), index); op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
return success(); return success();
} }
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
op, adaptor.composite(), adaptor.object(), op, adaptor.getComposite(), adaptor.getObject(),
LLVM::convertArrayToIndices(op.indices())); LLVM::convertArrayToIndices(op.getIndices()));
return success(); return success();
} }
}; };
@ -647,14 +647,14 @@ public:
// this entry point's execution mode. We set it to be: // this entry point's execution mode. We set it to be:
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
ModuleOp module = op->getParentOfType<ModuleOp>(); ModuleOp module = op->getParentOfType<ModuleOp>();
spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr(); spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
std::string moduleName; std::string moduleName;
if (module.getName().has_value()) if (module.getName().has_value())
moduleName = "_" + module.getName().value().str(); moduleName = "_" + module.getName()->str();
else else
moduleName = ""; moduleName = "";
std::string executionModeInfoName = llvm::formatv( std::string executionModeInfoName = llvm::formatv(
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(), "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
static_cast<uint32_t>(executionModeAttr.getValue())); static_cast<uint32_t>(executionModeAttr.getValue()));
MLIRContext *context = rewriter.getContext(); MLIRContext *context = rewriter.getContext();
@ -669,7 +669,7 @@ public:
auto llvmI32Type = IntegerType::get(context, 32); auto llvmI32Type = IntegerType::get(context, 32);
SmallVector<Type, 2> fields; SmallVector<Type, 2> fields;
fields.push_back(llvmI32Type); fields.push_back(llvmI32Type);
ArrayAttr values = op.values(); ArrayAttr values = op.getValues();
if (!values.empty()) { if (!values.empty()) {
auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
fields.push_back(arrayType); fields.push_back(arrayType);
@ -722,10 +722,10 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// Currently, there is no support of initialization with a constant value in // Currently, there is no support of initialization with a constant value in
// SPIR-V dialect. Specialization constants are not considered as well. // SPIR-V dialect. Specialization constants are not considered as well.
if (op.initializer()) if (op.getInitializer())
return failure(); return failure();
auto srcType = op.type().cast<spirv::PointerType>(); auto srcType = op.getType().cast<spirv::PointerType>();
auto dstType = typeConverter.convertType(srcType.getPointeeType()); auto dstType = typeConverter.convertType(srcType.getPointeeType());
if (!dstType) if (!dstType)
return failure(); return failure();
@ -759,12 +759,12 @@ public:
? LLVM::Linkage::Private ? LLVM::Linkage::Private
: LLVM::Linkage::External; : LLVM::Linkage::External;
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
op, dstType, isConstant, linkage, op.sym_name(), Attribute(), op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
/*alignment=*/0); /*alignment=*/0);
// Attach location attribute if applicable // Attach location attribute if applicable
if (op.locationAttr()) if (op.getLocationAttr())
newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr()); newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
return success(); return success();
} }
@ -781,7 +781,7 @@ public:
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type fromType = operation.operand().getType(); Type fromType = operation.getOperand().getType();
Type toType = operation.getType(); Type toType = operation.getType();
auto dstType = this->typeConverter.convertType(toType); auto dstType = this->typeConverter.convertType(toType);
@ -839,8 +839,8 @@ public:
return failure(); return failure();
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
operation, dstType, predicate, operation.operand1(), operation, dstType, predicate, operation.getOperand1(),
operation.operand2()); operation.getOperand2());
return success(); return success();
} }
}; };
@ -860,8 +860,8 @@ public:
return failure(); return failure();
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
operation, dstType, predicate, operation.operand1(), operation, dstType, predicate, operation.getOperand1(),
operation.operand2()); operation.getOperand2());
return success(); return success();
} }
}; };
@ -881,7 +881,7 @@ public:
Location loc = op.getLoc(); Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand()); Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
return success(); return success();
} }
@ -896,20 +896,20 @@ public:
LogicalResult LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (!op.memory_access()) { if (!op.getMemoryAccess()) {
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
this->typeConverter, /*alignment=*/0, this->typeConverter, /*alignment=*/0,
/*isVolatile=*/false, /*isVolatile=*/false,
/*isNonTemporal=*/false); /*isNonTemporal=*/false);
} }
auto memoryAccess = *op.memory_access(); auto memoryAccess = *op.getMemoryAccess();
switch (memoryAccess) { switch (memoryAccess) {
case spirv::MemoryAccess::Aligned: case spirv::MemoryAccess::Aligned:
case spirv::MemoryAccess::None: case spirv::MemoryAccess::None:
case spirv::MemoryAccess::Nontemporal: case spirv::MemoryAccess::Nontemporal:
case spirv::MemoryAccess::Volatile: { case spirv::MemoryAccess::Volatile: {
unsigned alignment = unsigned alignment =
memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
@ -946,7 +946,7 @@ public:
srcType.template cast<VectorType>(), minusOne)) srcType.template cast<VectorType>(), minusOne))
: rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
notOp.operand(), mask); notOp.getOperand(), mask);
return success(); return success();
} }
}; };
@ -1047,7 +1047,7 @@ public:
matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// There is no support of loop control at the moment. // There is no support of loop control at the moment.
if (loopOp.loop_control() != spirv::LoopControl::None) if (loopOp.getLoopControl() != spirv::LoopControl::None)
return failure(); return failure();
Location loc = loopOp.getLoc(); Location loc = loopOp.getLoc();
@ -1077,7 +1077,7 @@ public:
rewriter.setInsertionPointToEnd(mergeBlock); rewriter.setInsertionPointToEnd(mergeBlock);
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
rewriter.inlineRegionBefore(loopOp.body(), endBlock); rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
rewriter.replaceOp(loopOp, endBlock->getArguments()); rewriter.replaceOp(loopOp, endBlock->getArguments());
return success(); return success();
} }
@ -1096,14 +1096,14 @@ public:
// There is no support for `Flatten` or `DontFlatten` selection control at // There is no support for `Flatten` or `DontFlatten` selection control at
// the moment. This are just compiler hints and can be performed during the // the moment. This are just compiler hints and can be performed during the
// optimization passes. // optimization passes.
if (op.selection_control() != spirv::SelectionControl::None) if (op.getSelectionControl() != spirv::SelectionControl::None)
return failure(); return failure();
// `spv.mlir.selection` should have at least two blocks: one selection // `spv.mlir.selection` should have at least two blocks: one selection
// header block and one merge block. If no blocks are present, or control // header block and one merge block. If no blocks are present, or control
// flow branches straight to merge block (two blocks are present), the op is // flow branches straight to merge block (two blocks are present), the op is
// redundant and it is erased. // redundant and it is erased.
if (op.body().getBlocks().size() <= 2) { if (op.getBody().getBlocks().size() <= 2) {
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@ -1140,11 +1140,11 @@ public:
Block *trueBlock = condBrOp.getTrueBlock(); Block *trueBlock = condBrOp.getTrueBlock();
Block *falseBlock = condBrOp.getFalseBlock(); Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock); rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock, rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
condBrOp.trueTargetOperands(), falseBlock, condBrOp.getTrueTargetOperands(), falseBlock,
condBrOp.falseTargetOperands()); condBrOp.getFalseTargetOperands());
rewriter.inlineRegionBefore(op.body(), continueBlock); rewriter.inlineRegionBefore(op.getBody(), continueBlock);
rewriter.replaceOp(op, continueBlock->getArguments()); rewriter.replaceOp(op, continueBlock->getArguments());
return success(); return success();
} }
@ -1167,8 +1167,8 @@ public:
if (!dstType) if (!dstType)
return failure(); return failure();
Type op1Type = operation.operand1().getType(); Type op1Type = operation.getOperand1().getType();
Type op2Type = operation.operand2().getType(); Type op2Type = operation.getOperand2().getType();
if (op1Type == op2Type) { if (op1Type == op2Type) {
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
@ -1180,13 +1180,13 @@ public:
Value extended; Value extended;
if (isUnsignedIntegerOrVector(op2Type)) { if (isUnsignedIntegerOrVector(op2Type)) {
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType, extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
adaptor.operand2()); adaptor.getOperand2());
} else { } else {
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType, extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
adaptor.operand2()); adaptor.getOperand2());
} }
Value result = rewriter.template create<LLVMOp>( Value result = rewriter.template create<LLVMOp>(
loc, dstType, adaptor.operand1(), extended); loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(operation, result); rewriter.replaceOp(operation, result);
return success(); return success();
} }
@ -1204,8 +1204,8 @@ public:
return failure(); return failure();
Location loc = tanOp.getLoc(); Location loc = tanOp.getLoc();
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand()); Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand()); Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
return success(); return success();
} }
@ -1232,7 +1232,7 @@ public:
Location loc = tanhOp.getLoc(); Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
Value multiplied = Value multiplied =
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand()); rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value numerator = Value numerator =
@ -1255,7 +1255,7 @@ public:
auto srcType = varOp.getType(); auto srcType = varOp.getType();
// Initialization is supported for scalars and vectors only. // Initialization is supported for scalars and vectors only.
auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType(); auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
auto init = varOp.initializer(); auto init = varOp.getInitializer();
if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>()) if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
return failure(); return failure();
@ -1270,7 +1270,7 @@ public:
return success(); return success();
} }
Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size); Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated); rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated); rewriter.replaceOp(varOp, allocated);
return success(); return success();
} }
@ -1305,7 +1305,7 @@ public:
// Convert SPIR-V Function Control to equivalent LLVM function attribute // Convert SPIR-V Function Control to equivalent LLVM function attribute
MLIRContext *context = funcOp.getContext(); MLIRContext *context = funcOp.getContext();
switch (funcOp.function_control()) { switch (funcOp.getFunctionControl()) {
#define DISPATCH(functionControl, llvmAttr) \ #define DISPATCH(functionControl, llvmAttr) \
case functionControl: \ case functionControl: \
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
@ -1374,9 +1374,9 @@ public:
matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc(); Location loc = op.getLoc();
auto components = adaptor.components(); auto components = adaptor.getComponents();
auto vector1 = adaptor.vector1(); auto vector1 = adaptor.getVector1();
auto vector2 = adaptor.vector2(); auto vector2 = adaptor.getVector2();
int vector1Size = vector1.getType().cast<VectorType>().getNumElements(); int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
int vector2Size = vector2.getType().cast<VectorType>().getNumElements(); int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
if (vector1Size == vector2Size) { if (vector1Size == vector2Size) {
@ -1589,8 +1589,8 @@ void mlir::encodeBindAttribute(ModuleOp module) {
// SPIR-V module has a name, add it at the beginning. // SPIR-V module has a name, add it at the beginning.
auto moduleAndName = auto moduleAndName =
spvModule.getName().has_value() spvModule.getName().has_value()
? spvModule.getName().value().str() + "_" + op.sym_name().str() ? spvModule.getName()->str() + "_" + op.getSymName().str()
: op.sym_name().str(); : op.getSymName().str();
std::string name = std::string name =
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
std::to_string(descriptorSet.getInt()), std::to_string(descriptorSet.getInt()),

View File

@ -88,19 +88,19 @@ struct CombineChainedAccessChain
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp, LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>( auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
accessChainOp.base_ptr().getDefiningOp()); accessChainOp.getBasePtr().getDefiningOp());
if (!parentAccessChainOp) { if (!parentAccessChainOp) {
return failure(); return failure();
} }
// Combine indices. // Combine indices.
SmallVector<Value, 4> indices(parentAccessChainOp.indices()); SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
indices.append(accessChainOp.indices().begin(), indices.append(accessChainOp.getIndices().begin(),
accessChainOp.indices().end()); accessChainOp.getIndices().end());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.base_ptr(), indices); accessChainOp, parentAccessChainOp.getBasePtr(), indices);
return success(); return success();
} }
@ -126,23 +126,24 @@ void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) { OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
if (auto insertOp = composite().getDefiningOp<spirv::CompositeInsertOp>()) { if (auto insertOp =
if (indices() == insertOp.indices()) getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
return insertOp.object(); if (getIndices() == insertOp.getIndices())
return insertOp.getObject();
} }
if (auto constructOp = if (auto constructOp =
composite().getDefiningOp<spirv::CompositeConstructOp>()) { getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
auto type = constructOp.getType().cast<spirv::CompositeType>(); auto type = constructOp.getType().cast<spirv::CompositeType>();
if (indices().size() == 1 && if (getIndices().size() == 1 &&
constructOp.constituents().size() == type.getNumElements()) { constructOp.getConstituents().size() == type.getNumElements()) {
auto i = indices().begin()->cast<IntegerAttr>(); auto i = getIndices().begin()->cast<IntegerAttr>();
return constructOp.constituents()[i.getValue().getSExtValue()]; return constructOp.getConstituents()[i.getValue().getSExtValue()];
} }
} }
auto indexVector = auto indexVector =
llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) { llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt()); return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
})); }));
return extractCompositeElement(operands[0], indexVector); return extractCompositeElement(operands[0], indexVector);
@ -154,7 +155,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) { OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "spv.Constant has no operands"); assert(operands.empty() && "spv.Constant has no operands");
return value(); return getValue();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -164,8 +165,8 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) { OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.IAdd expects two operands"); assert(operands.size() == 2 && "spv.IAdd expects two operands");
// x + 0 = x // x + 0 = x
if (matchPattern(operand2(), m_Zero())) if (matchPattern(getOperand2(), m_Zero()))
return operand1(); return getOperand1();
// According to the SPIR-V spec: // According to the SPIR-V spec:
// //
@ -183,11 +184,11 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) { OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.IMul expects two operands"); assert(operands.size() == 2 && "spv.IMul expects two operands");
// x * 0 == 0 // x * 0 == 0
if (matchPattern(operand2(), m_Zero())) if (matchPattern(getOperand2(), m_Zero()))
return operand2(); return getOperand2();
// x * 1 = x // x * 1 = x
if (matchPattern(operand2(), m_One())) if (matchPattern(getOperand2(), m_One()))
return operand1(); return getOperand1();
// According to the SPIR-V spec: // According to the SPIR-V spec:
// //
@ -204,7 +205,7 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) { OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
// x - x = 0 // x - x = 0
if (operand1() == operand2()) if (getOperand1() == getOperand2())
return Builder(getContext()).getIntegerAttr(getType(), 0); return Builder(getContext()).getIntegerAttr(getType(), 0);
// According to the SPIR-V spec: // According to the SPIR-V spec:
@ -226,7 +227,7 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) { if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
// x && true = x // x && true = x
if (rhs.value()) if (rhs.value())
return operand1(); return getOperand1();
// x && false = false // x && false = false
if (!rhs.value()) if (!rhs.value())
@ -262,7 +263,7 @@ OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
// x || false = x // x || false = x
if (!rhs.value()) if (!rhs.value())
return operand1(); return getOperand1();
} }
return Attribute(); return Attribute();
@ -339,8 +340,8 @@ struct ConvertSelectionOpToSelect
cast<spirv::StoreOp>(trueBlock->front())->getAttrs(); cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
auto selectOp = rewriter.create<spirv::SelectOp>( auto selectOp = rewriter.create<spirv::SelectOp>(
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(), selectionOp.getLoc(), trueValue.getType(),
trueValue, falseValue); brConditionalOp.getCondition(), trueValue, falseValue);
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue, rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
selectOp.getResult(), storeOpAttributes); selectOp.getResult(), storeOpAttributes);
@ -371,13 +372,13 @@ private:
// Returns a source value for the given block. // Returns a source value for the given block.
Value getSrcValue(Block *block) const { Value getSrcValue(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front()); auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.value(); return storeOp.getValue();
} }
// Returns a destination value for the given block. // Returns a destination value for the given block.
Value getDstPtr(Block *block) const { Value getDstPtr(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front()); auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.ptr(); return storeOp.getPtr();
} }
}; };
@ -406,14 +407,14 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
// "Before version 1.4, Result Type must be a pointer, scalar, or vector. // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
// Starting with version 1.4, Result Type can additionally be a composite type // Starting with version 1.4, Result Type can additionally be a composite type
// other than a vector." // other than a vector."
bool isScalarOrVector = trueBrStoreOp.value() bool isScalarOrVector = trueBrStoreOp.getValue()
.getType() .getType()
.cast<spirv::SPIRVType>() .cast<spirv::SPIRVType>()
.isScalarOrVector(); .isScalarOrVector();
// Check that each `spv.Store` uses the same pointer, memory access // Check that each `spv.Store` uses the same pointer, memory access
// attributes and a valid type of the value. // attributes and a valid type of the value.
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) { !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
return failure(); return failure();
} }

View File

@ -106,7 +106,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
// Replace the values directly with the return operands. // Replace the values directly with the return operands.
assert(valuesToRepl.size() == 1 && assert(valuesToRepl.size() == 1 &&
"spv.ReturnValue expected to only handle one result"); "spv.ReturnValue expected to only handle one result");
valuesToRepl.front().replaceAllUsesWith(retValOp.value()); valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
} }
}; };
} // namespace } // namespace

File diff suppressed because it is too large Load Diff

View File

@ -94,16 +94,16 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
return nullptr; return nullptr;
spirv::ModuleOp firstModule = inputModules.front(); spirv::ModuleOp firstModule = inputModules.front();
auto addressingModel = firstModule.addressing_model(); auto addressingModel = firstModule.getAddressingModel();
auto memoryModel = firstModule.memory_model(); auto memoryModel = firstModule.getMemoryModel();
auto vceTriple = firstModule.vce_triple(); auto vceTriple = firstModule.getVceTriple();
// First check whether there are conflicts between addressing/memory model. // First check whether there are conflicts between addressing/memory model.
// Return early if so. // Return early if so.
for (auto module : inputModules) { for (auto module : inputModules) {
if (module.addressing_model() != addressingModel || if (module.getAddressingModel() != addressingModel ||
module.memory_model() != memoryModel || module.getMemoryModel() != memoryModel ||
module.vce_triple() != vceTriple) { module.getVceTriple() != vceTriple) {
module.emitError("input modules differ in addressing model, memory " module.emitError("input modules differ in addressing model, memory "
"model, and/or VCE triple"); "model, and/or VCE triple");
return nullptr; return nullptr;

View File

@ -40,7 +40,7 @@ public:
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs; SmallVector<NamedAttribute, 4> globalVarAttrs;
auto ptrType = op.type().cast<spirv::PointerType>(); auto ptrType = op.getType().cast<spirv::PointerType>();
auto structType = VulkanLayoutUtils::decorateType( auto structType = VulkanLayoutUtils::decorateType(
ptrType.getPointeeType().cast<spirv::StructType>()); ptrType.getPointeeType().cast<spirv::StructType>());
@ -71,11 +71,11 @@ public:
LogicalResult matchAndRewrite(spirv::AddressOfOp op, LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto spirvModule = op->getParentOfType<spirv::ModuleOp>(); auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
auto varName = op.variableAttr(); auto varName = op.getVariableAttr();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName); auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>( rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
op, varOp.type(), SymbolRefAttr::get(varName.getAttr())); op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
return success(); return success();
} }
}; };
@ -121,12 +121,12 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
target.addLegalOp<func::FuncOp>(); target.addLegalOp<func::FuncOp>();
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>( target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
[](spirv::GlobalVariableOp op) { [](spirv::GlobalVariableOp op) {
return VulkanLayoutUtils::isLegalType(op.type()); return VulkanLayoutUtils::isLegalType(op.getType());
}); });
// Change the type for the direct users. // Change the type for the direct users.
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) { target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
return VulkanLayoutUtils::isLegalType(op.pointer().getType()); return VulkanLayoutUtils::isLegalType(op.getPointer().getType());
}); });
// Change the type for the indirect users. // Change the type for the indirect users.
@ -134,7 +134,8 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
spirv::StoreOp>([&](Operation *op) { spirv::StoreOp>([&](Operation *op) {
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>(); auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType())) if (addrOp &&
!VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType()))
return false; return false;
} }
return true; return true;

View File

@ -88,13 +88,13 @@ getInterfaceVariables(spirv::FuncOp funcOp,
// instructions in this function. // instructions in this function.
funcOp.walk([&](spirv::AddressOfOp addressOfOp) { funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
auto var = auto var =
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable()); module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
// TODO: Per SPIR-V spec: "Before version 1.4, the interfaces // TODO: Per SPIR-V spec: "Before version 1.4, the interfaces
// storage classes are limited to the Input and Output storage classes. // storage classes are limited to the Input and Output storage classes.
// Starting with version 1.4, the interfaces storage classes are all // Starting with version 1.4, the interfaces storage classes are all
// storage classes used in declaring all global variables referenced by the // storage classes used in declaring all global variables referenced by the
// entry points call tree." We should consider the target environment here. // entry points call tree." We should consider the target environment here.
switch (var.type().cast<spirv::PointerType>().getStorageClass()) { switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
case spirv::StorageClass::Input: case spirv::StorageClass::Input:
case spirv::StorageClass::Output: case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation()); interfaceVarSet.insert(var.getOperation());
@ -105,7 +105,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
}); });
for (auto &var : interfaceVarSet) { for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get( interfaceVars.push_back(SymbolRefAttr::get(
funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name())); funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
} }
return success(); return success();
} }
@ -223,7 +223,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
auto zero = auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>( auto loadPtr = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), replacement, zero.constant()); funcOp.getLoc(), replacement, zero.getConstant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr); replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
} }
signatureConverter.remapInput(argType.index(), replacement); signatureConverter.remapInput(argType.index(), replacement);

View File

@ -63,7 +63,7 @@ void RewriteInsertsPass::runOnOperation() {
SmallVector<Value, 4> operands; SmallVector<Value, 4> operands;
// Collect inserted objects. // Collect inserted objects.
for (auto insertionOp : insertions) for (auto insertionOp : insertions)
operands.push_back(insertionOp.object()); operands.push_back(insertionOp.getObject());
OpBuilder builder(lastCompositeInsertOp); OpBuilder builder(lastCompositeInsertOp);
auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>( auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
@ -84,11 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain( LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op, spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) { SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
auto indicesArrayAttr = op.indices().cast<ArrayAttr>(); auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
// TODO: handle nested composite object. // TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) { if (indicesArrayAttr.size() == 1) {
auto numElements = auto numElements = op.getComposite()
op.composite().getType().cast<spirv::CompositeType>().getNumElements(); .getType()
.cast<spirv::CompositeType>()
.getNumElements();
auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt(); auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
// Need a last index to collect a sequential chain. // Need a last index to collect a sequential chain.
@ -102,12 +104,12 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
if (index == 0) if (index == 0)
return success(); return success();
op = op.composite().getDefiningOp<spirv::CompositeInsertOp>(); op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
if (!op) if (!op)
return failure(); return failure();
--index; --index;
indicesArrayAttr = op.indices().cast<ArrayAttr>(); indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
if ((indicesArrayAttr.size() != 1) || if ((indicesArrayAttr.size() != 1) ||
(indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index)) (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
return failure(); return failure();

View File

@ -642,7 +642,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
unsigned elementCount) { unsigned elementCount) {
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
auto ptrType = varOp.type().dyn_cast<spirv::PointerType>(); auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
if (!ptrType) if (!ptrType)
continue; continue;
@ -874,7 +874,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Special treatment for global variables, whose type requirements are // Special treatment for global variables, whose type requirements are
// conveyed by type attributes. // conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.type()); valueTypes.push_back(globalVar.getType());
// Make sure the op's operands/results use types that are allowed by the // Make sure the op's operands/results use types that are allowed by the
// target environment. // target environment.

View File

@ -51,8 +51,8 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
AliasedResourceMap aliasedResources; AliasedResourceMap aliasedResources;
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) { moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
if (varOp->getAttrOfType<UnitAttr>("aliased")) { if (varOp->getAttrOfType<UnitAttr>("aliased")) {
Optional<uint32_t> set = varOp.descriptor_set(); Optional<uint32_t> set = varOp.getDescriptorSet();
Optional<uint32_t> binding = varOp.binding(); Optional<uint32_t> binding = varOp.getBinding();
if (set && binding) if (set && binding)
aliasedResources[{*set, *binding}].push_back(varOp); aliasedResources[{*set, *binding}].push_back(varOp);
} }
@ -222,16 +222,16 @@ bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
} }
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) { if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()); auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
return shouldUnify(varOp); return shouldUnify(varOp);
} }
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op)) if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
return shouldUnify(acOp.base_ptr().getDefiningOp()); return shouldUnify(acOp.getBasePtr().getDefiningOp());
if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
return shouldUnify(loadOp.ptr().getDefiningOp()); return shouldUnify(loadOp.getPtr().getDefiningOp());
if (auto storeOp = dyn_cast<spirv::StoreOp>(op)) if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
return shouldUnify(storeOp.ptr().getDefiningOp()); return shouldUnify(storeOp.getPtr().getDefiningOp());
return false; return false;
} }
@ -265,7 +265,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
// Collect the element types for all resources in the current set. // Collect the element types for all resources in the current set.
SmallVector<spirv::SPIRVType> elementTypes; SmallVector<spirv::SPIRVType> elementTypes;
for (spirv::GlobalVariableOp resource : resources) { for (spirv::GlobalVariableOp resource : resources) {
Type elementType = getRuntimeArrayElementType(resource.type()); Type elementType = getRuntimeArrayElementType(resource.getType());
if (!elementType) if (!elementType)
return; // Unexpected resource variable type. return; // Unexpected resource variable type.
@ -326,7 +326,7 @@ struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
// Rewrite the AddressOf op to get the address of the canoncical resource. // Rewrite the AddressOf op to get the address of the canoncical resource.
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>(); auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>( auto srcVarOp = cast<spirv::GlobalVariableOp>(
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp); auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp); rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
return success(); return success();
@ -339,13 +339,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
LogicalResult LogicalResult
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor, matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>(); auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
if (!addressOp) if (!addressOp)
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op"); return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>(); auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>( auto srcVarOp = cast<spirv::GlobalVariableOp>(
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable())); SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp); auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp); spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
@ -356,7 +356,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
// We have the same bitwidth for source and destination element types. // We have the same bitwidth for source and destination element types.
// Thie indices keep the same. // Thie indices keep the same.
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.base_ptr(), adaptor.indices()); acOp, adaptor.getBasePtr(), adaptor.getIndices());
return success(); return success();
} }
@ -375,7 +375,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
auto ratioValue = rewriter.create<spirv::ConstantOp>( auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio)); loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.indices()); auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back(); Value oldIndex = indices.back();
indices.back() = indices.back() =
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue); rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
@ -383,7 +383,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue)); rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.base_ptr(), indices); acOp, adaptor.getBasePtr(), indices);
return success(); return success();
} }
@ -399,13 +399,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
auto ratioValue = rewriter.create<spirv::ConstantOp>( auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio)); loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.indices()); auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back(); Value oldIndex = indices.back();
indices.back() = indices.back() =
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue); rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>( rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.base_ptr(), indices); acOp, adaptor.getBasePtr(), indices);
return success(); return success();
} }
@ -420,13 +420,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
LogicalResult LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto srcPtrType = loadOp.ptr().getType().cast<spirv::PointerType>(); auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>(); auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
auto dstPtrType = adaptor.ptr().getType().cast<spirv::PointerType>(); auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>(); auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
Location loc = loadOp.getLoc(); Location loc = loadOp.getLoc();
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr()); auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
if (srcElemType == dstElemType) { if (srcElemType == dstElemType) {
rewriter.replaceOp(loadOp, newLoadOp->getResults()); rewriter.replaceOp(loadOp, newLoadOp->getResults());
return success(); return success();
@ -434,7 +434,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
if (areSameBitwidthScalarType(srcElemType, dstElemType)) { if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType, auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
newLoadOp.value()); newLoadOp.getValue());
rewriter.replaceOp(loadOp, castOp->getResults()); rewriter.replaceOp(loadOp, castOp->getResults());
return success(); return success();
@ -457,19 +457,19 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
components.reserve(ratio); components.reserve(ratio);
components.push_back(newLoadOp); components.push_back(newLoadOp);
auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>(); auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
if (!acOp) if (!acOp)
return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain"); return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
auto i32Type = rewriter.getI32Type(); auto i32Type = rewriter.getI32Type();
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter); Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
auto indices = llvm::to_vector<4>(acOp.indices()); auto indices = llvm::to_vector<4>(acOp.getIndices());
for (int i = 1; i < ratio; ++i) { for (int i = 1; i < ratio; ++i) {
// Load all subsequent components belonging to this element. // Load all subsequent components belonging to this element.
indices.back() = rewriter.create<spirv::IAddOp>( indices.back() = rewriter.create<spirv::IAddOp>(
loc, i32Type, indices.back(), oneValue); loc, i32Type, indices.back(), oneValue);
auto componentAcOp = rewriter.create<spirv::AccessChainOp>( auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
loc, acOp.base_ptr(), indices); loc, acOp.getBasePtr(), indices);
// Assuming little endian, this reads lower-ordered bits of the number // Assuming little endian, this reads lower-ordered bits of the number
// to lower-numbered components of the vector. // to lower-numbered components of the vector.
components.push_back( components.push_back(
@ -504,19 +504,19 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto srcElemType = auto srcElemType =
storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType(); storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
auto dstElemType = auto dstElemType =
adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType(); adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type"); return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType)) if (!areSameBitwidthScalarType(srcElemType, dstElemType))
return rewriter.notifyMatchFailure(storeOp, "different bitwidth"); return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
Location loc = storeOp.getLoc(); Location loc = storeOp.getLoc();
Value value = adaptor.value(); Value value = adaptor.getValue();
if (srcElemType != dstElemType) if (srcElemType != dstElemType)
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value); value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value, rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(), value,
storeOp->getAttrs()); storeOp->getAttrs());
return success(); return success();
} }

View File

@ -151,7 +151,7 @@ void UpdateVCEPass::runOnOperation() {
// Special treatment for global variables, whose type requirements are // Special treatment for global variables, whose type requirements are
// conveyed by type attributes. // conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.type()); valueTypes.push_back(globalVar.getType());
// Requirements from values' types // Requirements from values' types
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;

View File

@ -46,20 +46,20 @@ Value spirv::Deserializer::getValue(uint32_t id) {
} }
if (auto varOp = getGlobalVariable(id)) { if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation())); unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
return addressOfOp.pointer(); return addressOfOp.getPointer();
} }
if (auto constOp = getSpecConstant(id)) { if (auto constOp = getSpecConstant(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
unknownLoc, constOp.default_value().getType(), unknownLoc, constOp.getDefaultValue().getType(),
SymbolRefAttr::get(constOp.getOperation())); SymbolRefAttr::get(constOp.getOperation()));
return referenceOfOp.reference(); return referenceOfOp.getReference();
} }
if (auto constCompositeOp = getSpecConstantComposite(id)) { if (auto constCompositeOp = getSpecConstantComposite(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
unknownLoc, constCompositeOp.type(), unknownLoc, constCompositeOp.getType(),
SymbolRefAttr::get(constCompositeOp.getOperation())); SymbolRefAttr::get(constCompositeOp.getOperation()));
return referenceOfOp.reference(); return referenceOfOp.getReference();
} }
if (auto specConstOperationInfo = getSpecConstantOperation(id)) { if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
return materializeSpecConstantOperation( return materializeSpecConstantOperation(

View File

@ -1414,7 +1414,7 @@ Value spirv::Deserializer::materializeSpecConstantOperation(
auto specConstOperationOp = auto specConstOperationOp =
opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType); opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
Region &body = specConstOperationOp.body(); Region &body = specConstOperationOp.getBody();
// Move the new block into SpecConstantOperation's body. // Move the new block into SpecConstantOperation's body.
body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(), body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
Region::iterator(enclosedBlock)); Region::iterator(enclosedBlock));
@ -1983,17 +1983,17 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
assert((branchCondOp.getTrueBlock() == target || assert((branchCondOp.getTrueBlock() == target ||
branchCondOp.getFalseBlock() == target) && branchCondOp.getFalseBlock() == target) &&
"expected target to be either the true or false target"); "expected target to be either the true or false target");
if (target == branchCondOp.trueTarget()) if (target == branchCondOp.getTrueTarget())
opBuilder.create<spirv::BranchConditionalOp>( opBuilder.create<spirv::BranchConditionalOp>(
branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
branchCondOp.getFalseBlockArguments(), branchCondOp.getFalseBlockArguments(),
branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
branchCondOp.falseTarget()); branchCondOp.getFalseTarget());
else else
opBuilder.create<spirv::BranchConditionalOp>( opBuilder.create<spirv::BranchConditionalOp>(
branchCondOp.getLoc(), branchCondOp.condition(), branchCondOp.getLoc(), branchCondOp.getCondition(),
branchCondOp.getTrueBlockArguments(), blockArgs, branchCondOp.getTrueBlockArguments(), blockArgs,
branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
branchCondOp.getFalseBlock()); branchCondOp.getFalseBlock());
branchCondOp.erase(); branchCondOp.erase();

View File

@ -24,7 +24,7 @@ namespace mlir {
LogicalResult spirv::serialize(spirv::ModuleOp module, LogicalResult spirv::serialize(spirv::ModuleOp module,
SmallVectorImpl<uint32_t> &binary, SmallVectorImpl<uint32_t> &binary,
const SerializationOptions &options) { const SerializationOptions &options) {
if (!module.vce_triple()) if (!module.getVceTriple())
return module.emitError( return module.emitError(
"module must have 'vce_triple' attribute to be serializeable"); "module must have 'vce_triple' attribute to be serializeable");

View File

@ -58,7 +58,8 @@ visitInPrettyBlockOrder(Block *headerBlock,
namespace mlir { namespace mlir {
namespace spirv { namespace spirv {
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { if (auto resultID =
prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
valueIDMap[op.getResult()] = resultID; valueIDMap[op.getResult()] = resultID;
return success(); return success();
} }
@ -66,7 +67,7 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
} }
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(), if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
/*isSpec=*/true)) { /*isSpec=*/true)) {
// Emit the OpDecorate instruction for SpecId. // Emit the OpDecorate instruction for SpecId.
if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) { if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
@ -75,8 +76,8 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
return failure(); return failure();
} }
specConstIDMap[op.sym_name()] = resultID; specConstIDMap[op.getSymName()] = resultID;
return processName(resultID, op.sym_name()); return processName(resultID, op.getSymName());
} }
return failure(); return failure();
} }
@ -84,7 +85,7 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
LogicalResult LogicalResult
Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
uint32_t typeID = 0; uint32_t typeID = 0;
if (failed(processType(op.getLoc(), op.type(), typeID))) { if (failed(processType(op.getLoc(), op.getType(), typeID))) {
return failure(); return failure();
} }
@ -94,7 +95,7 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
operands.push_back(typeID); operands.push_back(typeID);
operands.push_back(resultID); operands.push_back(resultID);
auto constituents = op.constituents(); auto constituents = op.getConstituents();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>(); auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
@ -112,9 +113,9 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
encodeInstructionInto(typesGlobalValues, encodeInstructionInto(typesGlobalValues,
spirv::Opcode::OpSpecConstantComposite, operands); spirv::Opcode::OpSpecConstantComposite, operands);
specConstIDMap[op.sym_name()] = resultID; specConstIDMap[op.getSymName()] = resultID;
return processName(resultID, op.sym_name()); return processName(resultID, op.getSymName());
} }
LogicalResult LogicalResult
@ -199,7 +200,7 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
operands.push_back(resTypeID); operands.push_back(resTypeID);
auto funcID = getOrCreateFunctionID(op.getName()); auto funcID = getOrCreateFunctionID(op.getName());
operands.push_back(funcID); operands.push_back(funcID);
operands.push_back(static_cast<uint32_t>(op.function_control())); operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
operands.push_back(fnTypeID); operands.push_back(fnTypeID);
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands); encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
@ -310,7 +311,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Get TypeID. // Get TypeID.
uint32_t resultTypeID = 0; uint32_t resultTypeID = 0;
SmallVector<StringRef, 4> elidedAttrs; SmallVector<StringRef, 4> elidedAttrs;
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
return failure(); return failure();
} }
@ -320,7 +321,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
auto resultID = getNextID(); auto resultID = getNextID();
// Encode the name. // Encode the name.
auto varName = varOp.sym_name(); auto varName = varOp.getSymName();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
if (failed(processName(resultID, varName))) { if (failed(processName(resultID, varName))) {
return failure(); return failure();
@ -332,7 +333,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
operands.push_back(static_cast<uint32_t>(varOp.storageClass())); operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization. // Encode initialization.
if (auto initializer = varOp.initializer()) { if (auto initializer = varOp.getInitializer()) {
auto initializerID = getVariableID(*initializer); auto initializerID = getVariableID(*initializer);
if (!initializerID) { if (!initializerID) {
return emitError(varOp.getLoc(), return emitError(varOp.getLoc(),
@ -364,7 +365,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) { LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// Assign <id>s to all blocks so that branches inside the SelectionOp can // Assign <id>s to all blocks so that branches inside the SelectionOp can
// resolve properly. // resolve properly.
auto &body = selectionOp.body(); auto &body = selectionOp.getBody();
for (Block &block : body) for (Block &block : body)
getOrCreateBlockID(&block); getOrCreateBlockID(&block);
@ -390,7 +391,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
lastProcessedWasMergeInst = true; lastProcessedWasMergeInst = true;
encodeInstructionInto( encodeInstructionInto(
functionBody, spirv::Opcode::OpSelectionMerge, functionBody, spirv::Opcode::OpSelectionMerge,
{mergeID, static_cast<uint32_t>(selectionOp.selection_control())}); {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
return success(); return success();
}; };
if (failed( if (failed(
@ -420,7 +421,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
// properly. We don't need to assign for the entry block, which is just for // properly. We don't need to assign for the entry block, which is just for
// satisfying MLIR region's structural requirement. // satisfying MLIR region's structural requirement.
auto &body = loopOp.body(); auto &body = loopOp.getBody();
for (Block &block : llvm::drop_begin(body)) for (Block &block : llvm::drop_begin(body))
getOrCreateBlockID(&block); getOrCreateBlockID(&block);
@ -452,7 +453,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
lastProcessedWasMergeInst = true; lastProcessedWasMergeInst = true;
encodeInstructionInto( encodeInstructionInto(
functionBody, spirv::Opcode::OpLoopMerge, functionBody, spirv::Opcode::OpLoopMerge,
{mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())}); {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
return success(); return success();
}; };
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge))) if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
@ -483,12 +484,12 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
LogicalResult Serializer::processBranchConditionalOp( LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) { spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.condition()); auto conditionID = getValueID(condBranchOp.getCondition());
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock()); auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock()); auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID}; SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
if (auto weights = condBranchOp.branch_weights()) { if (auto weights = condBranchOp.getBranchWeights()) {
for (auto val : weights->getValue()) for (auto val : weights->getValue())
arguments.push_back(val.cast<IntegerAttr>().getInt()); arguments.push_back(val.cast<IntegerAttr>().getInt());
} }
@ -509,26 +510,26 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
} }
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.variable(); auto varName = addressOfOp.getVariable();
auto variableID = getVariableID(varName); auto variableID = getVariableID(varName);
if (!variableID) { if (!variableID) {
return addressOfOp.emitError("unknown result <id> for variable ") return addressOfOp.emitError("unknown result <id> for variable ")
<< varName; << varName;
} }
valueIDMap[addressOfOp.pointer()] = variableID; valueIDMap[addressOfOp.getPointer()] = variableID;
return success(); return success();
} }
LogicalResult LogicalResult
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) { Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
auto constName = referenceOfOp.spec_const(); auto constName = referenceOfOp.getSpecConst();
auto constID = getSpecConstID(constName); auto constID = getSpecConstID(constName);
if (!constID) { if (!constID) {
return referenceOfOp.emitError( return referenceOfOp.emitError(
"unknown result <id> for specialization constant ") "unknown result <id> for specialization constant ")
<< constName; << constName;
} }
valueIDMap[referenceOfOp.reference()] = constID; valueIDMap[referenceOfOp.getReference()] = constID;
return success(); return success();
} }
@ -537,21 +538,21 @@ LogicalResult
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) { Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
SmallVector<uint32_t, 4> operands; SmallVector<uint32_t, 4> operands;
// Add the ExecutionModel. // Add the ExecutionModel.
operands.push_back(static_cast<uint32_t>(op.execution_model())); operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
// Add the function <id>. // Add the function <id>.
auto funcID = getFunctionID(op.fn()); auto funcID = getFunctionID(op.getFn());
if (!funcID) { if (!funcID) {
return op.emitError("missing <id> for function ") return op.emitError("missing <id> for function ")
<< op.fn() << op.getFn()
<< "; function needs to be defined before spv.EntryPoint is " << "; function needs to be defined before spv.EntryPoint is "
"serialized"; "serialized";
} }
operands.push_back(funcID); operands.push_back(funcID);
// Add the name of the function. // Add the name of the function.
spirv::encodeStringLiteralInto(operands, op.fn()); spirv::encodeStringLiteralInto(operands, op.getFn());
// Add the interface values. // Add the interface values.
if (auto interface = op.interface()) { if (auto interface = op.getInterface()) {
for (auto var : interface.getValue()) { for (auto var : interface.getValue()) {
auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue()); auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
if (!id) { if (!id) {
@ -571,19 +572,19 @@ LogicalResult
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) { Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
SmallVector<uint32_t, 4> operands; SmallVector<uint32_t, 4> operands;
// Add the function <id>. // Add the function <id>.
auto funcID = getFunctionID(op.fn()); auto funcID = getFunctionID(op.getFn());
if (!funcID) { if (!funcID) {
return op.emitError("missing <id> for function ") return op.emitError("missing <id> for function ")
<< op.fn() << op.getFn()
<< "; function needs to be serialized before ExecutionModeOp is " << "; function needs to be serialized before ExecutionModeOp is "
"serialized"; "serialized";
} }
operands.push_back(funcID); operands.push_back(funcID);
// Add the ExecutionMode. // Add the ExecutionMode.
operands.push_back(static_cast<uint32_t>(op.execution_mode())); operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
// Serialize values if any. // Serialize values if any.
auto values = op.values(); auto values = op.getValues();
if (values) { if (values) {
for (auto &intVal : values.getValue()) { for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>( operands.push_back(static_cast<uint32_t>(
@ -598,7 +599,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
template <> template <>
LogicalResult LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) { Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
auto funcName = op.callee(); auto funcName = op.getCallee();
uint32_t resTypeID = 0; uint32_t resTypeID = 0;
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType(); Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
@ -609,7 +610,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
auto funcCallID = getNextID(); auto funcCallID = getNextID();
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID}; SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
for (auto value : op.arguments()) { for (auto value : op.getArguments()) {
auto valueID = getValueID(value); auto valueID = getValueID(value);
assert(valueID && "cannot find a value for spv.FunctionCall"); assert(valueID && "cannot find a value for spv.FunctionCall");
operands.push_back(valueID); operands.push_back(valueID);

View File

@ -119,7 +119,8 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
binary.clear(); binary.clear();
binary.reserve(moduleSize); binary.reserve(moduleSize);
spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
nextID);
binary.append(capabilities.begin(), capabilities.end()); binary.append(capabilities.begin(), capabilities.end());
binary.append(extensions.begin(), extensions.end()); binary.append(extensions.begin(), extensions.end());
binary.append(extendedSets.begin(), extendedSets.end()); binary.append(extendedSets.begin(), extendedSets.end());
@ -166,7 +167,7 @@ uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
} }
void Serializer::processCapability() { void Serializer::processCapability() {
for (auto cap : module.vce_triple()->getCapabilities()) for (auto cap : module.getVceTriple()->getCapabilities())
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
{static_cast<uint32_t>(cap)}); {static_cast<uint32_t>(cap)});
} }
@ -186,7 +187,7 @@ void Serializer::processDebugInfo() {
void Serializer::processExtension() { void Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName; llvm::SmallVector<uint32_t, 16> extName;
for (spirv::Extension ext : module.vce_triple()->getExtensions()) { for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
extName.clear(); extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
@ -1045,11 +1046,11 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
} else if (auto branchCondOp = } else if (auto branchCondOp =
dyn_cast<spirv::BranchConditionalOp>(terminator)) { dyn_cast<spirv::BranchConditionalOp>(terminator)) {
Optional<OperandRange> blockOperands; Optional<OperandRange> blockOperands;
if (branchCondOp.trueTarget() == block) { if (branchCondOp.getTrueTarget() == block) {
blockOperands = branchCondOp.trueTargetOperands(); blockOperands = branchCondOp.getTrueTargetOperands();
} else { } else {
assert(branchCondOp.falseTarget() == block); assert(branchCondOp.getFalseTarget() == block);
blockOperands = branchCondOp.falseTargetOperands(); blockOperands = branchCondOp.getFalseTargetOperands();
} }
assert(!blockOperands->empty() && assert(!blockOperands->empty() &&

View File

@ -1360,7 +1360,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
"static_cast<{0}::{1}>(1 << i);\n", "static_cast<{0}::{1}>(1 << i);\n",
enumAttr.getCppNamespace(), enumAttr.getEnumClassName(), enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
namedAttr.name); srcOp.getGetterName(namedAttr.name));
os << formatv( os << formatv(
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
enumAttr.getUnderlyingType()); enumAttr.getUnderlyingType());
@ -1368,7 +1368,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
// For IntEnumAttr, we just need to query the value as a whole. // For IntEnumAttr, we just need to query the value as a whole.
os << " {\n"; os << " {\n";
os << formatv(" auto tblgen_attrVal = this->{0}();\n", os << formatv(" auto tblgen_attrVal = this->{0}();\n",
namedAttr.name); srcOp.getGetterName(namedAttr.name));
} }
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
enumAttr.getCppNamespace(), avail.getQueryFnName()); enumAttr.getCppNamespace(), avail.getQueryFnName());