[MLIR][SPIRV] ModuleCombiner: deduplicate global vars, spec consts, and funcs.

This commit extends the functionality of the SPIR-V module combiner
library by adding new deduplication capabilities. In particular,
implementation of deduplication of global variables and specialization
constants, and functions is introduced.

For global variables, 2 variables are considered duplicate if they either
have the same descriptor set + binding or the same built_in attribute.

For specialization constants, 2 spec constants are considered duplicate if
they have the same spec_id attribute.

2 functions are deduplicated if they are identical. 2 functions are
identical if they have the same prototype, attributes, and body.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D90951
This commit is contained in:
ergawy 2020-11-19 10:05:55 -05:00 committed by Lei Zhang
parent 9bd50abc4c
commit 341f3c1120
4 changed files with 382 additions and 9 deletions

View File

@ -28,7 +28,7 @@ class ModuleOp;
/// combination process proceeds in 2 phases:
///
/// (1) resolve conflicts between pairs of ops from different modules
/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
/// (2) deduplicate equivalent ops/sub-ops in the merged module.
///
/// For the conflict resolution phase, the following rules are employed to
/// resolve such conflicts:
@ -39,13 +39,22 @@ class ModuleOp;
/// other symbol.
/// - If none of the 2 conflicting ops are spv.func, then rename either.
///
/// In all cases, the references to the updated symbol are also updated to
/// reflect the change.
/// For deduplication, the following 3 cases are taken into consideration:
///
/// - If 2 spv.globalVariable's have either the same descriptor set + binding
/// or the same build_in attribute value, then replace one of them using the
/// other.
/// - If 2 spv.specConstant's have the same spec_id attribute value, then
/// replace one of them using the other.
/// - If 2 spv.func's are identical replace one of them using the other.
///
/// In all cases, the references to the updated symbol (whether renamed or
/// deduplicated) are also updated to reflect the change.
///
/// \param modules the list of modules to combine. Input modules are not
/// modified.
/// \param combinedMdouleBuilder an OpBuilder to be used for
/// building up the combined module.
// building up the combined module.
/// \param symbRenameListener a listener that gets called everytime a symbol in
/// one of the input modules is renamed. The arguments
/// passed to the listener are: the input

View File

