[mlir] Fix wrong variable name in Linalg OpDSL

The name seems to have been left over from a renaming effort on an unexercised
codepaths that are difficult to catch in Python. Fix it and add a test that
exercises the codepath.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D114004
This commit is contained in:
Alex Zinenko 2021-11-16 17:09:08 +01:00
parent e852cc0d5a
commit bca003dea8
2 changed files with 17 additions and 4 deletions

View File

@ -112,7 +112,7 @@ def linalg_structured_op(dsl_func=None,
if dsl_func is None:
# Curry the keyword args in for delayed application.
return functools.partial(
tc_def_op, op_name=op_name, op_class_name=op_class_name)
linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
# Determine default names by introspecting the function.
if op_name is None:
op_name = dsl_func.__name__
@ -131,9 +131,10 @@ def linalg_structured_op(dsl_func=None,
if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)):
tc_model.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(f"@tc_def_op function parameters must be defaulted as "
f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
f"Found {param_name}: {param_default}")
raise ValueError(
f"@linalg_structured_op function parameters must be defaulted as "
f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
f"Found {param_name}: {param_default}")
dsl_func_args.append(param_default)
# Invoke the DSL func to finish populating the model.

View File

@ -126,6 +126,11 @@ def soft_plus_poly(
PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n])))
@linalg_structured_op(op_name="custom_op_name")
def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
O[D.n] = I[D.n]
with Context() as ctx, Location.unknown():
module = Module.create()
f16 = F16Type.get()
@ -392,5 +397,12 @@ with Context() as ctx, Location.unknown():
def test_f32_soft_plus(input, init_result):
return soft_plus_poly(input, outs=[init_result])
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
@builtin.FuncOp.from_py_func(
RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32))
def test_non_default_op_name(input, init_result):
return non_default_op_name(input, outs=[init_result])
print(module)