forked from OSchip/llvm-project
[mlir][OpDSL] Add support for adding canonicalization patterns.
Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation. This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement. Depends On D120725 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120726
This commit is contained in:
parent
8d7850705c
commit
d629645fcd
|
@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
|
|||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
defines(Canonicalizer)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += TypeFn.cast_signed(
|
||||
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
|
||||
|
@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via
|
|||
Special identifying op interfaces can be declared for the op via
|
||||
`implements(interface1[, interface2...])`.
|
||||
|
||||
Extra method definitions can be declared for the op via
|
||||
`defines(definition1[, definition2...])`.
|
||||
|
||||
## Parameters
|
||||
|
||||
Structured operations take two types of runtime parameters namely scalars and
|
||||
|
|
|
@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata
|
|||
the value operand, promoting it to the same data type as the output.
|
||||
implements:
|
||||
- LinalgFillOpInterface
|
||||
defines:
|
||||
- hasCanonicalizer
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
|
|
|
@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
FoldInsertPadIntoFill>(context);
|
||||
}
|
||||
|
||||
// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
|
||||
void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -689,6 +689,16 @@ ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
|
|||
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
|
||||
|
||||
|
||||
class OpDefinitionDef:
|
||||
"""A method that an op implements."""
|
||||
|
||||
def __init__(self, def_name: str):
|
||||
self.def_name = def_name
|
||||
|
||||
|
||||
Canonicalizer = OpDefinitionDef("hasCanonicalizer")
|
||||
|
||||
|
||||
class OpMetadataDef(YAMLObject):
|
||||
"""Metadata about the op (generally not behavior impacting)."""
|
||||
yaml_tag = "!LinalgOpMetadata"
|
||||
|
@ -699,6 +709,7 @@ class OpMetadataDef(YAMLObject):
|
|||
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
|
||||
self.doc = doc
|
||||
self.implements = [] # type: List[OpInterfaceDef]
|
||||
self.defines = [] # type: List[OpDefinitionsDef]
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
d = dict(
|
||||
|
@ -708,6 +719,8 @@ class OpMetadataDef(YAMLObject):
|
|||
)
|
||||
if self.implements:
|
||||
d["implements"] = [intr.cpp_name for intr in self.implements]
|
||||
if self.defines:
|
||||
d["defines"] = [defi.def_name for defi in self.defines]
|
||||
return d
|
||||
|
||||
|
||||
|
|
|
@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None,
|
|||
return DefinedOpCallable(op_name, op_def)
|
||||
|
||||
|
||||
def domain(*dimensions: DimDef):
|
||||
if any(not isinstance(d, DimDef) for d in dimensions):
|
||||
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
|
||||
current_op_def().domain.extend(dimensions)
|
||||
|
||||
|
||||
def implements(*interfaces: OpInterfaceDef):
|
||||
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
|
||||
raise ValueError(
|
||||
f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
|
||||
current_op_def().metadata.implements.extend(interfaces)
|
||||
|
||||
|
||||
def domain(*dimensions: DimDef):
|
||||
if current_op_def().domain:
|
||||
raise ValueError(f"Expected only one set of domain dimensions per operator")
|
||||
if any(not isinstance(dim, DimDef) for dim in dimensions):
|
||||
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
|
||||
current_op_def().domain.extend(dimensions)
|
||||
def defines(*definitions: OpDefinitionDef):
|
||||
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
|
||||
raise ValueError(
|
||||
f"Expected definitions of type OpDefinitionDef but got {definitions}")
|
||||
current_op_def().metadata.defines.extend(definitions)
|
||||
|
|
|
@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
|||
the value operand, promoting it to the same data type as the output.
|
||||
"""
|
||||
implements(FillOpInterface)
|
||||
defines(Canonicalizer)
|
||||
O[None] = TypeFn.cast_signed(U, value)
|
||||
|
||||
|
||||
|
|
|
@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig
|
|||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
|
||||
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
|
||||
# IMPL-NEXT: yields.push_back([[VAL1]])
|
||||
|
||||
# @linalg_structured_op
|
||||
# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
||||
# """Title.
|
||||
|
||||
# Detailed description.
|
||||
# """
|
||||
# implements(FillOpInterface)
|
||||
# defines(Canonicalizer)
|
||||
# O[None] = TypeFn.cast(U, value)
|
||||
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: test5
|
||||
cpp_class_name: Test5Op
|
||||
doc: |-
|
||||
Title.
|
||||
|
||||
Detailed description.
|
||||
implements:
|
||||
- LinalgFillOpInterface
|
||||
defines:
|
||||
- hasCanonicalizer
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: value
|
||||
kind: scalar
|
||||
type_var: T1
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<() -> ()>
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<() -> ()>
|
||||
- affine_map<() -> ()>
|
||||
iterator_types: []
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: value
|
||||
|
||||
# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5"
|
||||
# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])>
|
||||
|
||||
# ODS: let hasCanonicalizer = 1;
|
||||
|
|
|
@ -7,11 +7,14 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK-LABEL: matmul
|
||||
# CHECK: implements:
|
||||
# CHECK-NEXT: - LinalgContractionOpInterface
|
||||
# CHECK: defines:
|
||||
# CHECK-NEXT: - hasCanonicalizer
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T, S.M, S.K),
|
||||
B=TensorDef(T, S.K, S.N),
|
||||
C=TensorDef(U, S.M, S.N, output=True)):
|
||||
implements(ContractionOpInterface)
|
||||
defines(Canonicalizer)
|
||||
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.k, D.n])
|
|
@ -53,6 +53,7 @@ struct LinalgOpMetadata {
|
|||
std::string cppClassName;
|
||||
Optional<std::string> doc;
|
||||
SmallVector<std::string> implements;
|
||||
SmallVector<std::string> defines;
|
||||
};
|
||||
|
||||
struct SerializedAffineMap {
|
||||
|
@ -233,6 +234,7 @@ struct MappingTraits<LinalgOpMetadata> {
|
|||
io.mapRequired("cpp_class_name", info.cppClassName);
|
||||
io.mapOptional("doc", info.doc);
|
||||
io.mapOptional("implements", info.implements);
|
||||
io.mapOptional("defines", info.defines);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT(
|
|||
// {3}: documentation (summary + description)
|
||||
// {4}: op attribute list
|
||||
// {5}: builder methods taking standalone attribute parameters
|
||||
// {6}: additional methods for attributes used by indexing maps
|
||||
// {6}: additional method defintions
|
||||
// {7}: additional methods for attributes used by indexing maps
|
||||
static const char structuredOpOdsHeaderFormat[] = R"FMT(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definition for {0}
|
||||
|
@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
|||
];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasFolder = 1;
|
||||
{6}
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
||||
// Auto-generated.
|
||||
|
@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
|||
// Generic methods.
|
||||
static unsigned getNumRegionArgs();
|
||||
std::string getLibraryCallName();
|
||||
{6}
|
||||
{7}
|
||||
}];
|
||||
}
|
||||
)FMT";
|
||||
|
@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
|||
|
||||
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
||||
|
||||
std::string definitionList;
|
||||
for (const std::string &definition : opConfig.metadata->defines) {
|
||||
static const char definitionFmt[] = "let {0} = 1;\n";
|
||||
definitionList.append(llvm::formatv(definitionFmt, definition));
|
||||
}
|
||||
|
||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||
return isAttribute(arg.kind);
|
||||
})) {
|
||||
|
@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
|||
os << llvm::formatv(structuredOpOdsHeaderFormat,
|
||||
opConfig.metadata->cppClassName, opConfig.metadata->name,
|
||||
interfaceNameList, doc, attrList, attrBuilder,
|
||||
attrMethods);
|
||||
definitionList, attrMethods);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue