forked from OSchip/llvm-project
[mlir] switch the transform loop extension to use types
Add types to the Loop (SCF) extension of the transform dialect. See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135587
This commit is contained in:
parent
3e1f6d02f7
commit
59bb8af4c3
|
@ -9,8 +9,8 @@
|
|||
#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
|
||||
#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
|
||||
|
||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -12,10 +12,12 @@
|
|||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformEffects.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
|
||||
include "mlir/Dialect/PDL/IR/PDLTypes.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformTypes.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
|
||||
|
||||
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
|
||||
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
|
@ -30,12 +32,13 @@ def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
|
|||
}];
|
||||
|
||||
let arguments =
|
||||
(ins PDL_Operation:$target,
|
||||
(ins TransformTypeInterface:$target,
|
||||
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
|
||||
"1">:$num_loops);
|
||||
let results = (outs PDL_Operation:$parent);
|
||||
let results = (outs TransformTypeInterface:$parent);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
|
||||
|
@ -55,11 +58,15 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
|
|||
order as the operand handle.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
// Note that despite the name of the transform operation and related utility
|
||||
// functions, the actual implementation does not require the operation to be
|
||||
// a loop.
|
||||
let arguments = (ins TransformTypeInterface:$target,
|
||||
StrAttr:$func_name);
|
||||
let results = (outs PDL_Operation:$transformed);
|
||||
let results = (outs TransformTypeInterface:$transformed);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
|
||||
|
@ -90,12 +97,13 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
|
|||
}];
|
||||
|
||||
let arguments =
|
||||
(ins PDL_Operation:$target,
|
||||
(ins Transform_ScfForOp:$target,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
|
||||
// TODO: Return both the peeled loop and the remainder loop.
|
||||
let results = (outs PDL_Operation:$transformed);
|
||||
let results = (outs TransformTypeInterface:$transformed);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
|
@ -131,12 +139,13 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
|
|||
pipelined loops, which can be empty.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
let arguments = (ins Transform_ScfForOp:$target,
|
||||
DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
|
||||
DefaultValuedAttr<I64Attr, "10">:$read_latency);
|
||||
let results = (outs PDL_Operation:$transformed);
|
||||
let results = (outs TransformTypeInterface:$transformed);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` functional-type(operands, results)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
|
@ -165,10 +174,10 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
|
|||
removed after a full unrolling.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
let arguments = (ins Transform_ScfForOp:$target,
|
||||
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
let assemblyFormat = "$target attr-dict `:` type($target)";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
|
|
|
@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRSCFTransformOps
|
|||
MLIRAffineDialect
|
||||
MLIRFuncDialect
|
||||
MLIRIR
|
||||
MLIRPDLDialect
|
||||
MLIRSCFDialect
|
||||
MLIRSCFTransforms
|
||||
MLIRSCFUtils
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
|
@ -239,8 +238,6 @@ public:
|
|||
using Base::Base;
|
||||
|
||||
void init() {
|
||||
declareDependentDialect<pdl::PDLDialect>();
|
||||
|
||||
declareGeneratedDialect<AffineDialect>();
|
||||
declareGeneratedDialect<func::FuncDialect>();
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
@ -28,13 +27,14 @@ class GetParentForOp:
|
|||
"""Extension for GetParentForOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
num_loops: int = 1,
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
num_loops=_get_int64_attr(num_loops, default_value=1),
|
||||
ip=ip,
|
||||
|
@ -45,13 +45,14 @@ class LoopOutlineOp:
|
|||
"""Extension for LoopOutlineOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
func_name: Union[str, StringAttr],
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
func_name=(func_name if isinstance(func_name, StringAttr) else
|
||||
StringAttr.get(func_name)),
|
||||
|
@ -63,13 +64,14 @@ class LoopPeelOp:
|
|||
"""Extension for LoopPeelOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
fail_if_already_divisible: Union[bool, BoolAttr] = False,
|
||||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
|
||||
fail_if_already_divisible, BoolAttr) else
|
||||
|
@ -82,6 +84,7 @@ class LoopPipelineOp:
|
|||
"""Extension for LoopPipelineOp."""
|
||||
|
||||
def __init__(self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
|
||||
|
@ -89,7 +92,7 @@ class LoopPipelineOp:
|
|||
ip=None,
|
||||
loc=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(),
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
|
||||
read_latency=_get_int64_attr(read_latency, default_value=10),
|
||||
|
|
|
@ -51,7 +51,8 @@ transform.with_pdl_patterns {
|
|||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1
|
||||
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
|
||||
transform.loop.peel %loops#0
|
||||
%loop = transform.cast %loops#0 : !pdl.operation to !transform.op<"scf.for">
|
||||
transform.loop.peel %loop : (!transform.op<"scf.for">) -> !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,11 +14,11 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
|
||||
transform.test_print_remark_at_operand %match_name, "matched op name"
|
||||
transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
|
||||
transform.test_consume_operand %match_name
|
||||
|
||||
%match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1
|
||||
transform.test_print_remark_at_operand %match_attr, "matched attr name"
|
||||
transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation
|
||||
transform.test_consume_operand %match_attr
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ transform.with_pdl_patterns {
|
|||
^bb1(%arg1: !pdl.operation):
|
||||
%match_name = transform.structured.match
|
||||
ops{["arith.constant"]} filter_result_type = f32 in %arg1
|
||||
transform.test_print_remark_at_operand %match_name, "matched op name"
|
||||
transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
|
||||
transform.test_consume_operand %match_name
|
||||
}
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ transform.with_pdl_patterns {
|
|||
ops{["linalg.generic"]}
|
||||
attributes{iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
in %arg1
|
||||
transform.test_print_remark_at_operand %match_attr, "matched complex attr"
|
||||
transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
|
||||
transform.test_consume_operand %match_attr
|
||||
|
||||
%no_match = transform.structured.match
|
||||
|
|
|
@ -33,5 +33,5 @@ transform.sequence failures(propagate) {
|
|||
%0 = transform.structured.match ops{["memref.alloc"]} in %arg1
|
||||
%1 = transform.memref.multibuffer %0 {factor = 2 : i64}
|
||||
// Verify that the returned handle is usable.
|
||||
transform.test_print_remark_at_operand %1, "transformed"
|
||||
transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
|
||||
}
|
||||
|
|
|
@ -21,12 +21,12 @@ transform.with_pdl_patterns {
|
|||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
|
||||
// CHECK: = transform.loop.get_parent_for
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
%2 = transform.loop.get_parent_for %0 { num_loops = 2 }
|
||||
%3 = transform.loop.get_parent_for %0 { num_loops = 3 }
|
||||
transform.test_print_remark_at_operand %1, "third loop"
|
||||
transform.test_print_remark_at_operand %2, "second loop"
|
||||
transform.test_print_remark_at_operand %3, "first loop"
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
%2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
%3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for">
|
||||
transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for">
|
||||
transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for">
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ transform.with_pdl_patterns {
|
|||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
|
||||
// expected-error @below {{could not find an 'scf.for' parent}}
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -85,9 +85,9 @@ transform.with_pdl_patterns {
|
|||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
// CHECK: = transform.loop.outline %{{.*}}
|
||||
transform.loop.outline %1 {func_name = "foo"}
|
||||
transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,7 +115,7 @@ transform.with_pdl_patterns {
|
|||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["scf.while"]} in %arg1
|
||||
// expected-error @below {{failed to outline}}
|
||||
transform.loop.outline %0 {func_name = "foo"}
|
||||
transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,8 +145,8 @@ transform.with_pdl_patterns {
|
|||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
transform.loop.peel %1
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
transform.loop.peel %1 : (!transform.op<"scf.for">) -> !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,10 +181,10 @@ transform.with_pdl_patterns {
|
|||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
%2 = transform.loop.pipeline %1
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
%2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !pdl.operation
|
||||
// Verify that the returned handle is usable.
|
||||
transform.test_print_remark_at_operand %2, "transformed"
|
||||
transform.test_print_remark_at_operand %2, "transformed" : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,8 +208,8 @@ transform.with_pdl_patterns {
|
|||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
transform.loop.unroll %1 { factor = 4 }
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
|
||||
transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ transform.with_pdl_patterns {
|
|||
// expected-note @below {{invalidated by this transform op that consumes its operand #0}}
|
||||
test_consume_operand %1
|
||||
// expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
|
||||
test_print_remark_at_operand %0, "remark"
|
||||
test_print_remark_at_operand %0, "remark" : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ transform.with_pdl_patterns {
|
|||
%2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation
|
||||
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
|
||||
test_consume_operand %2
|
||||
test_print_remark_at_operand %0, "remark"
|
||||
test_print_remark_at_operand %0, "remark" : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ transform.with_pdl_patterns {
|
|||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
test_print_remark_at_operand %0, "matched"
|
||||
test_print_remark_at_operand %0, "matched" : !pdl.operation
|
||||
}
|
||||
|
||||
pdl.pattern @some : benefit(1) {
|
||||
|
@ -124,7 +124,7 @@ transform.with_pdl_patterns {
|
|||
%f = pdl_match @const in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
// CHECK: %{{.+}} = get_closest_isolated_parent %{{.+}}
|
||||
%m = get_closest_isolated_parent %f : (!pdl.operation) -> !pdl.operation
|
||||
test_print_remark_at_operand %m, "parent function"
|
||||
test_print_remark_at_operand %m, "parent function" : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -227,7 +227,7 @@ transform.with_pdl_patterns {
|
|||
}, {
|
||||
^bb2(%arg2: !pdl.operation):
|
||||
%2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation
|
||||
transform.test_print_remark_at_operand %2, "still here"
|
||||
transform.test_print_remark_at_operand %2, "still here" : !pdl.operation
|
||||
// This alternative succeeds.
|
||||
}, {
|
||||
^bb2(%arg2: !pdl.operation):
|
||||
|
@ -370,7 +370,7 @@ transform.with_pdl_patterns {
|
|||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @match_const in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
%1 = transform.loop.get_parent_for %0
|
||||
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !pdl.operation
|
||||
// expected-error @below {{only isolated-from-above ops can be alternative scopes}}
|
||||
alternatives %1 : !pdl.operation {
|
||||
^bb2(%arg2: !pdl.operation):
|
||||
|
@ -541,7 +541,7 @@ transform.with_pdl_patterns {
|
|||
%0 = pdl_match @addi in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
%1 = pdl_match @subi in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
%2 = merge_handles %0, %1 : !pdl.operation
|
||||
test_print_remark_at_operand %2, "matched"
|
||||
test_print_remark_at_operand %2, "matched" : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -675,7 +675,7 @@ transform.with_pdl_patterns {
|
|||
^bb2(%arg2: !pdl.operation):
|
||||
// expected-remark @below {{1}}
|
||||
transform.test_print_number_of_associated_payload_ir_ops %arg2
|
||||
transform.test_print_remark_at_operand %arg2, "transform applied"
|
||||
transform.test_print_remark_at_operand %arg2, "transform applied" : !pdl.operation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -725,7 +725,7 @@ transform.with_pdl_patterns {
|
|||
|
||||
// expected-remark @below {{3}}
|
||||
transform.test_print_number_of_associated_payload_ir_ops %results
|
||||
transform.test_print_remark_at_operand %results, "transform applied"
|
||||
transform.test_print_remark_at_operand %results, "transform applied" : !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -742,7 +742,7 @@ transform.sequence failures(propagate) {
|
|||
^bb1(%arg1: !pdl.operation):
|
||||
%addi = transform.structured.match ops{["arith.addi"]} in %arg1
|
||||
%muli = get_producer_of_operand %addi[0] : (!pdl.operation) -> !pdl.operation
|
||||
transform.test_print_remark_at_operand %muli, "found muli"
|
||||
transform.test_print_remark_at_operand %muli, "found muli" : !pdl.operation
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -69,10 +69,11 @@ def TestPrintRemarkAtOperandOp
|
|||
: Op<Transform_Dialect, "test_print_remark_at_operand",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let arguments = (ins
|
||||
Arg<PDL_Operation, "",
|
||||
Arg<TransformTypeInterface, "",
|
||||
[TransformMappingRead, PayloadIRRead]>:$operand,
|
||||
StrAttr:$message);
|
||||
let assemblyFormat = "$operand `,` $message attr-dict";
|
||||
let assemblyFormat =
|
||||
"$operand `,` $message attr-dict `:` type($operand)";
|
||||
let cppNamespace = "::mlir::test";
|
||||
}
|
||||
|
||||
|
|
|
@ -18,9 +18,10 @@ def run(f):
|
|||
|
||||
@run
|
||||
def getParentLoop():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[], pdl.OperationType.get())
|
||||
with InsertionPoint(sequence.body):
|
||||
loop.GetParentForOp(sequence.bodyTarget, num_loops=2)
|
||||
loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: getParentLoop
|
||||
# CHECK: = transform.loop.get_parent_for %
|
||||
|
@ -29,9 +30,10 @@ def getParentLoop():
|
|||
|
||||
@run
|
||||
def loopOutline():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[], transform.OperationType.get("scf.for"))
|
||||
with InsertionPoint(sequence.body):
|
||||
loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo")
|
||||
loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo")
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: loopOutline
|
||||
# CHECK: = transform.loop.outline %
|
||||
|
@ -40,9 +42,10 @@ def loopOutline():
|
|||
|
||||
@run
|
||||
def loopPeel():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[], transform.OperationType.get("scf.for"))
|
||||
with InsertionPoint(sequence.body):
|
||||
loop.LoopPeelOp(sequence.bodyTarget)
|
||||
loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: loopPeel
|
||||
# CHECK: = transform.loop.peel %
|
||||
|
@ -50,9 +53,10 @@ def loopPeel():
|
|||
|
||||
@run
|
||||
def loopPipeline():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[], transform.OperationType.get("scf.for"))
|
||||
with InsertionPoint(sequence.body):
|
||||
loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3)
|
||||
loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: loopPipeline
|
||||
# CHECK: = transform.loop.pipeline %
|
||||
|
@ -62,7 +66,8 @@ def loopPipeline():
|
|||
|
||||
@run
|
||||
def loopUnroll():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[], transform.OperationType.get("scf.for"))
|
||||
with InsertionPoint(sequence.body):
|
||||
loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
|
||||
transform.YieldOp()
|
||||
|
|
Loading…
Reference in New Issue