forked from OSchip/llvm-project
Add spv.specConstant and spv._reference_of
Similar to global variables, specialization constants also live in the module scope and can be referenced by instructions in functions in native SPIR-V. A direct modelling would be to allow functions in the SPIR-V dialect to implicit capture, but it means we are losing the ability to write passes for Functions. While in SPIR-V normally we want to process the module as a whole, it's not common to see multiple functions get used so we'd like to leave the door open for those cases. Therefore, similar to global variables, we introduce spv.specConstant to model three SPIR-V instructions: OpSpecConstantTrue, OpSpecConstantFalse, and OpSpecConstant. They do not return SSA value results; instead they have symbols and can only be referenced by the symbols. To use it in a function, we need to have another op spv._reference_of to turn the symbol into an SSA value. This breaks the tie and makes functions still explicit capture. Previously specialization constants were handled similarly as normal constants. That is incorrect given that specialization constant actually acts more like variable (without need to load and store). E.g., they cannot be de-duplicated like normal constants. This CL also refines various documents and comments. PiperOrigin-RevId: 264455172
This commit is contained in:
parent
82cf6051ee
commit
f4934bcc3e
|
@ -34,24 +34,25 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
|
|||
let summary = "Get the address of a global variable.";
|
||||
|
||||
let description = [{
|
||||
Variables in module scope are defined using symbol names. This
|
||||
instruction generates an SSA value that can be used to refer to
|
||||
the symbol within function scope for use in instructions that
|
||||
expect an SSA value. This operation has no equivalent SPIR-V
|
||||
instruction. Since variables in module scope in SPIR-V dialect are
|
||||
of pointer type, this instruction returns a pointer type as well,
|
||||
and the type is same as the variable referenced.
|
||||
Variables in module scope are defined using symbol names. This op generates
|
||||
an SSA value that can be used to refer to the symbol within function scope
|
||||
for use in ops that expect an SSA value. This operation has no corresponding
|
||||
SPIR-V instruction; it's merely used for modelling purpose in the SPIR-V
|
||||
dialect. Since variables in module scope in SPIR-V dialect are of pointer
|
||||
type, this op returns a pointer type as well, and the type is the same as
|
||||
the variable referenced.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
address-of-op ::= ssa-id `=` `spv.addressOf` `@`string-literal : pointer-type
|
||||
spv-address-of-op ::= ssa-id `=` `spv._address_of` symbol-ref-id
|
||||
`:` spirv-pointer-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.addressOf @var1 : !spv.ptr<f32, Input>
|
||||
%0 = spv._address_of @global_var : !spv.ptr<f32, Input>
|
||||
```
|
||||
}];
|
||||
|
||||
|
@ -66,6 +67,54 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
|
|||
let hasOpcode = 0;
|
||||
}
|
||||
|
||||
def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
|
||||
let summary = "The op that declares a SPIR-V normal constant";
|
||||
|
||||
let description = [{
|
||||
This op declares a SPIR-V normal constant. SPIR-V has multiple constant
|
||||
instructions covering different constant types:
|
||||
|
||||
* `OpConstantTrue` and `OpConstantFalse` for boolean constants
|
||||
* `OpConstant` for scalar constants
|
||||
* `OpConstantComposite` for composite constants
|
||||
* `OpConstantNull` for null constants
|
||||
* ...
|
||||
|
||||
Having such a plethora of constant instructions renders IR transformations
|
||||
more tedious. Therefore, we use a single `spv.constant` op to represent
|
||||
them all. Note that conversion between those SPIR-V constant instructions
|
||||
and this op is purely mechanical; so it can be scoped to the binary
|
||||
(de)serialzation process.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
spv-constant-op ::= ssa-id `=` `spv.constant` attribute-value
|
||||
(`:` spirv-type)?
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.constant true
|
||||
%1 = spv.constant dense<[2, 3]> : vector<2xf32>
|
||||
%2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
|
||||
```
|
||||
|
||||
TODO(antiagainst): support constant structs
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr:$value
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_Type:$constant
|
||||
);
|
||||
|
||||
let hasOpcode = 0;
|
||||
}
|
||||
|
||||
def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
|
||||
let summary = [{
|
||||
Declare an entry point, its execution model, and its interface.
|
||||
|
@ -81,11 +130,11 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
|
|||
OpEntryPoint instructions with the same Execution Model and the same
|
||||
Name string.
|
||||
|
||||
Interface is a list of symbol references to spv.globalVariable
|
||||
Interface is a list of symbol references to `spv.globalVariable`
|
||||
operations. These declare the set of global variables from a
|
||||
module that form the interface of this entry point. The set of
|
||||
Interface symbols must be equal to or a superset of the
|
||||
spv.globalVariables referenced by the entry point’s static call
|
||||
`spv.globalVariable`s referenced by the entry point’s static call
|
||||
tree, within the interface’s storage classes. Before version 1.4,
|
||||
the interface’s storage classes are limited to the Input and
|
||||
Output storage classes. Starting with version 1.4, the interface’s
|
||||
|
@ -140,14 +189,14 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
|
|||
Initializer is optional. If Initializer is present, it will be
|
||||
the initial value of the variable’s memory content. Initializer
|
||||
must be an symbol defined from a constant instruction or other
|
||||
spv.globalVariable operation in module scope. Initializer must
|
||||
`spv.globalVariable` operation in module scope. Initializer must
|
||||
have the same type as the type of the defined symbol.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
variable-op ::= `spv.globalVariable` spirv-type string-literal
|
||||
(`initializer(` symbol-reference `)`)?
|
||||
variable-op ::= `spv.globalVariable` spirv-type symbol-ref-id
|
||||
(`initializer(` symbol-ref-id `)`)?
|
||||
(`bind(` integer-literal, integer-literal `)`)?
|
||||
(`built_in(` string-literal `)`)?
|
||||
attribute-dict?
|
||||
|
@ -160,10 +209,10 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
|
|||
For example:
|
||||
|
||||
```
|
||||
spv.Variable !spv.ptr<f32, Input> @var0
|
||||
spv.Variable !spv.ptr<f32, Output> @var2 initializer(@var0)
|
||||
spv.Variable !spv.ptr<f32, Uniform> @var bind(1, 2)
|
||||
spv.Variable !spv.ptr<vector<3xi32>> @var3 built_in("GlobalInvocationID")
|
||||
spv.globalVariable @var0 : !spv.ptr<f32, Input> @var0
|
||||
spv.globalVariable @var1 initializer(@var0) : !spv.ptr<f32, Output>
|
||||
spv.globalVariable @var2 bind(1, 2) : !spv.ptr<f32, Uniform>
|
||||
spv.globalVariable @var3 built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
|
||||
```
|
||||
}];
|
||||
|
||||
|
@ -286,53 +335,81 @@ def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> {
|
|||
let hasOpcode = 0;
|
||||
}
|
||||
|
||||
def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
|
||||
let summary = "The op that declares a SPIR-V constant";
|
||||
def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
|
||||
let summary = "Reference a specialization constant.";
|
||||
|
||||
let description = [{
|
||||
This op declares a SPIR-V constant. SPIR-V has multiple constant
|
||||
instructions covering different constant types:
|
||||
|
||||
* `OpConstantTrue` and `OpConstantFalse` for boolean constants
|
||||
* `OpConstant` for scalar constants
|
||||
* `OpConstantComposite` for composite constants
|
||||
* `OpConstantNull` for null constants
|
||||
* ...
|
||||
|
||||
Having such a plethora of constant instructions renders IR transformations
|
||||
more tedious. Therefore, we use a single `spv.constant` op to represent
|
||||
them all. Note that conversion between those SPIR-V constant instructions
|
||||
and this op is purely mechanical; so it can be scoped to the binary
|
||||
(de)serialzation process.
|
||||
Specialization constant in module scope are defined using symbol names.
|
||||
This op generates an SSA value that can be used to refer to the symbol
|
||||
within function scope for use in ops that expect an SSA value.
|
||||
This operation has no corresponding SPIR-V instruction; it's merely used
|
||||
for modelling purpose in the SPIR-V dialect. This op's return type is
|
||||
the same as the specialization constant.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
spv-constant-op ::= ssa-id `=` `spv.constant` (`spec`)? attribute-value
|
||||
(`:` spirv-type)?
|
||||
spv-reference-of-op ::= ssa-id `=` `spv._reference_of` symbol-ref-id
|
||||
`:` spirv-scalar-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.constant spec true
|
||||
%1 = spv.constant dense<[2, 3]> : vector<2xf32>
|
||||
%2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
|
||||
%0 = spv._reference_of @spec_const : f32
|
||||
```
|
||||
|
||||
TODO(antiagainst): support constant structs
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyAttr:$value,
|
||||
UnitAttr:$is_spec_const
|
||||
SymbolRefAttr:$spec_const
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_Type:$constant
|
||||
SPV_Type:$reference
|
||||
);
|
||||
|
||||
let hasOpcode = 0;
|
||||
}
|
||||
|
||||
def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope]> {
|
||||
let summary = "The op that declares a SPIR-V specialization constant";
|
||||
|
||||
let description = [{
|
||||
This op declares a SPIR-V scalar specialization constant. SPIR-V has
|
||||
multiple constant instructions covering different scalar types:
|
||||
|
||||
* `OpSpecConstantTrue` and `OpSpecConstantFalse` for boolean constants
|
||||
* `OpSpecConstant` for scalar constants
|
||||
|
||||
Similar as `spv.constant`, this op represents all of the above cases.
|
||||
`OpSpecConstantComposite` and `OpSpecConstantOp` are modelled with
|
||||
separate ops.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
spv-spec-constant-op ::= `spv.specConstant` symbol-ref-id
|
||||
`=` attribute-value (`:` spirv-type)?
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
spv.specConstant @spec_const1 = true
|
||||
spv.specConstant @spec_const2 = 42 : i32
|
||||
```
|
||||
|
||||
TODO(antiagainst): support composite spec cosntants with another op
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$sym_name,
|
||||
AnyAttr:$default_value
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let hasOpcode = 0;
|
||||
}
|
||||
|
||||
#endif // SPIRV_STRUCTURE_OPS
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
|
@ -32,11 +33,12 @@ using namespace mlir;
|
|||
|
||||
// TODO(antiagainst): generate these strings using ODS.
|
||||
static constexpr const char kAlignmentAttrName[] = "alignment";
|
||||
static constexpr const char kDefaultValueAttrName[] = "default_value";
|
||||
static constexpr const char kFnNameAttrName[] = "fn";
|
||||
static constexpr const char kIndicesAttrName[] = "indices";
|
||||
static constexpr const char kInitializerAttrName[] = "initializer";
|
||||
static constexpr const char kInterfaceAttrName[] = "interface";
|
||||
static constexpr const char kIsSpecConstName[] = "is_spec_const";
|
||||
static constexpr const char kSpecConstAttrName[] = "spec_const";
|
||||
static constexpr const char kTypeAttrName[] = "type";
|
||||
static constexpr const char kValueAttrName[] = "value";
|
||||
static constexpr const char kValuesAttrName[] = "values";
|
||||
|
@ -469,11 +471,11 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
|||
auto varOp =
|
||||
moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
|
||||
if (!varOp) {
|
||||
return addressOfOp.emitError("expected spv.globalVariable symbol");
|
||||
return addressOfOp.emitOpError("expected spv.globalVariable symbol");
|
||||
}
|
||||
if (addressOfOp.pointer()->getType() != varOp.type()) {
|
||||
return addressOfOp.emitError(
|
||||
"mismatch in result type and type of global variable referenced");
|
||||
return addressOfOp.emitOpError(
|
||||
"result type mismatch with the referenced global variable's type");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -583,9 +585,6 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
|
||||
if (succeeded(parser->parseOptionalKeyword("spec")))
|
||||
state->addAttribute(kIsSpecConstName, parser->getBuilder().getUnitAttr());
|
||||
|
||||
Attribute value;
|
||||
if (parser->parseAttribute(value, kValueAttrName, state->attributes))
|
||||
return failure();
|
||||
|
@ -602,8 +601,7 @@ static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
|
|||
}
|
||||
|
||||
static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ConstantOp::getOperationName()
|
||||
<< (constOp.is_spec_const() ? " spec " : " ") << constOp.value();
|
||||
*printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
|
||||
if (constOp.getType().isa<spirv::ArrayType>()) {
|
||||
*printer << " : " << constOp.getType();
|
||||
}
|
||||
|
@ -810,17 +808,16 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
|||
if (varOp.storageClass() == spirv::StorageClass::Generic)
|
||||
return varOp.emitOpError("storage class cannot be 'Generic'");
|
||||
|
||||
if (auto initializer =
|
||||
varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
|
||||
// Get the module
|
||||
if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
|
||||
auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
|
||||
// TODO: Currently only variable initialization with other variables is
|
||||
// supported. They could be constants as well, but this needs module-level
|
||||
// constants to have symbol name as well.
|
||||
if (!moduleOp.lookupSymbol<spirv::GlobalVariableOp>(
|
||||
initializer.getValue())) {
|
||||
return varOp.emitOpError(
|
||||
"initializer must be result of a spv.globalVariable op");
|
||||
auto *initOp = moduleOp.lookupSymbol(init.getValue());
|
||||
// TODO: Currently only variable initialization with specialization
|
||||
// constants and other variables is supported. They could be normal
|
||||
// constants in the module scope as well.
|
||||
if (!initOp || !(isa<spirv::GlobalVariableOp>(initOp) ||
|
||||
isa<spirv::SpecConstantOp>(initOp))) {
|
||||
return varOp.emitOpError("initializer must be result of a "
|
||||
"spv.specConstant or spv.globalVariable op");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1033,6 +1030,42 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv._reference_of
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseReferenceOfOp(OpAsmParser *parser,
|
||||
OperationState *state) {
|
||||
SymbolRefAttr constRefAttr;
|
||||
Type type;
|
||||
if (parser->parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
|
||||
state->attributes) ||
|
||||
parser->parseColonType(type)) {
|
||||
return failure();
|
||||
}
|
||||
return parser->addTypeToList(type, state->types);
|
||||
}
|
||||
|
||||
static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ReferenceOfOp::getOperationName() << " @"
|
||||
<< referenceOfOp.spec_const() << " : "
|
||||
<< referenceOfOp.reference()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
||||
auto moduleOp = referenceOfOp.getParentOfType<spirv::ModuleOp>();
|
||||
auto specConstOp =
|
||||
moduleOp.lookupSymbol<spirv::SpecConstantOp>(referenceOfOp.spec_const());
|
||||
if (!specConstOp) {
|
||||
return referenceOfOp.emitOpError("expected spv.specConstant symbol");
|
||||
}
|
||||
if (referenceOfOp.reference()->getType() !=
|
||||
specConstOp.default_value().getType()) {
|
||||
return referenceOfOp.emitOpError("result type mismatch with the referenced "
|
||||
"specialization constant's type");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Return
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1084,6 +1117,50 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.specConstant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseSpecConstantOp(OpAsmParser *parser,
|
||||
OperationState *state) {
|
||||
StringAttr nameAttr;
|
||||
Attribute valueAttr;
|
||||
|
||||
if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||
state->attributes) ||
|
||||
parser->parseEqual() ||
|
||||
parser->parseAttribute(valueAttr, kDefaultValueAttrName,
|
||||
state->attributes))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::SpecConstantOp constOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::SpecConstantOp::getOperationName() << " @"
|
||||
<< constOp.sym_name() << " = ";
|
||||
printer->printAttribute(constOp.default_value());
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::SpecConstantOp constOp) {
|
||||
auto value = constOp.default_value();
|
||||
|
||||
switch (value.getKind()) {
|
||||
case StandardAttributes::Bool:
|
||||
case StandardAttributes::Integer:
|
||||
case StandardAttributes::Float: {
|
||||
// Make sure bitwidth is allowed.
|
||||
auto *dialect = static_cast<spirv::SPIRVDialect *>(constOp.getDialect());
|
||||
if (!dialect->isValidSPIRVType(value.getType()))
|
||||
return constOp.emitOpError("default value bitwidth disallowed");
|
||||
return success();
|
||||
}
|
||||
default:
|
||||
return constOp.emitOpError(
|
||||
"default value can only be a bool, integer, or float scalar");
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1220,17 +1297,27 @@ static LogicalResult verify(spirv::VariableOp varOp) {
|
|||
if (varOp.getNumOperands() != 0) {
|
||||
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
|
||||
// a global (module scope) OpVariable instruction".
|
||||
bool valid = false;
|
||||
if (auto *initOp = varOp.getOperand(0)->getDefiningOp()) {
|
||||
if (llvm::isa<spirv::ConstantOp>(initOp)) {
|
||||
valid = true;
|
||||
} else if (llvm::isa<spirv::VariableOp>(initOp)) {
|
||||
valid = llvm::isa_and_nonnull<spirv::ModuleOp>(initOp->getParentOp());
|
||||
}
|
||||
}
|
||||
if (!valid)
|
||||
auto *initOp = varOp.getOperand(0)->getDefiningOp();
|
||||
if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant
|
||||
isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
|
||||
isa<spirv::AddressOfOp>(initOp)))
|
||||
return varOp.emitOpError("initializer must be the result of a "
|
||||
"spv.Constant or module-level spv.Variable op");
|
||||
"constant or spv.globalVariable op");
|
||||
}
|
||||
|
||||
// TODO(antiagainst): generate these strings using ODS.
|
||||
auto *op = varOp.getOperation();
|
||||
auto descriptorSetName =
|
||||
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
|
||||
auto bindingName =
|
||||
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
|
||||
auto builtInName =
|
||||
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
|
||||
|
||||
for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
|
||||
if (op->getAttr(attr))
|
||||
return varOp.emitOpError("cannot have '")
|
||||
<< attr << "' attribute (only allowed in spv.globalVariable)";
|
||||
}
|
||||
|
||||
return success();
|
||||
|
|
|
@ -79,7 +79,7 @@ private:
|
|||
/// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
|
||||
LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Process SPIR-V OpName with `operands`
|
||||
/// Process SPIR-V OpName with `operands`.
|
||||
LogicalResult processName(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Method to process an OpDecorate instruction.
|
||||
|
@ -94,17 +94,27 @@ private:
|
|||
/// them to their handler method accordingly.
|
||||
LogicalResult processFunction(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Process the OpVariable instructions at current `offset` into `binary`. It
|
||||
/// is expected that this method is used for variables that are to be defined
|
||||
/// at module scope and will be deserialized into a spv.globalVariable
|
||||
/// Returns a symbol to be used for the specialization constant with the given
|
||||
/// result <id>. This tries to use the specialization constant's OpName if
|
||||
/// exists; otherwise creates one based on the <id>.
|
||||
std::string getSpecConstantSymbol(uint32_t id);
|
||||
|
||||
/// Gets the specialization constant with the given result <id>.
|
||||
spirv::SpecConstantOp getSpecConstant(uint32_t id) {
|
||||
return specConstMap.lookup(id);
|
||||
}
|
||||
|
||||
/// Processes the OpVariable instructions at current `offset` into `binary`.
|
||||
/// It is expected that this method is used for variables that are to be
|
||||
/// defined at module scope and will be deserialized into a spv.globalVariable
|
||||
/// instruction.
|
||||
LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Get the FuncOp associated with a result <id> of OpFunction.
|
||||
/// Gets the FuncOp associated with a result <id> of OpFunction.
|
||||
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
|
||||
|
||||
/// Get the global variable associated with a result <id> of OpVariable
|
||||
spirv::GlobalVariableOp getVariable(uint32_t id) {
|
||||
/// Gets the global variable associated with a result <id> of OpVariable.
|
||||
spirv::GlobalVariableOp getGlobalVariable(uint32_t id) {
|
||||
return globalVariableMap.lookup(id);
|
||||
}
|
||||
|
||||
|
@ -142,10 +152,9 @@ private:
|
|||
LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
|
||||
bool isSpec);
|
||||
|
||||
/// Processes a SPIR-V Op{|Spec}ConstantComposite instruction with the given
|
||||
/// `operands`. `isSpec` indicates whether this is a specialization constant.
|
||||
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands,
|
||||
bool isSpec);
|
||||
/// Processes a SPIR-V OpConstantComposite instruction with the given
|
||||
/// `operands`.
|
||||
LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
|
||||
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
|
||||
|
@ -155,15 +164,11 @@ private:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Get the Value associated with a result <id>.
|
||||
Value *getValue(uint32_t id) {
|
||||
if (auto varOp = getVariable(id)) {
|
||||
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
|
||||
unknownLoc, varOp.type(),
|
||||
opBuilder.getSymbolRefAttr(varOp.getOperation()));
|
||||
return addressOfOp.pointer();
|
||||
}
|
||||
return valueMap.lookup(id);
|
||||
}
|
||||
///
|
||||
/// This method inserts "casting" ops (`spv._address_of` and
|
||||
/// `spv._reference_of`) to turn an symbol into a SSA value for handling uses
|
||||
/// of module scope constants/variables in functions.
|
||||
Value *getValue(uint32_t id);
|
||||
|
||||
/// Slices the first instruction out of `binary` and returns its opcode and
|
||||
/// operands via `opcode` and `operands` respectively. Returns failure if
|
||||
|
@ -223,7 +228,10 @@ private:
|
|||
// Result <id> to function mapping.
|
||||
DenseMap<uint32_t, FuncOp> funcMap;
|
||||
|
||||
// Result <id> to variable mapping;
|
||||
// Result <id> to variable mapping.
|
||||
DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
|
||||
|
||||
// Result <id> to variable mapping.
|
||||
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
|
||||
|
||||
// Result <id> to value mapping.
|
||||
|
@ -500,6 +508,14 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
|
||||
auto constName = nameMap.lookup(id).str();
|
||||
if (constName.empty()) {
|
||||
constName = "spirv_spec_const_" + std::to_string(id);
|
||||
}
|
||||
return constName;
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
|
||||
unsigned wordIndex = 0;
|
||||
if (operands.size() < 3) {
|
||||
|
@ -542,7 +558,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
|
|||
// Initializer.
|
||||
SymbolRefAttr initializer = nullptr;
|
||||
if (wordIndex < operands.size()) {
|
||||
auto initializerOp = getVariable(operands[wordIndex]);
|
||||
auto initializerOp = getGlobalVariable(operands[wordIndex]);
|
||||
if (!initializerOp) {
|
||||
return emitError(unknownLoc, "unknown <id> ")
|
||||
<< operands[wordIndex] << "used as initializer";
|
||||
|
@ -834,8 +850,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
|
|||
<< bitwidth;
|
||||
};
|
||||
|
||||
spirv::ConstantOp op;
|
||||
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
|
||||
auto resultID = operands[1];
|
||||
|
||||
if (auto intType = resultType.dyn_cast<IntegerType>()) {
|
||||
auto bitwidth = intType.getWidth();
|
||||
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
|
||||
|
@ -857,9 +873,21 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
|
|||
}
|
||||
|
||||
auto attr = opBuilder.getIntegerAttr(intType, value);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr,
|
||||
isSpecConst);
|
||||
} else if (auto floatType = resultType.dyn_cast<FloatType>()) {
|
||||
|
||||
if (isSpec) {
|
||||
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
|
||||
auto op =
|
||||
opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr);
|
||||
specConstMap[resultID] = op;
|
||||
} else {
|
||||
auto op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
|
||||
valueMap[resultID] = op.getResult();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto floatType = resultType.dyn_cast<FloatType>()) {
|
||||
auto bitwidth = floatType.getWidth();
|
||||
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
|
||||
return failure();
|
||||
|
@ -883,15 +911,22 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
|
|||
}
|
||||
|
||||
auto attr = opBuilder.getFloatAttr(floatType, value);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr,
|
||||
isSpecConst);
|
||||
} else {
|
||||
return emitError(unknownLoc, "OpConstant can only generate values of "
|
||||
"scalar integer or floating-point type");
|
||||
if (isSpec) {
|
||||
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
|
||||
auto op =
|
||||
opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr);
|
||||
specConstMap[resultID] = op;
|
||||
} else {
|
||||
auto op =
|
||||
opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
|
||||
valueMap[resultID] = op.getResult();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
valueMap[operands[1]] = op.getResult();
|
||||
return success();
|
||||
return emitError(unknownLoc, "OpConstant can only generate values of "
|
||||
"scalar integer or floating-point type");
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processConstantBool(bool isTrue,
|
||||
|
@ -905,17 +940,23 @@ LogicalResult Deserializer::processConstantBool(bool isTrue,
|
|||
}
|
||||
|
||||
auto attr = opBuilder.getBoolAttr(isTrue);
|
||||
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
|
||||
auto op = opBuilder.create<spirv::ConstantOp>(
|
||||
unknownLoc, opBuilder.getI1Type(), attr, isSpecConst);
|
||||
auto resultID = operands[1];
|
||||
if (isSpec) {
|
||||
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
|
||||
auto op =
|
||||
opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName, attr);
|
||||
specConstMap[resultID] = op;
|
||||
} else {
|
||||
auto op = opBuilder.create<spirv::ConstantOp>(unknownLoc,
|
||||
opBuilder.getI1Type(), attr);
|
||||
valueMap[resultID] = op.getResult();
|
||||
}
|
||||
|
||||
valueMap[operands[1]] = op.getResult();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::processConstantComposite(ArrayRef<uint32_t> operands,
|
||||
bool isSpec) {
|
||||
Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() < 2) {
|
||||
return emitError(unknownLoc,
|
||||
"OpConstantComposite must have type <id> and result <id>");
|
||||
|
@ -952,15 +993,12 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands,
|
|||
}
|
||||
|
||||
spirv::ConstantOp op;
|
||||
UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
|
||||
if (auto vectorType = resultType.dyn_cast<VectorType>()) {
|
||||
auto attr = opBuilder.getDenseElementsAttr(vectorType, elements);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
|
||||
isSpecConst);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
|
||||
} else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
|
||||
auto attr = opBuilder.getArrayAttr(elements);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
|
||||
isSpecConst);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
|
||||
} else {
|
||||
return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
|
||||
<< resultType;
|
||||
|
@ -986,9 +1024,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
|
|||
if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
|
||||
resultType.isa<VectorType>()) {
|
||||
auto attr = opBuilder.getZeroAttr(resultType);
|
||||
UnitAttr isSpecConst;
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
|
||||
isSpecConst);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
|
||||
} else {
|
||||
return emitError(unknownLoc, "unsupported OpConstantNull type: ")
|
||||
<< resultType;
|
||||
|
@ -1002,6 +1038,22 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
|
|||
// Instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Value *Deserializer::getValue(uint32_t id) {
|
||||
if (auto varOp = getGlobalVariable(id)) {
|
||||
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
|
||||
unknownLoc, varOp.type(),
|
||||
opBuilder.getSymbolRefAttr(varOp.getOperation()));
|
||||
return addressOfOp.pointer();
|
||||
}
|
||||
if (auto constOp = getSpecConstant(id)) {
|
||||
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
|
||||
unknownLoc, constOp.default_value().getType(),
|
||||
opBuilder.getSymbolRefAttr(constOp.getOperation()));
|
||||
return referenceOfOp.reference();
|
||||
}
|
||||
return valueMap.lookup(id);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::sliceInstruction(spirv::Opcode &opcode,
|
||||
ArrayRef<uint32_t> &operands,
|
||||
|
@ -1069,9 +1121,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
case spirv::Opcode::OpSpecConstant:
|
||||
return processConstant(operands, /*isSpec=*/true);
|
||||
case spirv::Opcode::OpConstantComposite:
|
||||
return processConstantComposite(operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstantComposite:
|
||||
return processConstantComposite(operands, /*isSpec=*/true);
|
||||
return processConstantComposite(operands);
|
||||
case spirv::Opcode::OpConstantTrue:
|
||||
return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
|
||||
case spirv::Opcode::OpSpecConstantTrue:
|
||||
|
@ -1124,7 +1174,7 @@ Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
|
|||
}
|
||||
SmallVector<Attribute, 4> interface;
|
||||
while (wordIndex < words.size()) {
|
||||
auto arg = getVariable(words[wordIndex]);
|
||||
auto arg = getGlobalVariable(words[wordIndex]);
|
||||
if (!arg) {
|
||||
return emitError(unknownLoc, "undefined result <id> ")
|
||||
<< words[wordIndex] << " while decoding OpEntryPoint";
|
||||
|
|
|
@ -118,18 +118,24 @@ private:
|
|||
// Module structure
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult processMemoryModel();
|
||||
|
||||
LogicalResult processConstantOp(spirv::ConstantOp op);
|
||||
|
||||
uint32_t findFunctionID(StringRef fnName) const {
|
||||
return funcIDMap.lookup(fnName);
|
||||
uint32_t findSpecConstID(StringRef constName) const {
|
||||
return specConstIDMap.lookup(constName);
|
||||
}
|
||||
|
||||
uint32_t findVariableID(StringRef varName) const {
|
||||
return globalVarIDMap.lookup(varName);
|
||||
}
|
||||
|
||||
uint32_t findFunctionID(StringRef fnName) const {
|
||||
return funcIDMap.lookup(fnName);
|
||||
}
|
||||
|
||||
LogicalResult processMemoryModel();
|
||||
|
||||
LogicalResult processConstantOp(spirv::ConstantOp op);
|
||||
|
||||
LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
|
||||
|
||||
/// Emit OpName for the given `resultID`.
|
||||
LogicalResult processName(uint32_t resultID, StringRef name);
|
||||
|
||||
|
@ -190,17 +196,15 @@ private:
|
|||
/// and `valueAttr`. `constType` is needed here because we can interpret the
|
||||
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
|
||||
/// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
|
||||
/// constants. If `isSpec` is true, then the constant will be serialized as
|
||||
/// a specialization constant.
|
||||
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
|
||||
bool isSpec);
|
||||
/// constants.
|
||||
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
|
||||
|
||||
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
|
||||
/// with a proper OpConstant* instruction and pushes literal values for the
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareBoolVectorConstant(Location loc,
|
||||
DenseIntElementsAttr elementsAttr,
|
||||
bool isSpec, spirv::Opcode &opcode,
|
||||
spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
|
||||
|
@ -208,7 +212,7 @@ private:
|
|||
/// constant to `operands`.
|
||||
LogicalResult prepareIntVectorConstant(Location loc,
|
||||
DenseIntElementsAttr elementsAttr,
|
||||
bool isSpec, spirv::Opcode &opcode,
|
||||
spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
/// Prepares float ElementsAttr serialization. This method updates `opcode`
|
||||
|
@ -216,14 +220,24 @@ private:
|
|||
/// constant to `operands`.
|
||||
LogicalResult prepareFloatVectorConstant(Location loc,
|
||||
DenseFPElementsAttr elementsAttr,
|
||||
bool isSpec, spirv::Opcode &opcode,
|
||||
spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
|
||||
/// Prepares scalar attribute serialization. This method emits corresponding
|
||||
/// OpConstant* and returns the result <id> associated with it. Returns 0 if
|
||||
/// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
|
||||
/// true, then the constant will be serialized as a specialization constant.
|
||||
uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
|
||||
bool isSpec = false);
|
||||
|
||||
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
|
||||
uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
|
||||
bool isSpec = false);
|
||||
|
||||
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
|
||||
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
|
||||
bool isSpec = false);
|
||||
|
||||
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
|
||||
bool isSpec = false);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operations
|
||||
|
@ -231,9 +245,10 @@ private:
|
|||
|
||||
uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
|
||||
|
||||
/// Process spv.addressOf operations.
|
||||
LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
|
||||
|
||||
LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
|
||||
|
||||
/// Main dispatch method for serializing an operation.
|
||||
LogicalResult processOperation(Operation *op);
|
||||
|
||||
|
@ -275,19 +290,22 @@ private:
|
|||
SmallVector<uint32_t, 0> typesGlobalValues;
|
||||
SmallVector<uint32_t, 0> functions;
|
||||
|
||||
/// Map from type used in SPIR-V module to their <id>s
|
||||
/// Map from type used in SPIR-V module to their <id>s.
|
||||
DenseMap<Type, uint32_t> typeIDMap;
|
||||
|
||||
/// Map from constant values to their <id>s
|
||||
/// Map from constant values to their <id>s.
|
||||
DenseMap<Attribute, uint32_t> constIDMap;
|
||||
|
||||
/// Map from specialization constant names to their <id>s.
|
||||
llvm::StringMap<uint32_t> specConstIDMap;
|
||||
|
||||
/// Map from GlobalVariableOps name to <id>s.
|
||||
llvm::StringMap<uint32_t> globalVarIDMap;
|
||||
|
||||
/// Map from FuncOps name to <id>s.
|
||||
llvm::StringMap<uint32_t> funcIDMap;
|
||||
|
||||
/// Map from GlobalVariableOps name to <id>s
|
||||
llvm::StringMap<uint32_t> globalVarIDMap;
|
||||
|
||||
/// Map from results of normal operations to their <id>s
|
||||
/// Map from results of normal operations to their <id>s.
|
||||
DenseMap<Value *, uint32_t> valueIDMap;
|
||||
};
|
||||
} // namespace
|
||||
|
@ -347,14 +365,22 @@ LogicalResult Serializer::processMemoryModel() {
|
|||
}
|
||||
|
||||
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
|
||||
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
|
||||
op.is_spec_const())) {
|
||||
if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
|
||||
valueIDMap[op.getResult()] = resultID;
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
|
||||
if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
|
||||
/*isSpec=*/true)) {
|
||||
specConstIDMap[op.sym_name()] = resultID;
|
||||
return processName(resultID, op.sym_name());
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
|
||||
NamedAttribute attr) {
|
||||
auto attrName = attr.first.strref();
|
||||
|
@ -395,6 +421,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
|
|||
}
|
||||
|
||||
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
|
||||
assert(!name.empty() && "unexpected empty string for OpName");
|
||||
|
||||
SmallVector<uint32_t, 4> nameOperands;
|
||||
nameOperands.push_back(resultID);
|
||||
if (failed(encodeStringLiteralInto(nameOperands, name))) {
|
||||
|
@ -616,8 +644,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
|||
}
|
||||
operands.push_back(elementTypeID);
|
||||
if (auto elementCountID = prepareConstantInt(
|
||||
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
|
||||
/*isSpec=*/false)) {
|
||||
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
|
||||
operands.push_back(elementCountID);
|
||||
}
|
||||
return processTypeDecoration(loc, arrayType, resultID);
|
||||
|
@ -692,17 +719,10 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
||||
Attribute valueAttr, bool isSpec) {
|
||||
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
|
||||
return prepareConstantFp(loc, floatAttr, isSpec);
|
||||
Attribute valueAttr) {
|
||||
if (auto id = prepareConstantScalar(loc, valueAttr)) {
|
||||
return id;
|
||||
}
|
||||
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
|
||||
return prepareConstantInt(loc, intAttr, isSpec);
|
||||
}
|
||||
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
|
||||
return prepareConstantBool(loc, boolAttr, isSpec);
|
||||
}
|
||||
|
||||
// This is a composite literal. We need to handle each component separately
|
||||
// and then emit an OpConstantComposite for the whole.
|
||||
|
||||
|
@ -723,25 +743,21 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
|||
|
||||
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (vectorAttr.getType().getElementType().isInteger(1)) {
|
||||
if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
|
||||
operands)))
|
||||
if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
return 0;
|
||||
} else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
|
||||
operands)))
|
||||
} else if (failed(
|
||||
prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
return 0;
|
||||
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
|
||||
operands)))
|
||||
if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
return 0;
|
||||
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.reserve(arrayAttr.size() + 2);
|
||||
|
||||
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
|
||||
for (Attribute elementAttr : arrayAttr)
|
||||
if (auto elementID =
|
||||
prepareConstant(loc, elementType, elementAttr, isSpec)) {
|
||||
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return 0;
|
||||
|
@ -757,8 +773,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
|||
}
|
||||
|
||||
LogicalResult Serializer::prepareBoolVectorConstant(
|
||||
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
|
||||
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
|
||||
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
|
@ -773,15 +789,14 @@ LogicalResult Serializer::prepareBoolVectorConstant(
|
|||
// the splat value is zero.
|
||||
if (elementsAttr.isSplat()) {
|
||||
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
|
||||
if (!isSpec && !elementsAttr.getSplatValue<bool>()) {
|
||||
if (!elementsAttr.getSplatValue<bool>()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantBool(
|
||||
loc, elementsAttr.getSplatValue<BoolAttr>(), isSpec)) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
if (auto id =
|
||||
prepareConstantBool(loc, elementsAttr.getSplatValue<BoolAttr>())) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
|
@ -791,13 +806,12 @@ LogicalResult Serializer::prepareBoolVectorConstant(
|
|||
|
||||
// Otherwise, we need to process each element and compose them with
|
||||
// OpConstantComposite.
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
for (auto boolAttr : elementsAttr.getValues<BoolAttr>()) {
|
||||
// We are constructing an BoolAttr for each value here. But given that
|
||||
// we only use ElementsAttr for vectors with no more than 4 elements, it
|
||||
// should be fine here.
|
||||
if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
|
||||
if (auto elementID = prepareConstantBool(loc, boolAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
|
@ -807,8 +821,8 @@ LogicalResult Serializer::prepareBoolVectorConstant(
|
|||
}
|
||||
|
||||
LogicalResult Serializer::prepareIntVectorConstant(
|
||||
Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
|
||||
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
|
||||
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
|
@ -826,14 +840,13 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
|||
auto splatAttr = elementsAttr.getSplatValue<IntegerAttr>();
|
||||
|
||||
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
|
||||
if (!isSpec && splatAttr.getValue().isNullValue()) {
|
||||
if (splatAttr.getValue().isNullValue()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantInt(loc, splatAttr, isSpec)) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
if (auto id = prepareConstantInt(loc, splatAttr)) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
|
@ -842,15 +855,14 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
|||
|
||||
// Otherwise, we need to process each element and compose them with
|
||||
// OpConstantComposite.
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
for (auto intAttr : elementsAttr.getValues<IntegerAttr>()) {
|
||||
// We are constructing an IntegerAttr for each value here. But given that
|
||||
// we only use ElementsAttr for vectors with no more than 4 elements, it
|
||||
// should be fine here.
|
||||
// TODO(antiagainst): revisit this if special extensions enabling large
|
||||
// vectors are supported.
|
||||
if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
|
||||
if (auto elementID = prepareConstantInt(loc, intAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
|
@ -860,8 +872,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
|
|||
}
|
||||
|
||||
LogicalResult Serializer::prepareFloatVectorConstant(
|
||||
Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
|
||||
spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
|
||||
Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
|
@ -872,14 +884,13 @@ LogicalResult Serializer::prepareFloatVectorConstant(
|
|||
|
||||
if (elementsAttr.isSplat()) {
|
||||
FloatAttr splatAttr = elementsAttr.getSplatValue<FloatAttr>();
|
||||
if (!isSpec && splatAttr.getValue().isZero()) {
|
||||
if (splatAttr.getValue().isZero()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantFp(loc, splatAttr, isSpec)) {
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
if (auto id = prepareConstantFp(loc, splatAttr)) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
|
@ -887,10 +898,9 @@ LogicalResult Serializer::prepareFloatVectorConstant(
|
|||
return failure();
|
||||
}
|
||||
|
||||
opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
|
||||
: spirv::Opcode::OpConstantComposite;
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
for (auto fpAttr : elementsAttr.getValues<FloatAttr>()) {
|
||||
if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
|
||||
if (auto elementID = prepareConstantFp(loc, fpAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
|
@ -899,10 +909,28 @@ LogicalResult Serializer::prepareFloatVectorConstant(
|
|||
return success();
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
|
||||
bool isSpec) {
|
||||
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
|
||||
return prepareConstantFp(loc, floatAttr, isSpec);
|
||||
}
|
||||
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
|
||||
return prepareConstantInt(loc, intAttr, isSpec);
|
||||
}
|
||||
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
|
||||
return prepareConstantBool(loc, boolAttr, isSpec);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
|
||||
bool isSpec) {
|
||||
if (auto id = findConstantID(boolAttr)) {
|
||||
return id;
|
||||
if (!isSpec) {
|
||||
// We can de-duplicate nomral contants, but not specialization constants.
|
||||
if (auto id = findConstantID(boolAttr)) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// Process the type for this bool literal
|
||||
|
@ -919,13 +947,19 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
|
|||
: spirv::Opcode::OpConstantFalse);
|
||||
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
|
||||
|
||||
return constIDMap[boolAttr] = resultID;
|
||||
if (!isSpec) {
|
||||
constIDMap[boolAttr] = resultID;
|
||||
}
|
||||
return resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
|
||||
bool isSpec) {
|
||||
if (auto id = findConstantID(intAttr)) {
|
||||
return id;
|
||||
if (!isSpec) {
|
||||
// We can de-duplicate nomral contants, but not specialization constants.
|
||||
if (auto id = findConstantID(intAttr)) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// Process the type for this integer literal
|
||||
|
@ -972,20 +1006,26 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
|
|||
} else {
|
||||
std::string valueStr;
|
||||
llvm::raw_string_ostream rss(valueStr);
|
||||
value.print(rss, /*isSigned*/ false);
|
||||
value.print(rss, /*isSigned=*/false);
|
||||
|
||||
emitError(loc, "cannot serialize ")
|
||||
<< bitwidth << "-bit integer literal: " << rss.str();
|
||||
return 0;
|
||||
}
|
||||
|
||||
return constIDMap[intAttr] = resultID;
|
||||
if (!isSpec) {
|
||||
constIDMap[intAttr] = resultID;
|
||||
}
|
||||
return resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
|
||||
bool isSpec) {
|
||||
if (auto id = findConstantID(floatAttr)) {
|
||||
return id;
|
||||
if (!isSpec) {
|
||||
// We can de-duplicate nomral contants, but not specialization constants.
|
||||
if (auto id = findConstantID(floatAttr)) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// Process the type for this float literal
|
||||
|
@ -1025,7 +1065,10 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
|
|||
return 0;
|
||||
}
|
||||
|
||||
return constIDMap[floatAttr] = resultID;
|
||||
if (!isSpec) {
|
||||
constIDMap[floatAttr] = resultID;
|
||||
}
|
||||
return resultID;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1043,12 +1086,31 @@ LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
|
||||
auto constName = referenceOfOp.spec_const();
|
||||
auto constID = findSpecConstID(constName);
|
||||
if (!constID) {
|
||||
return referenceOfOp.emitError(
|
||||
"unknown result <id> for specialization constant ")
|
||||
<< constName;
|
||||
}
|
||||
valueIDMap[referenceOfOp.reference()] = constID;
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processOperation(Operation *op) {
|
||||
// First dispatch the methods that do not directly mirror an operation from
|
||||
// the SPIR-V spec
|
||||
if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) {
|
||||
return processConstantOp(constOp);
|
||||
}
|
||||
if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) {
|
||||
return processSpecConstantOp(specConstOp);
|
||||
}
|
||||
if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) {
|
||||
return processReferenceOfOp(refOpOp);
|
||||
}
|
||||
if (auto fnOp = dyn_cast<FuncOp>(op)) {
|
||||
return processFuncOp(fnOp);
|
||||
}
|
||||
|
|
|
@ -2,45 +2,29 @@
|
|||
|
||||
func @spirv_module() -> () {
|
||||
spv.module "Logical" "GLSL450" {
|
||||
// CHECK: spv.specConstant @sc_true = true
|
||||
spv.specConstant @sc_true = true
|
||||
// CHECK: spv.specConstant @sc_false = false
|
||||
spv.specConstant @sc_false = false
|
||||
|
||||
// CHECK: spv.constant spec true
|
||||
%0 = spv.constant spec true
|
||||
// CHECK: spv.constant spec false
|
||||
%1 = spv.constant spec false
|
||||
// CHECK: spv.specConstant @sc_int = -5 : i32
|
||||
spv.specConstant @sc_int = -5 : i32
|
||||
|
||||
// CHECK: spv.constant spec -5 : i32
|
||||
%2 = spv.constant spec -5 : i32
|
||||
// CHECK: spv.specConstant @sc_float = 1.000000e+00 : f32
|
||||
spv.specConstant @sc_float = 1. : f32
|
||||
|
||||
// CHECK: spv.constant spec 1.000000e+00 : f32
|
||||
%3 = spv.constant spec 1. : f32
|
||||
// CHECK-LABEL: @use
|
||||
func @use() -> (i32) {
|
||||
// We materialize a `spv._reference_of` op at every use of a
|
||||
// specialization constant in the deserializer. So two ops here.
|
||||
// CHECK: %[[USE1:.*]] = spv._reference_of @sc_int : i32
|
||||
// CHECK: %[[USE2:.*]] = spv._reference_of @sc_int : i32
|
||||
// CHECK: spv.IAdd %[[USE1]], %[[USE2]]
|
||||
|
||||
// Bool vector
|
||||
// CHECK: spv.constant spec dense<false> : vector<2xi1>
|
||||
%4 = spv.constant spec dense<false> : vector<2xi1>
|
||||
// CHECK: spv.constant spec dense<[true, true, true]> : vector<3xi1>
|
||||
%5 = spv.constant spec dense<true> : vector<3xi1>
|
||||
// CHECK: spv.constant spec dense<[false, true]> : vector<2xi1>
|
||||
%6 = spv.constant spec dense<[false, true]> : vector<2xi1>
|
||||
|
||||
// Integer vector
|
||||
// CHECK: spv.constant spec dense<0> : vector<2xi32>
|
||||
%7 = spv.constant spec dense<0> : vector<2xi32>
|
||||
// CHECK: spv.constant spec dense<1> : vector<3xi32>
|
||||
%8 = spv.constant spec dense<1> : vector<3xi32>
|
||||
// CHECK: spv.constant spec dense<[2, -3, 4]> : vector<3xi32>
|
||||
%9 = spv.constant spec dense<[2, -3, 4]> : vector<3xi32>
|
||||
|
||||
// Fp vector
|
||||
// CHECK: spv.constant spec dense<0.000000e+00> : vector<4xf32>
|
||||
%10 = spv.constant spec dense<0.> : vector<4xf32>
|
||||
// CHECK: spv.constant spec dense<-1.500000e+01> : vector<4xf32>
|
||||
%11 = spv.constant spec dense<-15.> : vector<4xf32>
|
||||
// CHECK: spv.constant spec dense<[7.500000e-01, -2.500000e-01, 1.000000e+01, 4.200000e+01]> : vector<4xf32>
|
||||
%12 = spv.constant spec dense<[0.75, -0.25, 10., 42.]> : vector<4xf32>
|
||||
|
||||
// Array
|
||||
// CHECK: spv.constant spec [dense<3.000000e+00> : vector<2xf32>, dense<[4.000000e+00, 5.000000e+00]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
|
||||
%13 = spv.constant spec [dense<3.0> : vector<2xf32>, dense<[4., 5.]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
|
||||
%0 = spv._reference_of @sc_int : i32
|
||||
%1 = spv.IAdd %0, %0 : i32
|
||||
spv.ReturnValue %1 : i32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
// spv.AccessChain
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @access_chain_struct() -> () {
|
||||
%0 = spv.constant 1: i32
|
||||
%1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
|
||||
|
@ -935,37 +936,71 @@ func @umod_scalar(%arg: i32) -> i32 {
|
|||
// spv.Variable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @variable_no_init(%arg0: f32) -> () {
|
||||
func @variable(%arg0: f32) -> () {
|
||||
// CHECK: spv.Variable : !spv.ptr<f32, Function>
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
func @variable_init() -> () {
|
||||
// -----
|
||||
|
||||
func @variable_init_normal_constant() -> () {
|
||||
%0 = spv.constant 4.0 : f32
|
||||
// CHECK: spv.Variable init(%0) : !spv.ptr<f32, Function>
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.globalVariable @global : !spv.ptr<f32, Workgroup>
|
||||
func @variable_init_global_variable() -> () {
|
||||
%0 = spv._address_of @global : !spv.ptr<f32, Workgroup>
|
||||
// CHECK: spv.Variable init({{.*}}) : !spv.ptr<!spv.ptr<f32, Workgroup>, Function>
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<!spv.ptr<f32, Workgroup>, Function>
|
||||
spv.Return
|
||||
}
|
||||
} attributes {
|
||||
capability = ["VariablePointers"],
|
||||
extension = ["SPV_KHR_variable_pointers"]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.specConstant @sc = 42 : i32
|
||||
// CHECK-LABEL: @variable_init_spec_constant
|
||||
func @variable_init_spec_constant() -> () {
|
||||
%0 = spv._reference_of @sc : i32
|
||||
// CHECK: spv.Variable init(%0) : !spv.ptr<i32, Function>
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @variable_bind() -> () {
|
||||
// CHECK: spv.Variable bind(1, 2) : !spv.ptr<f32, Function>
|
||||
// expected-error @+1 {{cannot have 'descriptor_set' attribute (only allowed in spv.globalVariable)}}
|
||||
%0 = spv.Variable bind(1, 2) : !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @variable_init_bind() -> () {
|
||||
%0 = spv.constant 4.0 : f32
|
||||
// CHECK: spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr<f32, Function>
|
||||
// expected-error @+1 {{cannot have 'binding' attribute (only allowed in spv.globalVariable)}}
|
||||
%1 = spv.Variable init(%0) {binding = 5 : i32} : !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @variable_builtin() -> () {
|
||||
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Function>
|
||||
// expected-error @+1 {{cannot have 'built_in' attribute (only allowed in spv.globalVariable)}}
|
||||
%1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Function>
|
||||
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Function>
|
||||
%2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr<vector<3xi32>, Function>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -980,7 +1015,7 @@ func @expect_ptr_result_type(%arg0: f32) -> () {
|
|||
// -----
|
||||
|
||||
func @variable_init(%arg0: f32) -> () {
|
||||
// expected-error @+1 {{op initializer must be the result of a spv.Constant or module-level spv.Variable op}}
|
||||
// expected-error @+1 {{op initializer must be the result of a constant or spv.globalVariable op}}
|
||||
%0 = spv.Variable init(%arg0) : !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ spv.module "Logical" "VulkanKHR" {
|
|||
spv.module "Logical" "VulkanKHR" {
|
||||
spv.globalVariable @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
|
||||
func @foo() -> () {
|
||||
// expected-error @+1{{expected spv.globalVariable symbol}}
|
||||
// expected-error @+1 {{expected spv.globalVariable symbol}}
|
||||
%0 = spv._address_of @var2 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
|
||||
}
|
||||
}
|
||||
|
@ -31,14 +31,13 @@ spv.module "Logical" "VulkanKHR" {
|
|||
spv.module "Logical" "VulkanKHR" {
|
||||
spv.globalVariable @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
|
||||
func @foo() -> () {
|
||||
// expected-error @+1{{mismatch in result type and type of global variable referenced}}
|
||||
// expected-error @+1 {{result type mismatch with the referenced global variable's type}}
|
||||
%0 = spv._address_of @var1 : !spv.ptr<f32, Input>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -183,13 +182,23 @@ spv.module "Logical" "VulkanKHR" {
|
|||
spv.globalVariable @var0 : !spv.ptr<f32, Input>
|
||||
}
|
||||
|
||||
// TODO: Fix test case after initialization with constant is addressed
|
||||
// TODO: Fix test case after initialization with normal constant is addressed
|
||||
// spv.module "Logical" "VulkanKHR" {
|
||||
// %0 = spv.constant 4.0 : f32
|
||||
// // CHECK1: spv.Variable init(%0) : !spv.ptr<f32, Private>
|
||||
// spv.globalVariable @var1 init(%0) : !spv.ptr<f32, Private>
|
||||
// }
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
spv.specConstant @sc = 4.0 : f32
|
||||
// CHECK: spv.globalVariable @var initializer(@sc) : !spv.ptr<f32, Private>
|
||||
spv.globalVariable @var initializer(@sc) : !spv.ptr<f32, Private>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
// CHECK: spv.globalVariable @var0 bind(1, 2) : !spv.ptr<f32, Uniform>
|
||||
spv.globalVariable @var0 bind(1, 2) : !spv.ptr<f32, Uniform>
|
||||
|
@ -202,6 +211,8 @@ spv.module "Logical" "VulkanKHR" {
|
|||
// spv.globalVariable @var1 initializer(%0) {binding = 5 : i32} : !spv.ptr<f32, Private>
|
||||
// }
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
// CHECK: spv.globalVariable @var1 built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
|
||||
spv.globalVariable @var1 built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
|
||||
|
@ -219,7 +230,7 @@ spv.module "Logical" "VulkanKHR" {
|
|||
// -----
|
||||
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
// expected-error @+1 {{op initializer must be result of a spv.globalVariable op}}
|
||||
// expected-error @+1 {{op initializer must be result of a spv.specConstant or spv.globalVariable op}}
|
||||
spv.globalVariable @var0 initializer(@var1) : !spv.ptr<f32, Private>
|
||||
}
|
||||
|
||||
|
@ -364,3 +375,95 @@ func @module_end_not_in_module() -> () {
|
|||
// expected-error @+1 {{op must appear in a 'spv.module' block}}
|
||||
spv._module_end
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv._reference_of
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.specConstant @sc1 = false
|
||||
spv.specConstant @sc2 = 42 : i64
|
||||
spv.specConstant @sc3 = 1.5 : f32
|
||||
|
||||
// CHECK-LABEL: @reference
|
||||
func @reference() -> i1 {
|
||||
// CHECK: spv._reference_of @sc1 : i1
|
||||
%0 = spv._reference_of @sc1 : i1
|
||||
spv.ReturnValue %0 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @initialize
|
||||
func @initialize() -> i64 {
|
||||
// CHECK: spv._reference_of @sc2 : i64
|
||||
%0 = spv._reference_of @sc2 : i64
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<i64, Function>
|
||||
%2 = spv.Load "Function" %1 : i64
|
||||
spv.ReturnValue %2 : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @compute
|
||||
func @compute() -> f32 {
|
||||
// CHECK: spv._reference_of @sc3 : f32
|
||||
%0 = spv._reference_of @sc3 : f32
|
||||
%1 = spv.constant 6.0 : f32
|
||||
%2 = spv.FAdd %0, %1 : f32
|
||||
spv.ReturnValue %2 : f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @foo() -> () {
|
||||
// expected-error @+1 {{expected spv.specConstant symbol}}
|
||||
%0 = spv._reference_of @sc : i32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.specConstant @sc = 42 : i32
|
||||
func @foo() -> () {
|
||||
// expected-error @+1 {{result type mismatch with the referenced specialization constant's type}}
|
||||
%0 = spv._reference_of @sc : f32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.specConstant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.specConstant @sc1 = false
|
||||
spv.specConstant @sc2 = 42 : i64
|
||||
spv.specConstant @sc3 = 1.5 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
// expected-error @+1 {{default value bitwidth disallowed}}
|
||||
spv.specConstant @sc = 15 : i4
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
// expected-error @+1 {{default value can only be a bool, integer, or float scalar}}
|
||||
spv.specConstant @sc = dense<[2, 3]> : vector<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @use_in_function() -> () {
|
||||
// expected-error @+1 {{op must appear in a 'spv.module' block}}
|
||||
spv.specConstant @sc = false
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue