diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 1acae7a7a389..a65350ccd430 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -112,7 +112,7 @@ def linalg_structured_op(dsl_func=None, if dsl_func is None: # Curry the keyword args in for delayed application. return functools.partial( - tc_def_op, op_name=op_name, op_class_name=op_class_name) + linalg_structured_op, op_name=op_name, op_class_name=op_class_name) # Determine default names by introspecting the function. if op_name is None: op_name = dsl_func.__name__ @@ -131,9 +131,10 @@ def linalg_structured_op(dsl_func=None, if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)): tc_model.add_operand(param_name, param_default.operand_def) else: - raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " - f"Found {param_name}: {param_default}") + raise ValueError( + f"@linalg_structured_op function parameters must be defaulted as " + f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " + f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) # Invoke the DSL func to finish populating the model. diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py index 71dc8a5474aa..d0c74270950e 100644 --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -126,6 +126,11 @@ def soft_plus_poly( PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n]))) +@linalg_structured_op(op_name="custom_op_name") +def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)): + O[D.n] = I[D.n] + + with Context() as ctx, Location.unknown(): module = Module.create() f16 = F16Type.get() @@ -392,5 +397,12 @@ with Context() as ctx, Location.unknown(): def test_f32_soft_plus(input, init_result): return soft_plus_poly(input, outs=[init_result]) + # Just check that we don't assert out on name mismatch. + # CHECK-LABEL: @test_non_default_op_name + @builtin.FuncOp.from_py_func( + RankedTensorType.get((42,), f32), RankedTensorType.get((42,), f32)) + def test_non_default_op_name(input, init_result): + return non_default_op_name(input, outs=[init_result]) + print(module)