forked from OSchip/llvm-project
[MLIR][SPIRV] ModuleCombiner: deduplicate global vars, spec consts, and funcs.
This commit extends the functionality of the SPIR-V module combiner library by adding new deduplication capabilities. In particular, implementation of deduplication of global variables and specialization constants, and functions is introduced. For global variables, 2 variables are considered duplicate if they either have the same descriptor set + binding or the same built_in attribute. For specialization constants, 2 spec constants are considered duplicate if they have the same spec_id attribute. 2 functions are deduplicated if they are identical. 2 functions are identical if they have the same prototype, attributes, and body. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D90951
This commit is contained in:
parent
9bd50abc4c
commit
341f3c1120
|
@ -28,7 +28,7 @@ class ModuleOp;
|
|||
/// combination process proceeds in 2 phases:
|
||||
///
|
||||
/// (1) resolve conflicts between pairs of ops from different modules
|
||||
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
|
||||
/// (2) deduplicate equivalent ops/sub-ops in the merged module.
|
||||
///
|
||||
/// For the conflict resolution phase, the following rules are employed to
|
||||
/// resolve such conflicts:
|
||||
|
@ -39,13 +39,22 @@ class ModuleOp;
|
|||
/// other symbol.
|
||||
/// - If none of the 2 conflicting ops are spv.func, then rename either.
|
||||
///
|
||||
/// In all cases, the references to the updated symbol are also updated to
|
||||
/// reflect the change.
|
||||
/// For deduplication, the following 3 cases are taken into consideration:
|
||||
///
|
||||
/// - If 2 spv.globalVariable's have either the same descriptor set + binding
|
||||
/// or the same build_in attribute value, then replace one of them using the
|
||||
/// other.
|
||||
/// - If 2 spv.specConstant's have the same spec_id attribute value, then
|
||||
/// replace one of them using the other.
|
||||
/// - If 2 spv.func's are identical replace one of them using the other.
|
||||
///
|
||||
/// In all cases, the references to the updated symbol (whether renamed or
|
||||
/// deduplicated) are also updated to reflect the change.
|
||||
///
|
||||
/// \param modules the list of modules to combine. Input modules are not
|
||||
/// modified.
|
||||
/// \param combinedMdouleBuilder an OpBuilder to be used for
|
||||
/// building up the combined module.
|
||||
// building up the combined module.
|
||||
/// \param symbRenameListener a listener that gets called everytime a symbol in
|
||||
/// one of the input modules is renamed. The arguments
|
||||
/// passed to the listener are: the input
|
||||
|
|
|
@ -12,10 +12,12 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -59,6 +61,59 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename KeyTy, typename SymbolOpTy>
|
||||
static SymbolOpTy
|
||||
emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
|
||||
DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
|
||||
auto result = deduplicationMap.try_emplace(key, symbolOp);
|
||||
|
||||
if (result.second)
|
||||
return SymbolOpTy();
|
||||
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
/// Computes a hash code to represent the argument SymbolOpInterface based on
|
||||
/// all the Op's attributes except for the symbol name.
|
||||
///
|
||||
/// \return the hash code computed from the Op's attributes as described above.
|
||||
///
|
||||
/// Note: We use the operation's name (not the symbol name) as part of the hash
|
||||
/// computation. This prevents, for example, mistakenly considering a global
|
||||
/// variable and a spec constant as duplicates because their descriptor set +
|
||||
/// binding and spec_id, repectively, happen to hash to the same value.
|
||||
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
|
||||
llvm::hash_code hashCode(0);
|
||||
hashCode = llvm::hash_combine(symbolOp.getOperation()->getName());
|
||||
|
||||
for (auto attr : symbolOp.getOperation()->getAttrs()) {
|
||||
if (attr.first == SymbolTable::getSymbolAttrName())
|
||||
continue;
|
||||
hashCode = llvm::hash_combine(hashCode, attr);
|
||||
}
|
||||
|
||||
return hashCode;
|
||||
}
|
||||
|
||||
/// Computes a hash code from the argument Block.
|
||||
llvm::hash_code computeHash(Block *block) {
|
||||
// TODO: Consider extracting BlockEquivalenceData into a common header and
|
||||
// re-using it here.
|
||||
llvm::hash_code hash(0);
|
||||
|
||||
for (Operation &op : *block) {
|
||||
// TODO: Properly handle operations with regions.
|
||||
if (op.getNumRegions() > 0)
|
||||
return 0;
|
||||
|
||||
hash = llvm::hash_combine(
|
||||
hash, OperationEquivalence::computeHash(
|
||||
&op, OperationEquivalence::Flags::IgnoreOperands));
|
||||
}
|
||||
|
||||
return hash;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
|
@ -174,6 +229,48 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
combinedModuleBuilder.insert(op.clone());
|
||||
}
|
||||
|
||||
// Deduplicate identical global variables, spec constants, and functions.
|
||||
DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
|
||||
SmallVector<SymbolOpInterface, 0> eraseList;
|
||||
|
||||
for (auto &op : combinedModule.getBlock().without_terminator()) {
|
||||
llvm::hash_code hashCode(0);
|
||||
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
|
||||
|
||||
if (!symbolOp)
|
||||
continue;
|
||||
|
||||
hashCode = computeHash(symbolOp);
|
||||
|
||||
// A 0 hash code means the op is not suitable for deduplication and should
|
||||
// be skipped. An example of this is when a function has ops with regions
|
||||
// which are not properly supported yet.
|
||||
if (!hashCode)
|
||||
continue;
|
||||
|
||||
if (auto funcOp = dyn_cast<FuncOp>(op))
|
||||
for (auto &blk : funcOp)
|
||||
hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
|
||||
|
||||
SymbolOpInterface replacementSymOp =
|
||||
emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
|
||||
|
||||
if (!replacementSymOp)
|
||||
continue;
|
||||
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(
|
||||
symbolOp, replacementSymOp.getName(), combinedModule))) {
|
||||
symbolOp.emitError("unable to update all symbol uses for ")
|
||||
<< symbolOp.getName() << " to " << replacementSymOp.getName();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
eraseList.push_back(symbolOp);
|
||||
}
|
||||
|
||||
for (auto symbolOp : eraseList)
|
||||
symbolOp.erase();
|
||||
|
||||
return combinedModule;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,10 +39,12 @@ spv.module Logical GLSL450 {
|
|||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_1
|
||||
// CHECK-NEXT: spv.FAdd
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_2
|
||||
// CHECK-NEXT: spv.ISub
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
@ -57,13 +59,15 @@ spv.module Logical GLSL450 {
|
|||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : f32) -> f32 "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
%0 = spv.FAdd %arg0, %arg0 : f32
|
||||
spv.ReturnValue %0 : f32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.func @foo(%arg0 : i32) -> i32 "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
%0 = spv.ISub %arg0, %arg0 : i32
|
||||
spv.ReturnValue %0 : i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -578,9 +582,9 @@ spv.module Logical GLSL450 {
|
|||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1 bind(1, 0)
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo
|
||||
// CHECK-NEXT: spv.globalVariable @foo bind(2, 0)
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
|
@ -589,7 +593,26 @@ spv.module Logical GLSL450 {
|
|||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
spv.globalVariable @foo bind(2, 0) : !spv.ptr<f32, Input>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo_1 built_in("GlobalInvocationId")
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @foo built_in("LocalInvocationId")
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// Deduplicate 2 global variables with the same descriptor set and binding.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo
|
||||
|
||||
// CHECK-NEXT: spv.func @use_foo
|
||||
// CHECK-NEXT: spv.mlir.addressof @foo
|
||||
// CHECK-NEXT: spv.Load
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @use_bar
|
||||
// CHECK-NEXT: spv.mlir.addressof @foo
|
||||
// CHECK-NEXT: spv.Load
|
||||
// CHECK-NEXT: spv.FAdd
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @use_foo() -> f32 "None" {
|
||||
%0 = spv.mlir.addressof @foo : !spv.ptr<f32, Input>
|
||||
%1 = spv.Load "Input" %0 : f32
|
||||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @use_bar() -> f32 "None" {
|
||||
%0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
|
||||
%1 = spv.Load "Input" %0 : f32
|
||||
%2 = spv.FAdd %1, %1 : f32
|
||||
spv.ReturnValue %2 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Deduplicate 2 global variables with the same descriptor set and binding but different types.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo bind(1, 0)
|
||||
|
||||
// CHECK-NEXT: spv.globalVariable @bar bind(1, 0)
|
||||
|
||||
// CHECK-NEXT: spv.func @use_bar
|
||||
// CHECK-NEXT: spv.mlir.addressof @bar
|
||||
// CHECK-NEXT: spv.Load
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
|
||||
|
||||
spv.func @use_bar() -> f32 "None" {
|
||||
%0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
|
||||
%1 = spv.Load "Input" %0 : f32
|
||||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Deduplicate 2 global variables with the same built-in attribute.
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.globalVariable @foo built_in("GlobalInvocationId")
|
||||
// CHECK-NEXT: spv.func @use_bar
|
||||
// CHECK-NEXT: spv.mlir.addressof @foo
|
||||
// CHECK-NEXT: spv.Load
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.globalVariable @bar built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||
|
||||
spv.func @use_bar() -> vector<3xi32> "None" {
|
||||
%0 = spv.mlir.addressof @bar : !spv.ptr<vector<3xi32>, Input>
|
||||
%1 = spv.Load "Input" %0 : vector<3xi32>
|
||||
spv.ReturnValue %1 : vector<3xi32>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @foo spec_id(5)
|
||||
|
||||
// CHECK-NEXT: spv.func @use_foo()
|
||||
// CHECK-NEXT: %0 = spv.mlir.referenceof @foo
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @use_bar()
|
||||
// CHECK-NEXT: %0 = spv.mlir.referenceof @foo
|
||||
// CHECK-NEXT: spv.FAdd
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @foo spec_id(5) = 1. : f32
|
||||
|
||||
spv.func @use_foo() -> (f32) "None" {
|
||||
%0 = spv.mlir.referenceof @foo : f32
|
||||
spv.ReturnValue %0 : f32
|
||||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar spec_id(5) = 1. : f32
|
||||
|
||||
spv.func @use_bar() -> (f32) "None" {
|
||||
%0 = spv.mlir.referenceof @bar : f32
|
||||
%1 = spv.FAdd %0, %0 : f32
|
||||
spv.ReturnValue %1 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: spv.module Logical GLSL450 {
|
||||
// CHECK-NEXT: spv.specConstant @bar spec_id(5)
|
||||
|
||||
// CHECK-NEXT: spv.func @foo(%arg0: f32)
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32)
|
||||
// CHECK-NEXT: spv.mlir.referenceof
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz(%arg0: i32)
|
||||
// CHECK-NEXT: spv.ReturnValue
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32)
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz_no_return_different_control
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @baz_no_return_another_control
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @kernel
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// CHECK-NEXT: spv.func @kernel_different_attr
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
module {
|
||||
spv.module Logical GLSL450 {
|
||||
spv.specConstant @bar spec_id(5) = 1. : f32
|
||||
|
||||
spv.func @foo(%arg0: f32) -> (f32) "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
|
||||
spv.func @foo_duplicate(%arg0: f32) -> (f32) "None" {
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
|
||||
spv.func @foo_different_body(%arg0: f32) -> (f32) "None" {
|
||||
%0 = spv.mlir.referenceof @bar : f32
|
||||
spv.ReturnValue %arg0 : f32
|
||||
}
|
||||
|
||||
spv.func @baz(%arg0: i32) -> (i32) "None" {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
|
||||
spv.func @baz_no_return(%arg0: i32) "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @baz_no_return_duplicate(%arg0: i32) -> () "None" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @baz_no_return_different_control(%arg0: i32) -> () "Inline" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @baz_no_return_another_control(%arg0: i32) -> () "Inline|Pure" {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @kernel(
|
||||
%arg0: f32,
|
||||
%arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
|
||||
attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
spv.func @kernel_different_attr(
|
||||
%arg0: f32,
|
||||
%arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
|
||||
attributes {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}} {
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue