diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h new file mode 100644 index 000000000000..b7ecd57d103d --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h @@ -0,0 +1,69 @@ +//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the entry point to the SPIR-V module combiner library. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ +#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ + +#include "mlir/Dialect/SPIRV/SPIRVModule.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +class OpBuilder; + +namespace spirv { +class ModuleOp; + +/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops +/// from all the input modules into one big combined module. To that end, the +/// combination process proceeds in 2 phases: +/// +/// (1) resolve conflicts between pairs of ops from different modules +/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) +/// +/// For the conflict resolution phase, the following rules are employed to +/// resolve such conflicts: +/// +/// - If 2 spv.func's have the same symbol name, then rename one of the +/// functions. +/// - If an spv.func and another op have the same symbol name, then rename the +/// other symbol. +/// - If none of the 2 conflicting ops are spv.func, then rename either. +/// +/// In all cases, the references to the updated symbol are also updated to +/// reflect the change. +/// +/// \param modules the list of modules to combine. Input modules are not +/// modified. +/// \param combinedMdouleBuilder an OpBuilder to be used for +/// building up the combined module. +/// \param symbRenameListener a listener that gets called everytime a symbol in +/// one of the input modules is renamed. The arguments +/// passed to the listener are: the input +/// spirv::ModuleOp that contains the renamed symbol, +/// a StringRef to the old symbol name, and a +/// StringRef to the new symbol name. Note that it is +/// the responsibility of the caller to properly +/// retain the storage underlying the passed +/// StringRefs if the listener callback outlives this +/// function call. +/// +/// \return the combined module. +OwningSPIRVModuleRef +combine(llvm::MutableArrayRef modules, + OpBuilder &combinedModuleBuilder, + llvm::function_ref + symbRenameListener); +} // namespace spirv +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index 10f06fdb8861..f37182121fed 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV MLIRTransforms ) +add_subdirectory(Linking) add_subdirectory(Serialization) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt new file mode 100644 index 000000000000..4cc016812701 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ModuleCombiner) diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt new file mode 100644 index 000000000000..22756fab23e5 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect_library(MLIRSPIRVModuleCombiner + ModuleCombiner.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ) diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp new file mode 100644 index 000000000000..7687ab27e753 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -0,0 +1,181 @@ +//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the the SPIR-V module combiner library. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/ModuleCombiner.h" + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringExtras.h" + +using namespace mlir; + +static constexpr unsigned maxFreeID = 1 << 20; + +static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, + spirv::ModuleOp combinedModule) { + SmallString<64> newSymName(oldSymName); + newSymName.push_back('_'); + + while (lastUsedID < maxFreeID) { + std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); + + if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { + newSymName += llvm::utostr(lastUsedID); + break; + } + } + + return newSymName; +} + +/// Check if a symbol with the same name as op already exists in source. If so, +/// rename op and update all its references in target. +static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, + spirv::ModuleOp target, + spirv::ModuleOp source, + unsigned &lastUsedID) { + if (!SymbolTable::lookupSymbolIn(source, op.getName())) + return success(); + + StringRef oldSymName = op.getName(); + SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); + + if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) + return op.emitError("unable to update all symbol uses for ") + << oldSymName << " to " << newSymName; + + SymbolTable::setSymbolName(op, newSymName); + return success(); +} + +namespace mlir { +namespace spirv { + +// TODO Properly test symbol rename listener mechanism. + +OwningSPIRVModuleRef +combine(llvm::MutableArrayRef modules, + OpBuilder &combinedModuleBuilder, + llvm::function_ref + symRenameListener) { + unsigned lastUsedID = 0; + + if (modules.empty()) + return nullptr; + + auto addressingModel = modules[0].addressing_model(); + auto memoryModel = modules[0].memory_model(); + + auto combinedModule = combinedModuleBuilder.create( + modules[0].getLoc(), addressingModel, memoryModel); + combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); + + // In some cases, a symbol in the (current state of the) combined module is + // renamed in order to maintain the conflicting symbol in the input module + // being merged. For example, if the conflict is between a global variable in + // the current combined module and a function in the input module, the global + // varaible is renamed. In order to notify listeners of the symbol updates in + // such cases, we need to keep track of the module from which the renamed + // symbol in the combined module originated. This map keeps such information. + DenseMap symNameToModuleMap; + + for (auto module : modules) { + if (module.addressing_model() != addressingModel || + module.memory_model() != memoryModel) { + module.emitError( + "input modules differ in addressing model and/or memory model"); + return nullptr; + } + + spirv::ModuleOp moduleClone = module.clone(); + + // In the combined module, rename all symbols that conflict with symbols + // from the current input module. This renmaing applies to all ops except + // for spv.funcs. This way, if the conflicting op in the input module is + // non-spv.func, we rename that symbol instead and maintain the spv.func in + // the combined module name as it is. + for (auto &op : combinedModule.getBlock().without_terminator()) { + if (auto symbolOp = dyn_cast(op)) { + StringRef oldSymName = symbolOp.getName(); + + if (!isa(op) && + failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, + lastUsedID))) + return nullptr; + + StringRef newSymName = symbolOp.getName(); + + if (symRenameListener && oldSymName != newSymName) { + spirv::ModuleOp originalModule = + symNameToModuleMap.lookup(oldSymName); + + if (!originalModule) { + module.emitError("unable to find original ModuleOp for symbol ") + << oldSymName; + return nullptr; + } + + symRenameListener(originalModule, oldSymName, newSymName); + + // Since the symbol name is updated, there is no need to maintain the + // entry that assocaites the old symbol name with the original module. + symNameToModuleMap.erase(oldSymName); + // Instead, add a new entry to map the new symbol name to the original + // module in case it gets renamed again later. + symNameToModuleMap[newSymName] = originalModule; + } + } + } + + // In the current input module, rename all symbols that conflict with + // symbols from the combined module. This includes renaming spv.funcs. + for (auto &op : moduleClone.getBlock().without_terminator()) { + if (auto symbolOp = dyn_cast(op)) { + StringRef oldSymName = symbolOp.getName(); + + if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, + lastUsedID))) + return nullptr; + + StringRef newSymName = symbolOp.getName(); + + if (symRenameListener && oldSymName != newSymName) { + symRenameListener(module, oldSymName, newSymName); + + // Insert the module associated with the symbol name. + auto emplaceResult = + symNameToModuleMap.try_emplace(symbolOp.getName(), module); + + // If an entry with the same symbol name is already present, this must + // be a problem with the implementation, specially clean-up of the map + // while iterating over the combined module above. + if (!emplaceResult.second) { + module.emitError("did not expect to find an entry for symbol ") + << symbolOp.getName(); + return nullptr; + } + } + } + } + + // Clone all the module's ops to the combined module. + for (auto &op : moduleClone.getBlock().without_terminator()) + combinedModuleBuilder.insert(op.clone()); + } + + return combinedModule; +} + +} // namespace spirv +} // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir new file mode 100644 index 000000000000..07fd41e4fe86 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @m1_sc +// CHECK-NEXT: spv.specConstant @m2_sc +// CHECK-NEXT: spv.func @variable_init_spec_constant +// CHECK-NEXT: spv._reference_of @m2_sc +// CHECK-NEXT: spv.Variable init +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @m1_sc = 42.42 : f32 +} + +spv.module Logical GLSL450 { + spv.specConstant @m2_sc = 42 : i32 + spv.func @variable_init_spec_constant() -> () "None" { + %0 = spv._reference_of @m2_sc : i32 + %1 = spv.Variable init(%0) : !spv.ptr + spv.Return + } +} +} + +// ----- + +module { +spv.module Physical64 GLSL450 { +} + +// expected-error @+1 {{input modules differ in addressing model and/or memory model}} +spv.module Logical GLSL450 { +} +} + +// ----- + +module { +spv.module Logical Simple { +} + +// expected-error @+1 {{input modules differ in addressing model and/or memory model}} +spv.module Logical GLSL450 { +} +} diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir new file mode 100644 index 000000000000..f5535c483171 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir @@ -0,0 +1,682 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +// Test basic renaming of conflicting funcOps. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} +} + +// ----- + +// Test basic renaming of conflicting funcOps across 3 modules. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_2 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Test properly updating references to a renamed funcOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv.FunctionCall @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.func @bar(%arg0 : f32) -> f32 "None" { + %0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32) + spv.ReturnValue %0 : f32 + } +} +} + +// ----- + +// Test properly updating references to a renamed funcOp if the functionCallOp +// preceeds the callee funcOp definition. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv.FunctionCall @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @bar(%arg0 : f32) -> f32 "None" { + %0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32) + spv.ReturnValue %0 : f32 + } + + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} +} + +// ----- + +// Test properly updating entryPointOp and executionModeOp attached to renamed +// funcOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1 +// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff" +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.EntryPoint "GLCompute" @foo + spv.ExecutionMode @foo "ContractionOff" +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.EntryPoint "GLCompute" @fo +// CHECK-NEXT: spv.ExecutionMode @foo "ContractionOff" + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1 +// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff" +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } + + spv.EntryPoint "GLCompute" @foo + spv.ExecutionMode @foo "ContractionOff" +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.EntryPoint "GLCompute" @foo + spv.ExecutionMode @foo "ContractionOff" +} +} + +// ----- + +// Resolve conflicting funcOp and globalVariableOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} + +// ----- + +// Resolve conflicting funcOp and globalVariableOp and update the global variable's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._address_of @foo_1 +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr + + spv.func @bar() -> f32 "None" { + %0 = spv._address_of @foo : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// Resolve conflicting globalVariableOp and funcOp and update the global variable's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._address_of @foo_1 +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr + + spv.func @bar() -> f32 "None" { + %0 = spv._address_of @foo : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantOp and update the spec constant's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 + + spv.func @bar() -> i32 "None" { + %0 = spv._reference_of @foo : i32 + spv.ReturnValue %0 : i32 + } +} +} + +// ----- + +// Resolve conflicting specConstantOp and funcOp and update the spec constant's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 + + spv.func @bar() -> i32 "None" { + %0 = spv._reference_of @foo : i32 + spv.ReturnValue %0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantCompositeOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantCompositeOp and update the spec +// constant's references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) +// CHECK-NEXT: spv.func @baz +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.CompositeExtract +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> + + spv.func @baz() -> i32 "None" { + %0 = spv._reference_of @foo : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + spv.ReturnValue %1 : i32 + } +} +} + +// ----- + +// Resolve conflicting specConstantCompositeOp and funcOp and update the spec +// constant's references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) +// CHECK-NEXT: spv.func @baz +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.CompositeExtract +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> + + spv.func @baz() -> i32 "None" { + %0 = spv._reference_of @foo : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + spv.ReturnValue %1 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Resolve conflicting spec constants and funcOps and update the spec constant's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar_1 +// CHECK-NEXT: spv.specConstantComposite @foo_2 (@bar_1, @bar_1) +// CHECK-NEXT: spv.func @baz +// CHECK-NEXT: spv._reference_of @foo_2 +// CHECK-NEXT: spv.CompositeExtract +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> + + spv.func @baz() -> i32 "None" { + %0 = spv._reference_of @foo : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + spv.ReturnValue %1 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } + + spv.func @bar(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} +} + +// ----- + +// Resolve conflicting globalVariableOps. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 + +// CHECK-NEXT: spv.globalVariable @foo +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} + +// ----- + +// Resolve conflicting globalVariableOp and specConstantOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 + +// CHECK-NEXT: spv.specConstant @foo +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 +} +} + +// ----- + +// Resolve conflicting specConstantOp and globalVariableOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @foo_1 + +// CHECK-NEXT: spv.globalVariable @foo +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} + +// ----- + +// Resolve conflicting globalVariableOp and specConstantCompositeOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 + +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo (@bar, @bar) +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> +} +} + +// ----- + +// Resolve conflicting globalVariableOp and specConstantComposite. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) + +// CHECK-NEXT: spv.globalVariable @foo +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} diff --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt index 204a63337730..6c74d2f26357 100644 --- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSPIRVTestPasses TestAvailability.cpp TestEntryPointAbi.cpp + TestModuleCombiner.cpp EXCLUDE_FROM_LIBMLIR @@ -14,5 +15,6 @@ add_mlir_library(MLIRSPIRVTestPasses MLIRIR MLIRPass MLIRSPIRV + MLIRSPIRVModuleCombiner MLIRSupport ) diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp new file mode 100644 index 000000000000..b321954c87f3 --- /dev/null +++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp @@ -0,0 +1,48 @@ +//===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/ModuleCombiner.h" + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +class TestModuleCombinerPass + : public PassWrapper> { +public: + TestModuleCombinerPass() = default; + TestModuleCombinerPass(const TestModuleCombinerPass &) {} + void runOnOperation() override; + +private: + mlir::spirv::OwningSPIRVModuleRef combinedModule; +}; +} // namespace + +void TestModuleCombinerPass::runOnOperation() { + auto modules = llvm::to_vector<4>(getOperation().getOps()); + + OpBuilder combinedModuleBuilder(modules[0]); + combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr); + + for (spirv::ModuleOp module : modules) + module.erase(); +} + +namespace mlir { +void registerTestSpirvModuleCombinerPass() { + PassRegistration registration( + "test-spirv-module-combiner", "Tests SPIR-V module combiner library"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 196bda69dbaf..b5506a5a34a0 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -79,6 +79,7 @@ void registerTestPrintNestingPass(); void registerTestRecursiveTypesPass(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); +void registerTestSpirvModuleCombinerPass(); void registerTestSCFUtilsPass(); void registerTestTraitsPass(); void registerTestVectorConversions(); @@ -140,6 +141,7 @@ void registerTestPasses() { registerTestReducer(); registerTestGpuParallelLoopMappingPass(); registerTestSpirvEntryPointABIPass(); + registerTestSpirvModuleCombinerPass(); registerTestSCFUtilsPass(); registerTestTraitsPass(); registerTestVectorConversions();