[spirv] Fix gen_spirv_dialect.py and add spv.Unreachable

This CL fixed gen_spirv_dialect.py to support nested delimiters when
chunking existing ODS entries in .td files and to allow ops without
correspondence in the spec. This is needed to pull in the definition
of OpUnreachable.

PiperOrigin-RevId: 277486465
This commit is contained in:
Lei Zhang 2019-10-30 05:40:47 -07:00 committed by A. Unique TensorFlower
parent f3efb60ccc
commit 80213ba5f0
6 changed files with 137 additions and 7 deletions

View File

@ -174,6 +174,7 @@ def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPV_OpcodeAttr :
@ -209,7 +210,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi,
SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
SPV_OC_OpModuleProcessed
SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";

View File

@ -373,6 +373,29 @@ def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
// -----
def SPV_UnreachableOp : SPV_Op<"Unreachable", [InFunctionScope, Terminator]> {
let summary = "Declares that this block is not reachable in the CFG.";
let description = [{
This instruction must be the last instruction in a block.
### Custom assembly form
``` {.ebnf}
unreachable-op ::= `spv.Unreachable`
```
}];
let arguments = (ins);
let results = (outs);
let parser = [{ return parseNoIOOp(parser, result); }];
let printer = [{ printNoIOOp(getOperation(), p); }];
}
// -----
def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
let summary = "Return a value from a function.";

View File

@ -2285,6 +2285,26 @@ static void print(spirv::UndefOp undefOp, OpAsmPrinter &printer) {
printer << spirv::UndefOp::getOperationName() << " : " << undefOp.getType();
}
//===----------------------------------------------------------------------===//
// spv.Unreachable
//===----------------------------------------------------------------------===//
static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
auto *op = unreachableOp.getOperation();
auto *block = op->getBlock();
// Fast track: if this is in entry block, its invalid. Otherwise, if no
// predecessors, it's valid.
if (block->isEntryBlock())
return unreachableOp.emitOpError("cannot be used in reachable block");
if (block->hasNoPredecessors())
return success();
// TODO(antiagainst): further verification needs to analyze reachablility from
// the entry block.
return success();
}
//===----------------------------------------------------------------------===//
// spv.Variable
//===----------------------------------------------------------------------===//

View File

@ -14,4 +14,14 @@ spv.module "Logical" "GLSL450" {
// CHECK: spv.ReturnValue {{.*}} : i32
spv.ReturnValue %1 : i32
}
// CHECK-LABEL: @unreachable
func @unreachable() {
spv.Return
// CHECK-NOT: ^bb
^bb1:
// Unreachable blocks will be dropped during serialization.
// CHECK-NOT: spv.Unreachable
spv.Unreachable
}
}

View File

@ -676,3 +676,37 @@ func @missing_entry_block() -> () {
}
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.Unreachable
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @unreachable_no_pred
func @unreachable_no_pred() {
spv.Return
^next:
// CHECK: spv.Unreachable
spv.Unreachable
}
// CHECK-LABEL: func @unreachable_with_pred
func @unreachable_with_pred() {
spv.Return
^parent:
spv.Branch ^unreachable
^unreachable:
// CHECK: spv.Unreachable
spv.Unreachable
}
// -----
func @unreachable() {
// expected-error @+1 {{cannot be used in reachable block}}
spv.Unreachable
}

View File

@ -505,6 +505,42 @@ def get_string_between(base, start, end):
return '', split[0]
def get_string_between_nested(base, start, end):
"""Extracts a substring with a nested start and end from a string.
Arguments:
- base: string to extract from.
- start: string to use as the start of the substring.
- end: string to use as the end of the substring.
Returns:
- The substring if found
- The part of the base after end of the substring. Is the base string itself
if the substring wasnt found.
"""
split = base.split(start, 1)
if len(split) == 2:
# Handle nesting delimiters
rest = split[1]
unmatched_start = 1
index = 0
while unmatched_start > 0 and index < len(rest):
if rest[index:].startswith(end):
unmatched_start -= 1
index += len(end)
elif rest[index:].startswith(start):
unmatched_start += 1
index += len(start)
else:
index += 1
assert index < len(rest), \
'cannot find end "{end}" while extracting substring '\
'starting with "{start}"'.format(start=start, end=end)
return rest[:index - len(end)].rstrip(end), rest[index:]
return '', split[0]
def extract_td_op_info(op_def):
"""Extracts potentially manually specified sections in op's definition.
@ -528,7 +564,7 @@ def extract_td_op_info(op_def):
inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
# Get category_args
op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0]
op_tmpl_params = get_string_between_nested(op_def, '<', '>')[0]
opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
category_args = rest.split('[', 1)[0]
@ -587,10 +623,12 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
# For each existing op, extract the manually-written sections out to retain
# them when re-generating the ops. Also append the existing ops to filter
# list.
name_op_map = {} # Map from opname to its existing ODS definition
op_info_dict = {}
for op in ops:
info_dict = extract_td_op_info(op)
opname = info_dict['opname']
name_op_map[opname] = op
op_info_dict[opname] = info_dict
filter_list.append(opname)
filter_list = sorted(list(set(filter_list)))
@ -598,11 +636,15 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
op_defs = []
for opname in filter_list:
# Find the grammar spec for this op
instruction = next(
inst for inst in instructions if inst['opname'] == opname)
op_defs.append(
get_op_definition(instruction, docs[opname],
op_info_dict.get(opname, {})))
try:
instruction = next(
inst for inst in instructions if inst['opname'] == opname)
op_defs.append(
get_op_definition(instruction, docs[opname],
op_info_dict.get(opname, {})))
except StopIteration:
# This is an op added by us; use the existing ODS definition.
op_defs.append(name_op_map[opname])
# Substitute the old op definitions
op_defs = [header] + op_defs + [footer]