forked from OSchip/llvm-project
[mlir][spirv] Fix a few issues in ModuleCombiner
- Fixed symbol insertion into `symNameToModuleMap`. Insertion needs to happen whether symbols are renamed or not. - Added check for the VCE triple and avoid dropping it. - Disabled function deduplication. It requires more careful rules. Right now it can remove different functions. - Added tests for symbol rename listener. - And some other code/comment cleanups. Reviewed By: ergawy Differential Revision: https://reviews.llvm.org/D106886
This commit is contained in:
parent
aa6340cf87
commit
23326b9f17
|
@ -13,6 +13,7 @@
|
|||
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
|
||||
#define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
|
|
@ -467,8 +467,9 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
let builders = [
|
||||
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
|
||||
OpBuilder<(ins "spirv::AddressingModel":$addressing_model,
|
||||
"spirv::MemoryModel":$memory_model,
|
||||
CArg<"Optional<StringRef>", "llvm::None">:$name)>
|
||||
"spirv::MemoryModel":$memory_model,
|
||||
CArg<"Optional<spirv::VerCapExtAttr>", "llvm::None">:$vce_triple,
|
||||
CArg<"Optional<StringRef>", "llvm::None">:$name)>
|
||||
];
|
||||
|
||||
// We need to ensure the block inside the region is properly terminated;
|
||||
|
|
|
@ -22,53 +22,54 @@ class OpBuilder;
|
|||
namespace spirv {
|
||||
class ModuleOp;
|
||||
|
||||
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
|
||||
/// The listener function to receive symbol renaming events.
|
||||
///
|
||||
/// `originalModule` is the input spirv::ModuleOp that contains the renamed
|
||||
/// symbol. `oldSymbol` and `newSymbol` are the original and renamed symbol.
|
||||
/// Note that it's the responsibility of the caller to properly retain the
|
||||
/// storage underlying the passed StringRefs if the listener callback outlives
|
||||
/// this function call.
|
||||
using SymbolRenameListener = function_ref<void(
|
||||
spirv::ModuleOp originalModule, StringRef oldSymbol, StringRef newSymbol)>;
|
||||
|
||||
/// Combines a list of SPIR-V `inputModules` into one. Returns the combined
|
||||
/// module on success; returns a null module otherwise.
|
||||
//
|
||||
/// \param inputModules the list of modules to combine. They won't be modified.
|
||||
/// \param combinedMdouleBuilder an OpBuilder for building the combined module.
|
||||
/// \param symbRenameListener a listener that gets called everytime a symbol in
|
||||
/// one of the input modules is renamed.
|
||||
///
|
||||
/// To combine multiple SPIR-V modules, we move all the module-level ops
|
||||
/// from all the input modules into one big combined module. To that end, the
|
||||
/// combination process proceeds in 2 phases:
|
||||
///
|
||||
/// (1) resolve conflicts between pairs of ops from different modules
|
||||
/// (2) deduplicate equivalent ops/sub-ops in the merged module.
|
||||
/// 1. resolve conflicts between pairs of ops from different modules,
|
||||
/// 2. deduplicate equivalent ops/sub-ops in the merged module.
|
||||
///
|
||||
/// For the conflict resolution phase, the following rules are employed to
|
||||
/// resolve such conflicts:
|
||||
///
|
||||
/// - If 2 spv.func's have the same symbol name, then rename one of the
|
||||
/// - If 2 spv.func's have the same symbol name, then rename one of the
|
||||
/// functions.
|
||||
/// - If an spv.func and another op have the same symbol name, then rename the
|
||||
/// - If an spv.func and another op have the same symbol name, then rename the
|
||||
/// other symbol.
|
||||
/// - If none of the 2 conflicting ops are spv.func, then rename either.
|
||||
/// - If none of the 2 conflicting ops are spv.func, then rename either.
|
||||
///
|
||||
/// For deduplication, the following 3 cases are taken into consideration:
|
||||
///
|
||||
/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
|
||||
/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
|
||||
/// or the same build_in attribute value, then replace one of them using the
|
||||
/// other.
|
||||
/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
|
||||
/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
|
||||
/// replace one of them using the other.
|
||||
/// - If 2 spv.func's are identical replace one of them using the other.
|
||||
/// - Deduplicating functions are not supported right now.
|
||||
///
|
||||
/// In all cases, the references to the updated symbol (whether renamed or
|
||||
/// deduplicated) are also updated to reflect the change.
|
||||
///
|
||||
/// \param modules the list of modules to combine. Input modules are not
|
||||
/// modified.
|
||||
/// \param combinedMdouleBuilder an OpBuilder to be used for
|
||||
// building up the combined module.
|
||||
/// \param symbRenameListener a listener that gets called everytime a symbol in
|
||||
/// one of the input modules is renamed. The arguments
|
||||
/// passed to the listener are: the input
|
||||
/// spirv::ModuleOp that contains the renamed symbol,
|
||||
/// a StringRef to the old symbol name, and a
|
||||
/// StringRef to the new symbol name. Note that it is
|
||||
/// the responsibility of the caller to properly
|
||||
/// retain the storage underlying the passed
|
||||
/// StringRefs if the listener callback outlives this
|
||||
/// function call.
|
||||
///
|
||||
/// \return the combined module.
|
||||
OwningOpRef<spirv::ModuleOp>
|
||||
combine(MutableArrayRef<ModuleOp> modules, OpBuilder &combinedModuleBuilder,
|
||||
function_ref<void(ModuleOp, StringRef, StringRef)> symbRenameListener);
|
||||
OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
|
||||
OpBuilder &combinedModuleBuilder,
|
||||
SymbolRenameListener symRenameListener);
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -310,7 +310,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
|
|||
// Add a keyword to the module name to avoid symbolic conflict.
|
||||
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
|
||||
auto spvModule = rewriter.create<spirv::ModuleOp>(
|
||||
moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
|
||||
moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None,
|
||||
StringRef(spvModuleName));
|
||||
|
||||
// Move the region from the module op into the SPIR-V module.
|
||||
|
|
|
@ -2540,6 +2540,7 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
|||
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
||||
spirv::AddressingModel addressingModel,
|
||||
spirv::MemoryModel memoryModel,
|
||||
Optional<VerCapExtAttr> vceTriple,
|
||||
Optional<StringRef> name) {
|
||||
state.addAttribute(
|
||||
"addressing_model",
|
||||
|
@ -2548,10 +2549,11 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
|||
static_cast<int32_t>(memoryModel)));
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(state.addRegion());
|
||||
if (name) {
|
||||
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(*name));
|
||||
}
|
||||
if (vceTriple)
|
||||
state.addAttribute(getVCETripleAttrName(), *vceTriple);
|
||||
if (name)
|
||||
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(*name));
|
||||
}
|
||||
|
||||
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
||||
|
|
|
@ -12,27 +12,33 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static constexpr unsigned maxFreeID = 1 << 20;
|
||||
|
||||
/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
|
||||
/// suffix in `lastUsedID`.
|
||||
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
|
||||
spirv::ModuleOp combinedModule) {
|
||||
spirv::ModuleOp module) {
|
||||
SmallString<64> newSymName(oldSymName);
|
||||
newSymName.push_back('_');
|
||||
|
||||
while (lastUsedID < maxFreeID) {
|
||||
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
|
||||
|
||||
if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
|
||||
if (!SymbolTable::lookupSymbolIn(module, possible)) {
|
||||
newSymName += llvm::utostr(lastUsedID);
|
||||
break;
|
||||
}
|
||||
|
@ -41,8 +47,8 @@ static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
|
|||
return newSymName;
|
||||
}
|
||||
|
||||
/// Check if a symbol with the same name as op already exists in source. If so,
|
||||
/// rename op and update all its references in target.
|
||||
/// Checks if a symbol with the same name as `op` already exists in `source`.
|
||||
/// If so, renames `op` and updates all its references in `target`.
|
||||
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
|
||||
spirv::ModuleOp target,
|
||||
spirv::ModuleOp source,
|
||||
|
@ -61,99 +67,67 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename KeyTy, typename SymbolOpTy>
|
||||
static SymbolOpTy
|
||||
emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
|
||||
DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
|
||||
auto result = deduplicationMap.try_emplace(key, symbolOp);
|
||||
|
||||
if (result.second)
|
||||
return SymbolOpTy();
|
||||
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
/// Computes a hash code to represent the argument SymbolOpInterface based on
|
||||
/// all the Op's attributes except for the symbol name.
|
||||
///
|
||||
/// \return the hash code computed from the Op's attributes as described above.
|
||||
/// Computes a hash code to represent `symbolOp` based on all its attributes
|
||||
/// except for the symbol name.
|
||||
///
|
||||
/// Note: We use the operation's name (not the symbol name) as part of the hash
|
||||
/// computation. This prevents, for example, mistakenly considering a global
|
||||
/// variable and a spec constant as duplicates because their descriptor set +
|
||||
/// binding and spec_id, respectively, happen to hash to the same value.
|
||||
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
|
||||
llvm::hash_code hashCode(0);
|
||||
hashCode = llvm::hash_combine(symbolOp->getName());
|
||||
auto range =
|
||||
llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
|
||||
return attr.first != SymbolTable::getSymbolAttrName();
|
||||
});
|
||||
|
||||
for (auto attr : symbolOp->getAttrs()) {
|
||||
if (attr.first == SymbolTable::getSymbolAttrName())
|
||||
continue;
|
||||
hashCode = llvm::hash_combine(hashCode, attr);
|
||||
}
|
||||
|
||||
return hashCode;
|
||||
}
|
||||
|
||||
/// Computes a hash code from the argument Block.
|
||||
llvm::hash_code computeHash(Block *block) {
|
||||
// TODO: Consider extracting BlockEquivalenceData into a common header and
|
||||
// re-using it here.
|
||||
llvm::hash_code hash(0);
|
||||
|
||||
for (Operation &op : *block) {
|
||||
// TODO: Properly handle operations with regions.
|
||||
if (op.getNumRegions() > 0)
|
||||
return 0;
|
||||
|
||||
hash = llvm::hash_combine(
|
||||
hash, OperationEquivalence::computeHash(
|
||||
&op, OperationEquivalence::Flags::IgnoreOperands));
|
||||
}
|
||||
|
||||
return hash;
|
||||
return llvm::hash_combine(
|
||||
symbolOp->getName(),
|
||||
llvm::hash_combine_range(range.begin(), range.end()));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
// TODO Properly test symbol rename listener mechanism.
|
||||
|
||||
OwningOpRef<spirv::ModuleOp>
|
||||
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
||||
OpBuilder &combinedModuleBuilder,
|
||||
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
|
||||
symRenameListener) {
|
||||
unsigned lastUsedID = 0;
|
||||
|
||||
if (modules.empty())
|
||||
OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
|
||||
OpBuilder &combinedModuleBuilder,
|
||||
SymbolRenameListener symRenameListener) {
|
||||
if (inputModules.empty())
|
||||
return nullptr;
|
||||
|
||||
auto addressingModel = modules[0].addressing_model();
|
||||
auto memoryModel = modules[0].memory_model();
|
||||
spirv::ModuleOp firstModule = inputModules.front();
|
||||
auto addressingModel = firstModule.addressing_model();
|
||||
auto memoryModel = firstModule.memory_model();
|
||||
auto vceTriple = firstModule.vce_triple();
|
||||
|
||||
// First check whether there are conflicts between addressing/memory model.
|
||||
// Return early if so.
|
||||
for (auto module : inputModules) {
|
||||
if (module.addressing_model() != addressingModel ||
|
||||
module.memory_model() != memoryModel ||
|
||||
module.vce_triple() != vceTriple) {
|
||||
module.emitError("input modules differ in addressing model, memory "
|
||||
"model, and/or VCE triple");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
|
||||
modules[0].getLoc(), addressingModel, memoryModel);
|
||||
firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
|
||||
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
|
||||
|
||||
// In some cases, a symbol in the (current state of the) combined module is
|
||||
// renamed in order to maintain the conflicting symbol in the input module
|
||||
// renamed in order to enable the conflicting symbol in the input module
|
||||
// being merged. For example, if the conflict is between a global variable in
|
||||
// the current combined module and a function in the input module, the global
|
||||
// variable is renamed. In order to notify listeners of the symbol updates in
|
||||
// such cases, we need to keep track of the module from which the renamed
|
||||
// symbol in the combined module originated. This map keeps such information.
|
||||
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
|
||||
llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
|
||||
|
||||
for (auto module : modules) {
|
||||
if (module.addressing_model() != addressingModel ||
|
||||
module.memory_model() != memoryModel) {
|
||||
module.emitError(
|
||||
"input modules differ in addressing model and/or memory model");
|
||||
return nullptr;
|
||||
}
|
||||
unsigned lastUsedID = 0;
|
||||
|
||||
spirv::ModuleOp moduleClone = module.clone();
|
||||
for (auto inputModule : inputModules) {
|
||||
spirv::ModuleOp moduleClone = inputModule.clone();
|
||||
|
||||
// In the combined module, rename all symbols that conflict with symbols
|
||||
// from the current input module. This renaming applies to all ops except
|
||||
|
@ -161,65 +135,70 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
// non-spv.func, we rename that symbol instead and maintain the spv.func in
|
||||
// the combined module name as it is.
|
||||
for (auto &op : *combinedModule.getBody()) {
|
||||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
|
||||
if (!symbolOp)
|
||||
continue;
|
||||
|
||||
if (!isa<FuncOp>(op) &&
|
||||
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
|
||||
lastUsedID)))
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
|
||||
if (!isa<FuncOp>(op) &&
|
||||
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
|
||||
lastUsedID)))
|
||||
return nullptr;
|
||||
|
||||
StringRef newSymName = symbolOp.getName();
|
||||
|
||||
if (symRenameListener && oldSymName != newSymName) {
|
||||
spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
|
||||
|
||||
if (!originalModule) {
|
||||
inputModule.emitError(
|
||||
"unable to find original spirv::ModuleOp for symbol ")
|
||||
<< oldSymName;
|
||||
return nullptr;
|
||||
|
||||
StringRef newSymName = symbolOp.getName();
|
||||
|
||||
if (symRenameListener && oldSymName != newSymName) {
|
||||
spirv::ModuleOp originalModule =
|
||||
symNameToModuleMap.lookup(oldSymName);
|
||||
|
||||
if (!originalModule) {
|
||||
module.emitError("unable to find original ModuleOp for symbol ")
|
||||
<< oldSymName;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
symRenameListener(originalModule, oldSymName, newSymName);
|
||||
|
||||
// Since the symbol name is updated, there is no need to maintain the
|
||||
// entry that associates the old symbol name with the original module.
|
||||
symNameToModuleMap.erase(oldSymName);
|
||||
// Instead, add a new entry to map the new symbol name to the original
|
||||
// module in case it gets renamed again later.
|
||||
symNameToModuleMap[newSymName] = originalModule;
|
||||
}
|
||||
|
||||
symRenameListener(originalModule, oldSymName, newSymName);
|
||||
|
||||
// Since the symbol name is updated, there is no need to maintain the
|
||||
// entry that associates the old symbol name with the original module.
|
||||
symNameToModuleMap.erase(oldSymName);
|
||||
// Instead, add a new entry to map the new symbol name to the original
|
||||
// module in case it gets renamed again later.
|
||||
symNameToModuleMap[newSymName] = originalModule;
|
||||
}
|
||||
}
|
||||
|
||||
// In the current input module, rename all symbols that conflict with
|
||||
// symbols from the combined module. This includes renaming spv.funcs.
|
||||
for (auto &op : *moduleClone.getBody()) {
|
||||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
|
||||
if (!symbolOp)
|
||||
continue;
|
||||
|
||||
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
|
||||
lastUsedID)))
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
|
||||
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
|
||||
lastUsedID)))
|
||||
return nullptr;
|
||||
|
||||
StringRef newSymName = symbolOp.getName();
|
||||
|
||||
if (symRenameListener) {
|
||||
if (oldSymName != newSymName)
|
||||
symRenameListener(inputModule, oldSymName, newSymName);
|
||||
|
||||
// Insert the module associated with the symbol name.
|
||||
auto emplaceResult =
|
||||
symNameToModuleMap.try_emplace(newSymName, inputModule);
|
||||
|
||||
// If an entry with the same symbol name is already present, this must
|
||||
// be a problem with the implementation, specially clean-up of the map
|
||||
// while iterating over the combined module above.
|
||||
if (!emplaceResult.second) {
|
||||
inputModule.emitError("did not expect to find an entry for symbol ")
|
||||
<< symbolOp.getName();
|
||||
return nullptr;
|
||||
|
||||
StringRef newSymName = symbolOp.getName();
|
||||
|
||||
if (symRenameListener && oldSymName != newSymName) {
|
||||
symRenameListener(module, oldSymName, newSymName);
|
||||
|
||||
// Insert the module associated with the symbol name.
|
||||
auto emplaceResult =
|
||||
symNameToModuleMap.try_emplace(symbolOp.getName(), module);
|
||||
|
||||
// If an entry with the same symbol name is already present, this must
|
||||
// be a problem with the implementation, specially clean-up of the map
|
||||
// while iterating over the combined module above.
|
||||
if (!emplaceResult.second) {
|
||||
module.emitError("did not expect to find an entry for symbol ")
|
||||
<< symbolOp.getName();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -234,30 +213,26 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
SmallVector<SymbolOpInterface, 0> eraseList;
|
||||
|
||||
for (auto &op : *combinedModule.getBody()) {
|
||||
llvm::hash_code hashCode(0);
|
||||
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
|
||||
|
||||
if (!symbolOp)
|
||||
continue;
|
||||
|
||||
hashCode = computeHash(symbolOp);
|
||||
|
||||
// A 0 hash code means the op is not suitable for deduplication and should
|
||||
// be skipped. An example of this is when a function has ops with regions
|
||||
// which are not properly supported yet.
|
||||
if (!hashCode)
|
||||
// Do not support ops with operands or results.
|
||||
// Global variables, spec constants, and functions won't have
|
||||
// operands/results, but just for safety here.
|
||||
if (op.getNumOperands() != 0 || op.getNumResults() != 0)
|
||||
continue;
|
||||
|
||||
if (auto funcOp = dyn_cast<FuncOp>(op))
|
||||
for (auto &blk : funcOp)
|
||||
hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
|
||||
|
||||
SymbolOpInterface replacementSymOp =
|
||||
emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
|
||||
|
||||
if (!replacementSymOp)
|
||||
// Deduplicating functions are not supported yet.
|
||||
if (isa<FuncOp>(op))
|
||||
continue;
|
||||
|
||||
auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
|
||||
if (result.second)
|
||||
continue;
|
||||
|
||||
SymbolOpInterface replacementSymOp = result.first->second;
|
||||
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(
|
||||
symbolOp, replacementSymOp.getName(), combinedModule))) {
|
||||
symbolOp.emitError("unable to update all symbol uses for ")
|
||||
|
|
|
@ -1,9 +1,19 @@
|
|||
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// Combine modules without the same symbols
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.SpecConstant @m1_sc
|
||||
// CHECK-NEXT: spv.GlobalVariable @m1_gv bind(1, 0)
|
||||
// CHECK-NEXT: spv.func @no_op
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: spv.EntryPoint "GLCompute" @no_op
|
||||
// CHECK-NEXT: spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
|
||||
|
||||
// CHECK-NEXT: spv.SpecConstant @m2_sc
|
||||
// CHECK-NEXT: spv.GlobalVariable @m2_gv bind(0, 1)
|
||||
// CHECK-NEXT: spv.func @variable_init_spec_constant
|
||||
// CHECK-NEXT: spv.mlir.referenceof @m2_sc
|
||||
// CHECK-NEXT: spv.Variable init
|
||||
|
@ -15,10 +25,17 @@
|
|||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.SpecConstant @m1_sc = 42.42 : f32
|
||||
spv.GlobalVariable @m1_gv bind(1, 0): !spv.ptr<f32, Input>
|
||||
spv.func @no_op() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
spv.EntryPoint "GLCompute" @no_op
|
||||
spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.SpecConstant @m2_sc = 42 : i32
|
||||
spv.GlobalVariable @m2_gv bind(0, 1): !spv.ptr<f32, Input>
|
||||
spv.func @variable_init_spec_constant() -> () "None" {
|
||||
%0 = spv.mlir.referenceof @m2_sc : i32
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
|
||||
|
@ -33,7 +50,7 @@ module {
|
|||
spv.module Physical64 GLSL450 {
|
||||
}
|
||||
|
||||
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
|
||||
// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}}
|
||||
spv.module Logical GLSL450 {
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +61,19 @@ module {
|
|||
spv.module Logical Simple {
|
||||
}
|
||||
|
||||
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
|
||||
// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}}
|
||||
spv.module Logical GLSL450 {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
}
|
||||
|
||||
// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}}
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -215,7 +215,7 @@ spv.module Logical GLSL450 {
|
|||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
|
||||
|
||||
spv.EntryPoint "GLCompute" @foo
|
||||
spv.ExecutionMode @foo "ContractionOff"
|
||||
}
|
||||
|
@ -383,7 +383,7 @@ spv.module Logical GLSL450 {
|
|||
spv.SpecConstant @foo = -5 : i32
|
||||
|
||||
spv.func @bar() -> i32 "None" {
|
||||
%0 = spv.mlir.referenceof @foo : i32
|
||||
%0 = spv.mlir.referenceof @foo : i32
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
|
@ -42,7 +41,6 @@ spv.module Logical GLSL450 {
|
|||
spv.ReturnValue %2 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -62,7 +60,6 @@ spv.module Logical GLSL450 {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
|
||||
}
|
||||
|
@ -76,7 +73,6 @@ spv.module Logical GLSL450 {
|
|||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -93,7 +89,6 @@ spv.module Logical GLSL450 {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||
}
|
||||
|
@ -107,10 +102,11 @@ spv.module Logical GLSL450 {
|
|||
spv.ReturnValue %1 : vector<3xi32>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Deduplicate 2 spec constants with the same spec ID.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.SpecConstant @foo spec_id(5)
|
||||
|
@ -128,7 +124,6 @@ spv.module Logical GLSL450 {
|
|||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.SpecConstant @foo spec_id(5) = 1. : f32
|
||||
|
||||
|
@ -147,48 +142,82 @@ spv.module Logical GLSL450 {
|
|||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Don't deduplicate functions with similar ops but different operands.
|
||||
|
||||
// CHECK: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
|
||||
// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG1]] : f32
|
||||
// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG2]] : f32
|
||||
// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: spv.func @foo_1(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
|
||||
// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG2]] : f32
|
||||
// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG1]] : f32
|
||||
// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
|
||||
%add = spv.FAdd %a, %b: f32
|
||||
%mul = spv.FMul %add, %c: f32
|
||||
spv.ReturnValue %mul: f32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
|
||||
%add = spv.FAdd %a, %c: f32
|
||||
%mul = spv.FMul %add, %b: f32
|
||||
spv.ReturnValue %mul: f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.SpecConstant @bar spec_id(5)
|
||||
// TODO: re-enable this test once we have better function deduplication.
|
||||
|
||||
// CHECK-NEXT: spv.func @foo(%arg0: f32)
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX: module {
|
||||
// XXXXX-NEXT: spv.module Logical GLSL450 {
|
||||
// XXXXX-NEXT: spv.SpecConstant @bar spec_id(5)
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32)
|
||||
// CHECK-NEXT: spv.mlir.referenceof
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @foo(%arg0: f32)
|
||||
// XXXXX-NEXT: spv.ReturnValue
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz(%arg0: i32)
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @foo_different_body(%arg0: f32)
|
||||
// XXXXX-NEXT: spv.mlir.referenceof
|
||||
// XXXXX-NEXT: spv.ReturnValue
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32)
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @baz(%arg0: i32)
|
||||
// XXXXX-NEXT: spv.ReturnValue
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz_no_return_different_control
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @baz_no_return(%arg0: i32)
|
||||
// XXXXX-NEXT: spv.Return
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz_no_return_another_control
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @baz_no_return_different_control
|
||||
// XXXXX-NEXT: spv.Return
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @kernel
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @baz_no_return_another_control
|
||||
// XXXXX-NEXT: spv.Return
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @kernel_different_attr
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// XXXXX-NEXT: spv.func @kernel
|
||||
// XXXXX-NEXT: spv.Return
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
// XXXXX-NEXT: spv.func @kernel_different_attr
|
||||
// XXXXX-NEXT: spv.Return
|
||||
// XXXXX-NEXT: }
|
||||
// XXXXX-NEXT: }
|
||||
// XXXXX-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
module {
|
||||
spv.module @Module1 Logical GLSL450 {
|
||||
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
spv.func @bar() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
spv.func @baz() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.SpecConstant @sc = -5 : i32
|
||||
}
|
||||
|
||||
spv.module @Module2 Logical GLSL450 {
|
||||
spv.func @foo() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @baz() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.SpecConstant @sc = -5 : i32
|
||||
}
|
||||
|
||||
spv.module @Module3 Logical GLSL450 {
|
||||
spv.func @foo() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @baz() -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.SpecConstant @sc = -5 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: [Module1] foo -> foo_1
|
||||
// CHECK: [Module1] sc -> sc_2
|
||||
|
||||
// CHECK: [Module2] bar -> bar_3
|
||||
// CHECK: [Module2] baz -> baz_4
|
||||
// CHECK: [Module2] sc -> sc_5
|
||||
|
||||
// CHECK: [Module3] foo -> foo_6
|
||||
// CHECK: [Module3] bar -> bar_7
|
||||
// CHECK: [Module3] baz -> baz_8
|
|
@ -37,7 +37,14 @@ void TestModuleCombinerPass::runOnOperation() {
|
|||
auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
|
||||
|
||||
OpBuilder combinedModuleBuilder(modules[0]);
|
||||
combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
|
||||
|
||||
auto listener = [](spirv::ModuleOp originalModule, StringRef oldSymbol,
|
||||
StringRef newSymbol) {
|
||||
llvm::outs() << "[" << originalModule.getName() << "] " << oldSymbol
|
||||
<< " -> " << newSymbol << "\n";
|
||||
};
|
||||
|
||||
combinedModule = spirv::combine(modules, combinedModuleBuilder, listener);
|
||||
|
||||
for (spirv::ModuleOp module : modules)
|
||||
module.erase();
|
||||
|
|
Loading…
Reference in New Issue