[mlir][SPIRVToLLVM] Propagate location attribute from spv.GlobalVariable to llvm.mlir.global

This patch is mainly to propogate location attribute from spv.GlobalVariable to llvm.mlir.global.

It also contains three small changes.

1. Remove the restriction on UniformConstant In SPIRVToLLVM.cpp;
2. Remove the errorCheck on relaxedPrecision when deserializering SPIR-V in Deserializer.cpp
3. In SPIRVOps.cpp, let ConstantOp take signedInteger too.

Co-authered: Alan Liu <alanliu.yf@gmail.com> and Xinyi Liu <xyliuhelen@gmail.com>

Reviewed by:antiagainst

Differential revision: https://reviews.llvm.org/D110207
This commit is contained in:
Weiwei Li 2021-10-05 00:04:33 +08:00
parent 3fe771bf02
commit 1e4cfe5e4f
7 changed files with 79 additions and 15 deletions

View File

@ -379,12 +379,25 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
let arguments = (ins
TypeAttr:$type,
StrAttr:$sym_name,
OptionalAttr<FlatSymbolRefAttr>:$initializer
OptionalAttr<FlatSymbolRefAttr>:$initializer,
OptionalAttr<I32Attr>:$location,
OptionalAttr<I32Attr>:$binding,
OptionalAttr<I32Attr>:$descriptorSet,
OptionalAttr<StrAttr>:$builtin
);
let results = (outs);
let builders = [
OpBuilder<(ins "TypeAttr":$type,
"StringAttr":$sym_name,
CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
[{
$_state.addAttribute("type", type);
$_state.addAttribute(sym_nameAttrName($_state.name), sym_name);
if (initializer)
$_state.addAttribute(initializerAttrName($_state.name), initializer);
}]>,
OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
[{
$_state.addAttribute("type", type);
@ -393,7 +406,16 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
OpBuilder<(ins "Type":$type, "StringRef":$name,
"unsigned":$descriptorSet, "unsigned":$binding)>,
OpBuilder<(ins "Type":$type, "StringRef":$name,
"spirv::BuiltIn":$builtin)>
"spirv::BuiltIn":$builtin)>,
OpBuilder<(ins "Type":$type,
"StringRef":$sym_name,
CArg<"FlatSymbolRefAttr", "{}">:$initializer),
[{
$_state.addAttribute("type", TypeAttr::get(type));
$_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name));
if (initializer)
$_state.addAttribute(initializerAttrName($_state.name), initializer);
}]>
];
let hasOpcode = 0;

View File

@ -733,17 +733,22 @@ public:
// required by SPIR-V runner.
// This is okay because multiple invocations are not supported yet.
auto storageClass = srcType.getStorageClass();
if (storageClass != spirv::StorageClass::Input &&
storageClass != spirv::StorageClass::Private &&
storageClass != spirv::StorageClass::Output &&
storageClass != spirv::StorageClass::StorageBuffer) {
switch (storageClass) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Private:
case spirv::StorageClass::Output:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::UniformConstant:
break;
default:
return failure();
}
// LLVM dialect spec: "If the global value is a constant, storing into it is
// not allowed.". This corresponds to SPIR-V 'Input' storage class that is
// read-only.
bool isConstant = storageClass == spirv::StorageClass::Input;
// not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
// storage class that is read-only.
bool isConstant = (storageClass == spirv::StorageClass::Input) ||
(storageClass == spirv::StorageClass::UniformConstant);
// SPIR-V spec: "By default, functions and global variables are private to a
// module and cannot be accessed by other modules. However, a module may be
// written to export or import functions and global (module scope)
@ -752,9 +757,14 @@ public:
auto linkage = storageClass == spirv::StorageClass::Private
? LLVM::Linkage::Private
: LLVM::Linkage::External;
rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
/*alignment=*/0);
// Attach location attribute if applicable
if (op.locationAttr())
newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
return success();
}
};

View File

@ -92,7 +92,12 @@ static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
if (!integerValueAttr) {
return failure();
}
value = integerValueAttr.getInt();
if (integerValueAttr.getType().isSignlessInteger())
value = integerValueAttr.getInt();
else
value = integerValueAttr.getSInt();
return success();
}
@ -2066,8 +2071,7 @@ Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
Type type, StringRef name,
unsigned descriptorSet, unsigned binding) {
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
nullptr);
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
state.addAttribute(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
builder.getI32IntegerAttr(descriptorSet));
@ -2079,8 +2083,7 @@ void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
Type type, StringRef name,
spirv::BuiltIn builtin) {
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
nullptr);
build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
state.addAttribute(
spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));

View File

@ -262,6 +262,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
case spirv::Decoration::Restrict:
case spirv::Decoration::RelaxedPrecision:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";

View File

@ -241,6 +241,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
case spirv::Decoration::Restrict:
case spirv::Decoration::RelaxedPrecision:
// For unit attributes, the args list has no values so we do nothing
if (auto unitAttr = attr.second.dyn_cast<UnitAttr>())
break;

View File

@ -67,6 +67,26 @@ spv.module @name Logical GLSL450 {
}
}
spv.module Logical GLSL450 {
// CHECK: llvm.mlir.global external @bar() {location = 1 : i32} : i32
// CHECK-LABEL: @foo
spv.GlobalVariable @bar {location = 1 : i32} : !spv.ptr<i32, Output>
spv.func @foo() "None" {
%0 = spv.mlir.addressof @bar : !spv.ptr<i32, Output>
spv.Return
}
}
spv.module Logical GLSL450 {
// CHECK: llvm.mlir.global external constant @bar() {location = 3 : i32} : f32
// CHECK-LABEL: @foo
spv.GlobalVariable @bar {descriptor_set = 0 : i32, location = 3 : i32} : !spv.ptr<f32, UniformConstant>
spv.func @foo() "None" {
%0 = spv.mlir.addressof @bar : !spv.ptr<f32, UniformConstant>
spv.Return
}
}
//===----------------------------------------------------------------------===//
// spv.Load
//===----------------------------------------------------------------------===//

View File

@ -49,3 +49,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.GlobalVariable @var bind(0, 0) {restrict} : !spv.ptr<!spv.struct<(!spv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK: relaxed_precision
spv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spv.ptr<vector<4xf32>, Output>
}