forked from OSchip/llvm-project
[spirv] Add spv.Variable
This is a direct modelling of SPIR-V's OpVariable. The custom assembly format parsers/prints descriptor in a nicer way if presents. There are other common decorations that can appear on variables like builtin, which can be supported later. This CL additionally deduplicates the parser/printer/verifier declaration in op definitions by adding defaults to SPV_Op base. by adding PiperOrigin-RevId: 253828254
This commit is contained in:
parent
847e15e3c2
commit
23962b0d63
|
@ -43,10 +43,14 @@ Compared to the binary format, we adjust how certain module-level SPIR-V
|
|||
instructions are represented in the SPIR-V dialect. Notably,
|
||||
|
||||
* Requirements for capabilities, extensions, extended instruction sets,
|
||||
addressing model, and memory model is conveyed using op attributes.
|
||||
addressing model, and memory model is conveyed using `spv.module` attributes.
|
||||
This is considered better because these information are for the
|
||||
exexcution environment. It's eaiser to probe them if on the module op
|
||||
itself.
|
||||
* Annotations/decoration instrutions are "folded" into the instructions they
|
||||
decorate and represented as attributes on those ops. This elimiates potential
|
||||
forward references of SSA values, improves IR readability, and makes
|
||||
querying the annotations more direct.
|
||||
* Various constant instructions are represented by the same `spv.constant`
|
||||
op. Those instructions are just for constants of different types; using one
|
||||
op to represent them reduces IR verbosity and makes transformations less
|
||||
|
|
|
@ -171,6 +171,17 @@ def SPV_StorageClassAttr :
|
|||
|
||||
// Base class for all SPIR-V ops.
|
||||
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<SPV_Dialect, mnemonic, traits>;
|
||||
Op<SPV_Dialect, mnemonic, traits> {
|
||||
// For each SPIR-V op, the following static functions need to be defined
|
||||
// in SPVOps.cpp:
|
||||
//
|
||||
// * static ParseResult parse<op-c++-class-name>(OpAsmParser *parser,
|
||||
// OperationState *result)
|
||||
// * static void print(OpAsmPrinter *p, <op-c++-class-name> op)
|
||||
// * static LogicalResult verify(<op-c++-class-name> op)
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
let printer = [{ return ::print(*this, p); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // SPIRV_BASE
|
||||
|
|
|
@ -58,6 +58,9 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
|
|||
|
||||
let parser = [{ return impl::parseBinaryOp(parser, result); }];
|
||||
let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
|
||||
|
||||
// No additional verification needed in addition to the ODS-generated ones.
|
||||
let verifier = [{ return success(); }];
|
||||
}
|
||||
|
||||
def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
|
||||
|
@ -77,4 +80,55 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
|
|||
let verifier = [{ return verifyReturn(*this); }];
|
||||
}
|
||||
|
||||
def SPV_VariableOp : SPV_Op<"Variable"> {
|
||||
let summary = [{
|
||||
Allocate an object in memory, resulting in a pointer to it, which can be
|
||||
used with OpLoad and OpStore
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type must be an OpTypePointer. Its Type operand is the type of object
|
||||
in memory.
|
||||
|
||||
Storage Class is the Storage Class of the memory holding the object. It
|
||||
cannot be Generic. It must be the same as the Storage Class operand of the
|
||||
Result Type.
|
||||
|
||||
Initializer is optional. If Initializer is present, it will be the initial
|
||||
value of the variable’s memory content. Initializer must be an <id> from a
|
||||
constant instruction or a global (module scope) OpVariable instruction.
|
||||
Initializer must have the same type as the type pointed to by Result Type.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)?
|
||||
(`bind(` integer-literal, integer-literal `)`)?
|
||||
attribute-dict? `:` spirv-pointer-type
|
||||
```
|
||||
|
||||
where `init` specifies initializer and `bind` specifies the descriptor set
|
||||
and binding number.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
%0 = spv.constant ...
|
||||
|
||||
%1 = spv.Variable : !spv.ptr<f32, Function>
|
||||
%2 = spv.Variable init(%0): !spv.ptr<f32, Private>
|
||||
%3 = spv.Variable init(%0) bind(1, 2): !spv.ptr<f32, Uniform>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPV_StorageClassAttr:$storage_class,
|
||||
SPV_Optional<AnyType>:$initializer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_AnyPtr:$pointer
|
||||
);
|
||||
}
|
||||
|
||||
#endif // SPIRV_OPS
|
||||
|
|
|
@ -61,12 +61,6 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
|
|||
let results = (outs);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
// Custom parser and printer implemented by static functions in SPVOps.cpp
|
||||
let parser = [{ return parseModule(parser, result); }];
|
||||
let printer = [{ printModule(*this, p); }];
|
||||
|
||||
let verifier = [{ return verifyModule(*this); }];
|
||||
}
|
||||
|
||||
def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator]> {
|
||||
|
@ -131,11 +125,6 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
|
|||
let results = (outs
|
||||
SPV_Type:$constant
|
||||
);
|
||||
|
||||
let parser = [{ return parseConstant(parser, result); }];
|
||||
let printer = [{ printConstant(*this, p); }];
|
||||
|
||||
let verifier = [{ return verifyConstant(*this); }];
|
||||
}
|
||||
|
||||
#endif // SPIRV_STRUCTURE_OPS
|
||||
|
|
|
@ -72,6 +72,7 @@ public:
|
|||
Type getPointeeType();
|
||||
|
||||
StorageClass getStorageClass();
|
||||
StringRef getStorageClassStr();
|
||||
};
|
||||
|
||||
// SPIR-V run-time array type
|
||||
|
|
|
@ -28,6 +28,9 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
static constexpr const char kBindingAttrName[] = "binding";
|
||||
static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
|
||||
static constexpr const char kStorageClassAttrName[] = "storage_class";
|
||||
static constexpr const char kValueAttrName[] = "value";
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -59,7 +62,7 @@ static LogicalResult verifyModuleOnly(Operation *op) {
|
|||
// spv.constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseConstant(OpAsmParser *parser, OperationState *state) {
|
||||
static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
|
||||
Attribute value;
|
||||
if (parser->parseAttribute(value, kValueAttrName, state->attributes))
|
||||
return failure();
|
||||
|
@ -75,12 +78,12 @@ static ParseResult parseConstant(OpAsmParser *parser, OperationState *state) {
|
|||
return parser->addTypeToList(type, state->types);
|
||||
}
|
||||
|
||||
static void printConstant(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
|
||||
*printer << spirv::ConstantOp::getOperationName() << " " << constOp.value()
|
||||
<< " : " << constOp.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verifyConstant(spirv::ConstantOp constOp) {
|
||||
static LogicalResult verify(spirv::ConstantOp constOp) {
|
||||
auto opType = constOp.getType();
|
||||
auto value = constOp.value();
|
||||
auto valueType = value.getType();
|
||||
|
@ -136,7 +139,7 @@ static void ensureModuleEnd(Region *region, Builder builder, Location loc) {
|
|||
block.push_back(Operation::create(state));
|
||||
}
|
||||
|
||||
static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
|
||||
static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
|
||||
Region *body = state->addRegion();
|
||||
|
||||
if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
|
||||
|
@ -149,7 +152,7 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static void printModule(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
|
||||
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
|
||||
auto *op = moduleOp.getOperation();
|
||||
*printer << spirv::ModuleOp::getOperationName();
|
||||
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
|
||||
|
@ -158,7 +161,7 @@ static void printModule(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
|
|||
printer->printOptionalAttrDict(op->getAttrs());
|
||||
}
|
||||
|
||||
static LogicalResult verifyModule(spirv::ModuleOp moduleOp) {
|
||||
static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
||||
auto &op = *moduleOp.getOperation();
|
||||
auto *dialect = op.getDialect();
|
||||
auto &body = op.getRegion(0).front();
|
||||
|
@ -207,6 +210,123 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Variable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
|
||||
// Parse optional initializer
|
||||
Optional<OpAsmParser::OperandType> initInfo;
|
||||
if (succeeded(parser->parseOptionalKeyword("init"))) {
|
||||
initInfo = OpAsmParser::OperandType();
|
||||
if (parser->parseLParen() || parser->parseOperand(*initInfo) ||
|
||||
parser->parseRParen())
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Parse optional descriptor binding
|
||||
Attribute set, binding;
|
||||
if (succeeded(parser->parseOptionalKeyword("bind"))) {
|
||||
Type i32Type = parser->getBuilder().getIntegerType(32);
|
||||
if (parser->parseLParen() ||
|
||||
parser->parseAttribute(set, i32Type, kDescriptorSetAttrName,
|
||||
state->attributes) ||
|
||||
parser->parseComma() ||
|
||||
parser->parseAttribute(binding, i32Type, kBindingAttrName,
|
||||
state->attributes) ||
|
||||
parser->parseRParen())
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Parse other attributes
|
||||
if (parser->parseOptionalAttributeDict(state->attributes))
|
||||
return failure();
|
||||
|
||||
// Parse result pointer type
|
||||
Type type;
|
||||
if (parser->parseColon())
|
||||
return failure();
|
||||
auto loc = parser->getCurrentLocation();
|
||||
if (parser->parseType(type))
|
||||
return failure();
|
||||
|
||||
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
||||
if (!ptrType)
|
||||
return parser->emitError(loc, "expected spv.ptr type");
|
||||
state->addTypes(ptrType);
|
||||
|
||||
// Resolve the initializer operand
|
||||
SmallVector<Value *, 1> init;
|
||||
if (initInfo) {
|
||||
if (parser->resolveOperand(*initInfo, ptrType.getPointeeType(), init))
|
||||
return failure();
|
||||
state->addOperands(init);
|
||||
}
|
||||
|
||||
// TODO(antiagainst): The enum attribute should be integer backed so we don't
|
||||
// have these excessive string conversions.
|
||||
auto attr = parser->getBuilder().getStringAttr(ptrType.getStorageClassStr());
|
||||
state->addAttribute(kStorageClassAttrName, attr);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
|
||||
auto *op = varOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs{kStorageClassAttrName};
|
||||
*printer << spirv::VariableOp::getOperationName();
|
||||
|
||||
// Print optional initializer
|
||||
if (op->getNumOperands() > 0) {
|
||||
*printer << " init(";
|
||||
printer->printOperands(varOp.initializer());
|
||||
*printer << ")";
|
||||
}
|
||||
|
||||
// Print optional descriptor binding
|
||||
auto set = varOp.getAttr(kDescriptorSetAttrName);
|
||||
auto binding = varOp.getAttr(kBindingAttrName);
|
||||
if (set && binding) {
|
||||
elidedAttrs.push_back(kDescriptorSetAttrName);
|
||||
elidedAttrs.push_back(kBindingAttrName);
|
||||
*printer << " bind(" << set << ", " << binding << ")";
|
||||
}
|
||||
|
||||
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
*printer << " : " << varOp.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::VariableOp varOp) {
|
||||
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
|
||||
// object. It cannot be Generic. It must be the same as the Storage Class
|
||||
// operand of the Result Type."
|
||||
if (varOp.storage_class() == "Generic")
|
||||
return varOp.emitOpError("storage class cannot be 'Generic'");
|
||||
|
||||
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
|
||||
if (varOp.storage_class() != pointerType.getStorageClassStr())
|
||||
return varOp.emitOpError(
|
||||
"storage class must match result pointer's storage class");
|
||||
|
||||
if (varOp.getNumOperands() != 0) {
|
||||
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
|
||||
// a global (module scope) OpVariable instruction".
|
||||
bool valid = false;
|
||||
if (auto *initOp = varOp.getOperand(0)->getDefiningOp()) {
|
||||
if (llvm::isa<spirv::ConstantOp>(initOp)) {
|
||||
valid = true;
|
||||
} else if (llvm::isa<spirv::VariableOp>(initOp)) {
|
||||
valid = llvm::isa_and_nonnull<spirv::ModuleOp>(initOp->getParentOp());
|
||||
}
|
||||
}
|
||||
if (!valid)
|
||||
return varOp.emitOpError("initializer must be the result of a "
|
||||
"spv.Constant or module-level spv.Variable op");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace spirv {
|
||||
|
||||
|
|
|
@ -101,6 +101,10 @@ StorageClass PointerType::getStorageClass() {
|
|||
return getImpl()->getStorageClass();
|
||||
}
|
||||
|
||||
StringRef PointerType::getStorageClassStr() {
|
||||
return stringifyStorageClass(getStorageClass());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RuntimeArrayType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -65,3 +65,68 @@ func @return_mismatch_func_signature() -> () {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Variable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @variable_no_init(%arg0: f32) -> () {
|
||||
// CHECK: spv.Variable : !spv.ptr<f32, Function>
|
||||
%0 = spv.Variable : !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
func @variable_init() -> () {
|
||||
%0 = spv.constant 4.0 : f32
|
||||
// CHECK: spv.Variable init(%0) : !spv.ptr<f32, Private>
|
||||
%1 = spv.Variable init(%0) : !spv.ptr<f32, Private>
|
||||
return
|
||||
}
|
||||
|
||||
func @variable_bind() -> () {
|
||||
// CHECK: spv.Variable bind(1, 2) : !spv.ptr<f32, Uniform>
|
||||
%0 = spv.Variable bind(1, 2) : !spv.ptr<f32, Uniform>
|
||||
return
|
||||
}
|
||||
|
||||
func @variable_init_bind() -> () {
|
||||
%0 = spv.constant 4.0 : f32
|
||||
// CHECK: spv.Variable init(%0) {binding: 5 : i32} : !spv.ptr<f32, Private>
|
||||
%1 = spv.Variable init(%0) {binding: 5 : i32} : !spv.ptr<f32, Private>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @expect_ptr_result_type(%arg0: f32) -> () {
|
||||
// expected-error @+1 {{expected spv.ptr type}}
|
||||
%0 = spv.Variable : f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @variable_init(%arg0: f32) -> () {
|
||||
// expected-error @+1 {{op initializer must be the result of a spv.Constant or module-level spv.Variable op}}
|
||||
%0 = spv.Variable init(%arg0) : !spv.ptr<f32, Private>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @storage_class_mismatch() -> () {
|
||||
%0 = spv.constant 5.0 : f32
|
||||
// expected-error @+1 {{storage class must match result pointer's storage class}}
|
||||
%1 = "spv.Variable"(%0) {storage_class : "Uniform"} : (f32) -> !spv.ptr<f32, Function>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cannot_be_generic_storage_class(%arg0: f32) -> () {
|
||||
// expected-error @+1 {{storage class cannot be 'Generic'}}
|
||||
%0 = spv.Variable : !spv.ptr<f32, Generic>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue