forked from OSchip/llvm-project
[mlir][Transform] Make applyToOne return a DiagnosedSilenceableFailure
This revision revisits the implementation of applyToOne and its handling of recoverable errors as well as propagation of null handles. The implementation is simplified to always require passing a vector<Operation*> in which the results are returned, resulting in less template instantiation magic. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D129185
This commit is contained in:
parent
7d1a295484
commit
5230710933
|
@ -15,6 +15,7 @@
|
|||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
class GenericOp;
|
||||
class LinalgOp;
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -22,9 +22,15 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
|
|||
let description = [{
|
||||
Decomposes named complex operations, such as higher-dimensional
|
||||
(depthwise) convolutions, into combinations of lower-dimensional equivalents
|
||||
when possible. The operand handle must point to a list of such operations.
|
||||
The returning handle points to the main produced computational operation,
|
||||
such as the lower-dimensional convolution.
|
||||
when possible.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-Linalg ops and drops them in the return.
|
||||
If all the operations referred to by the `target` PDLOperation decompose
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
The return handle points to only the subset of successfully produced
|
||||
computational operations, which can be empty.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target);
|
||||
|
@ -32,8 +38,10 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::linalg::LinalgOp target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -61,11 +69,16 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
|
|||
TransformOpInterface, TransformEachOpTrait]> {
|
||||
let description = [{
|
||||
Transforms a named structued operation into the generic form with the
|
||||
explicit attached region. The operand handle must point to a list of
|
||||
structured operations, it is consumed by the transformation and is not
|
||||
expected to be used afterwards. The resulting handle points to the list
|
||||
of equivalent generic operations, in the same order as the original named
|
||||
operations.
|
||||
explicit attached region.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-Linalg ops and drops them in the return.
|
||||
If all the operations referred to by the `target` PDLOperation generalize
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
The return handle points to only the subset of successfully produced
|
||||
equivalent generic operations, which can be empty or contain the original
|
||||
ops if they were already in generic form.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target);
|
||||
|
@ -73,8 +86,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::linalg::LinalgOp target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -84,6 +99,16 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
|
|||
let description = [{
|
||||
Interchanges the iterators of the operations pointed to by the target handle
|
||||
using the iterator interchange attribute.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-linalg::Generic ops and drops them in the return.
|
||||
This operation fails if the interchange attribute is invalid.
|
||||
If all the operations referred to by the `target` PDLOperation interchange
|
||||
properly, the transform succeeds.
|
||||
If any interchange fails, the transform definitely fails.
|
||||
The return handle points to only the subset of successfully produced
|
||||
interchanged operations, which can be empty.
|
||||
}];
|
||||
|
||||
let arguments =
|
||||
|
@ -95,8 +120,10 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
|
|||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::linalg::GenericOp target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -106,6 +133,16 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
|
|||
let description = [{
|
||||
Pads the operations pointed to by the target handle using the options
|
||||
provides as operation attributes.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-Linalg ops and drops them in the return.
|
||||
This operation may produce a definiteFailure if the padding fails for any
|
||||
reason.
|
||||
If all the operations referred to by the `target` PDLOperation pad
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
The return handle points to only the subset of successfully produced
|
||||
padded operations, which can be empty.
|
||||
}];
|
||||
|
||||
let arguments =
|
||||
|
@ -123,8 +160,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
|
|||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::linalg::LinalgOp target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -135,11 +174,23 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
|
|||
Indicates that ops of a specific kind in the given function should be
|
||||
scalarized (i.e. their dynamic dimensions tiled by 1).
|
||||
|
||||
This operation returns the tiled op but not the loops.
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-Linalg ops and drops them in the return.
|
||||
This operation produces `definiteFailure` if the scalarization fails for any
|
||||
reason.
|
||||
If all the operations referred to by the `target` PDLOperation scalarize
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
|
||||
The return handle points to only the subset of successfully produced
|
||||
tiled-by-1 operations, which can be empty.
|
||||
|
||||
This operation does not return handles to the tiled loop.
|
||||
We make this design choice because it is hard to know ahead of time the
|
||||
number of loops that will be produced (it depends on the number of dynamic
|
||||
dimensions after multiple transformations have been applied).
|
||||
Loops can always be recovered by navigating from the tiled operations if
|
||||
needed.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target);
|
||||
|
@ -148,8 +199,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::linalg::LinalgOp target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -206,7 +259,17 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
|||
- use_alloc: whether to use an alloc op to allocate the temporary
|
||||
tensor (default: do not use alloc op)
|
||||
|
||||
This op returns 4 handles to:
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-Linalg ops and drops them in the return.
|
||||
This operation produces `definiteFailure` if the splitting fails for any
|
||||
reason.
|
||||
|
||||
If all the operations referred to by the `target` PDLOperation split
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
The 4 returned handles points to only the subset of successfully produced
|
||||
computational operations, which can all be empty.
|
||||
This 4 returned handles point to:
|
||||
- the init op (or tensor_alloc op if use_alloc = true),
|
||||
- the fill op used to initialize the neutral element,
|
||||
- the split op and
|
||||
|
@ -316,15 +379,18 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
|||
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
|
||||
UnitAttr:$use_scaling_algorithm,
|
||||
UnitAttr:$use_alloc);
|
||||
let results = (outs PDL_Operation:$fill_op,
|
||||
let results = (outs PDL_Operation:$init_or_alloc_op,
|
||||
PDL_Operation:$fill_op,
|
||||
PDL_Operation:$split_linalg_op,
|
||||
PDL_Operation:$combining_linalg_op);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::linalg::LinalgOp target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -372,6 +438,13 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
|
|||
|
||||
Note that this transformation is invalidating the handles to any payload IR
|
||||
operation that is contained inside the vectorization target.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation produces `definiteFailure` if vectorization fails for any
|
||||
reason.
|
||||
The operation always returns the handle to the target op that is expected
|
||||
to be isolated from above.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
|
@ -381,8 +454,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<Operation *> applyToOne(
|
||||
::mlir::Operation *target, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::Operation *target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -239,6 +239,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
|
|||
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
|
||||
/// integers, in the range 0..`op.rank` without duplications
|
||||
/// (i.e. `[1,1,2]` is an invalid permutation).
|
||||
///
|
||||
/// Return failure if the permutation is not valid.
|
||||
FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
|
||||
GenericOp genericOp,
|
||||
ArrayRef<unsigned> interchangeVector);
|
||||
|
|
|
@ -50,7 +50,7 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
|
|||
outlined into a separate function. The provided name is used as a _base_
|
||||
for forming actual function names following SymbolTable auto-renaming
|
||||
scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
|
||||
have a SymbolTable ancestor (typically true because of the top-level
|
||||
have a SymbolTable ancestor (typically true because of the top-level
|
||||
module). Returns the handle to the list of outlined functions in the same
|
||||
order as the operand handle.
|
||||
}];
|
||||
|
@ -68,28 +68,40 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
|
|||
let summary = "Peels the last iteration of the loop";
|
||||
let description = [{
|
||||
Updates the given loop so that its step evenly divides its range and puts
|
||||
the remaining iteration into a separate loop or a conditional. Note that
|
||||
even though the Payload IR modification may be performed in-place, this
|
||||
operation consumes the operand handle and produces a new one. Applies to
|
||||
each loop associated with the operand handle individually. The results
|
||||
follow the same order as the operand.
|
||||
the remaining iteration into a separate loop or a conditional.
|
||||
|
||||
Note: If it can be proven statically that the step already evenly divides
|
||||
the range, this op is a no-op. In the absence of sufficient static
|
||||
information, this op may peel a loop, even if the step always divides the
|
||||
range evenly at runtime.
|
||||
In the absence of sufficient static information, this op may peel a loop,
|
||||
even if the step always divides the range evenly at runtime.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-scf::ForOp ops and drops them in the return.
|
||||
|
||||
This operation always succeeds and returns the scf::ForOp with the
|
||||
postcondition: "the loop trip count is divisible by the step".
|
||||
This operation may return the same unmodified loop handle when peeling did
|
||||
not modify the IR (i.e. the loop trip count was already divisible).
|
||||
|
||||
Note that even though the Payload IR modification may be performed
|
||||
in-place, this operation consumes the operand handle and produces a new
|
||||
one.
|
||||
|
||||
TODO: Return both the peeled loop and the remainder loop.
|
||||
}];
|
||||
|
||||
let arguments =
|
||||
(ins PDL_Operation:$target,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
|
||||
// TODO: Return both the peeled loop and the remainder loop.
|
||||
let results = (outs PDL_Operation:$transformed);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
|
||||
::mlir::scf::ForOp loop, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::scf::ForOp target,
|
||||
::llvm::SmallVector<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -102,10 +114,21 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
|
|||
each of them. That is, performs some amount of reads from memory before the
|
||||
loop rather than inside the loop, the same amount of writes into memory
|
||||
after the loop, and updates each iteration to read the data for a following
|
||||
iteration rather than the current one. The amount is specified by the
|
||||
attributes. The values read and about to be stored are transferred as loop
|
||||
iteration arguments. Currently supports memref and vector transfer
|
||||
operations as memory reads/writes.
|
||||
iteration rather than the current one.
|
||||
|
||||
The amount is specified by the attributes.
|
||||
|
||||
The values read and about to be stored are transferred as loop iteration
|
||||
arguments. Currently supports memref and vector transfer operations as
|
||||
memory reads/writes.
|
||||
|
||||
Return modes:
|
||||
=============
|
||||
This operation ignores non-scf::For ops and drops them in the return.
|
||||
If all the operations referred to by the `target` PDLOperation pipeline
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
The return handle points to only the subset of successfully produced
|
||||
pipelined loops, which can be empty.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
|
@ -116,8 +139,10 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
|
||||
::mlir::scf::ForOp loop, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::scf::ForOp target,
|
||||
::llvm::SmallVector<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -126,11 +151,18 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
|
|||
TransformOpInterface, TransformEachOpTrait]> {
|
||||
let summary = "Unrolls the given loop with the given unroll factor";
|
||||
let description = [{
|
||||
Unrolls each loop associated with the given handle to have up to the given
|
||||
number of loop body copies per iteration. If the unroll factor is larger
|
||||
than the loop trip count, the latter is used as the unroll factor instead.
|
||||
Does not produce a new handle as the operation may result in the loop being
|
||||
removed after a full unrolling.
|
||||
Unrolls each loop associated with the given handle to have up to the given
|
||||
number of loop body copies per iteration. If the unroll factor is larger
|
||||
than the loop trip count, the latter is used as the unroll factor instead.
|
||||
|
||||
Return modes:
|
||||
==============
|
||||
This operation ignores non-scf::For ops and drops them in the return.
|
||||
If all the operations referred to by the `target` PDLOperation unroll
|
||||
properly, the transform succeeds. Otherwise the transform silently fails.
|
||||
|
||||
Does not return handles as the operation may result in the loop being
|
||||
removed after a full unrolling.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
|
@ -139,8 +171,10 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::LogicalResult applyToOne(
|
||||
::mlir::scf::ForOp loop, TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::scf::ForOp target,
|
||||
::llvm::SmallVector<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -12,13 +12,14 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// The result of a transform IR operation application. This can have one of the
|
||||
/// three states:
|
||||
/// - success;
|
||||
/// - silencable (recoverable) failure with yet-unreported diagnostic;
|
||||
/// - silenceable (recoverable) failure with yet-unreported diagnostic;
|
||||
/// - definite failure.
|
||||
/// Silenceable failure is intended to communicate information about
|
||||
/// transformations that did not apply but in a way that supports recovery,
|
||||
|
@ -26,10 +27,10 @@ namespace mlir {
|
|||
/// predictable way. They are associated with a Diagnostic that provides more
|
||||
/// details on the failure. Silenceable failure can be discarded, turning the
|
||||
/// result into success, or "reported", emitting the diagnostic and turning the
|
||||
/// result into definite failure. Transform IR operations containing other
|
||||
/// operations are allowed to do either with the results of the nested
|
||||
/// transformations, but must propagate definite failures as their diagnostics
|
||||
/// have been already reported to the user.
|
||||
/// result into definite failure.
|
||||
/// Transform IR operations containing other operations are allowed to do either
|
||||
/// with the results of the nested transformations, but must propagate definite
|
||||
/// failures as their diagnostics have been already reported to the user.
|
||||
class LLVM_NODISCARD DiagnosedSilenceableFailure {
|
||||
public:
|
||||
explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
|
||||
|
@ -51,12 +52,17 @@ public:
|
|||
return DiagnosedSilenceableFailure(::mlir::failure());
|
||||
}
|
||||
|
||||
/// Constructs a DiagnosedSilenceableFailure in the silencable failure state,
|
||||
/// Constructs a DiagnosedSilenceableFailure in the silenceable failure state,
|
||||
/// ready to emit the given diagnostic. This is considered a failure
|
||||
/// regardless of the diagnostic severity.
|
||||
static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) {
|
||||
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) {
|
||||
return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
|
||||
}
|
||||
static DiagnosedSilenceableFailure
|
||||
silenceableFailure(SmallVector<Diagnostic> &&diag) {
|
||||
return DiagnosedSilenceableFailure(
|
||||
std::forward<SmallVector<Diagnostic>>(diag));
|
||||
}
|
||||
|
||||
/// Converts all kinds of failure into a LogicalResult failure, emitting the
|
||||
/// diagnostic if necessary. Must not be called more than once.
|
||||
|
@ -65,44 +71,72 @@ public:
|
|||
assert(!reported && "attempting to report a diagnostic more than once");
|
||||
reported = true;
|
||||
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
if (diagnostic) {
|
||||
diagnostic->getLocation().getContext()->getDiagEngine().emit(
|
||||
std::move(*diagnostic));
|
||||
diagnostic.reset();
|
||||
if (!diagnostics.empty()) {
|
||||
for (auto &&diagnostic : diagnostics) {
|
||||
diagnostic.getLocation().getContext()->getDiagEngine().emit(
|
||||
std::move(diagnostic));
|
||||
}
|
||||
diagnostics.clear();
|
||||
result = ::mlir::failure();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns `true` if this is a silencable failure.
|
||||
bool isSilenceableFailure() const { return diagnostic.hasValue(); }
|
||||
/// Returns `true` if this is a silenceable failure.
|
||||
bool isDefiniteFailure() const { return result.failed(); }
|
||||
|
||||
/// Returns `true` if this is a silenceable failure.
|
||||
bool isSilenceableFailure() const { return !diagnostics.empty(); }
|
||||
|
||||
/// Returns `true` if this is a success.
|
||||
bool succeeded() const {
|
||||
return !diagnostic.hasValue() && ::mlir::succeeded(result);
|
||||
return diagnostics.empty() && ::mlir::succeeded(result);
|
||||
}
|
||||
|
||||
/// Returns the diagnostic message without emitting it. Expects this object
|
||||
/// to be a silencable failure.
|
||||
std::string getMessage() const { return diagnostic->str(); }
|
||||
/// to be a silenceable failure.
|
||||
std::string getMessage() const {
|
||||
std::string res;
|
||||
for (auto &diagnostic : diagnostics) {
|
||||
res.append(diagnostic.str());
|
||||
res.append("\n");
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Converts silencable failure into LogicalResult success without reporting
|
||||
/// Returns a string representation of the failure mode (for error reporting).
|
||||
std::string getStatusString() const {
|
||||
if (succeeded())
|
||||
return "success";
|
||||
if (isSilenceableFailure())
|
||||
return "silenceable failure";
|
||||
return "definite failure";
|
||||
}
|
||||
|
||||
/// Converts silenceable failure into LogicalResult success without reporting
|
||||
/// the diagnostic, preserves the other states.
|
||||
LogicalResult silence() {
|
||||
if (diagnostic) {
|
||||
diagnostic.reset();
|
||||
if (!diagnostics.empty()) {
|
||||
diagnostics.clear();
|
||||
result = ::mlir::success();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Streams the given values into the diagnotic. Expects this object to be a
|
||||
/// silencable failure.
|
||||
/// Take the diagnostic and silence.
|
||||
SmallVector<Diagnostic> &&takeDiagnostics() {
|
||||
assert(!diagnostics.empty() && "expected a diagnostic to be present");
|
||||
auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); });
|
||||
return std::move(diagnostics);
|
||||
}
|
||||
|
||||
/// Streams the given values into the last diagnotic.
|
||||
/// Expects this object to be a silenceable failure.
|
||||
template <typename T>
|
||||
DiagnosedSilenceableFailure &operator<<(T &&value) & {
|
||||
assert(isSilenceableFailure() &&
|
||||
"can only append output in silencable failure state");
|
||||
*diagnostic << std::forward<T>(value);
|
||||
"can only append output in silenceable failure state");
|
||||
diagnostics.back() << std::forward<T>(value);
|
||||
return *this;
|
||||
}
|
||||
template <typename T>
|
||||
|
@ -110,31 +144,36 @@ public:
|
|||
return std::move(this->operator<<(std::forward<T>(value)));
|
||||
}
|
||||
|
||||
/// Attaches a note to the diagnostic. Expects this object to be a silencable
|
||||
/// failure.
|
||||
/// Attaches a note to the last diagnostic.
|
||||
/// Expects this object to be a silenceable failure.
|
||||
Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
|
||||
assert(isSilenceableFailure() &&
|
||||
"can only attach notes to silencable failures");
|
||||
return diagnostic->attachNote(loc);
|
||||
"can only attach notes to silenceable failures");
|
||||
return diagnostics.back().attachNote(loc);
|
||||
}
|
||||
|
||||
private:
|
||||
explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
|
||||
: diagnostic(std::move(diagnostic)), result(failure()) {}
|
||||
: diagnostics(), result(failure()) {
|
||||
diagnostics.emplace_back(std::move(diagnostic));
|
||||
}
|
||||
explicit DiagnosedSilenceableFailure(SmallVector<Diagnostic> &&diagnostics)
|
||||
: diagnostics(std::move(diagnostics)), result(failure()) {}
|
||||
|
||||
/// The diagnostic associated with this object. If present, the object is
|
||||
/// considered to be in the silencable failure state regardless of the
|
||||
/// The diagnostics associated with this object. If non-empty, the object is
|
||||
/// considered to be in the silenceable failure state regardless of the
|
||||
/// `result` field.
|
||||
Optional<Diagnostic> diagnostic;
|
||||
SmallVector<Diagnostic, 1> diagnostics;
|
||||
|
||||
/// The "definite" logical state, either success or failure. Ignored if the
|
||||
/// diagnostic message is present.
|
||||
/// The "definite" logical state, either success or failure.
|
||||
/// Ignored if the diagnostics message is present.
|
||||
LogicalResult result;
|
||||
|
||||
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
/// Whther the associated diagnostic have been reported. Diagnostic reporting
|
||||
/// consumes the diagnostic, so we need a mechanism to differentiate a
|
||||
/// reported diagnostic from a state where it was never created.
|
||||
/// Whether the associated diagnostics have been reported.
|
||||
/// Diagnostics reporting consumes the diagnostics, so we need a mechanism to
|
||||
/// differentiate reported diagnostics from a state where it was never
|
||||
/// created.
|
||||
bool reported = false;
|
||||
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
};
|
||||
|
@ -579,24 +618,45 @@ public:
|
|||
};
|
||||
|
||||
/// Trait implementing the TransformOpInterface for operations applying a
|
||||
/// transformation to a single operation handle and producing one or multiple
|
||||
/// operation handles.
|
||||
/// The op must implement a method with one of the following signatures:
|
||||
/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy, state)
|
||||
/// - FailureOr<SmallVector<convertible-to-Operation*>>applyToOne(OpTy, state)
|
||||
/// - LogicalResult applyToOne(OpTy, state)
|
||||
/// transformation to a single operation handle and producing zero, one or
|
||||
/// multiple operation handles.
|
||||
/// The op must implement a method with the following signature:
|
||||
/// - DiagnosedSilenceableFailure applyToOne(OpTy,
|
||||
/// SmallVector<Operation*> &results, state)
|
||||
/// to perform a transformation that is applied in turn to all payload IR
|
||||
/// operations that correspond to the handle of the transform IR operation.
|
||||
/// In the functions above, OpTy is either Operation * or a concrete payload IR
|
||||
/// Op class that the transformation is applied to (NOT the class of the
|
||||
/// transform IR op). The op is expected to have a single operand.
|
||||
/// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
|
||||
/// that the transformation is applied to (and NOT the class of the transform IR
|
||||
/// op).
|
||||
/// The `applyToOne` method takes an empty `results` vector that it fills with
|
||||
/// zero, one or multiple operations depending on the number of resultd expected
|
||||
/// by the transform op.
|
||||
/// The number of results must match the number of results of the transform op.
|
||||
/// `applyToOne` is allowed to fill the `results` with all null elements to
|
||||
/// signify that the transformation did not apply to the payload IR operations.
|
||||
/// Such null elements are filtered out from results before return.
|
||||
///
|
||||
/// The transform op having this trait is expected to have a single operand.
|
||||
template <typename OpTy>
|
||||
class TransformEachOpTrait
|
||||
: public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
|
||||
public:
|
||||
/// Calls `applyToOne` for every payload operation associated with the operand
|
||||
/// of this transform IR op. If `applyToOne` returns ops, associates them with
|
||||
/// the result of this transform op.
|
||||
/// of this transform IR op, the following case disjunction happens:
|
||||
/// 1. If not target payload ops are associated to the operand then fill the
|
||||
/// results vector with the expected number of null elements and return
|
||||
/// success. This is the corner case handling that allows propagating
|
||||
/// the "no-op" case gracefully to improve usability.
|
||||
/// 2. If any `applyToOne` returns definiteFailure, the transformation is
|
||||
/// immediately considered definitely failed and we return.
|
||||
/// 3. All applications of `applyToOne` are checked to return a number of
|
||||
/// results expected by the transform IR op. If not, this is a definite
|
||||
/// failure and we return early.
|
||||
/// 4. If `applyToOne` produces ops, associate them with the result of this
|
||||
/// transform op.
|
||||
/// 5. If any `applyToOne` return silenceableFailure, the transformation is
|
||||
/// considered silenceable.
|
||||
/// 6. Otherwise the transformation is considered successful.
|
||||
DiagnosedSilenceableFailure apply(TransformResults &transformResults,
|
||||
TransformState &state);
|
||||
|
||||
|
@ -714,88 +774,58 @@ public:
|
|||
namespace mlir {
|
||||
namespace transform {
|
||||
namespace detail {
|
||||
/// Appends `result` to the vector assuming it corresponds to the success state
|
||||
/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
|
||||
/// `LogicalResult`, appends an empy vector.
|
||||
template <typename Ty>
|
||||
std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
|
||||
appendTransformResultToVector(
|
||||
Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
|
||||
results.push_back(SmallVector<Operation *>());
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Ty>
|
||||
std::enable_if_t<
|
||||
llvm::conjunction<
|
||||
llvm::negation<std::is_same<Ty, LogicalResult>>,
|
||||
std::is_convertible<typename Ty::value_type, Operation *>>::value,
|
||||
LogicalResult>
|
||||
appendTransformResultToVector(
|
||||
Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
|
||||
if (failed(result))
|
||||
return failure();
|
||||
results.push_back(SmallVector<Operation *>{*result});
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename ContainerTy>
|
||||
std::enable_if_t<
|
||||
llvm::conjunction<
|
||||
llvm::negation<std::is_same<ContainerTy, LogicalResult>>,
|
||||
llvm::negation<std::is_convertible<typename ContainerTy::value_type,
|
||||
Operation *>>>::value,
|
||||
LogicalResult>
|
||||
appendTransformResultToVector(
|
||||
ContainerTy resultContainer,
|
||||
SmallVectorImpl<SmallVector<Operation *>> &results) {
|
||||
if (failed(resultContainer))
|
||||
return failure();
|
||||
results.push_back(*resultContainer);
|
||||
return success();
|
||||
}
|
||||
/// Applies a one-to-one or a one-to-many transform to each of the given
|
||||
/// targets. Puts the results of transforms, if any, in `results` in the same
|
||||
/// order. Fails if any of the application fails. Individual transforms must be
|
||||
/// callable with one of the following signatures:
|
||||
/// - FailureOr<convertible-to-Operation*>(OpTy)
|
||||
/// - LogicalResult(OpTy)
|
||||
/// - FailureOr<SmallVectorImpl<convertible-to-Operation*>>(
|
||||
/// SmallVectorImpl<OpTy>)
|
||||
/// - LogicalResult(SmallVectorImpl<OpTy>)
|
||||
/// callable with the following signature:
|
||||
/// - DiagnosedSilenceableFailure(OpTy,
|
||||
/// SmallVector<Operation*> &results, state)
|
||||
/// where OpTy is either
|
||||
/// - Operation *, in which case the transform is always applied;
|
||||
/// - a concrete Op class, in which case a check is performed whether
|
||||
/// `targets` contains operations of the same class and a silencable failure
|
||||
/// `targets` contains operations of the same class and a silenceable failure
|
||||
/// is reported if it does not.
|
||||
template <typename FnTy>
|
||||
DiagnosedSilenceableFailure
|
||||
applyTransformToEach(ArrayRef<Operation *> targets,
|
||||
SmallVectorImpl<SmallVector<Operation *>> &results,
|
||||
FnTy transform) {
|
||||
DiagnosedSilenceableFailure applyTransformToEach(
|
||||
Location loc, int expectedNumResults, ArrayRef<Operation *> targets,
|
||||
SmallVectorImpl<SmallVector<Operation *>> &results, FnTy transform) {
|
||||
SmallVector<Diagnostic> silenceableStack;
|
||||
using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
|
||||
static_assert(std::is_convertible<OpTy, Operation *>::value,
|
||||
"expected transform function to take an operation");
|
||||
using RetTy = typename llvm::function_traits<FnTy>::result_t;
|
||||
static_assert(std::is_convertible<RetTy, LogicalResult>::value,
|
||||
"expected transform function to return LogicalResult or "
|
||||
"FailureOr<convertible-to-Operation*>");
|
||||
for (Operation *target : targets) {
|
||||
// Emplace back a placeholder for the returned new ops.
|
||||
// This is filled with `expectedNumResults` if the op fails to apply.
|
||||
results.push_back(SmallVector<Operation *>());
|
||||
|
||||
auto specificOp = dyn_cast<OpTy>(target);
|
||||
if (!specificOp) {
|
||||
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
|
||||
diag << "attempted to apply transform to the wrong op kind";
|
||||
return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
|
||||
Diagnostic diag(loc, DiagnosticSeverity::Error);
|
||||
diag << "transform applied to the wrong op kind";
|
||||
diag.attachNote(target->getLoc()) << "when applied to this op";
|
||||
// Producing `expectedNumResults` nullptr is a silenceableFailure mode.
|
||||
// TODO: encode this implicit `expectedNumResults` nullptr ==
|
||||
// silenceableFailure with a proper trait.
|
||||
results.back().assign(expectedNumResults, nullptr);
|
||||
silenceableStack.push_back(std::move(diag));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto result = transform(specificOp);
|
||||
if (failed(appendTransformResultToVector(result, results)))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
DiagnosedSilenceableFailure result = transform(specificOp, results.back());
|
||||
if (result.isDefiniteFailure())
|
||||
return result;
|
||||
if (result.isSilenceableFailure())
|
||||
for (auto &&diag : result.takeDiagnostics())
|
||||
silenceableStack.push_back(std::move(diag));
|
||||
}
|
||||
if (!silenceableStack.empty()) {
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(
|
||||
std::move(silenceableStack));
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
/// Helper function to transform M ops with N results into N results of M ops.
|
||||
/// Helper function: transpose MxN into NxM; assumes that the input is a valid.
|
||||
static inline SmallVector<SmallVector<Operation *, 1>>
|
||||
transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
|
||||
SmallVector<SmallVector<Operation *, 1>> res;
|
||||
|
@ -824,74 +854,95 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
|
|||
decltype(&OpTy::applyToOne)>::template arg_t<0>;
|
||||
ArrayRef<Operation *> targets =
|
||||
state.getPayloadOps(this->getOperation()->getOperand(0));
|
||||
// Handle the corner case where no target is specified.
|
||||
|
||||
// Step 1. Handle the corner case where no target is specified.
|
||||
// This is typically the case when the matcher fails to apply and we need to
|
||||
// propagate gracefully.
|
||||
// In this case, we fill all results with an empty vector.
|
||||
if (targets.empty()) {
|
||||
SmallVector<Operation *> emptyResult;
|
||||
SmallVector<Operation *> empty;
|
||||
for (auto r : this->getOperation()->getResults())
|
||||
transformResults.set(r.template cast<OpResult>(), emptyResult);
|
||||
transformResults.set(r.template cast<OpResult>(), empty);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
// Step 2. Call applyToOne on each target and record newly produced ops in its
|
||||
// corresponding results entry.
|
||||
int expectedNumResults = this->getOperation()->getNumResults();
|
||||
SmallVector<SmallVector<Operation *>, 1> results;
|
||||
// In the multi-result case, collect the number of results each transform
|
||||
// produced.
|
||||
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
|
||||
targets, results, [&](TransformOpType specificOp) {
|
||||
return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
|
||||
this->getOperation()->getLoc(), expectedNumResults, targets, results,
|
||||
[&](TransformOpType specificOp, SmallVector<Operation *> &partialResult) {
|
||||
auto res = static_cast<OpTy *>(this)->applyToOne(specificOp,
|
||||
partialResult, state);
|
||||
if (res.isDefiniteFailure())
|
||||
return res;
|
||||
|
||||
// TODO: encode this implicit must always produce `expectedNumResults`
|
||||
// and nullptr is fine with a proper trait.
|
||||
if (static_cast<int>(partialResult.size()) != expectedNumResults) {
|
||||
auto loc = this->getOperation()->getLoc();
|
||||
auto diag = mlir::emitError(loc, "applications of ")
|
||||
<< OpTy::getOperationName() << " expected to produce "
|
||||
<< expectedNumResults << " results (actually produced "
|
||||
<< partialResult.size() << ").";
|
||||
diag.attachNote(loc)
|
||||
<< "If you need variadic results, consider a generic `apply` "
|
||||
<< "instead of the specialized `applyToOne`.";
|
||||
diag.attachNote(loc)
|
||||
<< "Producing " << expectedNumResults << " null results is "
|
||||
<< "allowed if the use case warrants it.";
|
||||
diag.attachNote(specificOp->getLoc()) << "when applied to this op";
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
// Check that all is null or none is null
|
||||
// TODO: relax this behavior and encode with a proper trait.
|
||||
if (llvm::any_of(partialResult, [](Operation *op) { return op; }) &&
|
||||
llvm::any_of(partialResult, [](Operation *op) { return !op; })) {
|
||||
auto loc = this->getOperation()->getLoc();
|
||||
auto diag = mlir::emitError(loc, "unexpected application of ")
|
||||
<< OpTy::getOperationName()
|
||||
<< " produces both null and non null results.";
|
||||
diag.attachNote(specificOp->getLoc()) << "when applied to this op";
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
return res;
|
||||
});
|
||||
// Propagate the failure (definite or silencable) if any.
|
||||
if (!result.succeeded())
|
||||
|
||||
// Step 3. Propagate the definite failure if any and bail out.
|
||||
if (result.isDefiniteFailure())
|
||||
return result;
|
||||
|
||||
// Legitimately no results, bail early.
|
||||
if (results.empty() && OpTy::template hasTrait<OpTrait::ZeroResults>())
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
// Step 4. If there are no results, return early.
|
||||
if (OpTy::template hasTrait<OpTrait::ZeroResults>())
|
||||
return result;
|
||||
|
||||
// Ensure all applications return the same number of results.
|
||||
// Variadic cases are much trickier to handle in a generic fashion.
|
||||
int64_t nRes = results.empty() ? 0 : results[0].size();
|
||||
if (llvm::any_of(results, [&](const auto &r) {
|
||||
return static_cast<int64_t>(r.size()) != nRes;
|
||||
})) {
|
||||
return static_cast<OpTy *>(this)->emitSilenceableError()
|
||||
<< "expected all applications of " << OpTy::getOperationName()
|
||||
<< " to produce " << nRes
|
||||
<< " results.\n If you need variadic results, consider using a "
|
||||
"generic `apply` instead of the specialized `applyToOne`";
|
||||
}
|
||||
// Ensure the number of results agrees with what the transform op expects.
|
||||
// Unless we see empty results, in which case we just want to propagate the
|
||||
// emptiness.
|
||||
if (this->getOperation()->getNumResults() != nRes) {
|
||||
InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
|
||||
<< "unexpected number of results (got " << nRes
|
||||
<< " expected "
|
||||
<< this->getOperation()->getNumResults() << ")";
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
|
||||
// Perform transposition of M applications producing N results each into N
|
||||
// results for each of the M applications.
|
||||
// Step 5. Perform transposition of M applications producing N results each
|
||||
// into N results for each of the M applications.
|
||||
SmallVector<SmallVector<Operation *, 1>> transposedResults =
|
||||
detail::transposeResults(results);
|
||||
// Single result applies to M ops produces one single M-result.
|
||||
|
||||
// Step 6. Single result applies to M ops produces one single M-result.
|
||||
if (OpTy::template hasTrait<OpTrait::OneResult>()) {
|
||||
assert(transposedResults.size() == 1 && "Expected single result");
|
||||
transformResults.set(
|
||||
this->getOperation()->getResult(0).template cast<OpResult>(),
|
||||
transposedResults[0]);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
// ApplyToOne may have returned silenceableFailure, propagate it.
|
||||
return result;
|
||||
}
|
||||
// M ops, N results each.
|
||||
|
||||
// Step 7. Filter out empty results and set the transformResults.
|
||||
for (const auto &it :
|
||||
llvm::zip(this->getOperation()->getResults(), transposedResults)) {
|
||||
transformResults.set(std::get<0>(it).template cast<OpResult>(),
|
||||
std::get<1>(it));
|
||||
SmallVector<Operation *, 1> filtered;
|
||||
llvm::copy_if(std::get<1>(it), std::back_inserter(filtered),
|
||||
[](Operation *op) { return op; });
|
||||
transformResults.set(std::get<0>(it).template cast<OpResult>(), filtered);
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
|
||||
// Step 8. ApplyToOne may have returned silenceableFailure, propagate it.
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
|
|
|
@ -62,11 +62,21 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
|
|||
return diag;
|
||||
}
|
||||
|
||||
/// Creates the silencable failure object with a diagnostic located at the
|
||||
/// Creates the silenceable failure object with a diagnostic located at the
|
||||
/// current operation.
|
||||
DiagnosedSilenceableFailure emitSilenceableError() {
|
||||
Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
|
||||
return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
|
||||
/// Creates the default silenceable failure for a transform op that failed
|
||||
/// to properly apply to a target.
|
||||
DiagnosedSilenceableFailure emitDefaultSilenceableFailure(
|
||||
Operation *target) {
|
||||
Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
|
||||
diag << $_op->getName() << " failed to apply";
|
||||
diag.attachNote(target->getLoc()) << "when applied to this op";
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -76,19 +76,24 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
|
|||
// DecomposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
FailureOr<LinalgOp> windowed =
|
||||
tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
|
||||
if (succeeded(windowed))
|
||||
return windowed;
|
||||
|
||||
if (succeeded(windowed)) {
|
||||
results.push_back(*windowed);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
FailureOr<LinalgOp> depthwise =
|
||||
tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
|
||||
if (succeeded(depthwise))
|
||||
return depthwise;
|
||||
|
||||
return reportUnknownTransformError(target);
|
||||
if (succeeded(depthwise)) {
|
||||
results.push_back(*depthwise);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
results.assign(1, nullptr);
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -221,41 +226,46 @@ LogicalResult transform::FuseOp::verify() {
|
|||
// GeneralizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
// Exit early if no transformation is needed.
|
||||
if (isa<GenericOp>(target))
|
||||
return target;
|
||||
|
||||
if (isa<GenericOp>(target)) {
|
||||
results.push_back(target);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
|
||||
if (succeeded(generic))
|
||||
return generic;
|
||||
|
||||
return reportUnknownTransformError(target);
|
||||
if (succeeded(generic)) {
|
||||
results.push_back(generic->getOperation());
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
results.assign(1, nullptr);
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InterchangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<LinalgOp>
|
||||
transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
SmallVector<unsigned> interchangeVector =
|
||||
extractUIntArray(getIteratorInterchange());
|
||||
// Exit early if no transformation is needed.
|
||||
if (interchangeVector.empty())
|
||||
return target;
|
||||
|
||||
auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
|
||||
if (!genericTarget) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "applies to " << GenericOp::getOperationName()
|
||||
<< " ops";
|
||||
diag.attachNote(target.getLoc()) << "attempted to apply to this op";
|
||||
return diag;
|
||||
if (interchangeVector.empty()) {
|
||||
results.push_back(target);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
|
||||
SimpleRewriter rewriter(target->getContext());
|
||||
FailureOr<GenericOp> res =
|
||||
interchangeGenericOp(rewriter, target, interchangeVector);
|
||||
if (failed(res))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
results.push_back(res->getOperation());
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
LogicalResult transform::InterchangeOp::verify() {
|
||||
|
@ -275,8 +285,10 @@ LogicalResult transform::InterchangeOp::verify() {
|
|||
// PadOp
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::PadOp::applyToOne(linalg::LinalgOp target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
// Convert the integer packing flags to booleans.
|
||||
SmallVector<bool> packPaddings;
|
||||
for (int64_t packPadding : extractI64Array(getPackPaddings()))
|
||||
|
@ -293,21 +305,19 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
|
|||
paddingValues.push_back(
|
||||
parseAttribute(attr.cast<StringAttr>(), elementType));
|
||||
if (!paddingValues.back()) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "expects a padding value that parses to "
|
||||
<< elementType << ", got " << std::get<0>(it);
|
||||
auto diag = this->emitOpError("expects a padding that parses to ")
|
||||
<< elementType << ", got " << std::get<0>(it);
|
||||
diag.attachNote(target.getLoc()) << "when applied to this op";
|
||||
return diag;
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Otherwise, add the attribute directly.
|
||||
if (attr.getType() != elementType) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "expects a padding value of type "
|
||||
<< elementType << ", got " << attr;
|
||||
auto diag = this->emitOpError("expects a padding value of type ")
|
||||
<< elementType << ", got " << attr;
|
||||
diag.attachNote(target.getLoc()) << "when applied to this op";
|
||||
return diag;
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
paddingValues.push_back(attr);
|
||||
}
|
||||
|
@ -327,13 +337,13 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
|
|||
|
||||
FailureOr<LinalgOp> result =
|
||||
tryApply<LinalgPaddingPattern>(target, paddingOptions);
|
||||
if (succeeded(result))
|
||||
return result;
|
||||
if (succeeded(result)) {
|
||||
results.push_back(result->getOperation());
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
InFlightDiagnostic diag = emitError()
|
||||
<< "failed to apply pattern to target op";
|
||||
diag.attachNote(target.getLoc()) << "target op";
|
||||
return diag;
|
||||
results.assign(1, nullptr);
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
LogicalResult transform::PadOp::verify() {
|
||||
|
@ -381,8 +391,10 @@ LogicalResult transform::PadOp::verify() {
|
|||
// ScalarizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
LinalgTilingOptions tilingOptions;
|
||||
tilingOptions.scalarizeDynamicDims();
|
||||
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
|
||||
|
@ -394,9 +406,10 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
|
|||
FailureOr<TiledLinalgOp> result =
|
||||
pattern.returningMatchAndRewrite(target, rewriter);
|
||||
if (failed(result))
|
||||
return failure();
|
||||
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
|
||||
|
||||
return result->op;
|
||||
results.push_back(result->op);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -558,9 +571,10 @@ LogicalResult SplitOp::verify() {
|
|||
// SplitReductionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<SmallVector<Operation *>>
|
||||
transform::SplitReductionOp::applyToOne(LinalgOp target,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
ControlSplitReductionFn splitFn = [&](LinalgOp) {
|
||||
return std::pair<int64_t, unsigned>(getSplitFactor(),
|
||||
getInsertSplitDimension());
|
||||
|
@ -572,10 +586,13 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
|
|||
? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
|
||||
: splitReduction(rewriter, target, splitFn, getUseAlloc());
|
||||
if (failed(splitResult))
|
||||
return getOperation()->emitError("failed to apply");
|
||||
return SmallVector<Operation *>{splitResult->fillOp,
|
||||
splitResult->splitLinalgOp,
|
||||
splitResult->resultCombiningLinalgOp};
|
||||
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
|
||||
|
||||
results.push_back(splitResult->initOrAlloc);
|
||||
results.push_back(splitResult->fillOp);
|
||||
results.push_back(splitResult->splitLinalgOp);
|
||||
results.push_back(splitResult->resultCombiningLinalgOp);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -618,13 +635,14 @@ void TileOp::print(OpAsmPrinter &p) {
|
|||
// VectorizeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::VectorizeOp::applyToOne(Operation *target,
|
||||
SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "applies only to isolated-from-above targets";
|
||||
auto diag = this->emitOpError("requires isolated-from-above targets");
|
||||
diag.attachNote(target->getLoc()) << "non-isolated target";
|
||||
return diag;
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
|
||||
MLIRContext *ctx = getContext();
|
||||
|
@ -642,8 +660,10 @@ FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
|
|||
linalg::populatePadOpVectorizationPatterns(patterns);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
|
||||
return reportUnknownTransformError(target);
|
||||
return target;
|
||||
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
|
||||
|
||||
results.push_back(target);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -127,18 +127,21 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
|
|||
// LoopPeelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop,
|
||||
TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::LoopPeelOp::applyToOne(scf::ForOp target,
|
||||
SmallVector<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
scf::ForOp result;
|
||||
IRRewriter rewriter(loop->getContext());
|
||||
IRRewriter rewriter(target->getContext());
|
||||
// This helper returns failure when peeling does not occur (i.e. when the IR
|
||||
// is not modified). This is not a failure for the op as the postcondition:
|
||||
// "the loop trip count is divisible by the step"
|
||||
// is valid.
|
||||
LogicalResult status =
|
||||
scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
|
||||
if (failed(status)) {
|
||||
if (getFailIfAlreadyDivisible())
|
||||
return reportUnknownTransformError(loop);
|
||||
return loop;
|
||||
}
|
||||
return result;
|
||||
scf::peelAndCanonicalizeForLoop(rewriter, target, result);
|
||||
// TODO: Return both the peeled loop and the remainder loop.
|
||||
results.push_back(failed(status) ? target : result);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -181,8 +184,10 @@ loopScheduling(scf::ForOp forOp,
|
|||
}
|
||||
}
|
||||
|
||||
FailureOr<scf::ForOp>
|
||||
transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
|
||||
DiagnosedSilenceableFailure
|
||||
transform::LoopPipelineOp::applyToOne(scf::ForOp target,
|
||||
SmallVector<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
scf::PipeliningOption options;
|
||||
options.getScheduleFn =
|
||||
[this](scf::ForOp forOp,
|
||||
|
@ -190,26 +195,33 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
|
|||
loopScheduling(forOp, schedule, getIterationInterval(),
|
||||
getReadLatency());
|
||||
};
|
||||
|
||||
scf::ForLoopPipeliningPattern pattern(options, loop->getContext());
|
||||
scf::ForLoopPipeliningPattern pattern(options, target->getContext());
|
||||
SimpleRewriter rewriter(getContext());
|
||||
rewriter.setInsertionPoint(loop);
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<scf::ForOp> patternResult =
|
||||
pattern.returningMatchAndRewrite(loop, rewriter);
|
||||
if (failed(patternResult))
|
||||
return reportUnknownTransformError(loop);
|
||||
return patternResult;
|
||||
pattern.returningMatchAndRewrite(target, rewriter);
|
||||
if (succeeded(patternResult)) {
|
||||
results.push_back(*patternResult);
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
results.assign(1, nullptr);
|
||||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LoopUnrollOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop,
|
||||
TransformState &state) {
|
||||
if (failed(loopUnrollByFactor(loop, getFactor())))
|
||||
return reportUnknownTransformError(loop);
|
||||
return success();
|
||||
DiagnosedSilenceableFailure
|
||||
transform::LoopUnrollOp::applyToOne(scf::ForOp target,
|
||||
SmallVector<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
if (failed(loopUnrollByFactor(target, getFactor()))) {
|
||||
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
|
||||
diag << "op failed to unroll";
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
return DiagnosedSilenceableFailure(success());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -37,7 +37,7 @@ transform.with_pdl_patterns {
|
|||
// -----
|
||||
|
||||
func.func @interchange_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// expected-note @below {{attempted to apply to this op}}
|
||||
// expected-note @below {{when applied to this op}}
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @match_generic in %arg1
|
||||
// expected-error @below {{applies to linalg.generic ops}}
|
||||
// expected-error @below {{transform applied to the wrong op kind}}
|
||||
transform.structured.interchange %0 { iterator_interchange = [1, 0]}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -99,7 +99,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
// expected-error @below {{expects a padding value that parses to 'f32', got "foo"}}
|
||||
// expected-error @below {{expects a padding that parses to 'f32', got "foo"}}
|
||||
%1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
|
||||
}
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ transform.with_pdl_patterns {
|
|||
func.func @pad(%arg0: tensor<24x12xf32>,
|
||||
%arg1: tensor<12x25xf32>,
|
||||
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
|
||||
// expected-note @below {{target op}}
|
||||
// expected-note @below {{when applied to this op}}
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
|
||||
func.return %0 : tensor<24x25xf32>
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
// expected-error @below {{failed to apply pattern to target op}}
|
||||
// expected-error @below {{transform.structured.pad failed to apply}}
|
||||
%1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
%1:3 = transform.structured.split_reduction %0
|
||||
%1:4 = transform.structured.split_reduction %0
|
||||
{ split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,6 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
%1:3 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
|
||||
%1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -176,7 +176,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
// expected-error @below {{applies only to isolated-from-above targets}}
|
||||
// expected-error @below {{op requires isolated-from-above targets}}
|
||||
%2 = transform.structured.vectorize %0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -380,16 +380,8 @@ transform.with_pdl_patterns {
|
|||
}
|
||||
// -----
|
||||
|
||||
transform.sequence {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
// expected-error @below {{unexpected number of results (got 0 expected 3)}}
|
||||
transform.test_wrong_number_of_results %arg0
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @foo() {
|
||||
"op" () : () -> ()
|
||||
// expected-note @below {{when applied to this op}}
|
||||
"op" () : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -406,7 +398,37 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @some in %arg1
|
||||
// expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}}
|
||||
// expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}}
|
||||
// expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
|
||||
// expected-note @below {{Producing 3 null results is allowed if the use case warrants it.}}
|
||||
transform.test_wrong_number_of_results %0
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @foo() {
|
||||
"op" () : () -> ()
|
||||
// expected-note @below {{when applied to this op}}
|
||||
"op" () : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @some : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @some in %arg1
|
||||
// expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}}
|
||||
// expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
|
||||
// expected-note @below {{Producing 1 null results is allowed if the use case warrants it.}}
|
||||
transform.test_wrong_number_of_multi_results %0
|
||||
}
|
||||
}
|
||||
|
@ -463,6 +485,31 @@ transform.with_pdl_patterns {
|
|||
|
||||
// -----
|
||||
|
||||
func.func @foo() {
|
||||
// expected-note @below {{when applied to this op}}
|
||||
"op" () : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @some : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @some in %arg1
|
||||
// expected-error @below {{unexpected application of transform.test_mixed_null_and_non_null_results produces both null and non null results.}}
|
||||
transform.test_mixed_null_and_non_null_results %0
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Expecting to match all operations by merging the handles that matched addi
|
||||
// and subi separately.
|
||||
func.func @foo(%arg0: index) {
|
||||
|
@ -498,4 +545,3 @@ transform.with_pdl_patterns {
|
|||
test_print_remark_at_operand %2, "matched"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -226,28 +226,44 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
|
|||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Operation *>>
|
||||
mlir::test::TestWrongNumberOfResultsOp::applyToOne(
|
||||
Operation *, transform::TransformState &state) {
|
||||
return SmallVector<Operation *>{};
|
||||
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
|
||||
Operation *target, SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
OperationState opState(target->getLoc(), "foo");
|
||||
results.push_back(OpBuilder(target).create(opState));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Operation *>>
|
||||
DiagnosedSilenceableFailure
|
||||
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
|
||||
Operation *op, transform::TransformState &state) {
|
||||
Operation *target, SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
static int count = 0;
|
||||
if (count++ > 0)
|
||||
return SmallVector<Operation *>{};
|
||||
OperationState opState(op->getLoc(), "foo");
|
||||
return SmallVector<Operation *>{OpBuilder(op).create(opState)};
|
||||
if (count++ == 0) {
|
||||
OperationState opState(target->getLoc(), "foo");
|
||||
results.push_back(OpBuilder(target).create(opState));
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Operation *>>
|
||||
DiagnosedSilenceableFailure
|
||||
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
|
||||
Operation *op, transform::TransformState &state) {
|
||||
OperationState opState(op->getLoc(), "foo");
|
||||
return SmallVector<Operation *>{OpBuilder(op).create(opState),
|
||||
OpBuilder(op).create(opState)};
|
||||
Operation *target, SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
OperationState opState(target->getLoc(), "foo");
|
||||
results.push_back(OpBuilder(target).create(opState));
|
||||
results.push_back(OpBuilder(target).create(opState));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
|
||||
Operation *target, SmallVectorImpl<Operation *> &results,
|
||||
transform::TransformState &state) {
|
||||
OperationState opState(target->getLoc(), "foo");
|
||||
results.push_back(nullptr);
|
||||
results.push_back(OpBuilder(target).create(opState));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -139,8 +139,10 @@ def TestWrongNumberOfResultsOp
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
let cppNamespace = "::mlir::test";
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
|
||||
::mlir::Operation *target, transform::TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::Operation * target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -153,8 +155,10 @@ def TestWrongNumberOfMultiResultsOp
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
let cppNamespace = "::mlir::test";
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
|
||||
::mlir::Operation *target, transform::TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::Operation * target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -168,8 +172,27 @@ def TestCorrectNumberOfMultiResultsOp
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
let cppNamespace = "::mlir::test";
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
|
||||
::mlir::Operation *target, transform::TransformState &state);
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::Operation * target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
def TestMixedNullAndNonNullResultsOp
|
||||
: Op<Transform_Dialect, "test_mixed_null_and_non_null_results",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
TransformEachOpTrait, TransformOpInterface]> {
|
||||
let arguments = (ins PDL_Operation:$target);
|
||||
let results = (outs PDL_Operation:$null,
|
||||
PDL_Operation:$non_null);
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
let cppNamespace = "::mlir::test";
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::Operation * target,
|
||||
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue