Add spirv::GlobalVariableOp that allows module level definition of variables

FuncOps in MLIR use explicit capture. So global variables defined in
module scope need to have a symbol name and this should be used to
refer to the variable within the function. This deviates from SPIR-V
spec, which assigns an SSA value to variables at all scopes that can
be used to refer to the variable, which requires SPIR-V functions to
allow implicit capture. To handle this add a new op,
spirv::GlobalVariableOp that can be used to define module scope
variables.
Since instructions need an SSA value, an new spirv::AddressOfOp is
added to convert a symbol reference to an SSA value for use with other
instructions.
This also means the spirv::EntryPointOp instruction needs to change to
allow initializers to be specified using symbol reference instead of
SSA value
The current spirv::VariableOp which returns an SSA value (as defined
by SPIR-V spec) can still be used to define function-scope variables.
PiperOrigin-RevId: 263951109
This commit is contained in:
Mahesh Ravishankar 2019-08-17 10:19:48 -07:00 committed by A. Unique TensorFlower
parent c268666f15
commit d745101339
12 changed files with 801 additions and 238 deletions

View File

@ -146,67 +146,6 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
// -----
def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
let summary = [{
Declare an entry point, its execution model, and its interface.
}];
let description = [{
Execution Model is the execution model for the entry point and its
static call tree. See Execution Model.
Entry Point must be the Result <id> of an OpFunction instruction.
Name is a name string for the entry point. A module cannot have two
OpEntryPoint instructions with the same Execution Model and the same
Name string.
Interface is a list of <id> of global OpVariable instructions. These
declare the set of global variables from a module that form the
interface of this entry point. The set of Interface <id> must be equal
to or a superset of the global OpVariable Result <id> referenced by the
entry points static call tree, within the interfaces storage classes.
Before version 1.4, the interfaces storage classes are limited to the
Input and Output storage classes. Starting with version 1.4, the
interfaces storage classes are all storage classes used in declaring
all global variables referenced by the entry points call tree.
Interface <id> are forward references. Before version 1.4, duplication
of these <id> is tolerated. Starting with version 1.4, an <id> must not
appear more than once.
### Custom assembly form
``` {.ebnf}
execution-model ::= "Vertex" | "TesellationControl" |
<and other SPIR-V execution models...>
entry-point-op ::= ssa-id ` = spv.EntryPoint ` execution-model fn-name
(ssa-use ( `, ` ssa-use)* ` : `
pointer-type ( `, ` pointer-type)* )?
```
For example:
```
spv.EntryPoint "GLCompute" @foo
spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
```
}];
let arguments = (ins
SPV_ExecutionModelAttr:$execution_model,
SymbolRefAttr:$fn,
Variadic<SPV_AnyPtr>:$interface
);
let results = (outs);
let autogenSerialization = 0;
}
// -----
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
let summary = "Declare an execution mode for an entry point.";

View File

@ -30,6 +30,160 @@
include "mlir/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE
def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
let summary = "Get the address of a global variable.";
let description = [{
Variables in module scope are defined using symbol names. This
instruction generates an SSA value that can be used to refer to
the symbol within function scope for use in instructions that
expect an SSA value. This operation has no equivalent SPIR-V
instruction. Since variables in module scope in SPIR-V dialect are
of pointer type, this instruction returns a pointer type as well,
and the type is same as the variable referenced.
### Custom assembly form
``` {.ebnf}
address-of-op ::= ssa-id `=` `spv.addressOf` `@`string-literal : pointer-type
```
For example:
```
%0 = spv.addressOf @var1 : !spv.ptr<f32, Input>
```
}];
let arguments = (ins
SymbolRefAttr:$variable
);
let results = (outs
SPV_AnyPtr:$pointer
);
let hasOpcode = 0;
}
def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
let summary = [{
Declare an entry point, its execution model, and its interface.
}];
let description = [{
Execution Model is the execution model for the entry point and its
static call tree. See Execution Model.
Entry Point must be the Result <id> of an OpFunction instruction.
Name is a name string for the entry point. A module cannot have two
OpEntryPoint instructions with the same Execution Model and the same
Name string.
Interface is a list of symbol references to spv.globalVariable
operations. These declare the set of global variables from a
module that form the interface of this entry point. The set of
Interface symbols must be equal to or a superset of the
spv.globalVariables referenced by the entry points static call
tree, within the interfaces storage classes. Before version 1.4,
the interfaces storage classes are limited to the Input and
Output storage classes. Starting with version 1.4, the interfaces
storage classes are all storage classes used in declaring all
global variables referenced by the entry points call tree.
### Custom assembly form
``` {.ebnf}
execution-model ::= "Vertex" | "TesellationControl" |
<and other SPIR-V execution models...>
entry-point-op ::= ssa-id `=` `spv.EntryPoint` execution-model
symbol-reference (`, ` symbol-reference)*
```
For example:
```
spv.EntryPoint "GLCompute" @foo
spv.EntryPoint "Kernel" @foo, @var1, @var2
```
}];
let arguments = (ins
SPV_ExecutionModelAttr:$execution_model,
SymbolRefAttr:$fn,
OptionalAttr<SymbolRefArrayAttr>:$interface
);
let results = (outs);
let autogenSerialization = 0;
}
def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> {
let summary = [{
Allocate an object in memory at module scope. The object is
referenced using a symbol name.
}];
let description = [{
The variable type must be an OpTypePointer. Its type operand is the type of
object in memory.
Storage Class is the Storage Class of the memory holding the object. It
cannot be Generic. It must be the same as the Storage Class operand of
the variable types. Only those storage classes that are valid at module
scope (like Input, Output, StorageBuffer, etc.) are valid.
Initializer is optional. If Initializer is present, it will be
the initial value of the variables memory content. Initializer
must be an symbol defined from a constant instruction or other
spv.globalVariable operation in module scope. Initializer must
have the same type as the type of the defined symbol.
### Custom assembly form
``` {.ebnf}
variable-op ::= `spv.globalVariable` spirv-type string-literal
(`initializer(` symbol-reference `)`)?
(`bind(` integer-literal, integer-literal `)`)?
(`built_in(` string-literal `)`)?
attribute-dict?
```
where `initializer` specifies initializer and `bind` specifies the
descriptor set and binding number. `built_in` specifies SPIR-V
BuiltIn decoration associated with the op.
For example:
```
spv.Variable !spv.ptr<f32, Input> @var0
spv.Variable !spv.ptr<f32, Output> @var2 initializer(@var0)
spv.Variable !spv.ptr<f32, Uniform> @var bind(1, 2)
spv.Variable !spv.ptr<vector<3xi32>> @var3 built_in("GlobalInvocationID")
```
}];
let arguments = (ins
TypeAttr:$type,
StrAttr:$sym_name,
OptionalAttr<SymbolRefAttr>:$initializer
);
let results = (outs);
let hasOpcode = 0;
let extraClassDeclaration = [{
::mlir::spirv::StorageClass storageClass() {
return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
}
}];
}
def SPV_ModuleOp : SPV_Op<"module",
[SingleBlockImplicitTerminator<"ModuleEndOp">,
NativeOpTrait<"SymbolTable">]> {

View File

@ -872,6 +872,11 @@ def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
}
def SymbolRefArrayAttr :
TypedArrayAttrBase<SymbolRefAttr, "symbol ref array attribute"> {
let constBuilderCall = ?;
}
//===----------------------------------------------------------------------===//
// Derive attribute kinds

View File

@ -136,26 +136,26 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
signatureConverter, newFuncOp))) {
return failure();
}
// Create spv.Variable ops for each of the arguments. These need to be bound
// by the runtime. For now use descriptor_set 0, and arg number as the binding
// number.
// Create spv.globalVariable ops for each of the arguments. These need to be
// bound by the runtime. For now use descriptor_set 0, and arg number as the
// binding number.
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return funcOp.emitError("expected op to be within a spv.module");
}
OpBuilder builder(module.getOperation()->getRegion(0));
SmallVector<Value *, 4> interface;
SmallVector<Attribute, 4> interface;
for (auto &convertedArgType :
llvm::enumerate(signatureConverter.getConvertedTypes())) {
auto variableOp = builder.create<spirv::VariableOp>(
funcOp.getLoc(), convertedArgType.value(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
llvm::None);
std::string varName = funcOp.getName().str() + "_arg_" +
std::to_string(convertedArgType.index());
auto variableOp = builder.create<spirv::GlobalVariableOp>(
funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
builder.getStringAttr(varName), nullptr);
variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
variableOp.setAttr("binding",
builder.getI32IntegerAttr(convertedArgType.index()));
interface.push_back(variableOp.getResult());
interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
@ -164,7 +164,8 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
funcOp.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
builder.getSymbolRefAttr(newFuncOp.getName()), interface);
builder.getSymbolRefAttr(newFuncOp.getName()),
builder.getArrayAttr(interface));
return success();
}
} // namespace mlir

View File

@ -32,11 +32,15 @@ using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kInitializerAttrName[] = "initializer";
static constexpr const char kInterfaceAttrName[] = "interface";
static constexpr const char kIsSpecConstName[] = "is_spec_const";
static constexpr const char kTypeAttrName[] = "type";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kVariableAttrName[] = "variable";
//===----------------------------------------------------------------------===//
// Common utility functions
@ -239,6 +243,71 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
printer->printOptionalAttrDict(op->getAttrs());
}
static ParseResult parseVariableDecorations(OpAsmParser *parser,
OperationState *state) {
auto builtInName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
if (succeeded(parser->parseOptionalKeyword("bind"))) {
Attribute set, binding;
// Parse optional descriptor binding
auto descriptorSetName = convertToSnakeCase(
stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
Type i32Type = parser->getBuilder().getIntegerType(32);
if (parser->parseLParen() ||
parser->parseAttribute(set, i32Type, descriptorSetName,
state->attributes) ||
parser->parseComma() ||
parser->parseAttribute(binding, i32Type, bindingName,
state->attributes) ||
parser->parseRParen()) {
return failure();
}
} else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
StringAttr builtIn;
if (parser->parseLParen() ||
parser->parseAttribute(builtIn, Type(), builtInName,
state->attributes) ||
parser->parseRParen()) {
return failure();
}
}
// Parse other attributes
if (parser->parseOptionalAttributeDict(state->attributes))
return failure();
return success();
}
static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
SmallVectorImpl<StringRef> &elidedAttrs) {
// Print optional descriptor binding
auto descriptorSetName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
if (descriptorSet && binding) {
elidedAttrs.push_back(descriptorSetName);
elidedAttrs.push_back(bindingName);
*printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
<< ")";
}
// Print BuiltIn attribute if present
auto builtInName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
*printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
elidedAttrs.push_back(builtInName);
}
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
@ -362,6 +431,53 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv._address_of
//===----------------------------------------------------------------------===//
static ParseResult parseAddressOfOp(OpAsmParser *parser,
OperationState *state) {
SymbolRefAttr varRefAttr;
Type type;
if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName,
state->attributes) ||
parser->parseColonType(type)) {
return failure();
}
auto ptrType = type.dyn_cast<spirv::PointerType>();
if (!ptrType) {
return parser->emitError(parser->getCurrentLocation(),
"expected spv.ptr type");
}
state->addTypes(ptrType);
return success();
}
static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
SmallVector<StringRef, 4> elidedAttrs;
*printer << spirv::AddressOfOp::getOperationName();
// Print symbol name.
*printer << " @" << addressOfOp.variable();
// Print the type.
*printer << " : " << addressOfOp.pointer();
}
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
auto varOp =
moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
if (!varOp) {
return addressOfOp.emitError("expected spv.globalVariable symbol");
}
if (addressOfOp.pointer()->getType() != varOp.type()) {
return addressOfOp.emitError(
"mismatch in result type and type of global variable referenced");
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
@ -541,18 +657,28 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser,
SmallVector<OpAsmParser::OperandType, 0> identifiers;
SmallVector<Type, 0> idTypes;
Attribute fn;
auto loc = parser->getCurrentLocation();
SymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
parser->parseTrailingOperandList(identifiers) ||
parser->parseOptionalColonTypeList(idTypes) ||
parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) {
return failure();
}
if (!fn.isa<SymbolRefAttr>()) {
return parser->emitError(loc, "expected symbol reference attribute");
if (!parser->parseOptionalComma()) {
// Parse the interface variables
SmallVector<Attribute, 4> interfaceVars;
do {
// The name of the interface variable attribute isnt important
auto attrName = "var_symbol";
SymbolRefAttr var;
SmallVector<NamedAttribute, 1> attrs;
if (parser->parseAttribute(var, Type(), attrName, attrs)) {
return failure();
}
interfaceVars.push_back(var);
} while (!parser->parseOptionalComma());
state->attributes.push_back(
{parser->getBuilder().getIdentifier(kInterfaceAttrName),
parser->getBuilder().getArrayAttr(interfaceVars)});
}
return success();
}
@ -561,27 +687,16 @@ static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) {
*printer << spirv::EntryPointOp::getOperationName() << " \""
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
<< entryPointOp.fn();
if (!entryPointOp.getNumOperands()) {
return;
if (auto interface = entryPointOp.interface()) {
*printer << ", ";
mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(),
[&](Attribute a) { printer->printAttribute(a); });
}
*printer << ", ";
mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
[&](Value *a) { printer->printOperand(a); });
*printer << " : ";
mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
[&](const Value *a) { *printer << a->getType(); });
}
static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
// Verify that all the interface ops are created from VariableOp
for (auto interface : entryPointOp.interface()) {
if (!llvm::isa_and_nonnull<spirv::VariableOp>(interface->getDefiningOp())) {
return entryPointOp.emitOpError("interface operands to entry point must "
"be generated from a variable op");
}
// TODO: Before version 1.4 the variables can only have storage_class of
// Input or Output. That needs to be verified.
}
// Checks for fn and interface symbol reference are done in spirv::ModuleOp
// verification.
return success();
}
@ -627,6 +742,95 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
}
//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//
static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
OperationState *state) {
// Parse variable type.
TypeAttr typeAttr;
auto loc = parser->getCurrentLocation();
if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName,
state->attributes)) {
return failure();
}
auto ptrType = typeAttr.getValue().dyn_cast<spirv::PointerType>();
if (!ptrType) {
return parser->emitError(loc, "expected spv.ptr type");
}
// Parse variable name.
StringAttr nameAttr;
if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
state->attributes)) {
return failure();
}
// Parse optional initializer
if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) {
SymbolRefAttr initSymbol;
if (parser->parseLParen() ||
parser->parseAttribute(initSymbol, Type(), kInitializerAttrName,
state->attributes) ||
parser->parseRParen())
return failure();
}
if (parseVariableDecorations(parser, state)) {
return failure();
}
return success();
}
static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
auto *op = varOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs{
spirv::attributeName<spirv::StorageClass>()};
*printer << spirv::GlobalVariableOp::getOperationName();
// Print variable type.
*printer << " " << varOp.type();
elidedAttrs.push_back(kTypeAttrName);
// Print variable name.
*printer << " @" << varOp.sym_name();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
// Print optional initializer
if (auto initializer = varOp.initializer()) {
*printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
<< ")";
elidedAttrs.push_back(kInitializerAttrName);
}
printVariableDecorations(op, printer, elidedAttrs);
}
static LogicalResult verify(spirv::GlobalVariableOp varOp) {
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
if (varOp.storageClass() == spirv::StorageClass::Generic)
return varOp.emitOpError("storage class cannot be 'Generic'");
if (auto initializer =
varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
// Get the module
auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
// TODO: Currently only variable initialization with other variables is
// supported. They could be constants as well, but this needs module-level
// constants to have symbol name as well.
if (!moduleOp.lookupSymbol<spirv::GlobalVariableOp>(
initializer.getValue())) {
return varOp.emitOpError(
"initializer must be result of a spv.globalVariable op");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//
@ -773,13 +977,33 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
for (auto &op : body) {
if (op.getDialect() == dialect) {
// For EntryPoint op, check that the function and execution model is not
// duplicated in EntryPointOps
// duplicated in EntryPointOps. Also verify that the interface specified
// comes from globalVariables here to make this check cheaper.
if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) {
auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
if (!funcOp) {
return entryPointOp.emitError("function '")
<< entryPointOp.fn() << "' not found in 'spv.module'";
}
if (auto interface = entryPointOp.interface()) {
for (auto varRef : interface.getValue().getValue()) {
auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
"specification instead of '")
<< varRef;
}
auto variableOp =
table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
if (!variableOp) {
return entryPointOp.emitError("expected spv.globalVariable "
"symbol reference instead of'")
<< varSymRef << "'";
}
}
}
auto key = std::pair<FuncOp, spirv::ExecutionModel>(
funcOp, entryPointOp.execution_model());
auto entryPtIt = entryPoints.find(key);
@ -898,42 +1122,9 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
return failure();
}
auto builtInName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
if (succeeded(parser->parseOptionalKeyword("bind"))) {
Attribute set, binding;
// Parse optional descriptor binding
auto descriptorSetName = convertToSnakeCase(
stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
Type i32Type = parser->getBuilder().getIntegerType(32);
if (parser->parseLParen() ||
parser->parseAttribute(set, i32Type, descriptorSetName,
state->attributes) ||
parser->parseComma() ||
parser->parseAttribute(binding, i32Type, bindingName,
state->attributes) ||
parser->parseRParen()) {
return failure();
}
} else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
Attribute builtIn;
if (parser->parseLParen() ||
parser->parseAttribute(builtIn, Type(), builtInName,
state->attributes) ||
parser->parseRParen()) {
return failure();
}
if (!builtIn.isa<StringAttr>()) {
return parser->emitError(parser->getCurrentLocation(),
"expected string value for built_in attribute");
}
}
// Parse other attributes
if (parser->parseOptionalAttributeDict(state->attributes))
if (parseVariableDecorations(parser, state)) {
return failure();
}
// Parse result pointer type
Type type;
@ -976,29 +1167,8 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
*printer << ")";
}
// Print optional descriptor binding
auto descriptorSetName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
if (descriptorSet && binding) {
elidedAttrs.push_back(descriptorSetName);
elidedAttrs.push_back(bindingName);
*printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
<< ")";
}
printVariableDecorations(op, printer, elidedAttrs);
// Print BuiltIn attribute if present
auto builtInName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
if (auto builtin = varOp.getAttrOfType<StringAttr>(builtInName)) {
*printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
elidedAttrs.push_back(builtInName);
}
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
*printer << " : " << varOp.getType();
}
@ -1006,8 +1176,11 @@ static LogicalResult verify(spirv::VariableOp varOp) {
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
if (varOp.storage_class() == spirv::StorageClass::Generic)
return varOp.emitOpError("storage class cannot be 'Generic'");
if (varOp.storage_class() != spirv::StorageClass::Function) {
return varOp.emitOpError(
"can only be used to model function-level variables. Use "
"spv.globalVariable for module-level variables.");
}
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
if (varOp.storage_class() != pointerType.getStorageClass())

View File

@ -90,9 +90,20 @@ private:
/// them to their handler method accordingly.
LogicalResult processFunction(ArrayRef<uint32_t> operands);
/// Process the OpVariable instructions at current `offset` into `binary`. It
/// is expected that this method is used for variables that are to be defined
/// at module scope and will be deserialized into a spv.globalVariable
/// instruction.
LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
/// Get the FuncOp associated with a result <id> of OpFunction.
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
/// Get the global variable associated with a result <id> of OpVariable
spirv::GlobalVariableOp getVariable(uint32_t id) {
return globalVariableMap.lookup(id);
}
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@ -138,7 +149,15 @@ private:
//===--------------------------------------------------------------------===//
/// Get the Value associated with a result <id>.
Value *getValue(uint32_t id) { return valueMap.lookup(id); }
Value *getValue(uint32_t id) {
if (auto varOp = getVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
unknownLoc, varOp.type(),
opBuilder.getSymbolRefAttr(varOp.getOperation()));
return addressOfOp.pointer();
}
return valueMap.lookup(id);
}
/// Slices the first instruction out of `binary` and returns its opcode and
/// operands via `opcode` and `operands` respectively. Returns failure if
@ -198,6 +217,9 @@ private:
// Result <id> to function mapping.
DenseMap<uint32_t, FuncOp> funcMap;
// Result <id> to variable mapping;
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
// Result <id> to value mapping.
DenseMap<uint32_t, Value *> valueMap;
@ -452,6 +474,76 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
if (operands.size() < 3) {
return emitError(
unknownLoc,
"OpVariable needs at least 3 operands, type, <id> and storage class");
}
// Result Type.
auto type = getType(operands[wordIndex]);
if (!type) {
return emitError(unknownLoc, "unknown result type <id> : ")
<< operands[wordIndex];
}
auto ptrType = type.dyn_cast<spirv::PointerType>();
if (!ptrType) {
return emitError(unknownLoc,
"expected a result type <id> to be a spv.ptr, found : ")
<< type;
}
wordIndex++;
// Result <id>.
auto variableID = operands[wordIndex];
auto variableName = nameMap.lookup(variableID).str();
if (variableName.empty()) {
variableName = "spirv_var_" + std::to_string(variableID);
}
wordIndex++;
// Storage class.
auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
if (ptrType.getStorageClass() != storageClass) {
return emitError(unknownLoc, "mismatch in storage class of pointer type ")
<< type << " and that specified in OpVariable instruction : "
<< stringifyStorageClass(storageClass);
}
wordIndex++;
// Initializer.
SymbolRefAttr initializer = nullptr;
if (wordIndex < operands.size()) {
auto initializerOp = getVariable(operands[wordIndex]);
if (!initializerOp) {
return emitError(unknownLoc, "unknown <id> ")
<< operands[wordIndex] << "used as initializer";
}
wordIndex++;
initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
}
if (wordIndex != operands.size()) {
return emitError(unknownLoc,
"found more operands than expected when deserializing "
"OpVariable instruction, only ")
<< wordIndex << " of " << operands.size() << " processed";
}
auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
unknownLoc, opBuilder.getTypeAttr(type),
opBuilder.getStringAttr(variableName), initializer);
// Decorations.
if (decorations.count(variableID)) {
for (auto attr : decorations[variableID].getAttrs()) {
varOp.setAttr(attr.first, attr.second);
}
}
globalVariableMap[variableID] = varOp;
return success();
}
LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc, "OpName needs at least 2 operands");
@ -887,6 +979,11 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return success();
}
break;
case spirv::Opcode::OpVariable:
if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
return processGlobalVariable(operands);
}
break;
case spirv::Opcode::OpName:
return processName(operands);
case spirv::Opcode::OpTypeVoid:
@ -954,18 +1051,19 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
"and OpFunction with <id> ")
<< fnID << ": " << fnName << " vs. " << parsedFunc.getName();
}
SmallVector<Value *, 4> interface;
SmallVector<Attribute, 4> interface;
while (wordIndex < words.size()) {
auto arg = getValue(words[wordIndex]);
auto arg = getVariable(words[wordIndex]);
if (!arg) {
return emitError(unknownLoc, "undefined result <id> ")
<< words[wordIndex] << " while decoding OpEntryPoint";
}
interface.push_back(arg);
interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
wordIndex++;
}
opBuilder.create<spirv::EntryPointOp>(
unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface);
opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
opBuilder.getSymbolRefAttr(fnName),
opBuilder.getArrayAttr(interface));
return success();
}

View File

@ -125,9 +125,19 @@ private:
return funcIDMap.lookup(fnName);
}
uint32_t findVariableID(StringRef varName) const {
return globalVarIDMap.lookup(varName);
}
/// Emit OpName for the given `resultID`.
LogicalResult processName(uint32_t resultID, StringRef name);
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
/// Process a SPIR-V GlobalVariableOp
LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
/// Process attributes that translate to decorations on the result <id>
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
@ -215,6 +225,9 @@ private:
uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
/// Process spv.addressOf operations.
LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
/// Main dispatch method for serializing an operation.
LogicalResult processOperation(Operation *op);
@ -265,6 +278,9 @@ private:
/// Map from FuncOps name to <id>s.
llvm::StringMap<uint32_t> funcIDMap;
/// Map from GlobalVariableOps name to <id>s
llvm::StringMap<uint32_t> globalVarIDMap;
/// Map from results of normal operations to their <id>s
DenseMap<Value *, uint32_t> valueIDMap;
};
@ -372,6 +388,15 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
if (failed(encodeStringLiteralInto(nameOperands, name))) {
return failure();
}
return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
}
namespace {
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
@ -416,10 +441,9 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
// Add function name.
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(funcID);
encodeStringLiteralInto(nameOperands, op.getName());
encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
if (failed(processName(funcID, op.getName()))) {
return failure();
}
// Declare the parameters.
for (auto arg : op.getArguments()) {
@ -450,6 +474,61 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
}
LogicalResult
Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Get TypeID.
uint32_t resultTypeID = 0;
SmallVector<StringRef, 4> elidedAttrs;
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
return failure();
}
elidedAttrs.push_back("type");
SmallVector<uint32_t, 4> operands;
operands.push_back(resultTypeID);
auto resultID = getNextID();
// Encode the name.
auto varName = varOp.sym_name();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
if (failed(processName(resultID, varName))) {
return failure();
}
globalVarIDMap[varName] = resultID;
operands.push_back(resultID);
// Encode StorageClass.
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization.
if (auto initializer = varOp.initializer()) {
auto initializerID = findVariableID(initializer.getValue());
if (!initializerID) {
return emitError(varOp.getLoc(),
"invalid usage of undefined variable as initializer");
}
operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
}
if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
operands))) {
elidedAttrs.push_back("initializer");
return failure();
}
// Encode decorations.
for (auto attr : varOp.getAttrs()) {
if (llvm::any_of(elidedAttrs,
[&](StringRef elided) { return attr.first.is(elided); })) {
continue;
}
if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
return failure();
}
}
return success();
}
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
@ -912,6 +991,17 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
// Operation
//===----------------------------------------------------------------------===//
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.variable();
auto variableID = findVariableID(varName);
if (!variableID) {
return addressOfOp.emitError("unknown result <id> for variable ")
<< varName;
}
valueIDMap[addressOfOp.pointer()] = variableID;
return success();
}
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
@ -924,6 +1014,12 @@ LogicalResult Serializer::processOperation(Operation *op) {
if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
return processGlobalVariableOp(varOp);
}
if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
return processAddressOfOp(addressOfOp);
}
return dispatchToAutogenSerialization(op);
}
@ -947,14 +1043,16 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
encodeStringLiteralInto(operands, op.fn());
// Add the interface values.
for (auto val : op.interface()) {
auto id = findValueID(val);
if (!id) {
return op.emitError("referencing unintialized variable <id>. "
"spv.EntryPoint is at the end of spv.module. All "
"referenced variables should already be defined");
if (auto interface = op.interface()) {
for (auto var : interface.getValue()) {
auto id = findVariableID(var.cast<SymbolRefAttr>().getValue());
if (!id) {
return op.emitError("referencing undefined global variable."
"spv.EntryPoint is at the end of spv.module. All "
"referenced variables should already be defined");
}
operands.push_back(id);
}
operands.push_back(id);
}
return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
operands);

View File

@ -1,8 +1,8 @@
// RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s
// CHECK: spv.module "Logical" "VulkanKHR" {
// CHECK-NEXT: [[VAR1:%.*]] = spv.Variable bind(0, 0) : !spv.ptr<f32, StorageBuffer>
// CHECK-NEXT: [[VAR2:%.*]] = spv.Variable bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, StorageBuffer> [[VAR1:@.*]] bind(0, 0)
// CHECK-NEXT: spv.globalVariable !spv.ptr<!spv.array<12 x f32>, StorageBuffer> [[VAR2:@.*]] bind(0, 1)
// CHECK-NEXT: func @kernel_1
// CHECK-NEXT: spv.Return
// CHECK: spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]

View File

@ -2,16 +2,16 @@
func @spirv_loadstore() -> () {
spv.module "Logical" "VulkanKHR" {
// CHECK: [[VAR1:%.*]] = spv.Variable : !spv.ptr<f32, Input>
// CHECK-NEXT: [[VAR2:%.*]] = spv.Variable : !spv.ptr<f32, Output>
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var2
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, Output> @var3
// CHECK-NEXT: func @noop({{%.*}}: !spv.ptr<f32, Input>, {{%.*}}: !spv.ptr<f32, Output>)
// CHECK: spv.EntryPoint "GLCompute" @noop, [[VAR1]], [[VAR2]]
%2 = spv.Variable : !spv.ptr<f32, Input>
%3 = spv.Variable : !spv.ptr<f32, Output>
// CHECK: spv.EntryPoint "GLCompute" @noop, @var2, @var3
spv.globalVariable !spv.ptr<f32, Input> @var2
spv.globalVariable !spv.ptr<f32, Output> @var3
func @noop(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) -> () {
spv.Return
}
spv.EntryPoint "GLCompute" @noop, %2, %3 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
spv.EntryPoint "GLCompute" @noop, @var2, @var3
spv.ExecutionMode @noop "ContractionOff"
}
return

View File

@ -1,15 +1,15 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
// CHECK: {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr<f32, Output>
// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var0 bind(1, 0)
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, Output> @var1 bind(0, 1)
// CHECK-NEXT: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 built_in("GlobalInvocationId")
// CHECK-NEXT: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var3 built_in("GlobalInvocationId")
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
%2 = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
%3 = spv.Variable bind(0, 1): !spv.ptr<f32, Output>
%4 = spv.Variable {built_in = "GlobalInvocationId"} : !spv.ptr<vector<3xi32>, Input>
%5 = spv.Variable built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
spv.globalVariable !spv.ptr<f32, Input> @var0 bind(1, 0)
spv.globalVariable !spv.ptr<f32, Output> @var1 bind(0, 1)
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 {built_in = "GlobalInvocationId"}
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var3 built_in("GlobalInvocationId")
}
return
}

