forked from OSchip/llvm-project
[MLIR][SPIRV] Start module combiner.
This commit adds a new library that merges/combines a number of spv modules into a combined one. The library has a single entry point: combine(...). To combine a number of MLIR spv modules, we move all the module-level ops from all the input modules into one big combined module. To that end, the combination process can proceed in 2 phases: (1) resolving conflicts between pairs of ops from different modules (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) This patch implements only the first phase. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D90477
This commit is contained in:
parent
13a56ca5a9
commit
27324f2855
|
@ -0,0 +1,69 @@
|
|||
//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares the entry point to the SPIR-V module combiner library.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
|
||||
#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
class OpBuilder;
|
||||
|
||||
namespace spirv {
|
||||
class ModuleOp;
|
||||
|
||||
/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
|
||||
/// from all the input modules into one big combined module. To that end, the
|
||||
/// combination process proceeds in 2 phases:
|
||||
///
|
||||
/// (1) resolve conflicts between pairs of ops from different modules
|
||||
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
|
||||
///
|
||||
/// For the conflict resolution phase, the following rules are employed to
|
||||
/// resolve such conflicts:
|
||||
///
|
||||
/// - If 2 spv.func's have the same symbol name, then rename one of the
|
||||
/// functions.
|
||||
/// - If an spv.func and another op have the same symbol name, then rename the
|
||||
/// other symbol.
|
||||
/// - If none of the 2 conflicting ops are spv.func, then rename either.
|
||||
///
|
||||
/// In all cases, the references to the updated symbol are also updated to
|
||||
/// reflect the change.
|
||||
///
|
||||
/// \param modules the list of modules to combine. Input modules are not
|
||||
/// modified.
|
||||
/// \param combinedMdouleBuilder an OpBuilder to be used for
|
||||
/// building up the combined module.
|
||||
/// \param symbRenameListener a listener that gets called everytime a symbol in
|
||||
/// one of the input modules is renamed. The arguments
|
||||
/// passed to the listener are: the input
|
||||
/// spirv::ModuleOp that contains the renamed symbol,
|
||||
/// a StringRef to the old symbol name, and a
|
||||
/// StringRef to the new symbol name. Note that it is
|
||||
/// the responsibility of the caller to properly
|
||||
/// retain the storage underlying the passed
|
||||
/// StringRefs if the listener callback outlives this
|
||||
/// function call.
|
||||
///
|
||||
/// \return the combined module.
|
||||
OwningSPIRVModuleRef
|
||||
combine(llvm::MutableArrayRef<ModuleOp> modules,
|
||||
OpBuilder &combinedModuleBuilder,
|
||||
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
|
||||
symbRenameListener);
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
|
|
@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV
|
|||
MLIRTransforms
|
||||
)
|
||||
|
||||
add_subdirectory(Linking)
|
||||
add_subdirectory(Serialization)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(ModuleCombiner)
|
|
@ -0,0 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRSPIRVModuleCombiner
|
||||
ModuleCombiner.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
|
||||
)
|
|
@ -0,0 +1,181 @@
|
|||
//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements the the SPIR-V module combiner library.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static constexpr unsigned maxFreeID = 1 << 20;
|
||||
|
||||
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
|
||||
spirv::ModuleOp combinedModule) {
|
||||
SmallString<64> newSymName(oldSymName);
|
||||
newSymName.push_back('_');
|
||||
|
||||
while (lastUsedID < maxFreeID) {
|
||||
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
|
||||
|
||||
if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
|
||||
newSymName += llvm::utostr(lastUsedID);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return newSymName;
|
||||
}
|
||||
|
||||
/// Check if a symbol with the same name as op already exists in source. If so,
|
||||
/// rename op and update all its references in target.
|
||||
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
|
||||
spirv::ModuleOp target,
|
||||
spirv::ModuleOp source,
|
||||
unsigned &lastUsedID) {
|
||||
if (!SymbolTable::lookupSymbolIn(source, op.getName()))
|
||||
return success();
|
||||
|
||||
StringRef oldSymName = op.getName();
|
||||
SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
|
||||
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
|
||||
return op.emitError("unable to update all symbol uses for ")
|
||||
<< oldSymName << " to " << newSymName;
|
||||
|
||||
SymbolTable::setSymbolName(op, newSymName);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
// TODO Properly test symbol rename listener mechanism.
|
||||
|
||||
OwningSPIRVModuleRef
|
||||
combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
||||
OpBuilder &combinedModuleBuilder,
|
||||
llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
|
||||
symRenameListener) {
|
||||
unsigned lastUsedID = 0;
|
||||
|
||||
if (modules.empty())
|
||||
return nullptr;
|
||||
|
||||
auto addressingModel = modules[0].addressing_model();
|
||||
auto memoryModel = modules[0].memory_model();
|
||||
|
||||
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
|
||||
modules[0].getLoc(), addressingModel, memoryModel);
|
||||
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
|
||||
|
||||
// In some cases, a symbol in the (current state of the) combined module is
|
||||
// renamed in order to maintain the conflicting symbol in the input module
|
||||
// being merged. For example, if the conflict is between a global variable in
|
||||
// the current combined module and a function in the input module, the global
|
||||
// varaible is renamed. In order to notify listeners of the symbol updates in
|
||||
// such cases, we need to keep track of the module from which the renamed
|
||||
// symbol in the combined module originated. This map keeps such information.
|
||||
DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
|
||||
|
||||
for (auto module : modules) {
|
||||
if (module.addressing_model() != addressingModel ||
|
||||
module.memory_model() != memoryModel) {
|
||||
module.emitError(
|
||||
"input modules differ in addressing model and/or memory model");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
spirv::ModuleOp moduleClone = module.clone();
|
||||
|
||||
// In the combined module, rename all symbols that conflict with symbols
|
||||
// from the current input module. This renmaing applies to all ops except
|
||||
// for spv.funcs. This way, if the conflicting op in the input module is
|
||||
// non-spv.func, we rename that symbol instead and maintain the spv.func in
|
||||
// the combined module name as it is.
|
||||
for (auto &op : combinedModule.getBlock().without_terminator()) {
|
||||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
|
||||
if (!isa<FuncOp>(op) &&
|
||||
failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
|
||||
lastUsedID)))
|
||||
return nullptr;
|
||||
|
||||
StringRef newSymName = symbolOp.getName();
|
||||
|
||||
if (symRenameListener && oldSymName != newSymName) {
|
||||
spirv::ModuleOp originalModule =
|
||||
symNameToModuleMap.lookup(oldSymName);
|
||||
|
||||
if (!originalModule) {
|
||||
module.emitError("unable to find original ModuleOp for symbol ")
|
||||
<< oldSymName;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
symRenameListener(originalModule, oldSymName, newSymName);
|
||||
|
||||
// Since the symbol name is updated, there is no need to maintain the
|
||||
// entry that assocaites the old symbol name with the original module.
|
||||
symNameToModuleMap.erase(oldSymName);
|
||||
// Instead, add a new entry to map the new symbol name to the original
|
||||
// module in case it gets renamed again later.
|
||||
symNameToModuleMap[newSymName] = originalModule;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In the current input module, rename all symbols that conflict with
|
||||
// symbols from the combined module. This includes renaming spv.funcs.
|
||||
for (auto &op : moduleClone.getBlock().without_terminator()) {
|
||||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
|
||||
if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
|
||||
lastUsedID)))
|
||||
return nullptr;
|
||||
|
||||
StringRef newSymName = symbolOp.getName();
|
||||
|
||||
if (symRenameListener && oldSymName != newSymName) {
|
||||
symRenameListener(module, oldSymName, newSymName);
|
||||
|
||||
// Insert the module associated with the symbol name.
|
||||
auto emplaceResult =
|
||||
symNameToModuleMap.try_emplace(symbolOp.getName(), module);
|
||||
|
||||
// If an entry with the same symbol name is already present, this must
|
||||
// be a problem with the implementation, specially clean-up of the map
|
||||
// while iterating over the combined module above.
|
||||
if (!emplaceResult.second) {
|
||||
module.emitError("did not expect to find an entry for symbol ")
|
||||
<< symbolOp.getName();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone all the module's ops to the combined module.
|
||||
for (auto &op : moduleClone.getBlock().without_terminator())
|
||||
combinedModuleBuilder.insert(op.clone());
|
||||
}
|
||||
|
||||
return combinedModule;
|
||||
}
|
||||
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
|
@ -0,0 +1,50 @@
|
|||
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @m1_sc
|
||||
// CHECK-NEXT: spv.specConstant @m2_sc
|
||||
// CHECK-NEXT: spv.func @variable_init_spec_constant
|
||||
// CHECK-NEXT: spv._reference_of @m2_sc
|
||||
// CHECK-NEXT: spv.Variable init
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @m1_sc = 42.42 : f32
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @m2_sc = 42 : i32
|
||||
spv.func @variable_init_spec_constant() -> () "None" {
|
||||
%0 = spv._reference_of @m2_sc : i32
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
spv.module Physical64 GLSL450 {
|
||||
}
|
||||
|
||||
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
|
||||
spv.module Logical GLSL450 {
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
spv.module Logical Simple {
|
||||
}
|
||||
|
||||
// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
|
||||
spv.module Logical GLSL450 {
|
||||
}
|
||||
}
|
|
@ -0,0 +1,682 @@
|
|||
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// Test basic renaming of conflicting funcOps.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test basic renaming of conflicting funcOps across 3 modules.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_2
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test properly updating references to a renamed funcOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv.FunctionCall @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
|
||||
spv.func @bar(%arg0 : f32) -> f32 "None" {
|
||||
%0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32)
|
||||
spv.ReturnValue %0 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test properly updating references to a renamed funcOp if the functionCallOp
|
||||
// preceeds the callee funcOp definition.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv.FunctionCall @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @bar(%arg0 : f32) -> f32 "None" {
|
||||
%0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32)
|
||||
spv.ReturnValue %0 : f32
|
||||
}
|
||||
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test properly updating entryPointOp and executionModeOp attached to renamed
|
||||
// funcOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1
|
||||
// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff"
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
|
||||
spv.EntryPoint "GLCompute" @foo
|
||||
spv.ExecutionMode @foo "ContractionOff"
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.EntryPoint "GLCompute" @fo
|
||||
// CHECK-NEXT: spv.ExecutionMode @foo "ContractionOff"
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1
|
||||
// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff"
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
|
||||
spv.EntryPoint "GLCompute" @foo
|
||||
spv.ExecutionMode @foo "ContractionOff"
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
|
||||
spv.EntryPoint "GLCompute" @foo
|
||||
spv.ExecutionMode @foo "ContractionOff"
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting funcOp and globalVariableOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting funcOp and globalVariableOp and update the global variable's
|
||||
// references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv._address_of @foo_1
|
||||
// CHECK-NEXT: spv.Load
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @bar() -> f32 "None" {
|
||||
%0 = spv._address_of @foo : !spv.ptr<f32, Input>
|
||||
%1 = spv.Load "Input" %0 : f32
|
||||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting globalVariableOp and funcOp and update the global variable's
|
||||
// references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv._address_of @foo_1
|
||||
// CHECK-NEXT: spv.Load
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @bar() -> f32 "None" {
|
||||
%0 = spv._address_of @foo : !spv.ptr<f32, Input>
|
||||
%1 = spv.Load "Input" %0 : f32
|
||||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting funcOp and specConstantOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.specConstant @foo_1
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @foo = -5 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting funcOp and specConstantOp and update the spec constant's
|
||||
// references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.specConstant @foo_1
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv._reference_of @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @foo = -5 : i32
|
||||
|
||||
spv.func @bar() -> i32 "None" {
|
||||
%0 = spv._reference_of @foo : i32
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting specConstantOp and funcOp and update the spec constant's
|
||||
// references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @foo_1
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv._reference_of @foo_1
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @foo = -5 : i32
|
||||
|
||||
spv.func @bar() -> i32 "None" {
|
||||
%0 = spv._reference_of @foo : i32
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting funcOp and specConstantCompositeOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.specConstant @bar
|
||||
// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar = -5 : i32
|
||||
spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting funcOp and specConstantCompositeOp and update the spec
|
||||
// constant's references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.specConstant @bar
|
||||
// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
|
||||
// CHECK-NEXT: spv.func @baz
|
||||
// CHECK-NEXT: spv._reference_of @foo_1
|
||||
// CHECK-NEXT: spv.CompositeExtract
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar = -5 : i32
|
||||
spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
|
||||
|
||||
spv.func @baz() -> i32 "None" {
|
||||
%0 = spv._reference_of @foo : !spv.array<2 x i32>
|
||||
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
|
||||
spv.ReturnValue %1 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting specConstantCompositeOp and funcOp and update the spec
|
||||
// constant's references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @bar
|
||||
// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
|
||||
// CHECK-NEXT: spv.func @baz
|
||||
// CHECK-NEXT: spv._reference_of @foo_1
|
||||
// CHECK-NEXT: spv.CompositeExtract
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar = -5 : i32
|
||||
spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
|
||||
|
||||
spv.func @baz() -> i32 "None" {
|
||||
%0 = spv._reference_of @foo : !spv.array<2 x i32>
|
||||
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
|
||||
spv.ReturnValue %1 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting spec constants and funcOps and update the spec constant's
|
||||
// references.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @bar_1
|
||||
// CHECK-NEXT: spv.specConstantComposite @foo_2 (@bar_1, @bar_1)
|
||||
// CHECK-NEXT: spv.func @baz
|
||||
// CHECK-NEXT: spv._reference_of @foo_2
|
||||
// CHECK-NEXT: spv.CompositeExtract
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @bar
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar = -5 : i32
|
||||
spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
|
||||
|
||||
spv.func @baz() -> i32 "None" {
|
||||
%0 = spv._reference_of @foo : !spv.array<2 x i32>
|
||||
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
|
||||
spv.ReturnValue %1 : i32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
|
||||
spv.func @bar(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting globalVariableOps.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting globalVariableOp and specConstantOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
|
||||
// CHECK-NEXT: spv.specConstant @foo
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @foo = -5 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting specConstantOp and globalVariableOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @foo_1
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @foo = -5 : i32
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting globalVariableOp and specConstantCompositeOp.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
|
||||
// CHECK-NEXT: spv.specConstant @bar
|
||||
// CHECK-NEXT: spv.specConstantComposite @foo (@bar, @bar)
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar = -5 : i32
|
||||
spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Resolve conflicting globalVariableOp and specConstantComposite.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @bar
|
||||
// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar)
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar = -5 : i32
|
||||
spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
add_mlir_library(MLIRSPIRVTestPasses
|
||||
TestAvailability.cpp
|
||||
TestEntryPointAbi.cpp
|
||||
TestModuleCombiner.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
@ -14,5 +15,6 @@ add_mlir_library(MLIRSPIRVTestPasses
|
|||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRSPIRV
|
||||
MLIRSPIRVModuleCombiner
|
||||
MLIRSupport
|
||||
)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
//===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class TestModuleCombinerPass
|
||||
: public PassWrapper<TestModuleCombinerPass,
|
||||
OperationPass<mlir::ModuleOp>> {
|
||||
public:
|
||||
TestModuleCombinerPass() = default;
|
||||
TestModuleCombinerPass(const TestModuleCombinerPass &) {}
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
mlir::spirv::OwningSPIRVModuleRef combinedModule;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TestModuleCombinerPass::runOnOperation() {
|
||||
auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
|
||||
|
||||
OpBuilder combinedModuleBuilder(modules[0]);
|
||||
combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
|
||||
|
||||
for (spirv::ModuleOp module : modules)
|
||||
module.erase();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestSpirvModuleCombinerPass() {
|
||||
PassRegistration<TestModuleCombinerPass> registration(
|
||||
"test-spirv-module-combiner", "Tests SPIR-V module combiner library");
|
||||
}
|
||||
} // namespace mlir
|
|
@ -79,6 +79,7 @@ void registerTestPrintNestingPass();
|
|||
void registerTestRecursiveTypesPass();
|
||||
void registerTestReducer();
|
||||
void registerTestSpirvEntryPointABIPass();
|
||||
void registerTestSpirvModuleCombinerPass();
|
||||
void registerTestSCFUtilsPass();
|
||||
void registerTestTraitsPass();
|
||||
void registerTestVectorConversions();
|
||||
|
@ -140,6 +141,7 @@ void registerTestPasses() {
|
|||
registerTestReducer();
|
||||
registerTestGpuParallelLoopMappingPass();
|
||||
registerTestSpirvEntryPointABIPass();
|
||||
registerTestSpirvModuleCombinerPass();
|
||||
registerTestSCFUtilsPass();
|
||||
registerTestTraitsPass();
|
||||
registerTestVectorConversions();
|
||||
|
|
Loading…
Reference in New Issue