[spirv] Add spv.Variable

This is a direct modelling of SPIR-V's OpVariable. The custom assembly format
parsers/prints descriptor in a nicer way if presents. There are other common
decorations that can appear on variables like builtin, which can be supported
later.

This CL additionally deduplicates the parser/printer/verifier declaration
in op definitions by adding defaults to SPV_Op base.
by adding

PiperOrigin-RevId: 253828254
This commit is contained in:
Lei Zhang 2019-06-18 11:15:55 -07:00 committed by Mehdi Amini
parent 847e15e3c2
commit 23962b0d63
8 changed files with 267 additions and 19 deletions

View File

@ -43,10 +43,14 @@ Compared to the binary format, we adjust how certain module-level SPIR-V
instructions are represented in the SPIR-V dialect. Notably,
* Requirements for capabilities, extensions, extended instruction sets,
addressing model, and memory model is conveyed using op attributes.
addressing model, and memory model is conveyed using `spv.module` attributes.
This is considered better because these information are for the
exexcution environment. It's eaiser to probe them if on the module op
itself.
* Annotations/decoration instrutions are "folded" into the instructions they
decorate and represented as attributes on those ops. This elimiates potential
forward references of SSA values, improves IR readability, and makes
querying the annotations more direct.
* Various constant instructions are represented by the same `spv.constant`
op. Those instructions are just for constants of different types; using one
op to represent them reduces IR verbosity and makes transformations less

View File

@ -171,6 +171,17 @@ def SPV_StorageClassAttr :
// Base class for all SPIR-V ops.
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
Op<SPV_Dialect, mnemonic, traits>;
Op<SPV_Dialect, mnemonic, traits> {
// For each SPIR-V op, the following static functions need to be defined
// in SPVOps.cpp:
//
// * static ParseResult parse<op-c++-class-name>(OpAsmParser *parser,
// OperationState *result)
// * static void print(OpAsmPrinter *p, <op-c++-class-name> op)
// * static LogicalResult verify(<op-c++-class-name> op)
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(*this, p); }];
let verifier = [{ return ::verify(*this); }];
}
#endif // SPIRV_BASE

View File

@ -58,6 +58,9 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> {
let parser = [{ return impl::parseBinaryOp(parser, result); }];
let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones.
let verifier = [{ return success(); }];
}
def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
@ -77,4 +80,55 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
let verifier = [{ return verifyReturn(*this); }];
}
def SPV_VariableOp : SPV_Op<"Variable"> {
let summary = [{
Allocate an object in memory, resulting in a pointer to it, which can be
used with OpLoad and OpStore
}];
let description = [{
Result Type must be an OpTypePointer. Its Type operand is the type of object
in memory.
Storage Class is the Storage Class of the memory holding the object. It
cannot be Generic. It must be the same as the Storage Class operand of the
Result Type.
Initializer is optional. If Initializer is present, it will be the initial
value of the variables memory content. Initializer must be an <id> from a
constant instruction or a global (module scope) OpVariable instruction.
Initializer must have the same type as the type pointed to by Result Type.
### Custom assembly form
``` {.ebnf}
variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)?
(`bind(` integer-literal, integer-literal `)`)?
attribute-dict? `:` spirv-pointer-type
```
where `init` specifies initializer and `bind` specifies the descriptor set
and binding number.
For example:
```
%0 = spv.constant ...
%1 = spv.Variable : !spv.ptr<f32, Function>
%2 = spv.Variable init(%0): !spv.ptr<f32, Private>
%3 = spv.Variable init(%0) bind(1, 2): !spv.ptr<f32, Uniform>
```
}];
let arguments = (ins
SPV_StorageClassAttr:$storage_class,
SPV_Optional<AnyType>:$initializer
);
let results = (outs
SPV_AnyPtr:$pointer
);
}
#endif // SPIRV_OPS

View File

