diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index ba95a761fbe8..de496a76d26d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -146,67 +146,6 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { // ----- -def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { - let summary = [{ - Declare an entry point, its execution model, and its interface. - }]; - - let description = [{ - Execution Model is the execution model for the entry point and its - static call tree. See Execution Model. - - Entry Point must be the Result of an OpFunction instruction. - - Name is a name string for the entry point. A module cannot have two - OpEntryPoint instructions with the same Execution Model and the same - Name string. - - Interface is a list of of global OpVariable instructions. These - declare the set of global variables from a module that form the - interface of this entry point. The set of Interface must be equal - to or a superset of the global OpVariable Result referenced by the - entry point’s static call tree, within the interface’s storage classes. - Before version 1.4, the interface’s storage classes are limited to the - Input and Output storage classes. Starting with version 1.4, the - interface’s storage classes are all storage classes used in declaring - all global variables referenced by the entry point’s call tree. - - Interface are forward references. Before version 1.4, duplication - of these is tolerated. Starting with version 1.4, an must not - appear more than once. - - ### Custom assembly form - - ``` {.ebnf} - execution-model ::= "Vertex" | "TesellationControl" | - - - entry-point-op ::= ssa-id ` = spv.EntryPoint ` execution-model fn-name - (ssa-use ( `, ` ssa-use)* ` : ` - pointer-type ( `, ` pointer-type)* )? - ``` - - For example: - - ``` - spv.EntryPoint "GLCompute" @foo - spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr, !spv.ptr - - ``` - }]; - - let arguments = (ins - SPV_ExecutionModelAttr:$execution_model, - SymbolRefAttr:$fn, - Variadic:$interface - ); - - let results = (outs); - let autogenSerialization = 0; -} - -// ----- - def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> { let summary = "Declare an execution mode for an entry point."; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index b44d8ef5d06a..d47563907428 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -30,6 +30,160 @@ include "mlir/SPIRV/SPIRVBase.td" #endif // SPIRV_BASE +def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> { + let summary = "Get the address of a global variable."; + + let description = [{ + Variables in module scope are defined using symbol names. This + instruction generates an SSA value that can be used to refer to + the symbol within function scope for use in instructions that + expect an SSA value. This operation has no equivalent SPIR-V + instruction. Since variables in module scope in SPIR-V dialect are + of pointer type, this instruction returns a pointer type as well, + and the type is same as the variable referenced. + + ### Custom assembly form + + ``` {.ebnf} + address-of-op ::= ssa-id `=` `spv.addressOf` `@`string-literal : pointer-type + ``` + + For example: + + ``` + %0 = spv.addressOf @var1 : !spv.ptr + ``` + }]; + + let arguments = (ins + SymbolRefAttr:$variable + ); + + let results = (outs + SPV_AnyPtr:$pointer + ); + + let hasOpcode = 0; +} + +def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { + let summary = [{ + Declare an entry point, its execution model, and its interface. + }]; + + let description = [{ + Execution Model is the execution model for the entry point and its + static call tree. See Execution Model. + + Entry Point must be the Result of an OpFunction instruction. + + Name is a name string for the entry point. A module cannot have two + OpEntryPoint instructions with the same Execution Model and the same + Name string. + + Interface is a list of symbol references to spv.globalVariable + operations. These declare the set of global variables from a + module that form the interface of this entry point. The set of + Interface symbols must be equal to or a superset of the + spv.globalVariables referenced by the entry point’s static call + tree, within the interface’s storage classes. Before version 1.4, + the interface’s storage classes are limited to the Input and + Output storage classes. Starting with version 1.4, the interface’s + storage classes are all storage classes used in declaring all + global variables referenced by the entry point’s call tree. + + ### Custom assembly form + + ``` {.ebnf} + execution-model ::= "Vertex" | "TesellationControl" | + + + entry-point-op ::= ssa-id `=` `spv.EntryPoint` execution-model + symbol-reference (`, ` symbol-reference)* + ``` + + For example: + + ``` + spv.EntryPoint "GLCompute" @foo + spv.EntryPoint "Kernel" @foo, @var1, @var2 + + ``` + }]; + + let arguments = (ins + SPV_ExecutionModelAttr:$execution_model, + SymbolRefAttr:$fn, + OptionalAttr:$interface + ); + + let results = (outs); + let autogenSerialization = 0; +} + + +def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> { + let summary = [{ + Allocate an object in memory at module scope. The object is + referenced using a symbol name. + }]; + + let description = [{ + The variable type must be an OpTypePointer. Its type operand is the type of + object in memory. + + Storage Class is the Storage Class of the memory holding the object. It + cannot be Generic. It must be the same as the Storage Class operand of + the variable types. Only those storage classes that are valid at module + scope (like Input, Output, StorageBuffer, etc.) are valid. + + Initializer is optional. If Initializer is present, it will be + the initial value of the variable’s memory content. Initializer + must be an symbol defined from a constant instruction or other + spv.globalVariable operation in module scope. Initializer must + have the same type as the type of the defined symbol. + + ### Custom assembly form + + ``` {.ebnf} + variable-op ::= `spv.globalVariable` spirv-type string-literal + (`initializer(` symbol-reference `)`)? + (`bind(` integer-literal, integer-literal `)`)? + (`built_in(` string-literal `)`)? + attribute-dict? + ``` + + where `initializer` specifies initializer and `bind` specifies the + descriptor set and binding number. `built_in` specifies SPIR-V + BuiltIn decoration associated with the op. + + For example: + + ``` + spv.Variable !spv.ptr @var0 + spv.Variable !spv.ptr @var2 initializer(@var0) + spv.Variable !spv.ptr @var bind(1, 2) + spv.Variable !spv.ptr> @var3 built_in("GlobalInvocationID") + ``` + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name, + OptionalAttr:$initializer + ); + + let results = (outs); + + let hasOpcode = 0; + + let extraClassDeclaration = [{ + ::mlir::spirv::StorageClass storageClass() { + return this->type().cast<::mlir::spirv::PointerType>().getStorageClass(); + } + }]; +} + def SPV_ModuleOp : SPV_Op<"module", [SingleBlockImplicitTerminator<"ModuleEndOp">, NativeOpTrait<"SymbolTable">]> { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 519222c91a41..3183a762da51 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -872,6 +872,11 @@ def SymbolRefAttr : Attr()">, let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; } +def SymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + //===----------------------------------------------------------------------===// // Derive attribute kinds diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 53a40dfa365e..035de4f815df 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -136,26 +136,26 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, signatureConverter, newFuncOp))) { return failure(); } - // Create spv.Variable ops for each of the arguments. These need to be bound - // by the runtime. For now use descriptor_set 0, and arg number as the binding - // number. + // Create spv.globalVariable ops for each of the arguments. These need to be + // bound by the runtime. For now use descriptor_set 0, and arg number as the + // binding number. auto module = funcOp.getParentOfType(); if (!module) { return funcOp.emitError("expected op to be within a spv.module"); } OpBuilder builder(module.getOperation()->getRegion(0)); - SmallVector interface; + SmallVector interface; for (auto &convertedArgType : llvm::enumerate(signatureConverter.getConvertedTypes())) { - auto variableOp = builder.create( - funcOp.getLoc(), convertedArgType.value(), - builder.getI32IntegerAttr( - static_cast(spirv::StorageClass::StorageBuffer)), - llvm::None); + std::string varName = funcOp.getName().str() + "_arg_" + + std::to_string(convertedArgType.index()); + auto variableOp = builder.create( + funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()), + builder.getStringAttr(varName), nullptr); variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0)); variableOp.setAttr("binding", builder.getI32IntegerAttr(convertedArgType.index())); - interface.push_back(variableOp.getResult()); + interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name())); } // Create an entry point instruction for this function. // TODO(ravishankarm) : Add execution mode for the entry function @@ -164,7 +164,8 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef operands, funcOp.getLoc(), builder.getI32IntegerAttr( static_cast(spirv::ExecutionModel::GLCompute)), - builder.getSymbolRefAttr(newFuncOp.getName()), interface); + builder.getSymbolRefAttr(newFuncOp.getName()), + builder.getArrayAttr(interface)); return success(); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 4bea441c366e..9947c0254a9a 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -32,11 +32,15 @@ using namespace mlir; // TODO(antiagainst): generate these strings using ODS. static constexpr const char kAlignmentAttrName[] = "alignment"; +static constexpr const char kFnNameAttrName[] = "fn"; static constexpr const char kIndicesAttrName[] = "indices"; +static constexpr const char kInitializerAttrName[] = "initializer"; +static constexpr const char kInterfaceAttrName[] = "interface"; static constexpr const char kIsSpecConstName[] = "is_spec_const"; +static constexpr const char kTypeAttrName[] = "type"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; -static constexpr const char kFnNameAttrName[] = "fn"; +static constexpr const char kVariableAttrName[] = "variable"; //===----------------------------------------------------------------------===// // Common utility functions @@ -239,6 +243,71 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) { printer->printOptionalAttrDict(op->getAttrs()); } +static ParseResult parseVariableDecorations(OpAsmParser *parser, + OperationState *state) { + auto builtInName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + if (succeeded(parser->parseOptionalKeyword("bind"))) { + Attribute set, binding; + // Parse optional descriptor binding + auto descriptorSetName = convertToSnakeCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + Type i32Type = parser->getBuilder().getIntegerType(32); + if (parser->parseLParen() || + parser->parseAttribute(set, i32Type, descriptorSetName, + state->attributes) || + parser->parseComma() || + parser->parseAttribute(binding, i32Type, bindingName, + state->attributes) || + parser->parseRParen()) { + return failure(); + } + } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) { + StringAttr builtIn; + if (parser->parseLParen() || + parser->parseAttribute(builtIn, Type(), builtInName, + state->attributes) || + parser->parseRParen()) { + return failure(); + } + } + + // Parse other attributes + if (parser->parseOptionalAttributeDict(state->attributes)) + return failure(); + + return success(); +} + +static void printVariableDecorations(Operation *op, OpAsmPrinter *printer, + SmallVectorImpl &elidedAttrs) { + // Print optional descriptor binding + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSet = op->getAttrOfType(descriptorSetName); + auto binding = op->getAttrOfType(bindingName); + if (descriptorSet && binding) { + elidedAttrs.push_back(descriptorSetName); + elidedAttrs.push_back(bindingName); + *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() + << ")"; + } + + // Print BuiltIn attribute if present + auto builtInName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + if (auto builtin = op->getAttrOfType(builtInName)) { + *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; + elidedAttrs.push_back(builtInName); + } + + printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// @@ -362,6 +431,53 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv._address_of +//===----------------------------------------------------------------------===// + +static ParseResult parseAddressOfOp(OpAsmParser *parser, + OperationState *state) { + SymbolRefAttr varRefAttr; + Type type; + if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName, + state->attributes) || + parser->parseColonType(type)) { + return failure(); + } + auto ptrType = type.dyn_cast(); + if (!ptrType) { + return parser->emitError(parser->getCurrentLocation(), + "expected spv.ptr type"); + } + state->addTypes(ptrType); + return success(); +} + +static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) { + SmallVector elidedAttrs; + *printer << spirv::AddressOfOp::getOperationName(); + + // Print symbol name. + *printer << " @" << addressOfOp.variable(); + + // Print the type. + *printer << " : " << addressOfOp.pointer(); +} + +static LogicalResult verify(spirv::AddressOfOp addressOfOp) { + auto moduleOp = addressOfOp.getParentOfType(); + auto varOp = + moduleOp.lookupSymbol(addressOfOp.variable()); + if (!varOp) { + return addressOfOp.emitError("expected spv.globalVariable symbol"); + } + if (addressOfOp.pointer()->getType() != varOp.type()) { + return addressOfOp.emitError( + "mismatch in result type and type of global variable referenced"); + } + return success(); +} + //===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// @@ -541,18 +657,28 @@ static ParseResult parseEntryPointOp(OpAsmParser *parser, SmallVector identifiers; SmallVector idTypes; - Attribute fn; - auto loc = parser->getCurrentLocation(); - + SymbolRefAttr fn; if (parseEnumAttribute(execModel, parser, state) || - parser->parseAttribute(fn, kFnNameAttrName, state->attributes) || - parser->parseTrailingOperandList(identifiers) || - parser->parseOptionalColonTypeList(idTypes) || - parser->resolveOperands(identifiers, idTypes, loc, state->operands)) { + parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) { return failure(); } - if (!fn.isa()) { - return parser->emitError(loc, "expected symbol reference attribute"); + + if (!parser->parseOptionalComma()) { + // Parse the interface variables + SmallVector interfaceVars; + do { + // The name of the interface variable attribute isnt important + auto attrName = "var_symbol"; + SymbolRefAttr var; + SmallVector attrs; + if (parser->parseAttribute(var, Type(), attrName, attrs)) { + return failure(); + } + interfaceVars.push_back(var); + } while (!parser->parseOptionalComma()); + state->attributes.push_back( + {parser->getBuilder().getIdentifier(kInterfaceAttrName), + parser->getBuilder().getArrayAttr(interfaceVars)}); } return success(); } @@ -561,27 +687,16 @@ static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter *printer) { *printer << spirv::EntryPointOp::getOperationName() << " \"" << stringifyExecutionModel(entryPointOp.execution_model()) << "\" @" << entryPointOp.fn(); - if (!entryPointOp.getNumOperands()) { - return; + if (auto interface = entryPointOp.interface()) { + *printer << ", "; + mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(), + [&](Attribute a) { printer->printAttribute(a); }); } - *printer << ", "; - mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(), - [&](Value *a) { printer->printOperand(a); }); - *printer << " : "; - mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(), - [&](const Value *a) { *printer << a->getType(); }); } static LogicalResult verify(spirv::EntryPointOp entryPointOp) { - // Verify that all the interface ops are created from VariableOp - for (auto interface : entryPointOp.interface()) { - if (!llvm::isa_and_nonnull(interface->getDefiningOp())) { - return entryPointOp.emitOpError("interface operands to entry point must " - "be generated from a variable op"); - } - // TODO: Before version 1.4 the variables can only have storage_class of - // Input or Output. That needs to be verified. - } + // Checks for fn and interface symbol reference are done in spirv::ModuleOp + // verification. return success(); } @@ -627,6 +742,95 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) { [&](Attribute a) { *printer << a.cast().getInt(); }); } +//===----------------------------------------------------------------------===// +// spv.globalVariable +//===----------------------------------------------------------------------===// + +static ParseResult parseGlobalVariableOp(OpAsmParser *parser, + OperationState *state) { + // Parse variable type. + TypeAttr typeAttr; + auto loc = parser->getCurrentLocation(); + if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName, + state->attributes)) { + return failure(); + } + auto ptrType = typeAttr.getValue().dyn_cast(); + if (!ptrType) { + return parser->emitError(loc, "expected spv.ptr type"); + } + + // Parse variable name. + StringAttr nameAttr; + if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + state->attributes)) { + return failure(); + } + + // Parse optional initializer + if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) { + SymbolRefAttr initSymbol; + if (parser->parseLParen() || + parser->parseAttribute(initSymbol, Type(), kInitializerAttrName, + state->attributes) || + parser->parseRParen()) + return failure(); + } + + if (parseVariableDecorations(parser, state)) { + return failure(); + } + + return success(); +} + +static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) { + auto *op = varOp.getOperation(); + SmallVector elidedAttrs{ + spirv::attributeName()}; + *printer << spirv::GlobalVariableOp::getOperationName(); + + // Print variable type. + *printer << " " << varOp.type(); + elidedAttrs.push_back(kTypeAttrName); + + // Print variable name. + *printer << " @" << varOp.sym_name(); + elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); + + // Print optional initializer + if (auto initializer = varOp.initializer()) { + *printer << " " << kInitializerAttrName << "(@" << initializer.getValue() + << ")"; + elidedAttrs.push_back(kInitializerAttrName); + } + printVariableDecorations(op, printer, elidedAttrs); +} + +static LogicalResult verify(spirv::GlobalVariableOp varOp) { + // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the + // object. It cannot be Generic. It must be the same as the Storage Class + // operand of the Result Type." + if (varOp.storageClass() == spirv::StorageClass::Generic) + return varOp.emitOpError("storage class cannot be 'Generic'"); + + if (auto initializer = + varOp.getAttrOfType(kInitializerAttrName)) { + // Get the module + auto moduleOp = varOp.getParentOfType(); + // TODO: Currently only variable initialization with other variables is + // supported. They could be constants as well, but this needs module-level + // constants to have symbol name as well. + if (!moduleOp.lookupSymbol( + initializer.getValue())) { + return varOp.emitOpError( + "initializer must be result of a spv.globalVariable op"); + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// @@ -773,13 +977,33 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { for (auto &op : body) { if (op.getDialect() == dialect) { // For EntryPoint op, check that the function and execution model is not - // duplicated in EntryPointOps + // duplicated in EntryPointOps. Also verify that the interface specified + // comes from globalVariables here to make this check cheaper. if (auto entryPointOp = llvm::dyn_cast(op)) { auto funcOp = table.lookup(entryPointOp.fn()); if (!funcOp) { return entryPointOp.emitError("function '") << entryPointOp.fn() << "' not found in 'spv.module'"; } + if (auto interface = entryPointOp.interface()) { + for (auto varRef : interface.getValue().getValue()) { + auto varSymRef = varRef.dyn_cast(); + if (!varSymRef) { + return entryPointOp.emitError( + "expected symbol reference for interface " + "specification instead of '") + << varRef; + } + auto variableOp = + table.lookup(varSymRef.getValue()); + if (!variableOp) { + return entryPointOp.emitError("expected spv.globalVariable " + "symbol reference instead of'") + << varSymRef << "'"; + } + } + } + auto key = std::pair( funcOp, entryPointOp.execution_model()); auto entryPtIt = entryPoints.find(key); @@ -898,42 +1122,9 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { return failure(); } - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); - if (succeeded(parser->parseOptionalKeyword("bind"))) { - Attribute set, binding; - // Parse optional descriptor binding - auto descriptorSetName = convertToSnakeCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); - Type i32Type = parser->getBuilder().getIntegerType(32); - if (parser->parseLParen() || - parser->parseAttribute(set, i32Type, descriptorSetName, - state->attributes) || - parser->parseComma() || - parser->parseAttribute(binding, i32Type, bindingName, - state->attributes) || - parser->parseRParen()) { - return failure(); - } - } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) { - Attribute builtIn; - if (parser->parseLParen() || - parser->parseAttribute(builtIn, Type(), builtInName, - state->attributes) || - parser->parseRParen()) { - return failure(); - } - if (!builtIn.isa()) { - return parser->emitError(parser->getCurrentLocation(), - "expected string value for built_in attribute"); - } - } - - // Parse other attributes - if (parser->parseOptionalAttributeDict(state->attributes)) + if (parseVariableDecorations(parser, state)) { return failure(); + } // Parse result pointer type Type type; @@ -976,29 +1167,8 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { *printer << ")"; } - // Print optional descriptor binding - auto descriptorSetName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); - auto descriptorSet = varOp.getAttrOfType(descriptorSetName); - auto binding = varOp.getAttrOfType(bindingName); - if (descriptorSet && binding) { - elidedAttrs.push_back(descriptorSetName); - elidedAttrs.push_back(bindingName); - *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() - << ")"; - } + printVariableDecorations(op, printer, elidedAttrs); - // Print BuiltIn attribute if present - auto builtInName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); - if (auto builtin = varOp.getAttrOfType(builtInName)) { - *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; - elidedAttrs.push_back(builtInName); - } - - printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); *printer << " : " << varOp.getType(); } @@ -1006,8 +1176,11 @@ static LogicalResult verify(spirv::VariableOp varOp) { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." - if (varOp.storage_class() == spirv::StorageClass::Generic) - return varOp.emitOpError("storage class cannot be 'Generic'"); + if (varOp.storage_class() != spirv::StorageClass::Function) { + return varOp.emitOpError( + "can only be used to model function-level variables. Use " + "spv.globalVariable for module-level variables."); + } auto pointerType = varOp.pointer()->getType().cast(); if (varOp.storage_class() != pointerType.getStorageClass()) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 1aad7173dc6c..a3d71eda5d92 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -90,9 +90,20 @@ private: /// them to their handler method accordingly. LogicalResult processFunction(ArrayRef operands); + /// Process the OpVariable instructions at current `offset` into `binary`. It + /// is expected that this method is used for variables that are to be defined + /// at module scope and will be deserialized into a spv.globalVariable + /// instruction. + LogicalResult processGlobalVariable(ArrayRef operands); + /// Get the FuncOp associated with a result of OpFunction. FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } + /// Get the global variable associated with a result of OpVariable + spirv::GlobalVariableOp getVariable(uint32_t id) { + return globalVariableMap.lookup(id); + } + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// @@ -138,7 +149,15 @@ private: //===--------------------------------------------------------------------===// /// Get the Value associated with a result . - Value *getValue(uint32_t id) { return valueMap.lookup(id); } + Value *getValue(uint32_t id) { + if (auto varOp = getVariable(id)) { + auto addressOfOp = opBuilder.create( + unknownLoc, varOp.type(), + opBuilder.getSymbolRefAttr(varOp.getOperation())); + return addressOfOp.pointer(); + } + return valueMap.lookup(id); + } /// Slices the first instruction out of `binary` and returns its opcode and /// operands via `opcode` and `operands` respectively. Returns failure if @@ -198,6 +217,9 @@ private: // Result to function mapping. DenseMap funcMap; + // Result to variable mapping; + DenseMap globalVariableMap; + // Result to value mapping. DenseMap valueMap; @@ -452,6 +474,76 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { + unsigned wordIndex = 0; + if (operands.size() < 3) { + return emitError( + unknownLoc, + "OpVariable needs at least 3 operands, type, and storage class"); + } + + // Result Type. + auto type = getType(operands[wordIndex]); + if (!type) { + return emitError(unknownLoc, "unknown result type : ") + << operands[wordIndex]; + } + auto ptrType = type.dyn_cast(); + if (!ptrType) { + return emitError(unknownLoc, + "expected a result type to be a spv.ptr, found : ") + << type; + } + wordIndex++; + + // Result . + auto variableID = operands[wordIndex]; + auto variableName = nameMap.lookup(variableID).str(); + if (variableName.empty()) { + variableName = "spirv_var_" + std::to_string(variableID); + } + wordIndex++; + + // Storage class. + auto storageClass = static_cast(operands[wordIndex]); + if (ptrType.getStorageClass() != storageClass) { + return emitError(unknownLoc, "mismatch in storage class of pointer type ") + << type << " and that specified in OpVariable instruction : " + << stringifyStorageClass(storageClass); + } + wordIndex++; + + // Initializer. + SymbolRefAttr initializer = nullptr; + if (wordIndex < operands.size()) { + auto initializerOp = getVariable(operands[wordIndex]); + if (!initializerOp) { + return emitError(unknownLoc, "unknown ") + << operands[wordIndex] << "used as initializer"; + } + wordIndex++; + initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); + } + if (wordIndex != operands.size()) { + return emitError(unknownLoc, + "found more operands than expected when deserializing " + "OpVariable instruction, only ") + << wordIndex << " of " << operands.size() << " processed"; + } + auto varOp = opBuilder.create( + unknownLoc, opBuilder.getTypeAttr(type), + opBuilder.getStringAttr(variableName), initializer); + + // Decorations. + if (decorations.count(variableID)) { + for (auto attr : decorations[variableID].getAttrs()) { + varOp.setAttr(attr.first, attr.second); + } + } + globalVariableMap[variableID] = varOp; + return success(); +} + LogicalResult Deserializer::processName(ArrayRef operands) { if (operands.size() < 2) { return emitError(unknownLoc, "OpName needs at least 2 operands"); @@ -887,6 +979,11 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return success(); } break; + case spirv::Opcode::OpVariable: + if (isa(opBuilder.getBlock()->getParentOp())) { + return processGlobalVariable(operands); + } + break; case spirv::Opcode::OpName: return processName(operands); case spirv::Opcode::OpTypeVoid: @@ -954,18 +1051,19 @@ Deserializer::processOp(ArrayRef words) { "and OpFunction with ") << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); } - SmallVector interface; + SmallVector interface; while (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); + auto arg = getVariable(words[wordIndex]); if (!arg) { return emitError(unknownLoc, "undefined result ") << words[wordIndex] << " while decoding OpEntryPoint"; } - interface.push_back(arg); + interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } - opBuilder.create( - unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface); + opBuilder.create(unknownLoc, exec_model, + opBuilder.getSymbolRefAttr(fnName), + opBuilder.getArrayAttr(interface)); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index d06363a1a8c5..575d995bf456 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -125,9 +125,19 @@ private: return funcIDMap.lookup(fnName); } + uint32_t findVariableID(StringRef varName) const { + return globalVarIDMap.lookup(varName); + } + + /// Emit OpName for the given `resultID`. + LogicalResult processName(uint32_t resultID, StringRef name); + /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); + /// Process a SPIR-V GlobalVariableOp + LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); + /// Process attributes that translate to decorations on the result LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); @@ -215,6 +225,9 @@ private: uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); } + /// Process spv.addressOf operations. + LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); + /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); @@ -265,6 +278,9 @@ private: /// Map from FuncOps name to s. llvm::StringMap funcIDMap; + /// Map from GlobalVariableOps name to s + llvm::StringMap globalVarIDMap; + /// Map from results of normal operations to their s DenseMap valueIDMap; }; @@ -372,6 +388,15 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); } +LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { + SmallVector nameOperands; + nameOperands.push_back(resultID); + if (failed(encodeStringLiteralInto(nameOperands, name))) { + return failure(); + } + return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); +} + namespace { template <> LogicalResult Serializer::processTypeDecoration( @@ -416,10 +441,9 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands); // Add function name. - SmallVector nameOperands; - nameOperands.push_back(funcID); - encodeStringLiteralInto(nameOperands, op.getName()); - encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); + if (failed(processName(funcID, op.getName()))) { + return failure(); + } // Declare the parameters. for (auto arg : op.getArguments()) { @@ -450,6 +474,61 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); } +LogicalResult +Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) { + // Get TypeID. + uint32_t resultTypeID = 0; + SmallVector elidedAttrs; + if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) { + return failure(); + } + elidedAttrs.push_back("type"); + SmallVector operands; + operands.push_back(resultTypeID); + auto resultID = getNextID(); + + // Encode the name. + auto varName = varOp.sym_name(); + elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); + if (failed(processName(resultID, varName))) { + return failure(); + } + globalVarIDMap[varName] = resultID; + operands.push_back(resultID); + + // Encode StorageClass. + operands.push_back(static_cast(varOp.storageClass())); + + // Encode initialization. + if (auto initializer = varOp.initializer()) { + auto initializerID = findVariableID(initializer.getValue()); + if (!initializerID) { + return emitError(varOp.getLoc(), + "invalid usage of undefined variable as initializer"); + } + operands.push_back(initializerID); + elidedAttrs.push_back("initializer"); + } + + if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable, + operands))) { + elidedAttrs.push_back("initializer"); + return failure(); + } + + // Encode decorations. + for (auto attr : varOp.getAttrs()) { + if (llvm::any_of(elidedAttrs, + [&](StringRef elided) { return attr.first.is(elided); })) { + continue; + } + if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// @@ -912,6 +991,17 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, // Operation //===----------------------------------------------------------------------===// +LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { + auto varName = addressOfOp.variable(); + auto variableID = findVariableID(varName); + if (!variableID) { + return addressOfOp.emitError("unknown result for variable ") + << varName; + } + valueIDMap[addressOfOp.pointer()] = variableID; + return success(); +} + LogicalResult Serializer::processOperation(Operation *op) { // First dispatch the methods that do not directly mirror an operation from // the SPIR-V spec @@ -924,6 +1014,12 @@ LogicalResult Serializer::processOperation(Operation *op) { if (isa(op)) { return success(); } + if (auto varOp = dyn_cast(op)) { + return processGlobalVariableOp(varOp); + } + if (auto addressOfOp = dyn_cast(op)) { + return processAddressOfOp(addressOfOp); + } return dispatchToAutogenSerialization(op); } @@ -947,14 +1043,16 @@ Serializer::processOp(spirv::EntryPointOp op) { encodeStringLiteralInto(operands, op.fn()); // Add the interface values. - for (auto val : op.interface()) { - auto id = findValueID(val); - if (!id) { - return op.emitError("referencing unintialized variable . " - "spv.EntryPoint is at the end of spv.module. All " - "referenced variables should already be defined"); + if (auto interface = op.interface()) { + for (auto var : interface.getValue()) { + auto id = findVariableID(var.cast().getValue()); + if (!id) { + return op.emitError("referencing undefined global variable." + "spv.EntryPoint is at the end of spv.module. All " + "referenced variables should already be defined"); + } + operands.push_back(id); } - operands.push_back(id); } return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir index 671a38bd1faf..dc7964d0c95c 100644 --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s // CHECK: spv.module "Logical" "VulkanKHR" { -// CHECK-NEXT: [[VAR1:%.*]] = spv.Variable bind(0, 0) : !spv.ptr -// CHECK-NEXT: [[VAR2:%.*]] = spv.Variable bind(0, 1) : !spv.ptr, StorageBuffer> +// CHECK-NEXT: spv.globalVariable !spv.ptr [[VAR1:@.*]] bind(0, 0) +// CHECK-NEXT: spv.globalVariable !spv.ptr, StorageBuffer> [[VAR2:@.*]] bind(0, 1) // CHECK-NEXT: func @kernel_1 // CHECK-NEXT: spv.Return // CHECK: spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]] diff --git a/mlir/test/Dialect/SPIRV/Serialization/entry_interface.mlir b/mlir/test/Dialect/SPIRV/Serialization/entry_interface.mlir index 970e7236d31f..8c22b39c81bc 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/entry_interface.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/entry_interface.mlir @@ -2,16 +2,16 @@ func @spirv_loadstore() -> () { spv.module "Logical" "VulkanKHR" { - // CHECK: [[VAR1:%.*]] = spv.Variable : !spv.ptr - // CHECK-NEXT: [[VAR2:%.*]] = spv.Variable : !spv.ptr + // CHECK: spv.globalVariable !spv.ptr @var2 + // CHECK-NEXT: spv.globalVariable !spv.ptr @var3 // CHECK-NEXT: func @noop({{%.*}}: !spv.ptr, {{%.*}}: !spv.ptr) - // CHECK: spv.EntryPoint "GLCompute" @noop, [[VAR1]], [[VAR2]] - %2 = spv.Variable : !spv.ptr - %3 = spv.Variable : !spv.ptr + // CHECK: spv.EntryPoint "GLCompute" @noop, @var2, @var3 + spv.globalVariable !spv.ptr @var2 + spv.globalVariable !spv.ptr @var3 func @noop(%arg0 : !spv.ptr, %arg1 : !spv.ptr) -> () { spv.Return } - spv.EntryPoint "GLCompute" @noop, %2, %3 : !spv.ptr, !spv.ptr + spv.EntryPoint "GLCompute" @noop, @var2, @var3 spv.ExecutionMode @noop "ContractionOff" } return diff --git a/mlir/test/Dialect/SPIRV/Serialization/variables.mlir b/mlir/test/Dialect/SPIRV/Serialization/variables.mlir index 15cc891fc808..3fbd0a94b819 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/variables.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/variables.mlir @@ -1,15 +1,15 @@ // RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s -// CHECK: {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr -// CHECK-NEXT: {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr -// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr, Input> -// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr, Input> +// CHECK: spv.globalVariable !spv.ptr @var0 bind(1, 0) +// CHECK-NEXT: spv.globalVariable !spv.ptr @var1 bind(0, 1) +// CHECK-NEXT: spv.globalVariable !spv.ptr, Input> @var2 built_in("GlobalInvocationId") +// CHECK-NEXT: spv.globalVariable !spv.ptr, Input> @var3 built_in("GlobalInvocationId") func @spirv_variables() -> () { spv.module "Logical" "VulkanKHR" { - %2 = spv.Variable bind(1, 0) : !spv.ptr - %3 = spv.Variable bind(0, 1): !spv.ptr - %4 = spv.Variable {built_in = "GlobalInvocationId"} : !spv.ptr, Input> - %5 = spv.Variable built_in("GlobalInvocationId") : !spv.ptr, Input> + spv.globalVariable !spv.ptr @var0 bind(1, 0) + spv.globalVariable !spv.ptr @var1 bind(0, 1) + spv.globalVariable !spv.ptr, Input> @var2 {built_in = "GlobalInvocationId"} + spv.globalVariable !spv.ptr, Input> @var3 built_in("GlobalInvocationId") } return } \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/Serialization/variables_init.mlir b/mlir/test/Dialect/SPIRV/Serialization/variables_init.mlir index 65e0703b46af..d6dd83ed45f0 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/variables_init.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/variables_init.mlir @@ -2,10 +2,10 @@ func @spirv_variables() -> () { spv.module "Logical" "VulkanKHR" { - // CHECK: [[INIT:%.*]] = spv.constant 4.000000e+00 : f32 - // CHECK: {{%.*}} = spv.Variable init([[INIT]]) bind(1, 0) : !spv.ptr - %0 = spv.constant 4.0 : f32 - %2 = spv.Variable init(%0) bind(1, 0) : !spv.ptr + // CHECK: spv.globalVariable !spv.ptr @var1 + // CHECK-NEXT: spv.globalVariable !spv.ptr @var2 initializer(@var1) bind(1, 0) + spv.globalVariable !spv.ptr @var1 + spv.globalVariable !spv.ptr @var2 initializer(@var1) bind(1, 0) } return } \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index ac9ddfd07948..052dc6871679 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -3,11 +3,10 @@ //===----------------------------------------------------------------------===// // spv.AccessChain //===----------------------------------------------------------------------===// - func @access_chain_struct() -> () { %0 = spv.constant 1: i32 %1 = spv.Variable : !spv.ptr>, Function> - // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> + // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function> return } @@ -16,7 +15,7 @@ func @access_chain_struct() -> () { func @access_chain_1D_array(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr, Function> - // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> + // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> %1 = spv.AccessChain %0[%arg0] : !spv.ptr, Function> return } @@ -25,7 +24,7 @@ func @access_chain_1D_array(%arg0 : i32) -> () { func @access_chain_2D_array_1(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> - // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> + // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr>, Function> %2 = spv.Load "Function" %1 ["Volatile"] : f32 return @@ -35,7 +34,7 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () { func @access_chain_2D_array_2(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> - // CHECK: {{.*}} = spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr>, Function> + // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr>, Function> %1 = spv.AccessChain %0[%arg0] : !spv.ptr>, Function> %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32> return @@ -115,6 +114,18 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr>, Function> return + +// ----- + +spv.module "Logical" "VulkanKHR" { + spv.globalVariable !spv.ptr>, Input> @var1 + func @access_chain() -> () { + %0 = spv.constant 1: i32 + %1 = spv._address_of @var1 : !spv.ptr>, Input> + // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Input> + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Input> + spv.Return + } } // ----- @@ -276,15 +287,15 @@ spv.module "Logical" "VulkanKHR" { } spv.module "Logical" "VulkanKHR" { - %2 = spv.Variable : !spv.ptr - %3 = spv.Variable : !spv.ptr + spv.globalVariable !spv.ptr @var2 + spv.globalVariable !spv.ptr @var3 func @do_something(%arg0 : !spv.ptr, %arg1 : !spv.ptr) -> () { %1 = spv.Load "Input" %arg0 : f32 spv.Store "Output" %arg1, %1 : f32 spv.Return } - // CHECK: spv.EntryPoint "GLCompute" @do_something, {{%.*}}, {{%.*}} : !spv.ptr, !spv.ptr - spv.EntryPoint "GLCompute" @do_something, %2, %3 : !spv.ptr, !spv.ptr + // CHECK: spv.EntryPoint "GLCompute" @do_something, @var2, @var3 + spv.EntryPoint "GLCompute" @do_something, @var2, @var3 } // ----- @@ -293,7 +304,7 @@ spv.module "Logical" "VulkanKHR" { func @do_nothing() -> () { spv.Return } - // expected-error @+1 {{custom op 'spv.EntryPoint' expected symbol reference attribute}} + // expected-error @+1 {{invalid kind of constant specified}} spv.EntryPoint "GLCompute" "do_nothing" } @@ -380,6 +391,74 @@ spv.module "Logical" "VulkanKHR" { // ----- +//===----------------------------------------------------------------------===// +// spv.globalVariable +//===----------------------------------------------------------------------===// + +spv.module "Logical" "VulkanKHR" { + // CHECK: spv.globalVariable !spv.ptr @var0 + spv.globalVariable !spv.ptr @var0 +} + +// TODO: Fix test case after initialization with constant is addressed +// spv.module "Logical" "VulkanKHR" { +// %0 = spv.constant 4.0 : f32 +// // CHECK1: spv.Variable init(%0) : !spv.ptr +// spv.globalVariable !spv.ptr @var1 init(%0) +// } + +spv.module "Logical" "VulkanKHR" { + // CHECK: spv.globalVariable !spv.ptr @var0 bind(1, 2) + spv.globalVariable !spv.ptr @var0 bind(1, 2) +} + +// TODO: Fix test case after initialization with constant is addressed +// spv.module "Logical" "VulkanKHR" { +// %0 = spv.constant 4.0 : f32 +// // CHECK1: spv.globalVariable !spv.ptr @var1 initializer(%0) {binding = 5 : i32} : !spv.ptr +// spv.globalVariable !spv.ptr @var1 initializer(%0) {binding = 5 : i32} : +// } + +spv.module "Logical" "VulkanKHR" { + // CHECK: spv.globalVariable !spv.ptr, Input> @var1 built_in("GlobalInvocationID") + spv.globalVariable !spv.ptr, Input> @var1 built_in("GlobalInvocationID") + // CHECK: spv.globalVariable !spv.ptr, Input> @var2 built_in("GlobalInvocationID") + spv.globalVariable !spv.ptr, Input> @var2 {built_in = "GlobalInvocationID"} +} + +// ----- + +spv.module "Logical" "VulkanKHR" { + // expected-error @+1 {{expected spv.ptr type}} + spv.globalVariable f32 @var0 +} + +// ----- + +spv.module "Logical" "VulkanKHR" { + // expected-error @+1 {{op initializer must be result of a spv.globalVariable op}} + spv.globalVariable !spv.ptr @var0 initializer(@var1) +} + +// ----- + +spv.module "Logical" "VulkanKHR" { + // expected-error @+1 {{storage class cannot be 'Generic'}} + spv.globalVariable !spv.ptr @var0 +} + +// ----- + +spv.module "Logical" "VulkanKHR" { + func @foo() { + // expected-error @+1 {{op failed to verify that op can only be used in a 'spv.module' block}} + spv.globalVariable !spv.ptr @var0 + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.FAdd //===----------------------------------------------------------------------===// @@ -499,7 +578,7 @@ func @iadd_scalar(%arg: i32) -> i32 { //===----------------------------------------------------------------------===// func @iequal_scalar(%arg0: i32, %arg1: i32) -> i1 { - // CHECK: {{.*}} = spv.IEqual {{.*}}, {{.*}} : i32 + // CHECK: spv.IEqual {{.*}}, {{.*}} : i32 %0 = spv.IEqual %arg0, %arg1 : i32 return %0 : i1 } @@ -507,7 +586,7 @@ func @iequal_scalar(%arg0: i32, %arg1: i32) -> i1 { // ----- func @iequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.IEqual {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.IEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.IEqual %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -519,7 +598,7 @@ func @iequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> //===----------------------------------------------------------------------===// func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.INotEqual {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.INotEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.INotEqual %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -670,6 +749,19 @@ func @aligned_load_incorrect_attributes() -> () { // ----- +spv.module "Logical" "VulkanKHR" { + spv.globalVariable !spv.ptr @var0 + // CHECK_LABEL: @simple_load + func @simple_load() -> () { + // CHECK: spv.Load "Input" {{%.*}} : f32 + %0 = spv._address_of @var0 : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.Return //===----------------------------------------------------------------------===// @@ -708,7 +800,7 @@ func @sdiv_scalar(%arg: i32) -> i32 { //===----------------------------------------------------------------------===// func @sgt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.SGreaterThan {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.SGreaterThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SGreaterThan %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -720,7 +812,7 @@ func @sgt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { //===----------------------------------------------------------------------===// func @sge_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.SGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.SGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SGreaterThanEqual %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -732,7 +824,7 @@ func @sge_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { //===----------------------------------------------------------------------===// func @slt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.SLessThan {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.SLessThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SLessThan %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -744,7 +836,7 @@ func @slt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { //===----------------------------------------------------------------------===// func @slte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.SLessThanEqual {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.SLessThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.SLessThanEqual %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -883,6 +975,18 @@ func @aligned_store_incorrect_attributes(%arg0 : f32) -> () { // ----- +spv.module "Logical" "VulkanKHR" { + spv.globalVariable !spv.ptr @var0 + func @simple_store(%arg0 : f32) -> () { + %0 = spv._address_of @var0 : !spv.ptr + // CHECK: spv.Store "Input" {{%.*}}, {{%.*}} : f32 + spv.Store "Input" %0, %arg0 : f32 + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.UDiv //===----------------------------------------------------------------------===// @@ -900,7 +1004,7 @@ func @udiv_scalar(%arg: i32) -> i32 { //===----------------------------------------------------------------------===// func @ugt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.UGreaterThan {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.UGreaterThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.UGreaterThan %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -912,7 +1016,7 @@ func @ugt_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { //===----------------------------------------------------------------------===// func @ugte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.UGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.UGreaterThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.UGreaterThanEqual %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -924,7 +1028,7 @@ func @ugte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { //===----------------------------------------------------------------------===// func @ult_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.ULessThan {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.ULessThan {{.*}}, {{.*}} : vector<4xi32> %0 = spv.ULessThan %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -936,7 +1040,7 @@ func @ult_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { //===----------------------------------------------------------------------===// func @ulte_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vector<4xi1> { - // CHECK: {{.*}} = spv.ULessThanEqual {{.*}}, {{.*}} : vector<4xi32> + // CHECK: spv.ULessThanEqual {{.*}}, {{.*}} : vector<4xi32> %0 = spv.ULessThanEqual %arg0, %arg1 : vector<4xi32> return %0 : vector<4xi1> } @@ -967,29 +1071,29 @@ func @variable_no_init(%arg0: f32) -> () { func @variable_init() -> () { %0 = spv.constant 4.0 : f32 - // CHECK: spv.Variable init(%0) : !spv.ptr - %1 = spv.Variable init(%0) : !spv.ptr + // CHECK: spv.Variable init(%0) : !spv.ptr + %1 = spv.Variable init(%0) : !spv.ptr return } func @variable_bind() -> () { - // CHECK: spv.Variable bind(1, 2) : !spv.ptr - %0 = spv.Variable bind(1, 2) : !spv.ptr + // CHECK: spv.Variable bind(1, 2) : !spv.ptr + %0 = spv.Variable bind(1, 2) : !spv.ptr return } func @variable_init_bind() -> () { %0 = spv.constant 4.0 : f32 - // CHECK: spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr - %1 = spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr + // CHECK: spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr + %1 = spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr return } func @variable_builtin() -> () { - // CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Input> - %1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Input> - // CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Input> - %2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr, Input> + // CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Function> + %1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Function> + // CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Function> + %2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr, Function> return } @@ -1005,23 +1109,14 @@ func @expect_ptr_result_type(%arg0: f32) -> () { func @variable_init(%arg0: f32) -> () { // expected-error @+1 {{op initializer must be the result of a spv.Constant or module-level spv.Variable op}} - %0 = spv.Variable init(%arg0) : !spv.ptr - return -} - -// ----- - -func @storage_class_mismatch() -> () { - %0 = spv.constant 5.0 : f32 - // expected-error @+1 {{storage class must match result pointer's storage class}} - %1 = "spv.Variable"(%0) {storage_class = 2: i32} : (f32) -> !spv.ptr + %0 = spv.Variable init(%arg0) : !spv.ptr return } // ----- func @cannot_be_generic_storage_class(%arg0: f32) -> () { - // expected-error @+1 {{storage class cannot be 'Generic'}} + // expected-error @+1 {{op can only be used to model function-level variables. Use spv.globalVariable for module-level variables}} %0 = spv.Variable : !spv.ptr return }