@ -12,10 +12,12 @@
#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
@ -59,6 +61,59 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
return success();
}
template <typename KeyTy, typename SymbolOpTy>
static SymbolOpTy
emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
auto result = deduplicationMap.try_emplace(key, symbolOp);
if (result.second)
return SymbolOpTy();
return result.first->second;
}
/// Computes a hash code to represent the argument SymbolOpInterface based on
/// all the Op's attributes except for the symbol name.
///
/// \return the hash code computed from the Op's attributes as described above.
///
/// Note: We use the operation's name (not the symbol name) as part of the hash
/// computation. This prevents, for example, mistakenly considering a global
/// variable and a spec constant as duplicates because their descriptor set +
/// binding and spec_id, repectively, happen to hash to the same value.
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
llvm::hash_code hashCode(0);
hashCode = llvm::hash_combine(symbolOp.getOperation()->getName());
for (auto attr : symbolOp.getOperation()->getAttrs()) {
if (attr.first == SymbolTable::getSymbolAttrName())
continue;
hashCode = llvm::hash_combine(hashCode, attr);
}
return hashCode;
}
/// Computes a hash code from the argument Block.
llvm::hash_code computeHash(Block *block) {
// TODO: Consider extracting BlockEquivalenceData into a common header and
// re-using it here.
llvm::hash_code hash(0);
for (Operation &op : *block) {
// TODO: Properly handle operations with regions.
if (op.getNumRegions() > 0)
return 0;
hash = llvm::hash_combine(
hash, OperationEquivalence::computeHash(
&op, OperationEquivalence::Flags::IgnoreOperands));
}
return hash;
}
namespace mlir {
namespace spirv {
@ -174,6 +229,48 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
combinedModuleBuilder.insert(op.clone());
}
// Deduplicate identical global variables, spec constants, and functions.
DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
SmallVector<SymbolOpInterface, 0> eraseList;
for (auto &op : combinedModule.getBlock().without_terminator()) {
llvm::hash_code hashCode(0);
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
hashCode = computeHash(symbolOp);
// A 0 hash code means the op is not suitable for deduplication and should
// be skipped. An example of this is when a function has ops with regions
// which are not properly supported yet.
if (!hashCode)
continue;
if (auto funcOp = dyn_cast<FuncOp>(op))
for (auto &blk : funcOp)
hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
SymbolOpInterface replacementSymOp =
emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
if (!replacementSymOp)
continue;
if (failed(SymbolTable::replaceAllSymbolUses(
symbolOp, replacementSymOp.getName(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")
<< symbolOp.getName() << " to " << replacementSymOp.getName();
return nullptr;
}
eraseList.push_back(symbolOp);
}
for (auto symbolOp : eraseList)
symbolOp.erase();
return combinedModule;
}

View File

@ -39,10 +39,12 @@ spv.module Logical GLSL450 {
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @foo_1
// CHECK-NEXT: spv.FAdd
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @foo_2
// CHECK-NEXT: spv.ISub
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: }
@ -57,13 +59,15 @@ spv.module Logical GLSL450 {
spv.module Logical GLSL450 {
spv.func @foo(%arg0 : f32) -> f32 "None" {
spv.ReturnValue %arg0 : f32
%0 = spv.FAdd %arg0, %arg0 : f32
spv.ReturnValue %0 : f32
}
}
spv.module Logical GLSL450 {
spv.func @foo(%arg0 : i32) -> i32 "None" {
spv.ReturnValue %arg0 : i32
%0 = spv.ISub %arg0, %arg0 : i32
spv.ReturnValue %0 : i32
}
}
}
@ -578,9 +582,9 @@ spv.module Logical GLSL450 {
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.globalVariable @foo_1
// CHECK-NEXT: spv.globalVariable @foo_1 bind(1, 0)
// CHECK-NEXT: spv.globalVariable @foo
// CHECK-NEXT: spv.globalVariable @foo bind(2, 0)
// CHECK-NEXT: }
module {
@ -589,7 +593,26 @@ spv.module Logical GLSL450 {
}
spv.module Logical GLSL450 {
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
spv.globalVariable @foo bind(2, 0) : !spv.ptr<f32, Input>
}
}
// -----
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.globalVariable @foo_1 built_in("GlobalInvocationId")
// CHECK-NEXT: spv.globalVariable @foo built_in("LocalInvocationId")
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
spv.module Logical GLSL450 {
spv.globalVariable @foo built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
}

View File

@ -0,0 +1,244 @@
// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
// Deduplicate 2 global variables with the same descriptor set and binding.
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.globalVariable @foo
// CHECK-NEXT: spv.func @use_foo
// CHECK-NEXT: spv.mlir.addressof @foo
// CHECK-NEXT: spv.Load
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @use_bar
// CHECK-NEXT: spv.mlir.addressof @foo
// CHECK-NEXT: spv.Load
// CHECK-NEXT: spv.FAdd
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
spv.func @use_foo() -> f32 "None" {
%0 = spv.mlir.addressof @foo : !spv.ptr<f32, Input>
%1 = spv.Load "Input" %0 : f32
spv.ReturnValue %1 : f32
}
}
spv.module Logical GLSL450 {
spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
spv.func @use_bar() -> f32 "None" {
%0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
%1 = spv.Load "Input" %0 : f32
%2 = spv.FAdd %1, %1 : f32
spv.ReturnValue %2 : f32
}
}
}
// -----
// Deduplicate 2 global variables with the same descriptor set and binding but different types.
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.globalVariable @foo bind(1, 0)
// CHECK-NEXT: spv.globalVariable @bar bind(1, 0)
// CHECK-NEXT: spv.func @use_bar
// CHECK-NEXT: spv.mlir.addressof @bar
// CHECK-NEXT: spv.Load
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.globalVariable @foo bind(1, 0) : !spv.ptr<i32, Input>
}
spv.module Logical GLSL450 {
spv.globalVariable @bar bind(1, 0) : !spv.ptr<f32, Input>
spv.func @use_bar() -> f32 "None" {
%0 = spv.mlir.addressof @bar : !spv.ptr<f32, Input>
%1 = spv.Load "Input" %0 : f32
spv.ReturnValue %1 : f32
}
}
}
// -----
// Deduplicate 2 global variables with the same built-in attribute.
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.globalVariable @foo built_in("GlobalInvocationId")
// CHECK-NEXT: spv.func @use_bar
// CHECK-NEXT: spv.mlir.addressof @foo
// CHECK-NEXT: spv.Load
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
spv.module Logical GLSL450 {
spv.globalVariable @bar built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
spv.func @use_bar() -> vector<3xi32> "None" {
%0 = spv.mlir.addressof @bar : !spv.ptr<vector<3xi32>, Input>
%1 = spv.Load "Input" %0 : vector<3xi32>
spv.ReturnValue %1 : vector<3xi32>
}
}
}
// -----
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.specConstant @foo spec_id(5)
// CHECK-NEXT: spv.func @use_foo()
// CHECK-NEXT: %0 = spv.mlir.referenceof @foo
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @use_bar()
// CHECK-NEXT: %0 = spv.mlir.referenceof @foo
// CHECK-NEXT: spv.FAdd
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.specConstant @foo spec_id(5) = 1. : f32
spv.func @use_foo() -> (f32) "None" {
%0 = spv.mlir.referenceof @foo : f32
spv.ReturnValue %0 : f32
}
}
spv.module Logical GLSL450 {
spv.specConstant @bar spec_id(5) = 1. : f32
spv.func @use_bar() -> (f32) "None" {
%0 = spv.mlir.referenceof @bar : f32
%1 = spv.FAdd %0, %0 : f32
spv.ReturnValue %1 : f32
}
}
}
// -----
// CHECK: module {
// CHECK-NEXT: spv.module Logical GLSL450 {
// CHECK-NEXT: spv.specConstant @bar spec_id(5)
// CHECK-NEXT: spv.func @foo(%arg0: f32)
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32)
// CHECK-NEXT: spv.mlir.referenceof
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @baz(%arg0: i32)
// CHECK-NEXT: spv.ReturnValue
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32)
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @baz_no_return_different_control
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @baz_no_return_another_control
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @kernel
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: spv.func @kernel_different_attr
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
spv.module Logical GLSL450 {
spv.specConstant @bar spec_id(5) = 1. : f32
spv.func @foo(%arg0: f32) -> (f32) "None" {
spv.ReturnValue %arg0 : f32
}
spv.func @foo_duplicate(%arg0: f32) -> (f32) "None" {
spv.ReturnValue %arg0 : f32
}
spv.func @foo_different_body(%arg0: f32) -> (f32) "None" {
%0 = spv.mlir.referenceof @bar : f32
spv.ReturnValue %arg0 : f32
}
spv.func @baz(%arg0: i32) -> (i32) "None" {
spv.ReturnValue %arg0 : i32
}
spv.func @baz_no_return(%arg0: i32) "None" {
spv.Return
}
spv.func @baz_no_return_duplicate(%arg0: i32) -> () "None" {
spv.Return
}
spv.func @baz_no_return_different_control(%arg0: i32) -> () "Inline" {
spv.Return
}
spv.func @baz_no_return_another_control(%arg0: i32) -> () "Inline|Pure" {
spv.Return
}
spv.func @kernel(
%arg0: f32,
%arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
spv.Return
}
spv.func @kernel_different_attr(
%arg0: f32,
%arg1: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>) "None"
attributes {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}} {
spv.Return
}
}
}