forked from OSchip/llvm-project
[mlir] Fix wrong variable name in Linalg OpDSL
The name seems to have been left over from a renaming effort on an unexercised codepaths that are difficult to catch in Python. Fix it and add a test that exercises the codepath. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D114004
This commit is contained in:
parent
e852cc0d5a
commit
bca003dea8
|
@ -112,7 +112,7 @@ def linalg_structured_op(dsl_func=None,
|
|||
if dsl_func is None:
|
||||
# Curry the keyword args in for delayed application.
|
||||
return functools.partial(
|
||||
tc_def_op, op_name=op_name, op_class_name=op_class_name)
|
||||
linalg_structured_op, op_name=op_name, op_class_name=op_class_name)
|
||||
# Determine default names by introspecting the function.
|
||||
if op_name is None:
|
||||
op_name = dsl_func.__name__
|
||||
|
@ -131,9 +131,10 @@ def linalg_structured_op(dsl_func=None,
|
|||
if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)):
|
||||
tc_model.add_operand(param_name, param_default.operand_def)
|
||||
else:
|
||||
raise ValueError(f"@tc_def_op function parameters must be defaulted as "
|
||||
f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
|
||||
f"Found {param_name}: {param_default}")
|
||||
raise ValueError(
|
||||
f"@linalg_structured_op function parameters must be defaulted as "
|
||||
f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
|
||||
f"Found {param_name}: {param_default}")
|
||||
dsl_func_args.append(param_default)
|
||||
|
||||
# Invoke the DSL func to finish populating the model.
|
||||
|
|
|
@ -126,6 +126,11 @@ def soft_plus_poly(
|
|||
PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n])))
|
||||
|
||||
|
||||
@linalg_structured_op(op_name="custom_op_name")
|
||||
def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
|
||||
O[D.n] = I[D.n]
|
||||
|
||||
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f16 = F16Type.get()
|
||||
|
@ -392,5 +397,12 @@ with Context() as ctx, Location.unknown():
|
|||
def test_f32_soft_plus(input, init_result):
|
||||
return soft_plus_poly(input, outs=[init_result])
|
||||
|
||||
# Just check that we don't assert out on name mismatch.
|
||||
# CHECK-LABEL: @test_non_default_op_name
|
||||
@builtin.FuncOp.from_py_func(
|
||||
RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32))
|
||||
def test_non_default_op_name(input, init_result):
|
||||
return non_default_op_name(input, outs=[init_result])
|
||||
|
||||
|
||||
print(module)
|
||||
|
|
Loading…
Reference in New Issue