[mlir][spirv] Fix a few issues in ModuleCombiner

- Fixed symbol insertion into `symNameToModuleMap`. Insertion
  needs to happen whether symbols are renamed or not.
- Added check for the VCE triple and avoid dropping it.
- Disabled function deduplication. It requires more careful
  rules. Right now it can remove different functions.
- Added tests for symbol rename listener.
- And some other code/comment cleanups.

Reviewed By: ergawy

Differential Revision: https://reviews.llvm.org/D106886
This commit is contained in:
Lei Zhang 2021-07-28 10:30:54 -04:00
parent aa6340cf87
commit 23326b9f17
11 changed files with 312 additions and 213 deletions

View File

@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
#define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/BuiltinOps.h"

View File

@ -467,8 +467,9 @@ def SPV_ModuleOp : SPV_Op<"module",
let builders = [
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
OpBuilder<(ins "spirv::AddressingModel":$addressing_model,
"spirv::MemoryModel":$memory_model,
CArg<"Optional<StringRef>", "llvm::None">:$name)>
"spirv::MemoryModel":$memory_model,
CArg<"Optional<spirv::VerCapExtAttr>", "llvm::None">:$vce_triple,
CArg<"Optional<StringRef>", "llvm::None">:$name)>
];
// We need to ensure the block inside the region is properly terminated;

View File

@ -22,53 +22,54 @@ class OpBuilder;
namespace spirv {
class ModuleOp;
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
/// The listener function to receive symbol renaming events.
///
/// `originalModule` is the input spirv::ModuleOp that contains the renamed
/// symbol. `oldSymbol` and `newSymbol` are the original and renamed symbol.
/// Note that it's the responsibility of the caller to properly retain the
/// storage underlying the passed StringRefs if the listener callback outlives
/// this function call.
using SymbolRenameListener = function_ref<void(
spirv::ModuleOp originalModule, StringRef oldSymbol, StringRef newSymbol)>;
/// Combines a list of SPIR-V `inputModules` into one. Returns the combined
/// module on success; returns a null module otherwise.
//
/// \param inputModules the list of modules to combine. They won't be modified.
/// \param combinedMdouleBuilder an OpBuilder for building the combined module.
/// \param symbRenameListener a listener that gets called everytime a symbol in
/// one of the input modules is renamed.
///
/// To combine multiple SPIR-V modules, we move all the module-level ops
/// from all the input modules into one big combined module. To that end, the
/// combination process proceeds in 2 phases:
///
/// (1) resolve conflicts between pairs of ops from different modules
/// (2) deduplicate equivalent ops/sub-ops in the merged module.
/// 1. resolve conflicts between pairs of ops from different modules,
/// 2. deduplicate equivalent ops/sub-ops in the merged module.
///
/// For the conflict resolution phase, the following rules are employed to
/// resolve such conflicts:
///
/// - If 2 spv.func's have the same symbol name, then rename one of the
/// - If 2 spv.func's have the same symbol name, then rename one of the
/// functions.
/// - If an spv.func and another op have the same symbol name, then rename the
/// - If an spv.func and another op have the same symbol name, then rename the
/// other symbol.
/// - If none of the 2 conflicting ops are spv.func, then rename either.
/// - If none of the 2 conflicting ops are spv.func, then rename either.
///
/// For deduplication, the following 3 cases are taken into consideration:
///
/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding
/// or the same build_in attribute value, then replace one of them using the
/// other.
/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then
/// replace one of them using the other.
/// - If 2 spv.func's are identical replace one of them using the other.
/// - Deduplicating functions are not supported right now.
///
/// In all cases, the references to the updated symbol (whether renamed or
/// deduplicated) are also updated to reflect the change.
///
/// \param modules the list of modules to combine. Input modules are not
/// modified.
/// \param combinedMdouleBuilder an OpBuilder to be used for
// building up the combined module.
/// \param symbRenameListener a listener that gets called everytime a symbol in
/// one of the input modules is renamed. The arguments
/// passed to the listener are: the input
/// spirv::ModuleOp that contains the renamed symbol,
/// a StringRef to the old symbol name, and a
/// StringRef to the new symbol name. Note that it is
/// the responsibility of the caller to properly
/// retain the storage underlying the passed
/// StringRefs if the listener callback outlives this
/// function call.
///
/// \return the combined module.
OwningOpRef<spirv::ModuleOp>
combine(MutableArrayRef<ModuleOp> modules, OpBuilder &combinedModuleBuilder,
function_ref<void(ModuleOp, StringRef, StringRef)> symbRenameListener);
OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
OpBuilder &combinedModuleBuilder,
SymbolRenameListener symRenameListener);
} // namespace spirv
} // namespace mlir