View File

@ -2,10 +2,10 @@
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
// CHECK: [[INIT:%.*]] = spv.constant 4.000000e+00 : f32
// CHECK: {{%.*}} = spv.Variable init([[INIT]]) bind(1, 0) : !spv.ptr<f32, Input>
%0 = spv.constant 4.0 : f32
%2 = spv.Variable init(%0) bind(1, 0) : !spv.ptr<f32, Input>
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var1
// CHECK-NEXT: spv.globalVariable !spv.ptr<f32, Input> @var2 initializer(@var1) bind(1, 0)
spv.globalVariable !spv.ptr<f32, Input> @var1
spv.globalVariable !spv.ptr<f32, Input> @var2 initializer(@var1) bind(1, 0)
}
return
}

View File

@ -3,11 +3,10 @@
//===----------------------------------------------------------------------===//
// spv.AccessChain
//===----------------------------------------------------------------------===//
func @access_chain_struct() -> () {
%0 = spv.constant 1: i32
%1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
// CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Function>
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
return
}
@ -16,7 +15,7 @@ func @access_chain_struct() -> () {
func @access_chain_1D_array(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4xf32>, Function>
// CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
%1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4xf32>, Function>
return
}
@ -25,7 +24,7 @@ func @access_chain_1D_array(%arg0 : i32) -> () {
func @access_chain_2D_array_1(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
// CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
%1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%2 = spv.Load "Function" %1 ["Volatile"] : f32
return
@ -35,7 +34,7 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () {
func @access_chain_2D_array_2(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
// CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
%1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
return
@ -115,6 +114,18 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
return
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input> @var1
func @access_chain() -> () {
%0 = spv.constant 1: i32
%1 = spv._address_of @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
// CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Input>
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
spv.Return
}
}
// -----
@ -276,15 +287,15 @@ spv.module "Logical" "VulkanKHR" {
}
spv.module "Logical" "VulkanKHR" {
%2 = spv.Variable : !spv.ptr<f32, Input>
%3 = spv.Variable : !spv.ptr<f32, Output>
spv.globalVariable !spv.ptr<f32, Input> @var2
spv.globalVariable !spv.ptr<f32, Output> @var3
func @do_something(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) -> () {
%1 = spv.Load "Input" %arg0 : f32
spv.Store "Output" %arg1, %1 : f32
spv.Return
}
// CHECK: spv.EntryPoint "GLCompute" @do_something, {{%.*}}, {{%.*}} : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
spv.EntryPoint "GLCompute" @do_something, %2, %3 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
// CHECK: spv.EntryPoint "GLCompute" @do_something, @var2, @var3
spv.EntryPoint "GLCompute" @do_something, @var2, @var3
}
// -----
@ -293,7 +304,7 @@ spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
// expected-error @+1 {{custom op 'spv.EntryPoint' expected symbol reference attribute}}
// expected-error @+1 {{invalid kind of constant specified}}
spv.EntryPoint "GLCompute" "do_nothing"
}
@ -380,6 +391,74 @@ spv.module "Logical" "VulkanKHR" {
// -----
//===----------------------------------------------------------------------===//
// spv.globalVariable
//===----------------------------------------------------------------------===//
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<f32, Input> @var0
spv.globalVariable !spv.ptr<f32, Input> @var0
}
// TODO: Fix test case after initialization with constant is addressed
// spv.module "Logical" "VulkanKHR" {
// %0 = spv.constant 4.0 : f32
// // CHECK1: spv.Variable init(%0) : !spv.ptr<f32, Private>
// spv.globalVariable !spv.ptr<f32, Private> @var1 init(%0)
// }
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<f32, Uniform> @var0 bind(1, 2)
spv.globalVariable !spv.ptr<f32, Uniform> @var0 bind(1, 2)
}
// TODO: Fix test case after initialization with constant is addressed
// spv.module "Logical" "VulkanKHR" {
// %0 = spv.constant 4.0 : f32
// // CHECK1: spv.globalVariable !spv.ptr<f32, Private> @var1 initializer(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
// spv.globalVariable !spv.ptr<f32, Private> @var1 initializer(%0) {binding = 5 : i32} :
// }
spv.module "Logical" "VulkanKHR" {
// CHECK: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var1 built_in("GlobalInvocationID")
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var1 built_in("GlobalInvocationID")
// CHECK: spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 built_in("GlobalInvocationID")
spv.globalVariable !spv.ptr<vector<3xi32>, Input> @var2 {built_in = "GlobalInvocationID"}
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{expected spv.ptr type}}
spv.globalVariable f32 @var0
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{op initializer must be result of a spv.globalVariable op}}
spv.globalVariable !spv.ptr<f32, Private> @var0 initializer(@var1)
}
// -----
spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{storage class cannot be 'Generic'}}
spv.globalVariable !spv.ptr<f32, Generic> @var0
}
// -----
spv.module "Logical" "VulkanKHR" {
func @foo() {
// expected-error @+1 {{op failed to verify that op can only be used in a 'spv.module' block}}
spv.globalVariable !spv.ptr<f32, Input> @var0
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.FAdd
//===----------------------------------------------------------------------===//
@ -499,7 +578,7 @@ func @iadd_scalar(%arg: i32) -> i32 {
//===----------------------------------------------------------------------===//
func @iequal_scalar(%arg0: i32, %arg1: i32) -> i1 {
// CHECK: {{.*}} = spv.IEqual {{.*}}, {{.*}} : i32
// CHECK: spv.IEqual {{.*}}, {{.*}} : i32
%0 = spv.IEqual %arg0, %arg1 : i32
return %0 : i1
}
@ -507,7 +586,7 @@ func @iequal_scalar(%arg0: i32, %arg1: i32) -> i1 {
// -----
func @iequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.IEqual {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.IEqual {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.IEqual %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -519,7 +598,7 @@ func @iequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1>
//===----------------------------------------------------------------------===//
func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.INotEqual {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.INotEqual {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.INotEqual %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -670,6 +749,19 @@ func @aligned_load_incorrect_attributes() -> () {
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<f32, Input> @var0
// CHECK_LABEL: @simple_load
func @simple_load() -> () {
// CHECK: spv.Load "Input" {{%.*}} : f32
%0 = spv._address_of @var0 : !spv.ptr<f32, Input>
%1 = spv.Load "Input" %0 : f32
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//
@ -708,7 +800,7 @@ func @sdiv_scalar(%arg: i32) -> i32 {
//===----------------------------------------------------------------------===//
func @sgt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.SGreaterThan {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.SGreaterThan {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.SGreaterThan %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -720,7 +812,7 @@ func @sgt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
//===----------------------------------------------------------------------===//
func @sge_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.SGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.SGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.SGreaterThanEqual %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -732,7 +824,7 @@ func @sge_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
//===----------------------------------------------------------------------===//
func @slt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.SLessThan {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.SLessThan %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -744,7 +836,7 @@ func @slt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
//===----------------------------------------------------------------------===//
func @slte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.SLessThanEqual {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.SLessThanEqual {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.SLessThanEqual %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -883,6 +975,18 @@ func @aligned_store_incorrect_attributes(%arg0 : f32) -> () {
// -----
spv.module "Logical" "VulkanKHR" {
spv.globalVariable !spv.ptr<f32, Input> @var0
func @simple_store(%arg0 : f32) -> () {
%0 = spv._address_of @var0 : !spv.ptr<f32, Input>
// CHECK: spv.Store "Input" {{%.*}}, {{%.*}} : f32
spv.Store "Input" %0, %arg0 : f32
spv.Return
}
}
// -----
//===----------------------------------------------------------------------===//
// spv.UDiv
//===----------------------------------------------------------------------===//
@ -900,7 +1004,7 @@ func @udiv_scalar(%arg: i32) -> i32 {
//===----------------------------------------------------------------------===//
func @ugt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.UGreaterThan {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.UGreaterThan {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.UGreaterThan %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -912,7 +1016,7 @@ func @ugt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
//===----------------------------------------------------------------------===//
func @ugte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.UGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.UGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.UGreaterThanEqual %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -924,7 +1028,7 @@ func @ugte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
//===----------------------------------------------------------------------===//
func @ult_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.ULessThan {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.ULessThan {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.ULessThan %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -936,7 +1040,7 @@ func @ult_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
//===----------------------------------------------------------------------===//
func @ulte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> {
// CHECK: {{.*}} = spv.ULessThanEqual {{.*}}, {{.*}} : vector<4xi32>
// CHECK: spv.ULessThanEqual {{.*}}, {{.*}} : vector<4xi32>
%0 = spv.ULessThanEqual %arg0, %arg1 : vector<4xi32>
return %0 : vector<4xi1>
}
@ -967,29 +1071,29 @@ func @variable_no_init(%arg0: f32) -> () {
func @variable_init() -> () {
%0 = spv.constant 4.0 : f32
// CHECK: spv.Variable init(%0) : !spv.ptr<f32, Private>
%1 = spv.Variable init(%0) : !spv.ptr<f32, Private>
// CHECK: spv.Variable init(%0) : !spv.ptr<f32, Function>
%1 = spv.Variable init(%0) : !spv.ptr<f32, Function>
return
}
func @variable_bind() -> () {
// CHECK: spv.Variable bind(1, 2) : !spv.ptr<f32, Uniform>
%0 = spv.Variable bind(1, 2) : !spv.ptr<f32, Uniform>
// CHECK: spv.Variable bind(1, 2) : !spv.ptr<f32, Function>
%0 = spv.Variable bind(1, 2) : !spv.ptr<f32, Function>
return
}
func @variable_init_bind() -> () {
%0 = spv.constant 4.0 : f32
// CHECK: spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
%1 = spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
// CHECK: spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr<f32, Function>
%1 = spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr<f32, Function>
return
}
func @variable_builtin() -> () {
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
%1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
%2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr<vector<3xi32>, Input>
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Function>
%1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Function>
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Function>
%2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr<vector<3xi32>, Function>
return
}
@ -1005,23 +1109,14 @@ func @expect_ptr_result_type(%arg0: f32) -> () {
func @variable_init(%arg0: f32) -> () {
// expected-error @+1 {{op initializer must be the result of a spv.Constant or module-level spv.Variable op}}
%0 = spv.Variable init(%arg0) : !spv.ptr<f32, Private>
return
}
// -----
func @storage_class_mismatch() -> () {
%0 = spv.constant 5.0 : f32
// expected-error @+1 {{storage class must match result pointer's storage class}}
%1 = "spv.Variable"(%0) {storage_class = 2: i32} : (f32) -> !spv.ptr<f32, Function>
%0 = spv.Variable init(%arg0) : !spv.ptr<f32, Function>
return
}
// -----
func @cannot_be_generic_storage_class(%arg0: f32) -> () {
// expected-error @+1 {{storage class cannot be 'Generic'}}
// expected-error @+1 {{op can only be used to model function-level variables. Use spv.globalVariable for module-level variables}}
%0 = spv.Variable : !spv.ptr<f32, Generic>
return
}