forked from OSchip/llvm-project
[mlir][spirv] Switch to kEmitAccessorPrefix_Predixed
Fixes https://github.com/llvm/llvm-project/issues/57887 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134580
This commit is contained in:
parent
cde3de5381
commit
90a1632d0b
|
@ -72,10 +72,6 @@ def SPIRV_Dialect : Dialect {
|
||||||
void printAttribute(
|
void printAttribute(
|
||||||
Attribute attr, DialectAsmPrinter &printer) const override;
|
Attribute attr, DialectAsmPrinter &printer) const override;
|
||||||
}];
|
}];
|
||||||
|
|
||||||
// TODO(https://github.com/llvm/llvm-project/issues/57887): Switch to
|
|
||||||
// _Prefixed accessors.
|
|
||||||
let emitAccessorPrefix = kEmitAccessorPrefix_Both;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -65,7 +65,7 @@ def SPV_BranchOp : SPV_Op<"Branch", [
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
/// Returns the block arguments.
|
/// Returns the block arguments.
|
||||||
operand_range getBlockArguments() { return targetOperands(); }
|
operand_range getBlockArguments() { return getTargetOperands(); }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let autogenSerialization = 0;
|
let autogenSerialization = 0;
|
||||||
|
@ -161,22 +161,22 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
|
||||||
|
|
||||||
/// Returns the number of arguments to the true target block.
|
/// Returns the number of arguments to the true target block.
|
||||||
unsigned getNumTrueBlockArguments() {
|
unsigned getNumTrueBlockArguments() {
|
||||||
return trueTargetOperands().size();
|
return getTrueTargetOperands().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of arguments to the false target block.
|
/// Returns the number of arguments to the false target block.
|
||||||
unsigned getNumFalseBlockArguments() {
|
unsigned getNumFalseBlockArguments() {
|
||||||
return falseTargetOperands().size();
|
return getFalseTargetOperands().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterator and range support for true target block arguments.
|
// Iterator and range support for true target block arguments.
|
||||||
operand_range getTrueBlockArguments() {
|
operand_range getTrueBlockArguments() {
|
||||||
return trueTargetOperands();
|
return getTrueTargetOperands();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterator and range support for false target block arguments.
|
// Iterator and range support for false target block arguments.
|
||||||
operand_range getFalseBlockArguments() {
|
operand_range getFalseBlockArguments() {
|
||||||
return falseTargetOperands();
|
return getFalseTargetOperands();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -394,9 +394,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
||||||
CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
|
CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
|
||||||
[{
|
[{
|
||||||
$_state.addAttribute("type", type);
|
$_state.addAttribute("type", type);
|
||||||
$_state.addAttribute(sym_nameAttrName($_state.name), sym_name);
|
$_state.addAttribute(getSymNameAttrName($_state.name), sym_name);
|
||||||
if (initializer)
|
if (initializer)
|
||||||
$_state.addAttribute(initializerAttrName($_state.name), initializer);
|
$_state.addAttribute(getInitializerAttrName($_state.name), initializer);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
|
OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
|
||||||
[{
|
[{
|
||||||
|
@ -412,9 +412,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
||||||
CArg<"FlatSymbolRefAttr", "{}">:$initializer),
|
CArg<"FlatSymbolRefAttr", "{}">:$initializer),
|
||||||
[{
|
[{
|
||||||
$_state.addAttribute("type", TypeAttr::get(type));
|
$_state.addAttribute("type", TypeAttr::get(type));
|
||||||
$_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name));
|
$_state.addAttribute(getSymNameAttrName($_state.name), $_builder.getStringAttr(sym_name));
|
||||||
if (initializer)
|
if (initializer)
|
||||||
$_state.addAttribute(initializerAttrName($_state.name), initializer);
|
$_state.addAttribute(getInitializerAttrName($_state.name), initializer);
|
||||||
}]>
|
}]>
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -424,7 +424,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
::mlir::spirv::StorageClass storageClass() {
|
::mlir::spirv::StorageClass storageClass() {
|
||||||
return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
|
return this->getType().cast<::mlir::spirv::PointerType>().getStorageClass();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
@ -509,7 +509,7 @@ def SPV_ModuleOp : SPV_Op<"module",
|
||||||
|
|
||||||
bool isOptionalSymbol() { return true; }
|
bool isOptionalSymbol() { return true; }
|
||||||
|
|
||||||
Optional<StringRef> getName() { return sym_name(); }
|
Optional<StringRef> getName() { return getSymName(); }
|
||||||
|
|
||||||
static StringRef getVCETripleAttrName() { return "vce_triple"; }
|
static StringRef getVCETripleAttrName() { return "vce_triple"; }
|
||||||
}];
|
}];
|
||||||
|
|
|
@ -69,12 +69,12 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
|
||||||
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
builder.getIntegerAttr(targetType, targetBits / sourceBits);
|
||||||
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
|
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
|
||||||
auto lastDim = op->getOperand(op.getNumOperands() - 1);
|
auto lastDim = op->getOperand(op.getNumOperands() - 1);
|
||||||
auto indices = llvm::to_vector<4>(op.indices());
|
auto indices = llvm::to_vector<4>(op.getIndices());
|
||||||
// There are two elements if this is a 1-D tensor.
|
// There are two elements if this is a 1-D tensor.
|
||||||
assert(indices.size() == 2);
|
assert(indices.size() == 2);
|
||||||
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
|
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
|
||||||
Type t = typeConverter.convertType(op.component_ptr().getType());
|
Type t = typeConverter.convertType(op.getComponentPtr().getType());
|
||||||
return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
|
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the shifted `targetBits`-bit value with the given offset.
|
/// Returns the shifted `targetBits`-bit value with the given offset.
|
||||||
|
@ -371,7 +371,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
||||||
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
|
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
|
||||||
// still returns a linearized accessing. If the accessing is not linearized,
|
// still returns a linearized accessing. If the accessing is not linearized,
|
||||||
// there will be offset issues.
|
// there will be offset issues.
|
||||||
assert(accessChainOp.indices().size() == 2);
|
assert(accessChainOp.getIndices().size() == 2);
|
||||||
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
|
||||||
srcBits, dstBits, rewriter);
|
srcBits, dstBits, rewriter);
|
||||||
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
|
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
|
||||||
|
@ -507,7 +507,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
|
||||||
// 6) store 32-bit value back
|
// 6) store 32-bit value back
|
||||||
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
|
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
|
||||||
// 4 to step 6 are done by AtomicOr as another atomic step.
|
// 4 to step 6 are done by AtomicOr as another atomic step.
|
||||||
assert(accessChainOp.indices().size() == 2);
|
assert(accessChainOp.getIndices().size() == 2);
|
||||||
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
|
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
|
||||||
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
|
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
|
||||||
|
|
||||||
|
|
|
@ -174,7 +174,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
|
||||||
// Create the block for the header.
|
// Create the block for the header.
|
||||||
auto *header = new Block();
|
auto *header = new Block();
|
||||||
// Insert the header.
|
// Insert the header.
|
||||||
loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
|
loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header);
|
||||||
|
|
||||||
// Create the new induction variable to use.
|
// Create the new induction variable to use.
|
||||||
Value adapLowerBound = adaptor.getLowerBound();
|
Value adapLowerBound = adaptor.getLowerBound();
|
||||||
|
@ -197,13 +197,13 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
|
||||||
|
|
||||||
// Move the blocks from the forOp into the loopOp. This is the body of the
|
// Move the blocks from the forOp into the loopOp. This is the body of the
|
||||||
// loopOp.
|
// loopOp.
|
||||||
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
|
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
|
||||||
getBlockIt(loopOp.body(), 2));
|
getBlockIt(loopOp.getBody(), 2));
|
||||||
|
|
||||||
SmallVector<Value, 8> args(1, adaptor.getLowerBound());
|
SmallVector<Value, 8> args(1, adaptor.getLowerBound());
|
||||||
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
|
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
|
||||||
// Branch into it from the entry.
|
// Branch into it from the entry.
|
||||||
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
|
rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
|
||||||
rewriter.create<spirv::BranchOp>(loc, header, args);
|
rewriter.create<spirv::BranchOp>(loc, header, args);
|
||||||
|
|
||||||
// Generate the rest of the loop header.
|
// Generate the rest of the loop header.
|
||||||
|
@ -252,12 +252,12 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
|
||||||
auto selectionOp =
|
auto selectionOp =
|
||||||
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
|
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
|
||||||
auto *mergeBlock =
|
auto *mergeBlock =
|
||||||
rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
|
rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end());
|
||||||
rewriter.create<spirv::MergeOp>(loc);
|
rewriter.create<spirv::MergeOp>(loc);
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
auto *selectionHeaderBlock =
|
auto *selectionHeaderBlock =
|
||||||
rewriter.createBlock(&selectionOp.body().front());
|
rewriter.createBlock(&selectionOp.getBody().front());
|
||||||
|
|
||||||
// Inline `then` region before the merge block and branch to it.
|
// Inline `then` region before the merge block and branch to it.
|
||||||
auto &thenRegion = ifOp.getThenRegion();
|
auto &thenRegion = ifOp.getThenRegion();
|
||||||
|
@ -367,12 +367,12 @@ WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Move the while before block as the initial loop header block.
|
// Move the while before block as the initial loop header block.
|
||||||
rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
|
rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
|
||||||
getBlockIt(loopOp.body(), 1));
|
getBlockIt(loopOp.getBody(), 1));
|
||||||
|
|
||||||
// Move the while after block as the initial loop body block.
|
// Move the while after block as the initial loop body block.
|
||||||
rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
|
rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
|
||||||
getBlockIt(loopOp.body(), 2));
|
getBlockIt(loopOp.getBody(), 2));
|
||||||
|
|
||||||
// Jump from the loop entry block to the loop header block.
|
// Jump from the loop entry block to the loop header block.
|
||||||
rewriter.setInsertionPointToEnd(&entryBlock);
|
rewriter.setInsertionPointToEnd(&entryBlock);
|
||||||
|
|
|
@ -89,7 +89,7 @@ createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
|
||||||
op->getAttrOfType<IntegerAttr>(descriptorSetName());
|
op->getAttrOfType<IntegerAttr>(descriptorSetName());
|
||||||
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
|
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
|
||||||
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
|
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
|
||||||
kernelModuleName.str(), op.sym_name().str(),
|
kernelModuleName.str(), op.getSymName().str(),
|
||||||
std::to_string(descriptorSet.getInt()),
|
std::to_string(descriptorSet.getInt()),
|
||||||
std::to_string(binding.getInt()));
|
std::to_string(binding.getInt()));
|
||||||
}
|
}
|
||||||
|
@ -126,14 +126,14 @@ static LogicalResult getKernelGlobalVariables(
|
||||||
/// Encodes the SPIR-V module's symbolic name into the name of the entry point
|
/// Encodes the SPIR-V module's symbolic name into the name of the entry point
|
||||||
/// function.
|
/// function.
|
||||||
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
|
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
|
||||||
StringRef spvModuleName = *module.sym_name();
|
StringRef spvModuleName = *module.getSymName();
|
||||||
// We already know that the module contains exactly one entry point function
|
// We already know that the module contains exactly one entry point function
|
||||||
// based on `getKernelGlobalVariables()` call. Update this function's name
|
// based on `getKernelGlobalVariables()` call. Update this function's name
|
||||||
// to:
|
// to:
|
||||||
// {spv_module_name}_{function_name}
|
// {spv_module_name}_{function_name}
|
||||||
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
|
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
|
||||||
StringRef funcName = entryPoint.fn();
|
StringRef funcName = entryPoint.getFn();
|
||||||
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
|
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
|
||||||
StringAttr newFuncName =
|
StringAttr newFuncName =
|
||||||
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
|
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
|
||||||
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
|
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
|
||||||
|
@ -236,7 +236,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
|
||||||
// LLVM dialect global variable.
|
// LLVM dialect global variable.
|
||||||
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
|
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
|
||||||
auto pointeeType =
|
auto pointeeType =
|
||||||
spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
|
spirvGlobal.getType().cast<spirv::PointerType>().getPointeeType();
|
||||||
auto dstGlobalType = typeConverter->convertType(pointeeType);
|
auto dstGlobalType = typeConverter->convertType(pointeeType);
|
||||||
if (!dstGlobalType)
|
if (!dstGlobalType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -228,14 +228,14 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
|
||||||
loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
|
loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
|
||||||
isVolatile, isNonTemporal);
|
isVolatile, isNonTemporal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
auto storeOp = cast<spirv::StoreOp>(op);
|
auto storeOp = cast<spirv::StoreOp>(op);
|
||||||
spirv::StoreOpAdaptor adaptor(operands);
|
spirv::StoreOpAdaptor adaptor(operands);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
|
||||||
adaptor.ptr(), alignment,
|
adaptor.getPtr(), alignment,
|
||||||
isVolatile, isNonTemporal);
|
isVolatile, isNonTemporal);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -305,19 +305,19 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
|
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto dstType = typeConverter.convertType(op.component_ptr().getType());
|
auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
// To use GEP we need to add a first 0 index to go through the pointer.
|
// To use GEP we need to add a first 0 index to go through the pointer.
|
||||||
auto indices = llvm::to_vector<4>(adaptor.indices());
|
auto indices = llvm::to_vector<4>(adaptor.getIndices());
|
||||||
Type indexType = op.indices().front().getType();
|
Type indexType = op.getIndices().front().getType();
|
||||||
auto llvmIndexType = typeConverter.convertType(indexType);
|
auto llvmIndexType = typeConverter.convertType(indexType);
|
||||||
if (!llvmIndexType)
|
if (!llvmIndexType)
|
||||||
return failure();
|
return failure();
|
||||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
|
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
|
||||||
indices.insert(indices.begin(), zero);
|
indices.insert(indices.begin(), zero);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
|
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.getBasePtr(),
|
||||||
indices);
|
indices);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -330,10 +330,10 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
|
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto dstType = typeConverter.convertType(op.pointer().getType());
|
auto dstType = typeConverter.convertType(op.getPointer().getType());
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
|
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -353,9 +353,9 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||||
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
|
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||||
typeConverter, rewriter);
|
typeConverter, rewriter);
|
||||||
Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
|
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||||
typeConverter, rewriter);
|
typeConverter, rewriter);
|
||||||
|
|
||||||
// Create a mask with bits set outside [Offset, Offset + Count - 1].
|
// Create a mask with bits set outside [Offset, Offset + Count - 1].
|
||||||
|
@ -372,9 +372,9 @@ public:
|
||||||
// Extract unchanged bits from the `Base` that are outside of
|
// Extract unchanged bits from the `Base` that are outside of
|
||||||
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
|
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
|
||||||
Value baseAndMask =
|
Value baseAndMask =
|
||||||
rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
|
rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
|
||||||
Value insertShiftedByOffset =
|
Value insertShiftedByOffset =
|
||||||
rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
|
rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
|
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
|
||||||
insertShiftedByOffset);
|
insertShiftedByOffset);
|
||||||
return success();
|
return success();
|
||||||
|
@ -408,14 +408,14 @@ public:
|
||||||
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
|
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
|
||||||
|
|
||||||
if (srcType.isa<VectorType>()) {
|
if (srcType.isa<VectorType>()) {
|
||||||
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
|
auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
|
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
|
||||||
constOp, dstType,
|
constOp, dstType,
|
||||||
dstElementsAttr.mapValues(
|
dstElementsAttr.mapValues(
|
||||||
signlessType, [&](const APInt &value) { return value; }));
|
signlessType, [&](const APInt &value) { return value; }));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
auto srcAttr = constOp.value().cast<IntegerAttr>();
|
auto srcAttr = constOp.getValue().cast<IntegerAttr>();
|
||||||
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
|
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
|
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
|
||||||
return success();
|
return success();
|
||||||
|
@ -441,9 +441,9 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||||
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
|
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||||
typeConverter, rewriter);
|
typeConverter, rewriter);
|
||||||
Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
|
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||||
typeConverter, rewriter);
|
typeConverter, rewriter);
|
||||||
|
|
||||||
// Create a constant that holds the size of the `Base`.
|
// Create a constant that holds the size of the `Base`.
|
||||||
|
@ -468,7 +468,7 @@ public:
|
||||||
Value amountToShiftLeft =
|
Value amountToShiftLeft =
|
||||||
rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
|
rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
|
||||||
Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
|
Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
|
||||||
loc, dstType, op.base(), amountToShiftLeft);
|
loc, dstType, op.getBase(), amountToShiftLeft);
|
||||||
|
|
||||||
// Shift the result right, filling the bits with the sign bit.
|
// Shift the result right, filling the bits with the sign bit.
|
||||||
Value amountToShiftRight =
|
Value amountToShiftRight =
|
||||||
|
@ -494,9 +494,9 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
|
||||||
Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
|
Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
|
||||||
typeConverter, rewriter);
|
typeConverter, rewriter);
|
||||||
Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
|
Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
|
||||||
typeConverter, rewriter);
|
typeConverter, rewriter);
|
||||||
|
|
||||||
// Create a mask with bits set at [0, Count - 1].
|
// Create a mask with bits set at [0, Count - 1].
|
||||||
|
@ -508,7 +508,7 @@ public:
|
||||||
|
|
||||||
// Shift `Base` by `Offset` and apply the mask on it.
|
// Shift `Base` by `Offset` and apply the mask on it.
|
||||||
Value shiftedBase =
|
Value shiftedBase =
|
||||||
rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
|
rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
|
||||||
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
|
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -538,20 +538,20 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// If branch weights exist, map them to 32-bit integer vector.
|
// If branch weights exist, map them to 32-bit integer vector.
|
||||||
ElementsAttr branchWeights = nullptr;
|
ElementsAttr branchWeights = nullptr;
|
||||||
if (auto weights = op.branch_weights()) {
|
if (auto weights = op.getBranchWeights()) {
|
||||||
VectorType weightType = VectorType::get(2, rewriter.getI32Type());
|
VectorType weightType = VectorType::get(2, rewriter.getI32Type());
|
||||||
branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
|
branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||||
op, op.condition(), op.getTrueBlockArguments(),
|
op, op.getCondition(), op.getTrueBlockArguments(),
|
||||||
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
|
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
|
||||||
op.getFalseBlock());
|
op.getFalseBlock());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
|
/// Converts `spv.getCompositeExtract` to `llvm.extractvalue` if the container type
|
||||||
/// is an aggregate type (struct or array). Otherwise, converts to
|
/// is an aggregate type (struct or array). Otherwise, converts to
|
||||||
/// `llvm.extractelement` that operates on vectors.
|
/// `llvm.extractelement` that operates on vectors.
|
||||||
class CompositeExtractPattern
|
class CompositeExtractPattern
|
||||||
|
@ -566,23 +566,23 @@ public:
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Type containerType = op.composite().getType();
|
Type containerType = op.getComposite().getType();
|
||||||
if (containerType.isa<VectorType>()) {
|
if (containerType.isa<VectorType>()) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
|
IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
|
||||||
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
|
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||||
op, dstType, adaptor.composite(), index);
|
op, dstType, adaptor.getComposite(), index);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
|
||||||
op, adaptor.composite(), LLVM::convertArrayToIndices(op.indices()));
|
op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
|
/// Converts `spv.getCompositeInsert` to `llvm.insertvalue` if the container type
|
||||||
/// is an aggregate type (struct or array). Otherwise, converts to
|
/// is an aggregate type (struct or array). Otherwise, converts to
|
||||||
/// `llvm.insertelement` that operates on vectors.
|
/// `llvm.insertelement` that operates on vectors.
|
||||||
class CompositeInsertPattern
|
class CompositeInsertPattern
|
||||||
|
@ -597,19 +597,19 @@ public:
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Type containerType = op.composite().getType();
|
Type containerType = op.getComposite().getType();
|
||||||
if (containerType.isa<VectorType>()) {
|
if (containerType.isa<VectorType>()) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
|
IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
|
||||||
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
|
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
|
||||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||||
op, dstType, adaptor.composite(), adaptor.object(), index);
|
op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
|
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
|
||||||
op, adaptor.composite(), adaptor.object(),
|
op, adaptor.getComposite(), adaptor.getObject(),
|
||||||
LLVM::convertArrayToIndices(op.indices()));
|
LLVM::convertArrayToIndices(op.getIndices()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -647,14 +647,14 @@ public:
|
||||||
// this entry point's execution mode. We set it to be:
|
// this entry point's execution mode. We set it to be:
|
||||||
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
|
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
|
||||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||||
spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr();
|
spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
|
||||||
std::string moduleName;
|
std::string moduleName;
|
||||||
if (module.getName().has_value())
|
if (module.getName().has_value())
|
||||||
moduleName = "_" + module.getName().value().str();
|
moduleName = "_" + module.getName()->str();
|
||||||
else
|
else
|
||||||
moduleName = "";
|
moduleName = "";
|
||||||
std::string executionModeInfoName = llvm::formatv(
|
std::string executionModeInfoName = llvm::formatv(
|
||||||
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(),
|
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
|
||||||
static_cast<uint32_t>(executionModeAttr.getValue()));
|
static_cast<uint32_t>(executionModeAttr.getValue()));
|
||||||
|
|
||||||
MLIRContext *context = rewriter.getContext();
|
MLIRContext *context = rewriter.getContext();
|
||||||
|
@ -669,7 +669,7 @@ public:
|
||||||
auto llvmI32Type = IntegerType::get(context, 32);
|
auto llvmI32Type = IntegerType::get(context, 32);
|
||||||
SmallVector<Type, 2> fields;
|
SmallVector<Type, 2> fields;
|
||||||
fields.push_back(llvmI32Type);
|
fields.push_back(llvmI32Type);
|
||||||
ArrayAttr values = op.values();
|
ArrayAttr values = op.getValues();
|
||||||
if (!values.empty()) {
|
if (!values.empty()) {
|
||||||
auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
|
auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
|
||||||
fields.push_back(arrayType);
|
fields.push_back(arrayType);
|
||||||
|
@ -722,10 +722,10 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// Currently, there is no support of initialization with a constant value in
|
// Currently, there is no support of initialization with a constant value in
|
||||||
// SPIR-V dialect. Specialization constants are not considered as well.
|
// SPIR-V dialect. Specialization constants are not considered as well.
|
||||||
if (op.initializer())
|
if (op.getInitializer())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto srcType = op.type().cast<spirv::PointerType>();
|
auto srcType = op.getType().cast<spirv::PointerType>();
|
||||||
auto dstType = typeConverter.convertType(srcType.getPointeeType());
|
auto dstType = typeConverter.convertType(srcType.getPointeeType());
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -759,12 +759,12 @@ public:
|
||||||
? LLVM::Linkage::Private
|
? LLVM::Linkage::Private
|
||||||
: LLVM::Linkage::External;
|
: LLVM::Linkage::External;
|
||||||
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
||||||
op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
|
op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
|
||||||
/*alignment=*/0);
|
/*alignment=*/0);
|
||||||
|
|
||||||
// Attach location attribute if applicable
|
// Attach location attribute if applicable
|
||||||
if (op.locationAttr())
|
if (op.getLocationAttr())
|
||||||
newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
|
newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -781,7 +781,7 @@ public:
|
||||||
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
|
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Type fromType = operation.operand().getType();
|
Type fromType = operation.getOperand().getType();
|
||||||
Type toType = operation.getType();
|
Type toType = operation.getType();
|
||||||
|
|
||||||
auto dstType = this->typeConverter.convertType(toType);
|
auto dstType = this->typeConverter.convertType(toType);
|
||||||
|
@ -839,8 +839,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
|
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
|
||||||
operation, dstType, predicate, operation.operand1(),
|
operation, dstType, predicate, operation.getOperand1(),
|
||||||
operation.operand2());
|
operation.getOperand2());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -860,8 +860,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
|
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
|
||||||
operation, dstType, predicate, operation.operand1(),
|
operation, dstType, predicate, operation.getOperand1(),
|
||||||
operation.operand2());
|
operation.getOperand2());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -881,7 +881,7 @@ public:
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
|
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
|
||||||
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
|
Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
|
||||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
|
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -896,20 +896,20 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (!op.memory_access()) {
|
if (!op.getMemoryAccess()) {
|
||||||
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
||||||
this->typeConverter, /*alignment=*/0,
|
this->typeConverter, /*alignment=*/0,
|
||||||
/*isVolatile=*/false,
|
/*isVolatile=*/false,
|
||||||
/*isNonTemporal=*/false);
|
/*isNonTemporal=*/false);
|
||||||
}
|
}
|
||||||
auto memoryAccess = *op.memory_access();
|
auto memoryAccess = *op.getMemoryAccess();
|
||||||
switch (memoryAccess) {
|
switch (memoryAccess) {
|
||||||
case spirv::MemoryAccess::Aligned:
|
case spirv::MemoryAccess::Aligned:
|
||||||
case spirv::MemoryAccess::None:
|
case spirv::MemoryAccess::None:
|
||||||
case spirv::MemoryAccess::Nontemporal:
|
case spirv::MemoryAccess::Nontemporal:
|
||||||
case spirv::MemoryAccess::Volatile: {
|
case spirv::MemoryAccess::Volatile: {
|
||||||
unsigned alignment =
|
unsigned alignment =
|
||||||
memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
|
memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
|
||||||
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
|
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
|
||||||
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
|
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
|
||||||
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
|
||||||
|
@ -946,7 +946,7 @@ public:
|
||||||
srcType.template cast<VectorType>(), minusOne))
|
srcType.template cast<VectorType>(), minusOne))
|
||||||
: rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
|
: rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
|
||||||
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
|
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
|
||||||
notOp.operand(), mask);
|
notOp.getOperand(), mask);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1047,7 +1047,7 @@ public:
|
||||||
matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
|
matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// There is no support of loop control at the moment.
|
// There is no support of loop control at the moment.
|
||||||
if (loopOp.loop_control() != spirv::LoopControl::None)
|
if (loopOp.getLoopControl() != spirv::LoopControl::None)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = loopOp.getLoc();
|
Location loc = loopOp.getLoc();
|
||||||
|
@ -1077,7 +1077,7 @@ public:
|
||||||
rewriter.setInsertionPointToEnd(mergeBlock);
|
rewriter.setInsertionPointToEnd(mergeBlock);
|
||||||
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
|
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
|
||||||
|
|
||||||
rewriter.inlineRegionBefore(loopOp.body(), endBlock);
|
rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
|
||||||
rewriter.replaceOp(loopOp, endBlock->getArguments());
|
rewriter.replaceOp(loopOp, endBlock->getArguments());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1096,14 +1096,14 @@ public:
|
||||||
// There is no support for `Flatten` or `DontFlatten` selection control at
|
// There is no support for `Flatten` or `DontFlatten` selection control at
|
||||||
// the moment. This are just compiler hints and can be performed during the
|
// the moment. This are just compiler hints and can be performed during the
|
||||||
// optimization passes.
|
// optimization passes.
|
||||||
if (op.selection_control() != spirv::SelectionControl::None)
|
if (op.getSelectionControl() != spirv::SelectionControl::None)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// `spv.mlir.selection` should have at least two blocks: one selection
|
// `spv.mlir.selection` should have at least two blocks: one selection
|
||||||
// header block and one merge block. If no blocks are present, or control
|
// header block and one merge block. If no blocks are present, or control
|
||||||
// flow branches straight to merge block (two blocks are present), the op is
|
// flow branches straight to merge block (two blocks are present), the op is
|
||||||
// redundant and it is erased.
|
// redundant and it is erased.
|
||||||
if (op.body().getBlocks().size() <= 2) {
|
if (op.getBody().getBlocks().size() <= 2) {
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1140,11 +1140,11 @@ public:
|
||||||
Block *trueBlock = condBrOp.getTrueBlock();
|
Block *trueBlock = condBrOp.getTrueBlock();
|
||||||
Block *falseBlock = condBrOp.getFalseBlock();
|
Block *falseBlock = condBrOp.getFalseBlock();
|
||||||
rewriter.setInsertionPointToEnd(currentBlock);
|
rewriter.setInsertionPointToEnd(currentBlock);
|
||||||
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
|
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
|
||||||
condBrOp.trueTargetOperands(), falseBlock,
|
condBrOp.getTrueTargetOperands(), falseBlock,
|
||||||
condBrOp.falseTargetOperands());
|
condBrOp.getFalseTargetOperands());
|
||||||
|
|
||||||
rewriter.inlineRegionBefore(op.body(), continueBlock);
|
rewriter.inlineRegionBefore(op.getBody(), continueBlock);
|
||||||
rewriter.replaceOp(op, continueBlock->getArguments());
|
rewriter.replaceOp(op, continueBlock->getArguments());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1167,8 +1167,8 @@ public:
|
||||||
if (!dstType)
|
if (!dstType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Type op1Type = operation.operand1().getType();
|
Type op1Type = operation.getOperand1().getType();
|
||||||
Type op2Type = operation.operand2().getType();
|
Type op2Type = operation.getOperand2().getType();
|
||||||
|
|
||||||
if (op1Type == op2Type) {
|
if (op1Type == op2Type) {
|
||||||
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
|
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
|
||||||
|
@ -1180,13 +1180,13 @@ public:
|
||||||
Value extended;
|
Value extended;
|
||||||
if (isUnsignedIntegerOrVector(op2Type)) {
|
if (isUnsignedIntegerOrVector(op2Type)) {
|
||||||
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
|
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
|
||||||
adaptor.operand2());
|
adaptor.getOperand2());
|
||||||
} else {
|
} else {
|
||||||
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
|
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
|
||||||
adaptor.operand2());
|
adaptor.getOperand2());
|
||||||
}
|
}
|
||||||
Value result = rewriter.template create<LLVMOp>(
|
Value result = rewriter.template create<LLVMOp>(
|
||||||
loc, dstType, adaptor.operand1(), extended);
|
loc, dstType, adaptor.getOperand1(), extended);
|
||||||
rewriter.replaceOp(operation, result);
|
rewriter.replaceOp(operation, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1204,8 +1204,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = tanOp.getLoc();
|
Location loc = tanOp.getLoc();
|
||||||
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
|
Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
|
||||||
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
|
Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
|
||||||
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
|
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1232,7 +1232,7 @@ public:
|
||||||
Location loc = tanhOp.getLoc();
|
Location loc = tanhOp.getLoc();
|
||||||
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
|
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
|
||||||
Value multiplied =
|
Value multiplied =
|
||||||
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
|
rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
|
||||||
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
|
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
|
||||||
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
|
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
|
||||||
Value numerator =
|
Value numerator =
|
||||||
|
@ -1255,7 +1255,7 @@ public:
|
||||||
auto srcType = varOp.getType();
|
auto srcType = varOp.getType();
|
||||||
// Initialization is supported for scalars and vectors only.
|
// Initialization is supported for scalars and vectors only.
|
||||||
auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
|
auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
|
||||||
auto init = varOp.initializer();
|
auto init = varOp.getInitializer();
|
||||||
if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
|
if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1270,7 +1270,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
|
Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
|
||||||
rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
|
rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
|
||||||
rewriter.replaceOp(varOp, allocated);
|
rewriter.replaceOp(varOp, allocated);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1305,7 +1305,7 @@ public:
|
||||||
|
|
||||||
// Convert SPIR-V Function Control to equivalent LLVM function attribute
|
// Convert SPIR-V Function Control to equivalent LLVM function attribute
|
||||||
MLIRContext *context = funcOp.getContext();
|
MLIRContext *context = funcOp.getContext();
|
||||||
switch (funcOp.function_control()) {
|
switch (funcOp.getFunctionControl()) {
|
||||||
#define DISPATCH(functionControl, llvmAttr) \
|
#define DISPATCH(functionControl, llvmAttr) \
|
||||||
case functionControl: \
|
case functionControl: \
|
||||||
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
|
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
|
||||||
|
@ -1374,9 +1374,9 @@ public:
|
||||||
matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
|
matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto components = adaptor.components();
|
auto components = adaptor.getComponents();
|
||||||
auto vector1 = adaptor.vector1();
|
auto vector1 = adaptor.getVector1();
|
||||||
auto vector2 = adaptor.vector2();
|
auto vector2 = adaptor.getVector2();
|
||||||
int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
|
int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
|
||||||
int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
|
int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
|
||||||
if (vector1Size == vector2Size) {
|
if (vector1Size == vector2Size) {
|
||||||
|
@ -1589,8 +1589,8 @@ void mlir::encodeBindAttribute(ModuleOp module) {
|
||||||
// SPIR-V module has a name, add it at the beginning.
|
// SPIR-V module has a name, add it at the beginning.
|
||||||
auto moduleAndName =
|
auto moduleAndName =
|
||||||
spvModule.getName().has_value()
|
spvModule.getName().has_value()
|
||||||
? spvModule.getName().value().str() + "_" + op.sym_name().str()
|
? spvModule.getName()->str() + "_" + op.getSymName().str()
|
||||||
: op.sym_name().str();
|
: op.getSymName().str();
|
||||||
std::string name =
|
std::string name =
|
||||||
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
|
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
|
||||||
std::to_string(descriptorSet.getInt()),
|
std::to_string(descriptorSet.getInt()),
|
||||||
|
|
|
@ -88,19 +88,19 @@ struct CombineChainedAccessChain
|
||||||
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
||||||
accessChainOp.base_ptr().getDefiningOp());
|
accessChainOp.getBasePtr().getDefiningOp());
|
||||||
|
|
||||||
if (!parentAccessChainOp) {
|
if (!parentAccessChainOp) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combine indices.
|
// Combine indices.
|
||||||
SmallVector<Value, 4> indices(parentAccessChainOp.indices());
|
SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
|
||||||
indices.append(accessChainOp.indices().begin(),
|
indices.append(accessChainOp.getIndices().begin(),
|
||||||
accessChainOp.indices().end());
|
accessChainOp.getIndices().end());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||||
accessChainOp, parentAccessChainOp.base_ptr(), indices);
|
accessChainOp, parentAccessChainOp.getBasePtr(), indices);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -126,23 +126,24 @@ void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (auto insertOp = composite().getDefiningOp<spirv::CompositeInsertOp>()) {
|
if (auto insertOp =
|
||||||
if (indices() == insertOp.indices())
|
getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
|
||||||
return insertOp.object();
|
if (getIndices() == insertOp.getIndices())
|
||||||
|
return insertOp.getObject();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto constructOp =
|
if (auto constructOp =
|
||||||
composite().getDefiningOp<spirv::CompositeConstructOp>()) {
|
getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
|
||||||
auto type = constructOp.getType().cast<spirv::CompositeType>();
|
auto type = constructOp.getType().cast<spirv::CompositeType>();
|
||||||
if (indices().size() == 1 &&
|
if (getIndices().size() == 1 &&
|
||||||
constructOp.constituents().size() == type.getNumElements()) {
|
constructOp.getConstituents().size() == type.getNumElements()) {
|
||||||
auto i = indices().begin()->cast<IntegerAttr>();
|
auto i = getIndices().begin()->cast<IntegerAttr>();
|
||||||
return constructOp.constituents()[i.getValue().getSExtValue()];
|
return constructOp.getConstituents()[i.getValue().getSExtValue()];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto indexVector =
|
auto indexVector =
|
||||||
llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
|
llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
|
||||||
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
|
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
|
||||||
}));
|
}));
|
||||||
return extractCompositeElement(operands[0], indexVector);
|
return extractCompositeElement(operands[0], indexVector);
|
||||||
|
@ -154,7 +155,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
assert(operands.empty() && "spv.Constant has no operands");
|
assert(operands.empty() && "spv.Constant has no operands");
|
||||||
return value();
|
return getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -164,8 +165,8 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
||||||
assert(operands.size() == 2 && "spv.IAdd expects two operands");
|
assert(operands.size() == 2 && "spv.IAdd expects two operands");
|
||||||
// x + 0 = x
|
// x + 0 = x
|
||||||
if (matchPattern(operand2(), m_Zero()))
|
if (matchPattern(getOperand2(), m_Zero()))
|
||||||
return operand1();
|
return getOperand1();
|
||||||
|
|
||||||
// According to the SPIR-V spec:
|
// According to the SPIR-V spec:
|
||||||
//
|
//
|
||||||
|
@ -183,11 +184,11 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
||||||
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
||||||
assert(operands.size() == 2 && "spv.IMul expects two operands");
|
assert(operands.size() == 2 && "spv.IMul expects two operands");
|
||||||
// x * 0 == 0
|
// x * 0 == 0
|
||||||
if (matchPattern(operand2(), m_Zero()))
|
if (matchPattern(getOperand2(), m_Zero()))
|
||||||
return operand2();
|
return getOperand2();
|
||||||
// x * 1 = x
|
// x * 1 = x
|
||||||
if (matchPattern(operand2(), m_One()))
|
if (matchPattern(getOperand2(), m_One()))
|
||||||
return operand1();
|
return getOperand1();
|
||||||
|
|
||||||
// According to the SPIR-V spec:
|
// According to the SPIR-V spec:
|
||||||
//
|
//
|
||||||
|
@ -204,7 +205,7 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// x - x = 0
|
// x - x = 0
|
||||||
if (operand1() == operand2())
|
if (getOperand1() == getOperand2())
|
||||||
return Builder(getContext()).getIntegerAttr(getType(), 0);
|
return Builder(getContext()).getIntegerAttr(getType(), 0);
|
||||||
|
|
||||||
// According to the SPIR-V spec:
|
// According to the SPIR-V spec:
|
||||||
|
@ -226,7 +227,7 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
|
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
|
||||||
// x && true = x
|
// x && true = x
|
||||||
if (rhs.value())
|
if (rhs.value())
|
||||||
return operand1();
|
return getOperand1();
|
||||||
|
|
||||||
// x && false = false
|
// x && false = false
|
||||||
if (!rhs.value())
|
if (!rhs.value())
|
||||||
|
@ -262,7 +263,7 @@ OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
|
||||||
// x || false = x
|
// x || false = x
|
||||||
if (!rhs.value())
|
if (!rhs.value())
|
||||||
return operand1();
|
return getOperand1();
|
||||||
}
|
}
|
||||||
|
|
||||||
return Attribute();
|
return Attribute();
|
||||||
|
@ -339,8 +340,8 @@ struct ConvertSelectionOpToSelect
|
||||||
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
|
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
|
||||||
|
|
||||||
auto selectOp = rewriter.create<spirv::SelectOp>(
|
auto selectOp = rewriter.create<spirv::SelectOp>(
|
||||||
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
|
selectionOp.getLoc(), trueValue.getType(),
|
||||||
trueValue, falseValue);
|
brConditionalOp.getCondition(), trueValue, falseValue);
|
||||||
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
|
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
|
||||||
selectOp.getResult(), storeOpAttributes);
|
selectOp.getResult(), storeOpAttributes);
|
||||||
|
|
||||||
|
@ -371,13 +372,13 @@ private:
|
||||||
// Returns a source value for the given block.
|
// Returns a source value for the given block.
|
||||||
Value getSrcValue(Block *block) const {
|
Value getSrcValue(Block *block) const {
|
||||||
auto storeOp = cast<spirv::StoreOp>(block->front());
|
auto storeOp = cast<spirv::StoreOp>(block->front());
|
||||||
return storeOp.value();
|
return storeOp.getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a destination value for the given block.
|
// Returns a destination value for the given block.
|
||||||
Value getDstPtr(Block *block) const {
|
Value getDstPtr(Block *block) const {
|
||||||
auto storeOp = cast<spirv::StoreOp>(block->front());
|
auto storeOp = cast<spirv::StoreOp>(block->front());
|
||||||
return storeOp.ptr();
|
return storeOp.getPtr();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -406,14 +407,14 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
||||||
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
|
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
|
||||||
// Starting with version 1.4, Result Type can additionally be a composite type
|
// Starting with version 1.4, Result Type can additionally be a composite type
|
||||||
// other than a vector."
|
// other than a vector."
|
||||||
bool isScalarOrVector = trueBrStoreOp.value()
|
bool isScalarOrVector = trueBrStoreOp.getValue()
|
||||||
.getType()
|
.getType()
|
||||||
.cast<spirv::SPIRVType>()
|
.cast<spirv::SPIRVType>()
|
||||||
.isScalarOrVector();
|
.isScalarOrVector();
|
||||||
|
|
||||||
// Check that each `spv.Store` uses the same pointer, memory access
|
// Check that each `spv.Store` uses the same pointer, memory access
|
||||||
// attributes and a valid type of the value.
|
// attributes and a valid type of the value.
|
||||||
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
|
if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
|
||||||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
|
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,7 +106,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
||||||
// Replace the values directly with the return operands.
|
// Replace the values directly with the return operands.
|
||||||
assert(valuesToRepl.size() == 1 &&
|
assert(valuesToRepl.size() == 1 &&
|
||||||
"spv.ReturnValue expected to only handle one result");
|
"spv.ReturnValue expected to only handle one result");
|
||||||
valuesToRepl.front().replaceAllUsesWith(retValOp.value());
|
valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -94,16 +94,16 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
spirv::ModuleOp firstModule = inputModules.front();
|
spirv::ModuleOp firstModule = inputModules.front();
|
||||||
auto addressingModel = firstModule.addressing_model();
|
auto addressingModel = firstModule.getAddressingModel();
|
||||||
auto memoryModel = firstModule.memory_model();
|
auto memoryModel = firstModule.getMemoryModel();
|
||||||
auto vceTriple = firstModule.vce_triple();
|
auto vceTriple = firstModule.getVceTriple();
|
||||||
|
|
||||||
// First check whether there are conflicts between addressing/memory model.
|
// First check whether there are conflicts between addressing/memory model.
|
||||||
// Return early if so.
|
// Return early if so.
|
||||||
for (auto module : inputModules) {
|
for (auto module : inputModules) {
|
||||||
if (module.addressing_model() != addressingModel ||
|
if (module.getAddressingModel() != addressingModel ||
|
||||||
module.memory_model() != memoryModel ||
|
module.getMemoryModel() != memoryModel ||
|
||||||
module.vce_triple() != vceTriple) {
|
module.getVceTriple() != vceTriple) {
|
||||||
module.emitError("input modules differ in addressing model, memory "
|
module.emitError("input modules differ in addressing model, memory "
|
||||||
"model, and/or VCE triple");
|
"model, and/or VCE triple");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -40,7 +40,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<NamedAttribute, 4> globalVarAttrs;
|
SmallVector<NamedAttribute, 4> globalVarAttrs;
|
||||||
|
|
||||||
auto ptrType = op.type().cast<spirv::PointerType>();
|
auto ptrType = op.getType().cast<spirv::PointerType>();
|
||||||
auto structType = VulkanLayoutUtils::decorateType(
|
auto structType = VulkanLayoutUtils::decorateType(
|
||||||
ptrType.getPointeeType().cast<spirv::StructType>());
|
ptrType.getPointeeType().cast<spirv::StructType>());
|
||||||
|
|
||||||
|
@ -71,11 +71,11 @@ public:
|
||||||
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
|
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
|
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
|
||||||
auto varName = op.variableAttr();
|
auto varName = op.getVariableAttr();
|
||||||
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
|
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
|
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
|
||||||
op, varOp.type(), SymbolRefAttr::get(varName.getAttr()));
|
op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -121,12 +121,12 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
|
||||||
target.addLegalOp<func::FuncOp>();
|
target.addLegalOp<func::FuncOp>();
|
||||||
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
|
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
|
||||||
[](spirv::GlobalVariableOp op) {
|
[](spirv::GlobalVariableOp op) {
|
||||||
return VulkanLayoutUtils::isLegalType(op.type());
|
return VulkanLayoutUtils::isLegalType(op.getType());
|
||||||
});
|
});
|
||||||
|
|
||||||
// Change the type for the direct users.
|
// Change the type for the direct users.
|
||||||
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
|
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
|
||||||
return VulkanLayoutUtils::isLegalType(op.pointer().getType());
|
return VulkanLayoutUtils::isLegalType(op.getPointer().getType());
|
||||||
});
|
});
|
||||||
|
|
||||||
// Change the type for the indirect users.
|
// Change the type for the indirect users.
|
||||||
|
@ -134,7 +134,8 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
|
||||||
spirv::StoreOp>([&](Operation *op) {
|
spirv::StoreOp>([&](Operation *op) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
|
auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
|
||||||
if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType()))
|
if (addrOp &&
|
||||||
|
!VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType()))
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -88,13 +88,13 @@ getInterfaceVariables(spirv::FuncOp funcOp,
|
||||||
// instructions in this function.
|
// instructions in this function.
|
||||||
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
|
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
|
||||||
auto var =
|
auto var =
|
||||||
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
|
module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
|
||||||
// TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
|
// TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
|
||||||
// storage classes are limited to the Input and Output storage classes.
|
// storage classes are limited to the Input and Output storage classes.
|
||||||
// Starting with version 1.4, the interface’s storage classes are all
|
// Starting with version 1.4, the interface’s storage classes are all
|
||||||
// storage classes used in declaring all global variables referenced by the
|
// storage classes used in declaring all global variables referenced by the
|
||||||
// entry point’s call tree." We should consider the target environment here.
|
// entry point’s call tree." We should consider the target environment here.
|
||||||
switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
|
switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
|
||||||
case spirv::StorageClass::Input:
|
case spirv::StorageClass::Input:
|
||||||
case spirv::StorageClass::Output:
|
case spirv::StorageClass::Output:
|
||||||
interfaceVarSet.insert(var.getOperation());
|
interfaceVarSet.insert(var.getOperation());
|
||||||
|
@ -105,7 +105,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
|
||||||
});
|
});
|
||||||
for (auto &var : interfaceVarSet) {
|
for (auto &var : interfaceVarSet) {
|
||||||
interfaceVars.push_back(SymbolRefAttr::get(
|
interfaceVars.push_back(SymbolRefAttr::get(
|
||||||
funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
|
funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -223,7 +223,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
|
||||||
auto zero =
|
auto zero =
|
||||||
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
|
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
|
||||||
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
|
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
|
||||||
funcOp.getLoc(), replacement, zero.constant());
|
funcOp.getLoc(), replacement, zero.getConstant());
|
||||||
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
|
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
|
||||||
}
|
}
|
||||||
signatureConverter.remapInput(argType.index(), replacement);
|
signatureConverter.remapInput(argType.index(), replacement);
|
||||||
|
|
|
@ -63,7 +63,7 @@ void RewriteInsertsPass::runOnOperation() {
|
||||||
SmallVector<Value, 4> operands;
|
SmallVector<Value, 4> operands;
|
||||||
// Collect inserted objects.
|
// Collect inserted objects.
|
||||||
for (auto insertionOp : insertions)
|
for (auto insertionOp : insertions)
|
||||||
operands.push_back(insertionOp.object());
|
operands.push_back(insertionOp.getObject());
|
||||||
|
|
||||||
OpBuilder builder(lastCompositeInsertOp);
|
OpBuilder builder(lastCompositeInsertOp);
|
||||||
auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
|
auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
|
||||||
|
@ -84,11 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
|
||||||
LogicalResult RewriteInsertsPass::collectInsertionChain(
|
LogicalResult RewriteInsertsPass::collectInsertionChain(
|
||||||
spirv::CompositeInsertOp op,
|
spirv::CompositeInsertOp op,
|
||||||
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
|
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
|
||||||
auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
|
auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
|
||||||
// TODO: handle nested composite object.
|
// TODO: handle nested composite object.
|
||||||
if (indicesArrayAttr.size() == 1) {
|
if (indicesArrayAttr.size() == 1) {
|
||||||
auto numElements =
|
auto numElements = op.getComposite()
|
||||||
op.composite().getType().cast<spirv::CompositeType>().getNumElements();
|
.getType()
|
||||||
|
.cast<spirv::CompositeType>()
|
||||||
|
.getNumElements();
|
||||||
|
|
||||||
auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
|
auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
|
||||||
// Need a last index to collect a sequential chain.
|
// Need a last index to collect a sequential chain.
|
||||||
|
@ -102,12 +104,12 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
|
||||||
if (index == 0)
|
if (index == 0)
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
|
op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
|
||||||
if (!op)
|
if (!op)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
--index;
|
--index;
|
||||||
indicesArrayAttr = op.indices().cast<ArrayAttr>();
|
indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
|
||||||
if ((indicesArrayAttr.size() != 1) ||
|
if ((indicesArrayAttr.size() != 1) ||
|
||||||
(indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
|
(indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -642,7 +642,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
|
||||||
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
|
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
|
||||||
unsigned elementCount) {
|
unsigned elementCount) {
|
||||||
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
|
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
|
||||||
auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
|
auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
|
||||||
if (!ptrType)
|
if (!ptrType)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
@ -874,7 +874,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||||
// Special treatment for global variables, whose type requirements are
|
// Special treatment for global variables, whose type requirements are
|
||||||
// conveyed by type attributes.
|
// conveyed by type attributes.
|
||||||
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
||||||
valueTypes.push_back(globalVar.type());
|
valueTypes.push_back(globalVar.getType());
|
||||||
|
|
||||||
// Make sure the op's operands/results use types that are allowed by the
|
// Make sure the op's operands/results use types that are allowed by the
|
||||||
// target environment.
|
// target environment.
|
||||||
|
|
|
@ -51,8 +51,8 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
|
||||||
AliasedResourceMap aliasedResources;
|
AliasedResourceMap aliasedResources;
|
||||||
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
|
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
|
||||||
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
|
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
|
||||||
Optional<uint32_t> set = varOp.descriptor_set();
|
Optional<uint32_t> set = varOp.getDescriptorSet();
|
||||||
Optional<uint32_t> binding = varOp.binding();
|
Optional<uint32_t> binding = varOp.getBinding();
|
||||||
if (set && binding)
|
if (set && binding)
|
||||||
aliasedResources[{*set, *binding}].push_back(varOp);
|
aliasedResources[{*set, *binding}].push_back(varOp);
|
||||||
}
|
}
|
||||||
|
@ -222,16 +222,16 @@ bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
|
||||||
}
|
}
|
||||||
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
|
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
|
||||||
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
||||||
auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
|
auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
|
||||||
return shouldUnify(varOp);
|
return shouldUnify(varOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
|
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
|
||||||
return shouldUnify(acOp.base_ptr().getDefiningOp());
|
return shouldUnify(acOp.getBasePtr().getDefiningOp());
|
||||||
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
|
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
|
||||||
return shouldUnify(loadOp.ptr().getDefiningOp());
|
return shouldUnify(loadOp.getPtr().getDefiningOp());
|
||||||
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
|
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
|
||||||
return shouldUnify(storeOp.ptr().getDefiningOp());
|
return shouldUnify(storeOp.getPtr().getDefiningOp());
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -265,7 +265,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
|
||||||
// Collect the element types for all resources in the current set.
|
// Collect the element types for all resources in the current set.
|
||||||
SmallVector<spirv::SPIRVType> elementTypes;
|
SmallVector<spirv::SPIRVType> elementTypes;
|
||||||
for (spirv::GlobalVariableOp resource : resources) {
|
for (spirv::GlobalVariableOp resource : resources) {
|
||||||
Type elementType = getRuntimeArrayElementType(resource.type());
|
Type elementType = getRuntimeArrayElementType(resource.getType());
|
||||||
if (!elementType)
|
if (!elementType)
|
||||||
return; // Unexpected resource variable type.
|
return; // Unexpected resource variable type.
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
|
||||||
// Rewrite the AddressOf op to get the address of the canoncical resource.
|
// Rewrite the AddressOf op to get the address of the canoncical resource.
|
||||||
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
||||||
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
||||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
|
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
|
||||||
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
||||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
|
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
|
||||||
return success();
|
return success();
|
||||||
|
@ -339,13 +339,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
|
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
|
auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
|
||||||
if (!addressOp)
|
if (!addressOp)
|
||||||
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
|
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
|
||||||
|
|
||||||
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
|
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
|
||||||
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
||||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
|
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
|
||||||
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
||||||
|
|
||||||
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
|
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
|
||||||
|
@ -356,7 +356,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
||||||
// We have the same bitwidth for source and destination element types.
|
// We have the same bitwidth for source and destination element types.
|
||||||
// Thie indices keep the same.
|
// Thie indices keep the same.
|
||||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||||
acOp, adaptor.base_ptr(), adaptor.indices());
|
acOp, adaptor.getBasePtr(), adaptor.getIndices());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +375,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
||||||
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
||||||
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
||||||
|
|
||||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
||||||
Value oldIndex = indices.back();
|
Value oldIndex = indices.back();
|
||||||
indices.back() =
|
indices.back() =
|
||||||
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
|
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
|
||||||
|
@ -383,7 +383,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
||||||
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
|
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||||
acOp, adaptor.base_ptr(), indices);
|
acOp, adaptor.getBasePtr(), indices);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -399,13 +399,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
|
||||||
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
||||||
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
||||||
|
|
||||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
||||||
Value oldIndex = indices.back();
|
Value oldIndex = indices.back();
|
||||||
indices.back() =
|
indices.back() =
|
||||||
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
|
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||||
acOp, adaptor.base_ptr(), indices);
|
acOp, adaptor.getBasePtr(), indices);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,13 +420,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
|
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto srcPtrType = loadOp.ptr().getType().cast<spirv::PointerType>();
|
auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
|
||||||
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
||||||
auto dstPtrType = adaptor.ptr().getType().cast<spirv::PointerType>();
|
auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
|
||||||
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
|
||||||
|
|
||||||
Location loc = loadOp.getLoc();
|
Location loc = loadOp.getLoc();
|
||||||
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
|
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
|
||||||
if (srcElemType == dstElemType) {
|
if (srcElemType == dstElemType) {
|
||||||
rewriter.replaceOp(loadOp, newLoadOp->getResults());
|
rewriter.replaceOp(loadOp, newLoadOp->getResults());
|
||||||
return success();
|
return success();
|
||||||
|
@ -434,7 +434,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
||||||
|
|
||||||
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
|
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
|
||||||
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
|
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
|
||||||
newLoadOp.value());
|
newLoadOp.getValue());
|
||||||
rewriter.replaceOp(loadOp, castOp->getResults());
|
rewriter.replaceOp(loadOp, castOp->getResults());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -457,19 +457,19 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
|
||||||
components.reserve(ratio);
|
components.reserve(ratio);
|
||||||
components.push_back(newLoadOp);
|
components.push_back(newLoadOp);
|
||||||
|
|
||||||
auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
|
auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
|
||||||
if (!acOp)
|
if (!acOp)
|
||||||
return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
|
return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
|
||||||
|
|
||||||
auto i32Type = rewriter.getI32Type();
|
auto i32Type = rewriter.getI32Type();
|
||||||
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
|
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
|
||||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
auto indices = llvm::to_vector<4>(acOp.getIndices());
|
||||||
for (int i = 1; i < ratio; ++i) {
|
for (int i = 1; i < ratio; ++i) {
|
||||||
// Load all subsequent components belonging to this element.
|
// Load all subsequent components belonging to this element.
|
||||||
indices.back() = rewriter.create<spirv::IAddOp>(
|
indices.back() = rewriter.create<spirv::IAddOp>(
|
||||||
loc, i32Type, indices.back(), oneValue);
|
loc, i32Type, indices.back(), oneValue);
|
||||||
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
|
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
|
||||||
loc, acOp.base_ptr(), indices);
|
loc, acOp.getBasePtr(), indices);
|
||||||
// Assuming little endian, this reads lower-ordered bits of the number
|
// Assuming little endian, this reads lower-ordered bits of the number
|
||||||
// to lower-numbered components of the vector.
|
// to lower-numbered components of the vector.
|
||||||
components.push_back(
|
components.push_back(
|
||||||
|
@ -504,19 +504,19 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
|
||||||
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
|
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto srcElemType =
|
auto srcElemType =
|
||||||
storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||||
auto dstElemType =
|
auto dstElemType =
|
||||||
adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||||
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
|
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
|
||||||
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
|
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
|
||||||
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
|
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
|
||||||
return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
|
return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
|
||||||
|
|
||||||
Location loc = storeOp.getLoc();
|
Location loc = storeOp.getLoc();
|
||||||
Value value = adaptor.value();
|
Value value = adaptor.getValue();
|
||||||
if (srcElemType != dstElemType)
|
if (srcElemType != dstElemType)
|
||||||
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
|
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
|
||||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
|
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(), value,
|
||||||
storeOp->getAttrs());
|
storeOp->getAttrs());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,7 +151,7 @@ void UpdateVCEPass::runOnOperation() {
|
||||||
// Special treatment for global variables, whose type requirements are
|
// Special treatment for global variables, whose type requirements are
|
||||||
// conveyed by type attributes.
|
// conveyed by type attributes.
|
||||||
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
||||||
valueTypes.push_back(globalVar.type());
|
valueTypes.push_back(globalVar.getType());
|
||||||
|
|
||||||
// Requirements from values' types
|
// Requirements from values' types
|
||||||
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
|
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
|
||||||
|
|
|
@ -46,20 +46,20 @@ Value spirv::Deserializer::getValue(uint32_t id) {
|
||||||
}
|
}
|
||||||
if (auto varOp = getGlobalVariable(id)) {
|
if (auto varOp = getGlobalVariable(id)) {
|
||||||
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
|
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
|
||||||
unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
|
unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
|
||||||
return addressOfOp.pointer();
|
return addressOfOp.getPointer();
|
||||||
}
|
}
|
||||||
if (auto constOp = getSpecConstant(id)) {
|
if (auto constOp = getSpecConstant(id)) {
|
||||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||||
unknownLoc, constOp.default_value().getType(),
|
unknownLoc, constOp.getDefaultValue().getType(),
|
||||||
SymbolRefAttr::get(constOp.getOperation()));
|
SymbolRefAttr::get(constOp.getOperation()));
|
||||||
return referenceOfOp.reference();
|
return referenceOfOp.getReference();
|
||||||
}
|
}
|
||||||
if (auto constCompositeOp = getSpecConstantComposite(id)) {
|
if (auto constCompositeOp = getSpecConstantComposite(id)) {
|
||||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||||
unknownLoc, constCompositeOp.type(),
|
unknownLoc, constCompositeOp.getType(),
|
||||||
SymbolRefAttr::get(constCompositeOp.getOperation()));
|
SymbolRefAttr::get(constCompositeOp.getOperation()));
|
||||||
return referenceOfOp.reference();
|
return referenceOfOp.getReference();
|
||||||
}
|
}
|
||||||
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
|
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
|
||||||
return materializeSpecConstantOperation(
|
return materializeSpecConstantOperation(
|
||||||
|
|
|
@ -1414,7 +1414,7 @@ Value spirv::Deserializer::materializeSpecConstantOperation(
|
||||||
auto specConstOperationOp =
|
auto specConstOperationOp =
|
||||||
opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
|
opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
|
||||||
|
|
||||||
Region &body = specConstOperationOp.body();
|
Region &body = specConstOperationOp.getBody();
|
||||||
// Move the new block into SpecConstantOperation's body.
|
// Move the new block into SpecConstantOperation's body.
|
||||||
body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
|
body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
|
||||||
Region::iterator(enclosedBlock));
|
Region::iterator(enclosedBlock));
|
||||||
|
@ -1983,17 +1983,17 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
|
||||||
assert((branchCondOp.getTrueBlock() == target ||
|
assert((branchCondOp.getTrueBlock() == target ||
|
||||||
branchCondOp.getFalseBlock() == target) &&
|
branchCondOp.getFalseBlock() == target) &&
|
||||||
"expected target to be either the true or false target");
|
"expected target to be either the true or false target");
|
||||||
if (target == branchCondOp.trueTarget())
|
if (target == branchCondOp.getTrueTarget())
|
||||||
opBuilder.create<spirv::BranchConditionalOp>(
|
opBuilder.create<spirv::BranchConditionalOp>(
|
||||||
branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
|
branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
|
||||||
branchCondOp.getFalseBlockArguments(),
|
branchCondOp.getFalseBlockArguments(),
|
||||||
branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
|
branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
|
||||||
branchCondOp.falseTarget());
|
branchCondOp.getFalseTarget());
|
||||||
else
|
else
|
||||||
opBuilder.create<spirv::BranchConditionalOp>(
|
opBuilder.create<spirv::BranchConditionalOp>(
|
||||||
branchCondOp.getLoc(), branchCondOp.condition(),
|
branchCondOp.getLoc(), branchCondOp.getCondition(),
|
||||||
branchCondOp.getTrueBlockArguments(), blockArgs,
|
branchCondOp.getTrueBlockArguments(), blockArgs,
|
||||||
branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
|
branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
|
||||||
branchCondOp.getFalseBlock());
|
branchCondOp.getFalseBlock());
|
||||||
|
|
||||||
branchCondOp.erase();
|
branchCondOp.erase();
|
||||||
|
|
|
@ -24,7 +24,7 @@ namespace mlir {
|
||||||
LogicalResult spirv::serialize(spirv::ModuleOp module,
|
LogicalResult spirv::serialize(spirv::ModuleOp module,
|
||||||
SmallVectorImpl<uint32_t> &binary,
|
SmallVectorImpl<uint32_t> &binary,
|
||||||
const SerializationOptions &options) {
|
const SerializationOptions &options) {
|
||||||
if (!module.vce_triple())
|
if (!module.getVceTriple())
|
||||||
return module.emitError(
|
return module.emitError(
|
||||||
"module must have 'vce_triple' attribute to be serializeable");
|
"module must have 'vce_triple' attribute to be serializeable");
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,8 @@ visitInPrettyBlockOrder(Block *headerBlock,
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace spirv {
|
namespace spirv {
|
||||||
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
||||||
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
|
if (auto resultID =
|
||||||
|
prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
|
||||||
valueIDMap[op.getResult()] = resultID;
|
valueIDMap[op.getResult()] = resultID;
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -66,7 +67,7 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
||||||
if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
|
if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
|
||||||
/*isSpec=*/true)) {
|
/*isSpec=*/true)) {
|
||||||
// Emit the OpDecorate instruction for SpecId.
|
// Emit the OpDecorate instruction for SpecId.
|
||||||
if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
|
if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
|
||||||
|
@ -75,8 +76,8 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
specConstIDMap[op.sym_name()] = resultID;
|
specConstIDMap[op.getSymName()] = resultID;
|
||||||
return processName(resultID, op.sym_name());
|
return processName(resultID, op.getSymName());
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -84,7 +85,7 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
||||||
LogicalResult
|
LogicalResult
|
||||||
Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
||||||
uint32_t typeID = 0;
|
uint32_t typeID = 0;
|
||||||
if (failed(processType(op.getLoc(), op.type(), typeID))) {
|
if (failed(processType(op.getLoc(), op.getType(), typeID))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +95,7 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
||||||
operands.push_back(typeID);
|
operands.push_back(typeID);
|
||||||
operands.push_back(resultID);
|
operands.push_back(resultID);
|
||||||
|
|
||||||
auto constituents = op.constituents();
|
auto constituents = op.getConstituents();
|
||||||
|
|
||||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||||
auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
|
auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
|
||||||
|
@ -112,9 +113,9 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
|
||||||
|
|
||||||
encodeInstructionInto(typesGlobalValues,
|
encodeInstructionInto(typesGlobalValues,
|
||||||
spirv::Opcode::OpSpecConstantComposite, operands);
|
spirv::Opcode::OpSpecConstantComposite, operands);
|
||||||
specConstIDMap[op.sym_name()] = resultID;
|
specConstIDMap[op.getSymName()] = resultID;
|
||||||
|
|
||||||
return processName(resultID, op.sym_name());
|
return processName(resultID, op.getSymName());
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
@ -199,7 +200,7 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
|
||||||
operands.push_back(resTypeID);
|
operands.push_back(resTypeID);
|
||||||
auto funcID = getOrCreateFunctionID(op.getName());
|
auto funcID = getOrCreateFunctionID(op.getName());
|
||||||
operands.push_back(funcID);
|
operands.push_back(funcID);
|
||||||
operands.push_back(static_cast<uint32_t>(op.function_control()));
|
operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
|
||||||
operands.push_back(fnTypeID);
|
operands.push_back(fnTypeID);
|
||||||
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
|
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
|
||||||
|
|
||||||
|
@ -310,7 +311,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
||||||
// Get TypeID.
|
// Get TypeID.
|
||||||
uint32_t resultTypeID = 0;
|
uint32_t resultTypeID = 0;
|
||||||
SmallVector<StringRef, 4> elidedAttrs;
|
SmallVector<StringRef, 4> elidedAttrs;
|
||||||
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
|
if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,7 +321,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
||||||
auto resultID = getNextID();
|
auto resultID = getNextID();
|
||||||
|
|
||||||
// Encode the name.
|
// Encode the name.
|
||||||
auto varName = varOp.sym_name();
|
auto varName = varOp.getSymName();
|
||||||
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
|
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
|
||||||
if (failed(processName(resultID, varName))) {
|
if (failed(processName(resultID, varName))) {
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -332,7 +333,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
||||||
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
|
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
|
||||||
|
|
||||||
// Encode initialization.
|
// Encode initialization.
|
||||||
if (auto initializer = varOp.initializer()) {
|
if (auto initializer = varOp.getInitializer()) {
|
||||||
auto initializerID = getVariableID(*initializer);
|
auto initializerID = getVariableID(*initializer);
|
||||||
if (!initializerID) {
|
if (!initializerID) {
|
||||||
return emitError(varOp.getLoc(),
|
return emitError(varOp.getLoc(),
|
||||||
|
@ -364,7 +365,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
|
||||||
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
|
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
|
||||||
// Assign <id>s to all blocks so that branches inside the SelectionOp can
|
// Assign <id>s to all blocks so that branches inside the SelectionOp can
|
||||||
// resolve properly.
|
// resolve properly.
|
||||||
auto &body = selectionOp.body();
|
auto &body = selectionOp.getBody();
|
||||||
for (Block &block : body)
|
for (Block &block : body)
|
||||||
getOrCreateBlockID(&block);
|
getOrCreateBlockID(&block);
|
||||||
|
|
||||||
|
@ -390,7 +391,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
|
||||||
lastProcessedWasMergeInst = true;
|
lastProcessedWasMergeInst = true;
|
||||||
encodeInstructionInto(
|
encodeInstructionInto(
|
||||||
functionBody, spirv::Opcode::OpSelectionMerge,
|
functionBody, spirv::Opcode::OpSelectionMerge,
|
||||||
{mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
|
{mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
if (failed(
|
if (failed(
|
||||||
|
@ -420,7 +421,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
||||||
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve
|
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve
|
||||||
// properly. We don't need to assign for the entry block, which is just for
|
// properly. We don't need to assign for the entry block, which is just for
|
||||||
// satisfying MLIR region's structural requirement.
|
// satisfying MLIR region's structural requirement.
|
||||||
auto &body = loopOp.body();
|
auto &body = loopOp.getBody();
|
||||||
for (Block &block : llvm::drop_begin(body))
|
for (Block &block : llvm::drop_begin(body))
|
||||||
getOrCreateBlockID(&block);
|
getOrCreateBlockID(&block);
|
||||||
|
|
||||||
|
@ -452,7 +453,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
||||||
lastProcessedWasMergeInst = true;
|
lastProcessedWasMergeInst = true;
|
||||||
encodeInstructionInto(
|
encodeInstructionInto(
|
||||||
functionBody, spirv::Opcode::OpLoopMerge,
|
functionBody, spirv::Opcode::OpLoopMerge,
|
||||||
{mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
|
{mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
|
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
|
||||||
|
@ -483,12 +484,12 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
|
||||||
|
|
||||||
LogicalResult Serializer::processBranchConditionalOp(
|
LogicalResult Serializer::processBranchConditionalOp(
|
||||||
spirv::BranchConditionalOp condBranchOp) {
|
spirv::BranchConditionalOp condBranchOp) {
|
||||||
auto conditionID = getValueID(condBranchOp.condition());
|
auto conditionID = getValueID(condBranchOp.getCondition());
|
||||||
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
|
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
|
||||||
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
|
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
|
||||||
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
|
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
|
||||||
|
|
||||||
if (auto weights = condBranchOp.branch_weights()) {
|
if (auto weights = condBranchOp.getBranchWeights()) {
|
||||||
for (auto val : weights->getValue())
|
for (auto val : weights->getValue())
|
||||||
arguments.push_back(val.cast<IntegerAttr>().getInt());
|
arguments.push_back(val.cast<IntegerAttr>().getInt());
|
||||||
}
|
}
|
||||||
|
@ -509,26 +510,26 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
|
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
|
||||||
auto varName = addressOfOp.variable();
|
auto varName = addressOfOp.getVariable();
|
||||||
auto variableID = getVariableID(varName);
|
auto variableID = getVariableID(varName);
|
||||||
if (!variableID) {
|
if (!variableID) {
|
||||||
return addressOfOp.emitError("unknown result <id> for variable ")
|
return addressOfOp.emitError("unknown result <id> for variable ")
|
||||||
<< varName;
|
<< varName;
|
||||||
}
|
}
|
||||||
valueIDMap[addressOfOp.pointer()] = variableID;
|
valueIDMap[addressOfOp.getPointer()] = variableID;
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
|
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
|
||||||
auto constName = referenceOfOp.spec_const();
|
auto constName = referenceOfOp.getSpecConst();
|
||||||
auto constID = getSpecConstID(constName);
|
auto constID = getSpecConstID(constName);
|
||||||
if (!constID) {
|
if (!constID) {
|
||||||
return referenceOfOp.emitError(
|
return referenceOfOp.emitError(
|
||||||
"unknown result <id> for specialization constant ")
|
"unknown result <id> for specialization constant ")
|
||||||
<< constName;
|
<< constName;
|
||||||
}
|
}
|
||||||
valueIDMap[referenceOfOp.reference()] = constID;
|
valueIDMap[referenceOfOp.getReference()] = constID;
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,21 +538,21 @@ LogicalResult
|
||||||
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
|
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
|
||||||
SmallVector<uint32_t, 4> operands;
|
SmallVector<uint32_t, 4> operands;
|
||||||
// Add the ExecutionModel.
|
// Add the ExecutionModel.
|
||||||
operands.push_back(static_cast<uint32_t>(op.execution_model()));
|
operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
|
||||||
// Add the function <id>.
|
// Add the function <id>.
|
||||||
auto funcID = getFunctionID(op.fn());
|
auto funcID = getFunctionID(op.getFn());
|
||||||
if (!funcID) {
|
if (!funcID) {
|
||||||
return op.emitError("missing <id> for function ")
|
return op.emitError("missing <id> for function ")
|
||||||
<< op.fn()
|
<< op.getFn()
|
||||||
<< "; function needs to be defined before spv.EntryPoint is "
|
<< "; function needs to be defined before spv.EntryPoint is "
|
||||||
"serialized";
|
"serialized";
|
||||||
}
|
}
|
||||||
operands.push_back(funcID);
|
operands.push_back(funcID);
|
||||||
// Add the name of the function.
|
// Add the name of the function.
|
||||||
spirv::encodeStringLiteralInto(operands, op.fn());
|
spirv::encodeStringLiteralInto(operands, op.getFn());
|
||||||
|
|
||||||
// Add the interface values.
|
// Add the interface values.
|
||||||
if (auto interface = op.interface()) {
|
if (auto interface = op.getInterface()) {
|
||||||
for (auto var : interface.getValue()) {
|
for (auto var : interface.getValue()) {
|
||||||
auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
|
auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
|
||||||
if (!id) {
|
if (!id) {
|
||||||
|
@ -571,19 +572,19 @@ LogicalResult
|
||||||
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
|
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
|
||||||
SmallVector<uint32_t, 4> operands;
|
SmallVector<uint32_t, 4> operands;
|
||||||
// Add the function <id>.
|
// Add the function <id>.
|
||||||
auto funcID = getFunctionID(op.fn());
|
auto funcID = getFunctionID(op.getFn());
|
||||||
if (!funcID) {
|
if (!funcID) {
|
||||||
return op.emitError("missing <id> for function ")
|
return op.emitError("missing <id> for function ")
|
||||||
<< op.fn()
|
<< op.getFn()
|
||||||
<< "; function needs to be serialized before ExecutionModeOp is "
|
<< "; function needs to be serialized before ExecutionModeOp is "
|
||||||
"serialized";
|
"serialized";
|
||||||
}
|
}
|
||||||
operands.push_back(funcID);
|
operands.push_back(funcID);
|
||||||
// Add the ExecutionMode.
|
// Add the ExecutionMode.
|
||||||
operands.push_back(static_cast<uint32_t>(op.execution_mode()));
|
operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
|
||||||
|
|
||||||
// Serialize values if any.
|
// Serialize values if any.
|
||||||
auto values = op.values();
|
auto values = op.getValues();
|
||||||
if (values) {
|
if (values) {
|
||||||
for (auto &intVal : values.getValue()) {
|
for (auto &intVal : values.getValue()) {
|
||||||
operands.push_back(static_cast<uint32_t>(
|
operands.push_back(static_cast<uint32_t>(
|
||||||
|
@ -598,7 +599,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
|
||||||
template <>
|
template <>
|
||||||
LogicalResult
|
LogicalResult
|
||||||
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
||||||
auto funcName = op.callee();
|
auto funcName = op.getCallee();
|
||||||
uint32_t resTypeID = 0;
|
uint32_t resTypeID = 0;
|
||||||
|
|
||||||
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
|
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
|
||||||
|
@ -609,7 +610,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
||||||
auto funcCallID = getNextID();
|
auto funcCallID = getNextID();
|
||||||
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
|
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
|
||||||
|
|
||||||
for (auto value : op.arguments()) {
|
for (auto value : op.getArguments()) {
|
||||||
auto valueID = getValueID(value);
|
auto valueID = getValueID(value);
|
||||||
assert(valueID && "cannot find a value for spv.FunctionCall");
|
assert(valueID && "cannot find a value for spv.FunctionCall");
|
||||||
operands.push_back(valueID);
|
operands.push_back(valueID);
|
||||||
|
|
|
@ -119,7 +119,8 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
|
||||||
binary.clear();
|
binary.clear();
|
||||||
binary.reserve(moduleSize);
|
binary.reserve(moduleSize);
|
||||||
|
|
||||||
spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
|
spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
|
||||||
|
nextID);
|
||||||
binary.append(capabilities.begin(), capabilities.end());
|
binary.append(capabilities.begin(), capabilities.end());
|
||||||
binary.append(extensions.begin(), extensions.end());
|
binary.append(extensions.begin(), extensions.end());
|
||||||
binary.append(extendedSets.begin(), extendedSets.end());
|
binary.append(extendedSets.begin(), extendedSets.end());
|
||||||
|
@ -166,7 +167,7 @@ uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Serializer::processCapability() {
|
void Serializer::processCapability() {
|
||||||
for (auto cap : module.vce_triple()->getCapabilities())
|
for (auto cap : module.getVceTriple()->getCapabilities())
|
||||||
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
|
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
|
||||||
{static_cast<uint32_t>(cap)});
|
{static_cast<uint32_t>(cap)});
|
||||||
}
|
}
|
||||||
|
@ -186,7 +187,7 @@ void Serializer::processDebugInfo() {
|
||||||
|
|
||||||
void Serializer::processExtension() {
|
void Serializer::processExtension() {
|
||||||
llvm::SmallVector<uint32_t, 16> extName;
|
llvm::SmallVector<uint32_t, 16> extName;
|
||||||
for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
|
for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
|
||||||
extName.clear();
|
extName.clear();
|
||||||
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
|
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
|
||||||
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
|
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
|
||||||
|
@ -1045,11 +1046,11 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
|
||||||
} else if (auto branchCondOp =
|
} else if (auto branchCondOp =
|
||||||
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
|
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
|
||||||
Optional<OperandRange> blockOperands;
|
Optional<OperandRange> blockOperands;
|
||||||
if (branchCondOp.trueTarget() == block) {
|
if (branchCondOp.getTrueTarget() == block) {
|
||||||
blockOperands = branchCondOp.trueTargetOperands();
|
blockOperands = branchCondOp.getTrueTargetOperands();
|
||||||
} else {
|
} else {
|
||||||
assert(branchCondOp.falseTarget() == block);
|
assert(branchCondOp.getFalseTarget() == block);
|
||||||
blockOperands = branchCondOp.falseTargetOperands();
|
blockOperands = branchCondOp.getFalseTargetOperands();
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(!blockOperands->empty() &&
|
assert(!blockOperands->empty() &&
|
||||||
|
|
|
@ -1360,7 +1360,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
|
||||||
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
|
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
|
||||||
"static_cast<{0}::{1}>(1 << i);\n",
|
"static_cast<{0}::{1}>(1 << i);\n",
|
||||||
enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
|
enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
|
||||||
namedAttr.name);
|
srcOp.getGetterName(namedAttr.name));
|
||||||
os << formatv(
|
os << formatv(
|
||||||
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
|
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
|
||||||
enumAttr.getUnderlyingType());
|
enumAttr.getUnderlyingType());
|
||||||
|
@ -1368,7 +1368,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
|
||||||
// For IntEnumAttr, we just need to query the value as a whole.
|
// For IntEnumAttr, we just need to query the value as a whole.
|
||||||
os << " {\n";
|
os << " {\n";
|
||||||
os << formatv(" auto tblgen_attrVal = this->{0}();\n",
|
os << formatv(" auto tblgen_attrVal = this->{0}();\n",
|
||||||
namedAttr.name);
|
srcOp.getGetterName(namedAttr.name));
|
||||||
}
|
}
|
||||||
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
|
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
|
||||||
enumAttr.getCppNamespace(), avail.getQueryFnName());
|
enumAttr.getCppNamespace(), avail.getQueryFnName());
|
||||||
|
|
Loading…
Reference in New Issue