Fix gen_spirv_dialect.py regarding 1D/2D/3D Dim symbol name

PiperOrigin-RevId: 281131561
This commit is contained in:
Lei Zhang 2019-11-18 12:47:54 -08:00 committed by A. Unique TensorFlower
parent 6c77e59bfd
commit 1f475e316c
1 changed files with 12 additions and 2 deletions

View File

@ -152,6 +152,15 @@ def gen_operand_kind_enum_attr(operand_kind):
if 'enumerants' not in operand_kind:
return '', ''
# Returns a symbol for the given case in the given kind. This function
# handles Dim specially to avoid having numbers as the start of symbols,
# which does not play well with C++ and the MLIR parser.
def get_case_symbol(kind_name, case_name):
if kind_name == 'Dim':
if case_name == '1D' or case_name == '2D' or case_name == '3D':
return 'Dim{}'.format(case_name)
return case_name
kind_name = operand_kind['kind']
is_bit_enum = operand_kind['category'] == 'BitEnum'
kind_category = 'Bit' if is_bit_enum else 'I32'
@ -162,13 +171,14 @@ def gen_operand_kind_enum_attr(operand_kind):
max_len = max([len(symbol) for (symbol, _) in kind_cases])
# Generate the definition for each enum case
fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\
fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\
'{category}EnumAttrCase<"{symbol}", {value}>;'
case_defs = [
fmt_str.format(
category=kind_category,
acronym=kind_acronym,
symbol=case[0],
case=case[0],
symbol=get_case_symbol(kind_name, case[0]),
value=case[1],
colon=':',
offset=(max_len + 1 - len(case[0]))) for case in kind_cases