[mlir] switch the transform loop extension to use types

Add types to the Loop (SCF) extension of the transform dialect.

See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135587
This commit is contained in:
Alex Zinenko 2022-10-10 14:38:31 +00:00
parent 3e1f6d02f7
commit 59bb8af4c3
13 changed files with 84 additions and 69 deletions

View File

@ -9,8 +9,8 @@
#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {

View File

@ -12,10 +12,12 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
@ -30,12 +32,13 @@ def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
}];
let arguments =
(ins PDL_Operation:$target,
(ins TransformTypeInterface:$target,
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
"1">:$num_loops);
let results = (outs PDL_Operation:$parent);
let results = (outs TransformTypeInterface:$parent);
let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
@ -55,11 +58,15 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
order as the operand handle.
}];
let arguments = (ins PDL_Operation:$target,
// Note that despite the name of the transform operation and related utility
// functions, the actual implementation does not require the operation to be
// a loop.
let arguments = (ins TransformTypeInterface:$target,
StrAttr:$func_name);
let results = (outs PDL_Operation:$transformed);
let results = (outs TransformTypeInterface:$transformed);
let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}
def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
@ -90,12 +97,13 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
}];
let arguments =
(ins PDL_Operation:$target,
(ins Transform_ScfForOp:$target,
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
// TODO: Return both the peeled loop and the remainder loop.
let results = (outs PDL_Operation:$transformed);
let results = (outs TransformTypeInterface:$transformed);
let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@ -131,12 +139,13 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
pipelined loops, which can be empty.
}];
let arguments = (ins PDL_Operation:$target,
let arguments = (ins Transform_ScfForOp:$target,
DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
DefaultValuedAttr<I64Attr, "10">:$read_latency);
let results = (outs PDL_Operation:$transformed);
let results = (outs TransformTypeInterface:$transformed);
let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@ -165,10 +174,10 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
removed after a full unrolling.
}];
let arguments = (ins PDL_Operation:$target,
let arguments = (ins Transform_ScfForOp:$target,
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
let assemblyFormat = "$target attr-dict";
let assemblyFormat = "$target attr-dict `:` type($target)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(

View File

@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRSCFTransformOps
MLIRAffineDialect
MLIRFuncDialect
MLIRIR
MLIRPDLDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils

View File

@ -9,7 +9,6 @@
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
@ -239,8 +238,6 @@ public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<func::FuncDialect>();

View File

@ -5,7 +5,6 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@ -28,13 +27,14 @@ class GetParentForOp:
"""Extension for GetParentForOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: int = 1,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
num_loops=_get_int64_attr(num_loops, default_value=1),
ip=ip,
@ -45,13 +45,14 @@ class LoopOutlineOp:
"""Extension for LoopOutlineOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
@ -63,13 +64,14 @@ class LoopPeelOp:
"""Extension for LoopPeelOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
fail_if_already_divisible, BoolAttr) else
@ -82,6 +84,7 @@ class LoopPipelineOp:
"""Extension for LoopPipelineOp."""
def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
@ -89,7 +92,7 @@ class LoopPipelineOp:
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
read_latency=_get_int64_attr(read_latency, default_value=10),

View File

@ -51,7 +51,8 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
transform.loop.peel %loops#0
%loop = transform.cast %loops#0 : !pdl.operation to !transform.op<"scf.for">
transform.loop.peel %loop : (!transform.op<"scf.for">) -> !pdl.operation
}
}

View File

@ -14,11 +14,11 @@ transform.with_pdl_patterns {
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
transform.test_consume_operand %match_name
%match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1
transform.test_print_remark_at_operand %match_attr, "matched attr name"
transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation
transform.test_consume_operand %match_attr
}
}
@ -38,7 +38,7 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match
ops{["arith.constant"]} filter_result_type = f32 in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
transform.test_consume_operand %match_name
}
}
@ -69,7 +69,7 @@ transform.with_pdl_patterns {
ops{["linalg.generic"]}
attributes{iterator_types = ["parallel", "parallel", "parallel"]}
in %arg1
transform.test_print_remark_at_operand %match_attr, "matched complex attr"
transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
transform.test_consume_operand %match_attr
%no_match = transform.structured.match

View File

@ -33,5 +33,5 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["memref.alloc"]} in %arg1
%1 = transform.memref.multibuffer %0 {factor = 2 : i64}
// Verify that the returned handle is usable.
transform.test_print_remark_at_operand %1, "transformed"
transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
}

View File

@ -21,12 +21,12 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// CHECK: = transform.loop.get_parent_for
%1 = transform.loop.get_parent_for %0
%2 = transform.loop.get_parent_for %0 { num_loops = 2 }
%3 = transform.loop.get_parent_for %0 { num_loops = 3 }
transform.test_print_remark_at_operand %1, "third loop"
transform.test_print_remark_at_operand %2, "second loop"
transform.test_print_remark_at_operand %3, "first loop"
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
%2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"scf.for">
%3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"scf.for">
transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for">
transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for">
transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for">
}
}
@ -44,7 +44,7 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// expected-error @below {{could not find an 'scf.for' parent}}
%1 = transform.loop.get_parent_for %0
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
}
}
@ -85,9 +85,9 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
// CHECK: = transform.loop.outline %{{.*}}
transform.loop.outline %1 {func_name = "foo"}
transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
}
}
@ -115,7 +115,7 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["scf.while"]} in %arg1
// expected-error @below {{failed to outline}}
transform.loop.outline %0 {func_name = "foo"}
transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
}
}
@ -145,8 +145,8 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
transform.loop.peel %1
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
transform.loop.peel %1 : (!transform.op<"scf.for">) -> !pdl.operation
}
}
@ -181,10 +181,10 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
%1 = transform.loop.get_parent_for %0
%2 = transform.loop.pipeline %1
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
%2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !pdl.operation
// Verify that the returned handle is usable.
transform.test_print_remark_at_operand %2, "transformed"
transform.test_print_remark_at_operand %2, "transformed" : !pdl.operation
}
}
@ -208,8 +208,8 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
transform.loop.unroll %1 { factor = 4 }
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
}
}

View File

@ -23,7 +23,7 @@ transform.with_pdl_patterns {
// expected-note @below {{invalidated by this transform op that consumes its operand #0}}
test_consume_operand %1
// expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
test_print_remark_at_operand %0, "remark"
test_print_remark_at_operand %0, "remark" : !pdl.operation
}
}
@ -57,7 +57,7 @@ transform.with_pdl_patterns {
%2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
test_consume_operand %2
test_print_remark_at_operand %0, "remark"
test_print_remark_at_operand %0, "remark" : !pdl.operation
}
}

View File

@ -76,7 +76,7 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
test_print_remark_at_operand %0, "matched"
test_print_remark_at_operand %0, "matched" : !pdl.operation
}
pdl.pattern @some : benefit(1) {
@ -124,7 +124,7 @@ transform.with_pdl_patterns {
%f = pdl_match @const in %arg1 : (!pdl.operation) -> !pdl.operation
// CHECK: %{{.+}} = get_closest_isolated_parent %{{.+}}
%m = get_closest_isolated_parent %f : (!pdl.operation) -> !pdl.operation
test_print_remark_at_operand %m, "parent function"
test_print_remark_at_operand %m, "parent function" : !pdl.operation
}
}
@ -227,7 +227,7 @@ transform.with_pdl_patterns {
}, {
^bb2(%arg2: !pdl.operation):
%2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation
transform.test_print_remark_at_operand %2, "still here"
transform.test_print_remark_at_operand %2, "still here" : !pdl.operation
// This alternative succeeds.
}, {
^bb2(%arg2: !pdl.operation):
@ -370,7 +370,7 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @match_const in %arg1 : (!pdl.operation) -> !pdl.operation
%1 = transform.loop.get_parent_for %0
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !pdl.operation
// expected-error @below {{only isolated-from-above ops can be alternative scopes}}
alternatives %1 : !pdl.operation {
^bb2(%arg2: !pdl.operation):
@ -541,7 +541,7 @@ transform.with_pdl_patterns {
%0 = pdl_match @addi in %arg1 : (!pdl.operation) -> !pdl.operation
%1 = pdl_match @subi in %arg1 : (!pdl.operation) -> !pdl.operation
%2 = merge_handles %0, %1 : !pdl.operation
test_print_remark_at_operand %2, "matched"
test_print_remark_at_operand %2, "matched" : !pdl.operation
}
}
@ -675,7 +675,7 @@ transform.with_pdl_patterns {
^bb2(%arg2: !pdl.operation):
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %arg2
transform.test_print_remark_at_operand %arg2, "transform applied"
transform.test_print_remark_at_operand %arg2, "transform applied" : !pdl.operation
}
}
}
@ -725,7 +725,7 @@ transform.with_pdl_patterns {
// expected-remark @below {{3}}
transform.test_print_number_of_associated_payload_ir_ops %results
transform.test_print_remark_at_operand %results, "transform applied"
transform.test_print_remark_at_operand %results, "transform applied" : !pdl.operation
}
}
@ -742,7 +742,7 @@ transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%addi = transform.structured.match ops{["arith.addi"]} in %arg1
%muli = get_producer_of_operand %addi[0] : (!pdl.operation) -> !pdl.operation
transform.test_print_remark_at_operand %muli, "found muli"
transform.test_print_remark_at_operand %muli, "found muli" : !pdl.operation
}
// -----

View File

@ -69,10 +69,11 @@ def TestPrintRemarkAtOperandOp
: Op<Transform_Dialect, "test_print_remark_at_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins
Arg<PDL_Operation, "",
Arg<TransformTypeInterface, "",
[TransformMappingRead, PayloadIRRead]>:$operand,
StrAttr:$message);
let assemblyFormat = "$operand `,` $message attr-dict";
let assemblyFormat =
"$operand `,` $message attr-dict `:` type($operand)";
let cppNamespace = "::mlir::test";
}

View File

@ -18,9 +18,10 @@ def run(f):
@run
def getParentLoop():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], pdl.OperationType.get())
with InsertionPoint(sequence.body):
loop.GetParentForOp(sequence.bodyTarget, num_loops=2)
loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2)
transform.YieldOp()
# CHECK-LABEL: TEST: getParentLoop
# CHECK: = transform.loop.get_parent_for %
@ -29,9 +30,10 @@ def getParentLoop():
@run
def loopOutline():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo")
loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo")
transform.YieldOp()
# CHECK-LABEL: TEST: loopOutline
# CHECK: = transform.loop.outline %
@ -40,9 +42,10 @@ def loopOutline():
@run
def loopPeel():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
loop.LoopPeelOp(sequence.bodyTarget)
loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: loopPeel
# CHECK: = transform.loop.peel %
@ -50,9 +53,10 @@ def loopPeel():
@run
def loopPipeline():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3)
loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3)
transform.YieldOp()
# CHECK-LABEL: TEST: loopPipeline
# CHECK: = transform.loop.pipeline %
@ -62,7 +66,8 @@ def loopPipeline():
@run
def loopUnroll():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
transform.YieldOp()