From 23962b0d634601416a8bd1a0da9b825d214bc5f4 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 18 Jun 2019 11:15:55 -0700 Subject: [PATCH] [spirv] Add spv.Variable This is a direct modelling of SPIR-V's OpVariable. The custom assembly format parsers/prints descriptor in a nicer way if presents. There are other common decorations that can appear on variables like builtin, which can be supported later. This CL additionally deduplicates the parser/printer/verifier declaration in op definitions by adding defaults to SPV_Op base. by adding PiperOrigin-RevId: 253828254 --- mlir/g3doc/Dialects/SPIR-V.md | 6 +- mlir/include/mlir/SPIRV/SPIRVBase.td | 13 +- mlir/include/mlir/SPIRV/SPIRVOps.td | 54 ++++++++ mlir/include/mlir/SPIRV/SPIRVStructureOps.td | 11 -- mlir/include/mlir/SPIRV/SPIRVTypes.h | 1 + mlir/lib/SPIRV/SPIRVOps.cpp | 132 ++++++++++++++++++- mlir/lib/SPIRV/SPIRVTypes.cpp | 4 + mlir/test/SPIRV/ops.mlir | 65 +++++++++ 8 files changed, 267 insertions(+), 19 deletions(-) diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index 686646d6a456..58400ef58bbc 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -43,10 +43,14 @@ Compared to the binary format, we adjust how certain module-level SPIR-V instructions are represented in the SPIR-V dialect. Notably, * Requirements for capabilities, extensions, extended instruction sets, - addressing model, and memory model is conveyed using op attributes. + addressing model, and memory model is conveyed using `spv.module` attributes. This is considered better because these information are for the exexcution environment. It's eaiser to probe them if on the module op itself. +* Annotations/decoration instrutions are "folded" into the instructions they + decorate and represented as attributes on those ops. This elimiates potential + forward references of SSA values, improves IR readability, and makes + querying the annotations more direct. * Various constant instructions are represented by the same `spv.constant` op. Those instructions are just for constants of different types; using one op to represent them reduces IR verbosity and makes transformations less diff --git a/mlir/include/mlir/SPIRV/SPIRVBase.td b/mlir/include/mlir/SPIRV/SPIRVBase.td index 64c692f2ce90..50ac64af2f01 100644 --- a/mlir/include/mlir/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/SPIRV/SPIRVBase.td @@ -171,6 +171,17 @@ def SPV_StorageClassAttr : // Base class for all SPIR-V ops. class SPV_Op traits = []> : - Op; + Op { + // For each SPIR-V op, the following static functions need to be defined + // in SPVOps.cpp: + // + // * static ParseResult parse(OpAsmParser *parser, + // OperationState *result) + // * static void print(OpAsmPrinter *p, op) + // * static LogicalResult verify( op) + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(*this, p); }]; + let verifier = [{ return ::verify(*this); }]; +} #endif // SPIRV_BASE diff --git a/mlir/include/mlir/SPIRV/SPIRVOps.td b/mlir/include/mlir/SPIRV/SPIRVOps.td index a58a17968ab3..9fa9e176777a 100644 --- a/mlir/include/mlir/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVOps.td @@ -58,6 +58,9 @@ def SPV_FMulOp : SPV_Op<"FMul", [NoSideEffect, SameOperandsAndResultType]> { let parser = [{ return impl::parseBinaryOp(parser, result); }]; let printer = [{ return impl::printBinaryOp(getOperation(), p); }]; + + // No additional verification needed in addition to the ODS-generated ones. + let verifier = [{ return success(); }]; } def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { @@ -77,4 +80,55 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> { let verifier = [{ return verifyReturn(*this); }]; } +def SPV_VariableOp : SPV_Op<"Variable"> { + let summary = [{ + Allocate an object in memory, resulting in a pointer to it, which can be + used with OpLoad and OpStore + }]; + + let description = [{ + Result Type must be an OpTypePointer. Its Type operand is the type of object + in memory. + + Storage Class is the Storage Class of the memory holding the object. It + cannot be Generic. It must be the same as the Storage Class operand of the + Result Type. + + Initializer is optional. If Initializer is present, it will be the initial + value of the variable’s memory content. Initializer must be an from a + constant instruction or a global (module scope) OpVariable instruction. + Initializer must have the same type as the type pointed to by Result Type. + + ### Custom assembly form + + ``` {.ebnf} + variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)? + (`bind(` integer-literal, integer-literal `)`)? + attribute-dict? `:` spirv-pointer-type + ``` + + where `init` specifies initializer and `bind` specifies the descriptor set + and binding number. + + For example: + + ``` + %0 = spv.constant ... + + %1 = spv.Variable : !spv.ptr + %2 = spv.Variable init(%0): !spv.ptr + %3 = spv.Variable init(%0) bind(1, 2): !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_StorageClassAttr:$storage_class, + SPV_Optional:$initializer + ); + + let results = (outs + SPV_AnyPtr:$pointer + ); +} + #endif // SPIRV_OPS diff --git a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td index 2c61dfd6489f..16faf49434af 100644 --- a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td @@ -61,12 +61,6 @@ def SPV_ModuleOp : SPV_Op<"module", []> { let results = (outs); let regions = (region SizedRegion<1>:$body); - - // Custom parser and printer implemented by static functions in SPVOps.cpp - let parser = [{ return parseModule(parser, result); }]; - let printer = [{ printModule(*this, p); }]; - - let verifier = [{ return verifyModule(*this); }]; } def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator]> { @@ -131,11 +125,6 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { let results = (outs SPV_Type:$constant ); - - let parser = [{ return parseConstant(parser, result); }]; - let printer = [{ printConstant(*this, p); }]; - - let verifier = [{ return verifyConstant(*this); }]; } #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/include/mlir/SPIRV/SPIRVTypes.h b/mlir/include/mlir/SPIRV/SPIRVTypes.h index 03fecd974dc7..ddab2de84941 100644 --- a/mlir/include/mlir/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/SPIRV/SPIRVTypes.h @@ -72,6 +72,7 @@ public: Type getPointeeType(); StorageClass getStorageClass(); + StringRef getStorageClassStr(); }; // SPIR-V run-time array type diff --git a/mlir/lib/SPIRV/SPIRVOps.cpp b/mlir/lib/SPIRV/SPIRVOps.cpp index 9c4fa4dc60a5..5d9602df998f 100644 --- a/mlir/lib/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/SPIRV/SPIRVOps.cpp @@ -28,6 +28,9 @@ using namespace mlir; +static constexpr const char kBindingAttrName[] = "binding"; +static constexpr const char kDescriptorSetAttrName[] = "descriptor_set"; +static constexpr const char kStorageClassAttrName[] = "storage_class"; static constexpr const char kValueAttrName[] = "value"; //===----------------------------------------------------------------------===// @@ -59,7 +62,7 @@ static LogicalResult verifyModuleOnly(Operation *op) { // spv.constant //===----------------------------------------------------------------------===// -static ParseResult parseConstant(OpAsmParser *parser, OperationState *state) { +static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) { Attribute value; if (parser->parseAttribute(value, kValueAttrName, state->attributes)) return failure(); @@ -75,12 +78,12 @@ static ParseResult parseConstant(OpAsmParser *parser, OperationState *state) { return parser->addTypeToList(type, state->types); } -static void printConstant(spirv::ConstantOp constOp, OpAsmPrinter *printer) { +static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) { *printer << spirv::ConstantOp::getOperationName() << " " << constOp.value() << " : " << constOp.getType(); } -static LogicalResult verifyConstant(spirv::ConstantOp constOp) { +static LogicalResult verify(spirv::ConstantOp constOp) { auto opType = constOp.getType(); auto value = constOp.value(); auto valueType = value.getType(); @@ -136,7 +139,7 @@ static void ensureModuleEnd(Region *region, Builder builder, Location loc) { block.push_back(Operation::create(state)); } -static ParseResult parseModule(OpAsmParser *parser, OperationState *state) { +static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) { Region *body = state->addRegion(); if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || @@ -149,7 +152,7 @@ static ParseResult parseModule(OpAsmParser *parser, OperationState *state) { return success(); } -static void printModule(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) { +static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) { auto *op = moduleOp.getOperation(); *printer << spirv::ModuleOp::getOperationName(); printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, @@ -158,7 +161,7 @@ static void printModule(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) { printer->printOptionalAttrDict(op->getAttrs()); } -static LogicalResult verifyModule(spirv::ModuleOp moduleOp) { +static LogicalResult verify(spirv::ModuleOp moduleOp) { auto &op = *moduleOp.getOperation(); auto *dialect = op.getDialect(); auto &body = op.getRegion(0).front(); @@ -207,6 +210,123 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.Variable +//===----------------------------------------------------------------------===// + +static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { + // Parse optional initializer + Optional initInfo; + if (succeeded(parser->parseOptionalKeyword("init"))) { + initInfo = OpAsmParser::OperandType(); + if (parser->parseLParen() || parser->parseOperand(*initInfo) || + parser->parseRParen()) + return failure(); + } + + // Parse optional descriptor binding + Attribute set, binding; + if (succeeded(parser->parseOptionalKeyword("bind"))) { + Type i32Type = parser->getBuilder().getIntegerType(32); + if (parser->parseLParen() || + parser->parseAttribute(set, i32Type, kDescriptorSetAttrName, + state->attributes) || + parser->parseComma() || + parser->parseAttribute(binding, i32Type, kBindingAttrName, + state->attributes) || + parser->parseRParen()) + return failure(); + } + + // Parse other attributes + if (parser->parseOptionalAttributeDict(state->attributes)) + return failure(); + + // Parse result pointer type + Type type; + if (parser->parseColon()) + return failure(); + auto loc = parser->getCurrentLocation(); + if (parser->parseType(type)) + return failure(); + + auto ptrType = type.dyn_cast(); + if (!ptrType) + return parser->emitError(loc, "expected spv.ptr type"); + state->addTypes(ptrType); + + // Resolve the initializer operand + SmallVector init; + if (initInfo) { + if (parser->resolveOperand(*initInfo, ptrType.getPointeeType(), init)) + return failure(); + state->addOperands(init); + } + + // TODO(antiagainst): The enum attribute should be integer backed so we don't + // have these excessive string conversions. + auto attr = parser->getBuilder().getStringAttr(ptrType.getStorageClassStr()); + state->addAttribute(kStorageClassAttrName, attr); + + return success(); +} + +static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { + auto *op = varOp.getOperation(); + SmallVector elidedAttrs{kStorageClassAttrName}; + *printer << spirv::VariableOp::getOperationName(); + + // Print optional initializer + if (op->getNumOperands() > 0) { + *printer << " init("; + printer->printOperands(varOp.initializer()); + *printer << ")"; + } + + // Print optional descriptor binding + auto set = varOp.getAttr(kDescriptorSetAttrName); + auto binding = varOp.getAttr(kBindingAttrName); + if (set && binding) { + elidedAttrs.push_back(kDescriptorSetAttrName); + elidedAttrs.push_back(kBindingAttrName); + *printer << " bind(" << set << ", " << binding << ")"; + } + + printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); + *printer << " : " << varOp.getType(); +} + +static LogicalResult verify(spirv::VariableOp varOp) { + // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the + // object. It cannot be Generic. It must be the same as the Storage Class + // operand of the Result Type." + if (varOp.storage_class() == "Generic") + return varOp.emitOpError("storage class cannot be 'Generic'"); + + auto pointerType = varOp.pointer()->getType().cast(); + if (varOp.storage_class() != pointerType.getStorageClassStr()) + return varOp.emitOpError( + "storage class must match result pointer's storage class"); + + if (varOp.getNumOperands() != 0) { + // SPIR-V spec: "Initializer must be an from a constant instruction or + // a global (module scope) OpVariable instruction". + bool valid = false; + if (auto *initOp = varOp.getOperand(0)->getDefiningOp()) { + if (llvm::isa(initOp)) { + valid = true; + } else if (llvm::isa(initOp)) { + valid = llvm::isa_and_nonnull(initOp->getParentOp()); + } + } + if (!valid) + return varOp.emitOpError("initializer must be the result of a " + "spv.Constant or module-level spv.Variable op"); + } + + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/lib/SPIRV/SPIRVTypes.cpp b/mlir/lib/SPIRV/SPIRVTypes.cpp index c69d2720b479..b273cdc99d5e 100644 --- a/mlir/lib/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/SPIRV/SPIRVTypes.cpp @@ -101,6 +101,10 @@ StorageClass PointerType::getStorageClass() { return getImpl()->getStorageClass(); } +StringRef PointerType::getStorageClassStr() { + return stringifyStorageClass(getStorageClass()); +} + //===----------------------------------------------------------------------===// // RuntimeArrayType //===----------------------------------------------------------------------===// diff --git a/mlir/test/SPIRV/ops.mlir b/mlir/test/SPIRV/ops.mlir index bfc5f586f515..c71afa21393c 100644 --- a/mlir/test/SPIRV/ops.mlir +++ b/mlir/test/SPIRV/ops.mlir @@ -65,3 +65,68 @@ func @return_mismatch_func_signature() -> () { } return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.Variable +//===----------------------------------------------------------------------===// + +func @variable_no_init(%arg0: f32) -> () { + // CHECK: spv.Variable : !spv.ptr + %0 = spv.Variable : !spv.ptr + return +} + +func @variable_init() -> () { + %0 = spv.constant 4.0 : f32 + // CHECK: spv.Variable init(%0) : !spv.ptr + %1 = spv.Variable init(%0) : !spv.ptr + return +} + +func @variable_bind() -> () { + // CHECK: spv.Variable bind(1, 2) : !spv.ptr + %0 = spv.Variable bind(1, 2) : !spv.ptr + return +} + +func @variable_init_bind() -> () { + %0 = spv.constant 4.0 : f32 + // CHECK: spv.Variable init(%0) {binding: 5 : i32} : !spv.ptr + %1 = spv.Variable init(%0) {binding: 5 : i32} : !spv.ptr + return +} + +// ----- + +func @expect_ptr_result_type(%arg0: f32) -> () { + // expected-error @+1 {{expected spv.ptr type}} + %0 = spv.Variable : f32 + return +} + +// ----- + +func @variable_init(%arg0: f32) -> () { + // expected-error @+1 {{op initializer must be the result of a spv.Constant or module-level spv.Variable op}} + %0 = spv.Variable init(%arg0) : !spv.ptr + return +} + +// ----- + +func @storage_class_mismatch() -> () { + %0 = spv.constant 5.0 : f32 + // expected-error @+1 {{storage class must match result pointer's storage class}} + %1 = "spv.Variable"(%0) {storage_class : "Uniform"} : (f32) -> !spv.ptr + return +} + +// ----- + +func @cannot_be_generic_storage_class(%arg0: f32) -> () { + // expected-error @+1 {{storage class cannot be 'Generic'}} + %0 = spv.Variable : !spv.ptr + return +}