[mlir][OpDSL] Add support for adding canonicalization patterns.

Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation.

This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement.

Depends On D120725

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120726
This commit is contained in:
gysit 2022-03-08 15:56:40 +00:00
parent 8d7850705c
commit d629645fcd
9 changed files with 109 additions and 9 deletions

View File

@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
defines(Canonicalizer)
implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast_signed(
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via
Special identifying op interfaces can be declared for the op via
`implements(interface1[, interface2...])`.
Extra method definitions can be declared for the op via
`defines(definition1[, definition2...])`.
## Parameters
Structured operations take two types of runtime parameters namely scalars and

View File

@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata
the value operand, promoting it to the same data type as the output.
implements:
- LinalgFillOpInterface
defines:
- hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig

View File

@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
FoldInsertPadIntoFill>(context);
}
// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}
//===----------------------------------------------------------------------===//
// GenericOps
//===----------------------------------------------------------------------===//

View File

@ -689,6 +689,16 @@ ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
class OpDefinitionDef:
"""A method that an op implements."""
def __init__(self, def_name: str):
self.def_name = def_name
Canonicalizer = OpDefinitionDef("hasCanonicalizer")
class OpMetadataDef(YAMLObject):
"""Metadata about the op (generally not behavior impacting)."""
yaml_tag = "!LinalgOpMetadata"
@ -699,6 +709,7 @@ class OpMetadataDef(YAMLObject):
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
self.doc = doc
self.implements = [] # type: List[OpInterfaceDef]
self.defines = [] # type: List[OpDefinitionsDef]
def to_yaml_custom_dict(self):
d = dict(
@ -708,6 +719,8 @@ class OpMetadataDef(YAMLObject):
)
if self.implements:
d["implements"] = [intr.cpp_name for intr in self.implements]
if self.defines:
d["defines"] = [defi.def_name for defi in self.defines]
return d

View File

@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None,
return DefinedOpCallable(op_name, op_def)
def domain(*dimensions: DimDef):
if any(not isinstance(d, DimDef) for d in dimensions):
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
current_op_def().domain.extend(dimensions)
def implements(*interfaces: OpInterfaceDef):
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
raise ValueError(
f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
current_op_def().metadata.implements.extend(interfaces)
def domain(*dimensions: DimDef):
if current_op_def().domain:
raise ValueError(f"Expected only one set of domain dimensions per operator")
if any(not isinstance(dim, DimDef) for dim in dimensions):
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
current_op_def().domain.extend(dimensions)
def defines(*definitions: OpDefinitionDef):
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
raise ValueError(
f"Expected definitions of type OpDefinitionDef but got {definitions}")
current_op_def().metadata.defines.extend(definitions)

View File

@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
the value operand, promoting it to the same data type as the output.
"""
implements(FillOpInterface)
defines(Canonicalizer)
O[None] = TypeFn.cast_signed(U, value)

View File

@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
# IMPL-NEXT: yields.push_back([[VAL1]])
# @linalg_structured_op
# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
# """Title.
# Detailed description.
# """
# implements(FillOpInterface)
# defines(Canonicalizer)
# O[None] = TypeFn.cast(U, value)
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: test5
cpp_class_name: Test5Op
doc: |-
Title.
Detailed description.
implements:
- LinalgFillOpInterface
defines:
- hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: value
kind: scalar
type_var: T1
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: value
# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5"
# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])>
# ODS: let hasCanonicalizer = 1;

View File

@ -7,11 +7,14 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK-LABEL: matmul
# CHECK: implements:
# CHECK-NEXT: - LinalgContractionOpInterface
# CHECK: defines:
# CHECK-NEXT: - hasCanonicalizer
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
implements(ContractionOpInterface)
defines(Canonicalizer)
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
U, B[D.k, D.n])

View File

@ -53,6 +53,7 @@ struct LinalgOpMetadata {
std::string cppClassName;
Optional<std::string> doc;
SmallVector<std::string> implements;
SmallVector<std::string> defines;
};
struct SerializedAffineMap {
@ -233,6 +234,7 @@ struct MappingTraits<LinalgOpMetadata> {
io.mapRequired("cpp_class_name", info.cppClassName);
io.mapOptional("doc", info.doc);
io.mapOptional("implements", info.implements);
io.mapOptional("defines", info.defines);
}
};
@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT(
// {3}: documentation (summary + description)
// {4}: op attribute list
// {5}: builder methods taking standalone attribute parameters
// {6}: additional methods for attributes used by indexing maps
// {6}: additional method defintions
// {7}: additional methods for attributes used by indexing maps
static const char structuredOpOdsHeaderFormat[] = R"FMT(
//===----------------------------------------------------------------------===//
// Op definition for {0}
@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
{6}
let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated.
@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
// Generic methods.
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
{6}
{7}
}];
}
)FMT";
@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
std::string definitionList;
for (const std::string &definition : opConfig.metadata->defines) {
static const char definitionFmt[] = "let {0} = 1;\n";
definitionList.append(llvm::formatv(definitionFmt, definition));
}
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return isAttribute(arg.kind);
})) {
@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
os << llvm::formatv(structuredOpOdsHeaderFormat,
opConfig.metadata->cppClassName, opConfig.metadata->name,
interfaceNameList, doc, attrList, attrBuilder,
attrMethods);
definitionList, attrMethods);
return success();
}