@ -61,12 +61,6 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
let results = (outs);
let regions = (region SizedRegion<1>:$body);
// Custom parser and printer implemented by static functions in SPVOps.cpp
let parser = [{ return parseModule(parser, result); }];
let printer = [{ printModule(*this, p); }];
let verifier = [{ return verifyModule(*this); }];
}
def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator]> {
@ -131,11 +125,6 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
let results = (outs
SPV_Type:$constant
);
let parser = [{ return parseConstant(parser, result); }];
let printer = [{ printConstant(*this, p); }];
let verifier = [{ return verifyConstant(*this); }];
}
#endif // SPIRV_STRUCTURE_OPS

View File

@ -72,6 +72,7 @@ public:
Type getPointeeType();
StorageClass getStorageClass();
StringRef getStorageClassStr();
};
// SPIR-V run-time array type

View File

@ -28,6 +28,9 @@
using namespace mlir;
static constexpr const char kBindingAttrName[] = "binding";
static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
static constexpr const char kStorageClassAttrName[] = "storage_class";
static constexpr const char kValueAttrName[] = "value";
//===----------------------------------------------------------------------===//
@ -59,7 +62,7 @@ static LogicalResult verifyModuleOnly(Operation *op) {
// spv.constant
//===----------------------------------------------------------------------===//
static ParseResult parseConstant(OpAsmParser *parser, OperationState *state) {
static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
Attribute value;
if (parser->parseAttribute(value, kValueAttrName, state->attributes))
return failure();
@ -75,12 +78,12 @@ static ParseResult parseConstant(OpAsmParser *parser, OperationState *state) {
return parser->addTypeToList(type, state->types);
}
static void printConstant(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
*printer << spirv::ConstantOp::getOperationName() << " " << constOp.value()
<< " : " << constOp.getType();
}
static LogicalResult verifyConstant(spirv::ConstantOp constOp) {
static LogicalResult verify(spirv::ConstantOp constOp) {
auto opType = constOp.getType();
auto value = constOp.value();
auto valueType = value.getType();
@ -136,7 +139,7 @@ static void ensureModuleEnd(Region *region, Builder builder, Location loc) {
block.push_back(Operation::create(state));
}
static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
Region *body = state->addRegion();
if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
@ -149,7 +152,7 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) {
return success();
}
static void printModule(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
auto *op = moduleOp.getOperation();
*printer << spirv::ModuleOp::getOperationName();
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
@ -158,7 +161,7 @@ static void printModule(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
printer->printOptionalAttrDict(op->getAttrs());
}
static LogicalResult verifyModule(spirv::ModuleOp moduleOp) {
static LogicalResult verify(spirv::ModuleOp moduleOp) {
auto &op = *moduleOp.getOperation();
auto *dialect = op.getDialect();
auto &body = op.getRegion(0).front();
@ -207,6 +210,123 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.Variable
//===----------------------------------------------------------------------===//
static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
// Parse optional initializer
Optional<OpAsmParser::OperandType> initInfo;
if (succeeded(parser->parseOptionalKeyword("init"))) {
initInfo = OpAsmParser::OperandType();
if (parser->parseLParen() || parser->parseOperand(*initInfo) ||
parser->parseRParen())
return failure();
}
// Parse optional descriptor binding
Attribute set, binding;
if (succeeded(parser->parseOptionalKeyword("bind"))) {
Type i32Type = parser->getBuilder().getIntegerType(32);
if (parser->parseLParen() ||
parser->parseAttribute(set, i32Type, kDescriptorSetAttrName,
state->attributes) ||
parser->parseComma() ||
parser->parseAttribute(binding, i32Type, kBindingAttrName,
state->attributes) ||
parser->parseRParen())
return failure();
}
// Parse other attributes
if (parser->parseOptionalAttributeDict(state->attributes))
return failure();
// Parse result pointer type
Type type;
if (parser->parseColon())
return failure();
auto loc = parser->getCurrentLocation();
if (parser->parseType(type))
return failure();
auto ptrType = type.dyn_cast<spirv::PointerType>();
if (!ptrType)
return parser->emitError(loc, "expected spv.ptr type");
state->addTypes(ptrType);
// Resolve the initializer operand
SmallVector<Value *, 1> init;
if (initInfo) {
if (parser->resolveOperand(*initInfo, ptrType.getPointeeType(), init))
return failure();
state->addOperands(init);
}
// TODO(antiagainst): The enum attribute should be integer backed so we don't
// have these excessive string conversions.
auto attr = parser->getBuilder().getStringAttr(ptrType.getStorageClassStr());
state->addAttribute(kStorageClassAttrName, attr);
return success();
}
static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
auto *op = varOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs{kStorageClassAttrName};
*printer << spirv::VariableOp::getOperationName();
// Print optional initializer
if (op->getNumOperands() > 0) {
*printer << " init(";
printer->printOperands(varOp.initializer());
*printer << ")";
}
// Print optional descriptor binding
auto set = varOp.getAttr(kDescriptorSetAttrName);
auto binding = varOp.getAttr(kBindingAttrName);
if (set && binding) {
elidedAttrs.push_back(kDescriptorSetAttrName);
elidedAttrs.push_back(kBindingAttrName);
*printer << " bind(" << set << ", " << binding << ")";
}
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
*printer << " : " << varOp.getType();
}
static LogicalResult verify(spirv::VariableOp varOp) {
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
if (varOp.storage_class() == "Generic")
return varOp.emitOpError("storage class cannot be 'Generic'");
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
if (varOp.storage_class() != pointerType.getStorageClassStr())
return varOp.emitOpError(
"storage class must match result pointer's storage class");
if (varOp.getNumOperands() != 0) {
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
// a global (module scope) OpVariable instruction".
bool valid = false;
if (auto *initOp = varOp.getOperand(0)->getDefiningOp()) {
if (llvm::isa<spirv::ConstantOp>(initOp)) {
valid = true;
} else if (llvm::isa<spirv::VariableOp>(initOp)) {
valid = llvm::isa_and_nonnull<spirv::ModuleOp>(initOp->getParentOp());
}
}
if (!valid)
return varOp.emitOpError("initializer must be the result of a "
"spv.Constant or module-level spv.Variable op");
}
return success();
}
namespace mlir {
namespace spirv {

View File

@ -101,6 +101,10 @@ StorageClass PointerType::getStorageClass() {
return getImpl()->getStorageClass();
}
StringRef PointerType::getStorageClassStr() {
return stringifyStorageClass(getStorageClass());
}
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//

View File

@ -65,3 +65,68 @@ func @return_mismatch_func_signature() -> () {
}
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.Variable
//===----------------------------------------------------------------------===//
func @variable_no_init(%arg0: f32) -> () {
// CHECK: spv.Variable : !spv.ptr<f32, Function>
%0 = spv.Variable : !spv.ptr<f32, Function>
return
}
func @variable_init() -> () {
%0 = spv.constant 4.0 : f32
// CHECK: spv.Variable init(%0) : !spv.ptr<f32, Private>
%1 = spv.Variable init(%0) : !spv.ptr<f32, Private>
return
}
func @variable_bind() -> () {
// CHECK: spv.Variable bind(1, 2) : !spv.ptr<f32, Uniform>
%0 = spv.Variable bind(1, 2) : !spv.ptr<f32, Uniform>
return
}
func @variable_init_bind() -> () {
%0 = spv.constant 4.0 : f32
// CHECK: spv.Variable init(%0) {binding: 5 : i32} : !spv.ptr<f32, Private>
%1 = spv.Variable init(%0) {binding: 5 : i32} : !spv.ptr<f32, Private>
return
}
// -----
func @expect_ptr_result_type(%arg0: f32) -> () {
// expected-error @+1 {{expected spv.ptr type}}
%0 = spv.Variable : f32
return
}
// -----
func @variable_init(%arg0: f32) -> () {
// expected-error @+1 {{op initializer must be the result of a spv.Constant or module-level spv.Variable op}}
%0 = spv.Variable init(%arg0) : !spv.ptr<f32, Private>
return
}
// -----
func @storage_class_mismatch() -> () {
%0 = spv.constant 5.0 : f32
// expected-error @+1 {{storage class must match result pointer's storage class}}
%1 = "spv.Variable"(%0) {storage_class : "Uniform"} : (f32) -> !spv.ptr<f32, Function>
return
}
// -----
func @cannot_be_generic_storage_class(%arg0: f32) -> () {
// expected-error @+1 {{storage class cannot be 'Generic'}}
%0 = spv.Variable : !spv.ptr<f32, Generic>
return
}