View File

@ -310,7 +310,7 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
auto spvModule = rewriter.create<spirv::ModuleOp>(
moduleOp.getLoc(), addressingModel, memoryModel.getValue(),
moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.

View File

@ -2540,6 +2540,7 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
spirv::AddressingModel addressingModel,
spirv::MemoryModel memoryModel,
Optional<VerCapExtAttr> vceTriple,
Optional<StringRef> name) {
state.addAttribute(
"addressing_model",
@ -2548,10 +2549,11 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
static_cast<int32_t>(memoryModel)));
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
}
if (vceTriple)
state.addAttribute(getVCETripleAttrName(), *vceTriple);
if (name)
state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {

View File

@ -12,27 +12,33 @@
#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
using namespace mlir;
static constexpr unsigned maxFreeID = 1 << 20;
/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
spirv::ModuleOp combinedModule) {
spirv::ModuleOp module) {
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
while (lastUsedID < maxFreeID) {
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
if (!SymbolTable::lookupSymbolIn(module, possible)) {
newSymName += llvm::utostr(lastUsedID);
break;
}
@ -41,8 +47,8 @@ static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
return newSymName;
}
/// Check if a symbol with the same name as op already exists in source. If so,
/// rename op and update all its references in target.
/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
spirv::ModuleOp target,
spirv::ModuleOp source,
@ -61,99 +67,67 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
return success();
}
template <typename KeyTy, typename SymbolOpTy>
static SymbolOpTy
emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
auto result = deduplicationMap.try_emplace(key, symbolOp);
if (result.second)
return SymbolOpTy();
return result.first->second;
}
/// Computes a hash code to represent the argument SymbolOpInterface based on
/// all the Op's attributes except for the symbol name.
///
/// \return the hash code computed from the Op's attributes as described above.
/// Computes a hash code to represent `symbolOp` based on all its attributes
/// except for the symbol name.
///
/// Note: We use the operation's name (not the symbol name) as part of the hash
/// computation. This prevents, for example, mistakenly considering a global
/// variable and a spec constant as duplicates because their descriptor set +
/// binding and spec_id, respectively, happen to hash to the same value.
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
llvm::hash_code hashCode(0);
hashCode = llvm::hash_combine(symbolOp->getName());
auto range =
llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
return attr.first != SymbolTable::getSymbolAttrName();
});
for (auto attr : symbolOp->getAttrs()) {
if (attr.first == SymbolTable::getSymbolAttrName())
continue;
hashCode = llvm::hash_combine(hashCode, attr);
}
return hashCode;
}
/// Computes a hash code from the argument Block.
llvm::hash_code computeHash(Block *block) {
// TODO: Consider extracting BlockEquivalenceData into a common header and
// re-using it here.
llvm::hash_code hash(0);
for (Operation &op : *block) {
// TODO: Properly handle operations with regions.
if (op.getNumRegions() > 0)
return 0;
hash = llvm::hash_combine(
hash, OperationEquivalence::computeHash(
&op, OperationEquivalence::Flags::IgnoreOperands));
}
return hash;
return llvm::hash_combine(
symbolOp->getName(),
llvm::hash_combine_range(range.begin(), range.end()));
}
namespace mlir {
namespace spirv {
// TODO Properly test symbol rename listener mechanism.
OwningOpRef<spirv::ModuleOp>
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
OpBuilder &combinedModuleBuilder,
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
symRenameListener) {
unsigned lastUsedID = 0;
if (modules.empty())
OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
OpBuilder &combinedModuleBuilder,
SymbolRenameListener symRenameListener) {
if (inputModules.empty())
return nullptr;
auto addressingModel = modules[0].addressing_model();
auto memoryModel = modules[0].memory_model();
spirv::ModuleOp firstModule = inputModules.front();
auto addressingModel = firstModule.addressing_model();
auto memoryModel = firstModule.memory_model();
auto vceTriple = firstModule.vce_triple();
// First check whether there are conflicts between addressing/memory model.
// Return early if so.
for (auto module : inputModules) {
if (module.addressing_model() != addressingModel ||
module.memory_model() != memoryModel ||
module.vce_triple() != vceTriple) {
module.emitError("input modules differ in addressing model, memory "
"model, and/or VCE triple");
return nullptr;
}
}
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
modules[0].getLoc(), addressingModel, memoryModel);
firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
// In some cases, a symbol in the (current state of the) combined module is
// renamed in order to maintain the conflicting symbol in the input module
// renamed in order to enable the conflicting symbol in the input module
// being merged. For example, if the conflict is between a global variable in
// the current combined module and a function in the input module, the global
// variable is renamed. In order to notify listeners of the symbol updates in
// such cases, we need to keep track of the module from which the renamed
// symbol in the combined module originated. This map keeps such information.
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
for (auto module : modules) {
if (module.addressing_model() != addressingModel ||
module.memory_model() != memoryModel) {
module.emitError(
"input modules differ in addressing model and/or memory model");
return nullptr;
}
unsigned lastUsedID = 0;
spirv::ModuleOp moduleClone = module.clone();
for (auto inputModule : inputModules) {
spirv::ModuleOp moduleClone = inputModule.clone();
// In the combined module, rename all symbols that conflict with symbols
// from the current input module. This renaming applies to all ops except
@ -161,65 +135,70 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
// non-spv.func, we rename that symbol instead and maintain the spv.func in
// the combined module name as it is.
for (auto &op : *combinedModule.getBody()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
if (!isa<FuncOp>(op) &&
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
lastUsedID)))
StringRef oldSymName = symbolOp.getName();
if (!isa<FuncOp>(op) &&
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
lastUsedID)))
return nullptr;
StringRef newSymName = symbolOp.getName();
if (symRenameListener && oldSymName != newSymName) {
spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
if (!originalModule) {
inputModule.emitError(
"unable to find original spirv::ModuleOp for symbol ")
<< oldSymName;
return nullptr;
StringRef newSymName = symbolOp.getName();
if (symRenameListener && oldSymName != newSymName) {
spirv::ModuleOp originalModule =
symNameToModuleMap.lookup(oldSymName);
if (!originalModule) {
module.emitError("unable to find original ModuleOp for symbol ")
<< oldSymName;
return nullptr;
}
symRenameListener(originalModule, oldSymName, newSymName);
// Since the symbol name is updated, there is no need to maintain the
// entry that associates the old symbol name with the original module.
symNameToModuleMap.erase(oldSymName);
// Instead, add a new entry to map the new symbol name to the original
// module in case it gets renamed again later.
symNameToModuleMap[newSymName] = originalModule;
}
symRenameListener(originalModule, oldSymName, newSymName);
// Since the symbol name is updated, there is no need to maintain the
// entry that associates the old symbol name with the original module.
symNameToModuleMap.erase(oldSymName);
// Instead, add a new entry to map the new symbol name to the original
// module in case it gets renamed again later.
symNameToModuleMap[newSymName] = originalModule;
}
}
// In the current input module, rename all symbols that conflict with
// symbols from the combined module. This includes renaming spv.funcs.
for (auto &op : *moduleClone.getBody()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
lastUsedID)))
StringRef oldSymName = symbolOp.getName();
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
lastUsedID)))
return nullptr;
StringRef newSymName = symbolOp.getName();
if (symRenameListener) {
if (oldSymName != newSymName)
symRenameListener(inputModule, oldSymName, newSymName);
// Insert the module associated with the symbol name.
auto emplaceResult =
symNameToModuleMap.try_emplace(newSymName, inputModule);
// If an entry with the same symbol name is already present, this must
// be a problem with the implementation, specially clean-up of the map
// while iterating over the combined module above.
if (!emplaceResult.second) {
inputModule.emitError("did not expect to find an entry for symbol ")
<< symbolOp.getName();
return nullptr;
StringRef newSymName = symbolOp.getName();
if (symRenameListener && oldSymName != newSymName) {
symRenameListener(module, oldSymName, newSymName);
// Insert the module associated with the symbol name.
auto emplaceResult =
symNameToModuleMap.try_emplace(symbolOp.getName(), module);
// If an entry with the same symbol name is already present, this must
// be a problem with the implementation, specially clean-up of the map
// while iterating over the combined module above.
if (!emplaceResult.second) {
module.emitError("did not expect to find an entry for symbol ")
<< symbolOp.getName();
return nullptr;
}
}
}
}
@ -234,30 +213,26 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
SmallVector<SymbolOpInterface, 0> eraseList;
for (auto &op : *combinedModule.getBody()) {
llvm::hash_code hashCode(0);
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
hashCode = computeHash(symbolOp);
// A 0 hash code means the op is not suitable for deduplication and should
// be skipped. An example of this is when a function has ops with regions
// which are not properly supported yet.
if (!hashCode)
// Do not support ops with operands or results.
// Global variables, spec constants, and functions won't have
// operands/results, but just for safety here.
if (op.getNumOperands() != 0 || op.getNumResults() != 0)
continue;
if (auto funcOp = dyn_cast<FuncOp>(op))
for (auto &blk : funcOp)
hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
SymbolOpInterface replacementSymOp =
emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
if (!replacementSymOp)
// Deduplicating functions are not supported yet.
if (isa<FuncOp>(op))
continue;
auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
if (result.second)
continue;
SymbolOpInterface replacementSymOp = result.first->second;
if (failed(SymbolTable::replaceAllSymbolUses(
symbolOp, replacementSymOp.getName(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")

View File

@ -1,9 +1,19 @@
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
// Combine modules without the same symbols
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.SpecConstant @m1_sc
// CHECK-NEXT: spv.GlobalVariable @m1_gv bind(1, 0)
// CHECK-NEXT: spv.func @no_op
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: spv.EntryPoint "GLCompute" @no_op
// CHECK-NEXT: spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
// CHECK-NEXT: spv.SpecConstant @m2_sc
// CHECK-NEXT: spv.GlobalVariable @m2_gv bind(0, 1)
// CHECK-NEXT: spv.func @variable_init_spec_constant
// CHECK-NEXT: spv.mlir.referenceof @m2_sc
// CHECK-NEXT: spv.Variable init
@ -15,10 +25,17 @@
module {
spv.module Logical GLSL450 {
spv.SpecConstant @m1_sc = 42.42 : f32
spv.GlobalVariable @m1_gv bind(1, 0): !spv.ptr<f32, Input>
spv.func @no_op() -> () "None" {
spv.Return
}
spv.EntryPoint "GLCompute" @no_op
spv.ExecutionMode @no_op "LocalSize", 32, 1, 1
}
spv.module Logical GLSL450 {
spv.SpecConstant @m2_sc = 42 : i32
spv.GlobalVariable @m2_gv bind(0, 1): !spv.ptr<f32, Input>
spv.func @variable_init_spec_constant() -> () "None" {
%0 = spv.mlir.referenceof @m2_sc : i32
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
@ -33,7 +50,7 @@ module {
spv.module Physical64 GLSL450 {
}
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}}
spv.module Logical GLSL450 {
}
}
@ -44,7 +61,19 @@ module {
spv.module Logical Simple {
}
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}}
spv.module Logical GLSL450 {
}
}
// -----
module {
spv.module Logical GLSL450 {
}
// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}}
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]> {
}
}

View File

@ -215,7 +215,7 @@ spv.module Logical GLSL450 {
spv.func @foo(%arg0 : i32) -> i32 "None" {
spv.ReturnValue %arg0 : i32
}
spv.EntryPoint "GLCompute" @foo
spv.ExecutionMode @foo "ContractionOff"
}
@ -383,7 +383,7 @@ spv.module Logical GLSL450 {
spv.SpecConstant @foo = -5 : i32
spv.func @bar() -> i32 "None" {
%0 = spv.mlir.referenceof @foo : i32
%0 = spv.mlir.referenceof @foo : i32
spv.ReturnValue %0 : i32
}
}

View File

@ -21,7 +21,6 @@
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
@ -42,7 +41,6 @@ spv.module Logical GLSL450 {
spv.ReturnValue %2 : f32
}
}
}
// -----
@ -62,7 +60,6 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
}
@ -76,7 +73,6 @@ spv.module Logical GLSL450 {
spv.ReturnValue %1 : f32
}
}
}
// -----
@ -93,7 +89,6 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.GlobalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
@ -107,10 +102,11 @@ spv.module Logical GLSL450 {
spv.ReturnValue %1 : vector<3xi32>
}
}
}
// -----
// Deduplicate 2 spec constants with the same spec ID.
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.SpecConstant @foo spec_id(5)
@ -128,7 +124,6 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.SpecConstant @foo spec_id(5) = 1. : f32
@ -147,48 +142,82 @@ spv.module Logical GLSL450 {
spv.ReturnValue %1 : f32
}
}
// -----
// Don't deduplicate functions with similar ops but different operands.
// CHECK: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.func @foo(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG1]] : f32
// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG2]] : f32
// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @foo_1(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32)
// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG2]] : f32
// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG1]] : f32
// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32
// CHECK-NEXT: }
// CHECK-NEXT: }
spv.module Logical GLSL450 {
spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
%add = spv.FAdd %a, %b: f32
%mul = spv.FMul %add, %c: f32
spv.ReturnValue %mul: f32
}
}
spv.module Logical GLSL450 {
spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" {
%add = spv.FAdd %a, %c: f32
%mul = spv.FMul %add, %b: f32
spv.ReturnValue %mul: f32
}
}
// -----
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.SpecConstant @bar spec_id(5)
// TODO: re-enable this test once we have better function deduplication.
// CHECK-NEXT: spv.func @foo(%arg0: f32)
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// XXXXX: module {
// XXXXX-NEXT: spv.module Logical GLSL450 {
// XXXXX-NEXT: spv.SpecConstant @bar spec_id(5)
// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32)
// CHECK-NEXT: spv.mlir.referenceof
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @foo(%arg0: f32)
// XXXXX-NEXT: spv.ReturnValue
// XXXXX-NEXT: }
// CHECK-NEXT: spv.func @baz(%arg0: i32)
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @foo_different_body(%arg0: f32)
// XXXXX-NEXT: spv.mlir.referenceof
// XXXXX-NEXT: spv.ReturnValue
// XXXXX-NEXT: }
// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32)
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @baz(%arg0: i32)
// XXXXX-NEXT: spv.ReturnValue
// XXXXX-NEXT: }
// CHECK-NEXT: spv.func @baz_no_return_different_control
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @baz_no_return(%arg0: i32)
// XXXXX-NEXT: spv.Return
// XXXXX-NEXT: }
// CHECK-NEXT: spv.func @baz_no_return_another_control
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @baz_no_return_different_control
// XXXXX-NEXT: spv.Return
// XXXXX-NEXT: }
// CHECK-NEXT: spv.func @kernel
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @baz_no_return_another_control
// XXXXX-NEXT: spv.Return
// XXXXX-NEXT: }
// CHECK-NEXT: spv.func @kernel_different_attr
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// XXXXX-NEXT: spv.func @kernel
// XXXXX-NEXT: spv.Return
// XXXXX-NEXT: }
// XXXXX-NEXT: spv.func @kernel_different_attr
// XXXXX-NEXT: spv.Return
// XXXXX-NEXT: }
// XXXXX-NEXT: }
// XXXXX-NEXT: }
module {
spv.module Logical GLSL450 {

View File

@ -0,0 +1,54 @@
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
module {
spv.module @Module1 Logical GLSL450 {
spv.GlobalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
spv.func @bar() -> () "None" {
spv.Return
}
spv.func @baz() -> () "None" {
spv.Return
}
spv.SpecConstant @sc = -5 : i32
}
spv.module @Module2 Logical GLSL450 {
spv.func @foo() -> () "None" {
spv.Return
}
spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
spv.func @baz() -> () "None" {
spv.Return
}
spv.SpecConstant @sc = -5 : i32
}
spv.module @Module3 Logical GLSL450 {
spv.func @foo() -> () "None" {
spv.Return
}
spv.GlobalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
spv.func @baz() -> () "None" {
spv.Return
}
spv.SpecConstant @sc = -5 : i32
}
}
// CHECK: [Module1] foo -> foo_1
// CHECK: [Module1] sc -> sc_2
// CHECK: [Module2] bar -> bar_3
// CHECK: [Module2] baz -> baz_4
// CHECK: [Module2] sc -> sc_5
// CHECK: [Module3] foo -> foo_6
// CHECK: [Module3] bar -> bar_7
// CHECK: [Module3] baz -> baz_8

View File

@ -37,7 +37,14 @@ void TestModuleCombinerPass::runOnOperation() {
auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
OpBuilder combinedModuleBuilder(modules[0]);
combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
auto listener = [](spirv::ModuleOp originalModule, StringRef oldSymbol,
StringRef newSymbol) {
llvm::outs() << "[" << originalModule.getName() << "] " << oldSymbol
<< " -> " << newSymbol << "\n";
};
combinedModule = spirv::combine(modules, combinedModuleBuilder, listener);
for (spirv::ModuleOp module : modules)
module.erase();