From d629645fcdf30576b1d4dc9ea2639321c4b33eae Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 15:56:40 +0000 Subject: [PATCH] [mlir][OpDSL] Add support for adding canonicalization patterns. Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation. This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement. Depends On D120725 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120726 --- mlir/docs/Dialects/Linalg/OpDSL.md | 4 ++ .../Linalg/IR/LinalgNamedStructuredOps.yaml | 2 + mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 ++ .../linalg/opdsl/lang/comprehension.py | 13 +++++ .../mlir/dialects/linalg/opdsl/lang/dsl.py | 20 +++++-- .../linalg/opdsl/ops/core_named_ops.py | 1 + .../test-linalg-ods-yaml-gen.yaml | 55 +++++++++++++++++++ .../opdsl/{interfaces.py => metadata.py} | 3 + .../mlir-linalg-ods-yaml-gen.cpp | 16 +++++- 9 files changed, 109 insertions(+), 9 deletions(-) rename mlir/test/python/dialects/linalg/opdsl/{interfaces.py => metadata.py} (86%) diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md index d7526bf9f3ba..99136f1472f1 100644 --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K), them to the same data type as the accumulator/output. """ domain(D.m, D.n, D.k) + defines(Canonicalizer) implements(ContractionOpInterface) C[D.m, D.n] += TypeFn.cast_signed( U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n]) @@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via Special identifying op interfaces can be declared for the op via `implements(interface1[, interface2...])`. +Extra method definitions can be declared for the op via +`defines(definition1[, definition2...])`. + ## Parameters Structured operations take two types of runtime parameters namely scalars and diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 7511e268ae85..21f28cbd84c3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata the value operand, promoting it to the same data type as the output. implements: - LinalgFillOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d4e46f7619d7..53ff45a53104 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, FoldInsertPadIntoFill>(context); } +// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp. +void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 1de5449e27e3..47083de625de 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -689,6 +689,16 @@ ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") +class OpDefinitionDef: + """A method that an op implements.""" + + def __init__(self, def_name: str): + self.def_name = def_name + + +Canonicalizer = OpDefinitionDef("hasCanonicalizer") + + class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" @@ -699,6 +709,7 @@ class OpMetadataDef(YAMLObject): self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] def to_yaml_custom_dict(self): d = dict( @@ -708,6 +719,8 @@ class OpMetadataDef(YAMLObject): ) if self.implements: d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] return d diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index bd9042ac0aac..45b8d5ccd13d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None, return DefinedOpCallable(op_name, op_def) +def domain(*dimensions: DimDef): + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) + + def implements(*interfaces: OpInterfaceDef): + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}") current_op_def().metadata.implements.extend(interfaces) -def domain(*dimensions: DimDef): - if current_op_def().domain: - raise ValueError(f"Expected only one set of domain dimensions per operator") - if any(not isinstance(dim, DimDef) for dim in dimensions): - raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") - current_op_def().domain.extend(dimensions) +def defines(*definitions: OpDefinitionDef): + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}") + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7798d7f9498e..39934131cb22 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): the value operand, promoting it to the same data type as the output. """ implements(FillOpInterface) + defines(Canonicalizer) O[None] = TypeFn.cast_signed(U, value) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index 3f6c76347014..a31984764ebb 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0)) # IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0)) # IMPL-NEXT: yields.push_back([[VAL1]]) + +# @linalg_structured_op +# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)): +# """Title. + +# Detailed description. +# """ +# implements(FillOpInterface) +# defines(Canonicalizer) +# O[None] = TypeFn.cast(U, value) + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test5 + cpp_class_name: Test5Op + doc: |- + Title. + + Detailed description. + implements: + - LinalgFillOpInterface + defines: + - hasCanonicalizer +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + kind: scalar + type_var: T1 + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: type + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value + +# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5" +# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])> + +# ODS: let hasCanonicalizer = 1; diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py similarity index 86% rename from mlir/test/python/dialects/linalg/opdsl/interfaces.py rename to mlir/test/python/dialects/linalg/opdsl/metadata.py index ca9bd04cd967..a7502e9eb1aa 100644 --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py @@ -7,11 +7,14 @@ from mlir.dialects.linalg.opdsl.lang import * # CHECK-LABEL: matmul # CHECK: implements: # CHECK-NEXT: - LinalgContractionOpInterface +# CHECK: defines: +# CHECK-NEXT: - hasCanonicalizer @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) + defines(Canonicalizer) C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( U, B[D.k, D.n]) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 1cf1247262e0..5cade2a24f43 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -53,6 +53,7 @@ struct LinalgOpMetadata { std::string cppClassName; Optional doc; SmallVector implements; + SmallVector defines; }; struct SerializedAffineMap { @@ -233,6 +234,7 @@ struct MappingTraits { io.mapRequired("cpp_class_name", info.cppClassName); io.mapOptional("doc", info.doc); io.mapOptional("implements", info.implements); + io.mapOptional("defines", info.defines); } }; @@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT( // {3}: documentation (summary + description) // {4}: op attribute list // {5}: builder methods taking standalone attribute parameters -// {6}: additional methods for attributes used by indexing maps +// {6}: additional method defintions +// {7}: additional methods for attributes used by indexing maps static const char structuredOpOdsHeaderFormat[] = R"FMT( //===----------------------------------------------------------------------===// // Op definition for {0} @@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], ]; let hasCustomAssemblyFormat = 1; let hasFolder = 1; + {6} let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. @@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); - {6} + {7} }]; } )FMT"; @@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); + std::string definitionList; + for (const std::string &definition : opConfig.metadata->defines) { + static const char definitionFmt[] = "let {0} = 1;\n"; + definitionList.append(llvm::formatv(definitionFmt, definition)); + } + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return isAttribute(arg.kind); })) { @@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName, opConfig.metadata->name, interfaceNameList, doc, attrList, attrBuilder, - attrMethods); + definitionList, attrMethods); return success(); }