[mlir][Transform] Make applyToOne return a DiagnosedSilenceableFailure

This revision revisits the implementation of applyToOne and its handling
of recoverable errors as well as propagation of null handles.
The implementation is simplified to always require passing a vector<Operation*>
in which the results are returned, resulting in less template instantiation magic.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D129185
This commit is contained in:
Nicolas Vasilache 2022-07-07 07:08:22 -07:00
parent 7d1a295484
commit 5230710933
16 changed files with 624 additions and 334 deletions

View File

@ -15,6 +15,7 @@
namespace mlir {
namespace linalg {
class GenericOp;
class LinalgOp;
} // namespace linalg
} // namespace mlir

View File

@ -22,9 +22,15 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
let description = [{
Decomposes named complex operations, such as higher-dimensional
(depthwise) convolutions, into combinations of lower-dimensional equivalents
when possible. The operand handle must point to a list of such operations.
The returning handle points to the main produced computational operation,
such as the lower-dimensional convolution.
when possible.
Return modes:
=============
This operation ignores non-Linalg ops and drops them in the return.
If all the operations referred to by the `target` PDLOperation decompose
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
computational operations, which can be empty.
}];
let arguments = (ins PDL_Operation:$target);
@ -32,8 +38,10 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -61,11 +69,16 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
TransformOpInterface, TransformEachOpTrait]> {
let description = [{
Transforms a named structued operation into the generic form with the
explicit attached region. The operand handle must point to a list of
structured operations, it is consumed by the transformation and is not
expected to be used afterwards. The resulting handle points to the list
of equivalent generic operations, in the same order as the original named
operations.
explicit attached region.
Return modes:
=============
This operation ignores non-Linalg ops and drops them in the return.
If all the operations referred to by the `target` PDLOperation generalize
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
equivalent generic operations, which can be empty or contain the original
ops if they were already in generic form.
}];
let arguments = (ins PDL_Operation:$target);
@ -73,8 +86,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -84,6 +99,16 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let description = [{
Interchanges the iterators of the operations pointed to by the target handle
using the iterator interchange attribute.
Return modes:
=============
This operation ignores non-linalg::Generic ops and drops them in the return.
This operation fails if the interchange attribute is invalid.
If all the operations referred to by the `target` PDLOperation interchange
properly, the transform succeeds.
If any interchange fails, the transform definitely fails.
The return handle points to only the subset of successfully produced
interchanged operations, which can be empty.
}];
let arguments =
@ -95,8 +120,10 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let hasVerifier = 1;
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::GenericOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -106,6 +133,16 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
let description = [{
Pads the operations pointed to by the target handle using the options
provides as operation attributes.
Return modes:
=============
This operation ignores non-Linalg ops and drops them in the return.
This operation may produce a definiteFailure if the padding fails for any
reason.
If all the operations referred to by the `target` PDLOperation pad
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
padded operations, which can be empty.
}];
let arguments =
@ -123,8 +160,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
let hasVerifier = 1;
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -135,11 +174,23 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
Indicates that ops of a specific kind in the given function should be
scalarized (i.e. their dynamic dimensions tiled by 1).
This operation returns the tiled op but not the loops.
Return modes:
=============
This operation ignores non-Linalg ops and drops them in the return.
This operation produces `definiteFailure` if the scalarization fails for any
reason.
If all the operations referred to by the `target` PDLOperation scalarize
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
tiled-by-1 operations, which can be empty.
This operation does not return handles to the tiled loop.
We make this design choice because it is hard to know ahead of time the
number of loops that will be produced (it depends on the number of dynamic
dimensions after multiple transformations have been applied).
Loops can always be recovered by navigating from the tiled operations if
needed.
}];
let arguments = (ins PDL_Operation:$target);
@ -148,8 +199,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -206,7 +259,17 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
- use_alloc: whether to use an alloc op to allocate the temporary
tensor (default: do not use alloc op)
This op returns 4 handles to:
Return modes:
=============
This operation ignores non-Linalg ops and drops them in the return.
This operation produces `definiteFailure` if the splitting fails for any
reason.
If all the operations referred to by the `target` PDLOperation split
properly, the transform succeeds. Otherwise the transform silently fails.
The 4 returned handles points to only the subset of successfully produced
computational operations, which can all be empty.
This 4 returned handles point to:
- the init op (or tensor_alloc op if use_alloc = true),
- the fill op used to initialize the neutral element,
- the split op and
@ -316,15 +379,18 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
UnitAttr:$use_scaling_algorithm,
UnitAttr:$use_alloc);
let results = (outs PDL_Operation:$fill_op,
let results = (outs PDL_Operation:$init_or_alloc_op,
PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::linalg::LinalgOp target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -372,6 +438,13 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
Note that this transformation is invalidating the handles to any payload IR
operation that is contained inside the vectorization target.
Return modes:
=============
This operation produces `definiteFailure` if vectorization fails for any
reason.
The operation always returns the handle to the target op that is expected
to be isolated from above.
}];
let arguments = (ins PDL_Operation:$target,
@ -381,8 +454,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<Operation *> applyToOne(
::mlir::Operation *target, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}

View File

@ -239,6 +239,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
///
/// Return failure if the permutation is not valid.
FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);

View File

@ -50,7 +50,7 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
outlined into a separate function. The provided name is used as a _base_
for forming actual function names following SymbolTable auto-renaming
scheme to avoid duplicate symbols. Expects that all ops in the Payload IR
have a SymbolTable ancestor (typically true because of the top-level
have a SymbolTable ancestor (typically true because of the top-level
module). Returns the handle to the list of outlined functions in the same
order as the operand handle.
}];
@ -68,28 +68,40 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
let summary = "Peels the last iteration of the loop";
let description = [{
Updates the given loop so that its step evenly divides its range and puts
the remaining iteration into a separate loop or a conditional. Note that
even though the Payload IR modification may be performed in-place, this
operation consumes the operand handle and produces a new one. Applies to
each loop associated with the operand handle individually. The results
follow the same order as the operand.
the remaining iteration into a separate loop or a conditional.
Note: If it can be proven statically that the step already evenly divides
the range, this op is a no-op. In the absence of sufficient static
information, this op may peel a loop, even if the step always divides the
range evenly at runtime.
In the absence of sufficient static information, this op may peel a loop,
even if the step always divides the range evenly at runtime.
Return modes:
=============
This operation ignores non-scf::ForOp ops and drops them in the return.
This operation always succeeds and returns the scf::ForOp with the
postcondition: "the loop trip count is divisible by the step".
This operation may return the same unmodified loop handle when peeling did
not modify the IR (i.e. the loop trip count was already divisible).
Note that even though the Payload IR modification may be performed
in-place, this operation consumes the operand handle and produces a new
one.
TODO: Return both the peeled loop and the remainder loop.
}];
let arguments =
(ins PDL_Operation:$target,
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
// TODO: Return both the peeled loop and the remainder loop.
let results = (outs PDL_Operation:$transformed);
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
::mlir::scf::ForOp loop, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::scf::ForOp target,
::llvm::SmallVector<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -102,10 +114,21 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
each of them. That is, performs some amount of reads from memory before the
loop rather than inside the loop, the same amount of writes into memory
after the loop, and updates each iteration to read the data for a following
iteration rather than the current one. The amount is specified by the
attributes. The values read and about to be stored are transferred as loop
iteration arguments. Currently supports memref and vector transfer
operations as memory reads/writes.
iteration rather than the current one.
The amount is specified by the attributes.
The values read and about to be stored are transferred as loop iteration
arguments. Currently supports memref and vector transfer operations as
memory reads/writes.
Return modes:
=============
This operation ignores non-scf::For ops and drops them in the return.
If all the operations referred to by the `target` PDLOperation pipeline
properly, the transform succeeds. Otherwise the transform silently fails.
The return handle points to only the subset of successfully produced
pipelined loops, which can be empty.
}];
let arguments = (ins PDL_Operation:$target,
@ -116,8 +139,10 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
::mlir::scf::ForOp loop, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::scf::ForOp target,
::llvm::SmallVector<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -126,11 +151,18 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
TransformOpInterface, TransformEachOpTrait]> {
let summary = "Unrolls the given loop with the given unroll factor";
let description = [{
Unrolls each loop associated with the given handle to have up to the given
number of loop body copies per iteration. If the unroll factor is larger
than the loop trip count, the latter is used as the unroll factor instead.
Does not produce a new handle as the operation may result in the loop being
removed after a full unrolling.
Unrolls each loop associated with the given handle to have up to the given
number of loop body copies per iteration. If the unroll factor is larger
than the loop trip count, the latter is used as the unroll factor instead.
Return modes:
==============
This operation ignores non-scf::For ops and drops them in the return.
If all the operations referred to by the `target` PDLOperation unroll
properly, the transform succeeds. Otherwise the transform silently fails.
Does not return handles as the operation may result in the loop being
removed after a full unrolling.
}];
let arguments = (ins PDL_Operation:$target,
@ -139,8 +171,10 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::LogicalResult applyToOne(
::mlir::scf::ForOp loop, TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::scf::ForOp target,
::llvm::SmallVector<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}

View File

@ -12,13 +12,14 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/ScopeExit.h"
namespace mlir {
/// The result of a transform IR operation application. This can have one of the
/// three states:
/// - success;
/// - silencable (recoverable) failure with yet-unreported diagnostic;
/// - silenceable (recoverable) failure with yet-unreported diagnostic;
/// - definite failure.
/// Silenceable failure is intended to communicate information about
/// transformations that did not apply but in a way that supports recovery,
@ -26,10 +27,10 @@ namespace mlir {
/// predictable way. They are associated with a Diagnostic that provides more
/// details on the failure. Silenceable failure can be discarded, turning the
/// result into success, or "reported", emitting the diagnostic and turning the
/// result into definite failure. Transform IR operations containing other
/// operations are allowed to do either with the results of the nested
/// transformations, but must propagate definite failures as their diagnostics
/// have been already reported to the user.
/// result into definite failure.
/// Transform IR operations containing other operations are allowed to do either
/// with the results of the nested transformations, but must propagate definite
/// failures as their diagnostics have been already reported to the user.
class LLVM_NODISCARD DiagnosedSilenceableFailure {
public:
explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
@ -51,12 +52,17 @@ public:
return DiagnosedSilenceableFailure(::mlir::failure());
}
/// Constructs a DiagnosedSilenceableFailure in the silencable failure state,
/// Constructs a DiagnosedSilenceableFailure in the silenceable failure state,
/// ready to emit the given diagnostic. This is considered a failure
/// regardless of the diagnostic severity.
static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) {
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) {
return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
}
static DiagnosedSilenceableFailure
silenceableFailure(SmallVector<Diagnostic> &&diag) {
return DiagnosedSilenceableFailure(
std::forward<SmallVector<Diagnostic>>(diag));
}
/// Converts all kinds of failure into a LogicalResult failure, emitting the
/// diagnostic if necessary. Must not be called more than once.
@ -65,44 +71,72 @@ public:
assert(!reported && "attempting to report a diagnostic more than once");
reported = true;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
if (diagnostic) {
diagnostic->getLocation().getContext()->getDiagEngine().emit(
std::move(*diagnostic));
diagnostic.reset();
if (!diagnostics.empty()) {
for (auto &&diagnostic : diagnostics) {
diagnostic.getLocation().getContext()->getDiagEngine().emit(
std::move(diagnostic));
}
diagnostics.clear();
result = ::mlir::failure();
}
return result;
}
/// Returns `true` if this is a silencable failure.
bool isSilenceableFailure() const { return diagnostic.hasValue(); }
/// Returns `true` if this is a silenceable failure.
bool isDefiniteFailure() const { return result.failed(); }
/// Returns `true` if this is a silenceable failure.
bool isSilenceableFailure() const { return !diagnostics.empty(); }
/// Returns `true` if this is a success.
bool succeeded() const {
return !diagnostic.hasValue() && ::mlir::succeeded(result);
return diagnostics.empty() && ::mlir::succeeded(result);
}
/// Returns the diagnostic message without emitting it. Expects this object
/// to be a silencable failure.
std::string getMessage() const { return diagnostic->str(); }
/// to be a silenceable failure.
std::string getMessage() const {
std::string res;
for (auto &diagnostic : diagnostics) {
res.append(diagnostic.str());
res.append("\n");
}
return res;
}
/// Converts silencable failure into LogicalResult success without reporting
/// Returns a string representation of the failure mode (for error reporting).
std::string getStatusString() const {
if (succeeded())
return "success";
if (isSilenceableFailure())
return "silenceable failure";
return "definite failure";
}
/// Converts silenceable failure into LogicalResult success without reporting
/// the diagnostic, preserves the other states.
LogicalResult silence() {
if (diagnostic) {
diagnostic.reset();
if (!diagnostics.empty()) {
diagnostics.clear();
result = ::mlir::success();
}
return result;
}
/// Streams the given values into the diagnotic. Expects this object to be a
/// silencable failure.
/// Take the diagnostic and silence.
SmallVector<Diagnostic> &&takeDiagnostics() {
assert(!diagnostics.empty() && "expected a diagnostic to be present");
auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); });
return std::move(diagnostics);
}
/// Streams the given values into the last diagnotic.
/// Expects this object to be a silenceable failure.
template <typename T>
DiagnosedSilenceableFailure &operator<<(T &&value) & {
assert(isSilenceableFailure() &&
"can only append output in silencable failure state");
*diagnostic << std::forward<T>(value);
"can only append output in silenceable failure state");
diagnostics.back() << std::forward<T>(value);
return *this;
}
template <typename T>
@ -110,31 +144,36 @@ public:
return std::move(this->operator<<(std::forward<T>(value)));
}
/// Attaches a note to the diagnostic. Expects this object to be a silencable
/// failure.
/// Attaches a note to the last diagnostic.
/// Expects this object to be a silenceable failure.
Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
assert(isSilenceableFailure() &&
"can only attach notes to silencable failures");
return diagnostic->attachNote(loc);
"can only attach notes to silenceable failures");
return diagnostics.back().attachNote(loc);
}
private:
explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
: diagnostic(std::move(diagnostic)), result(failure()) {}
: diagnostics(), result(failure()) {
diagnostics.emplace_back(std::move(diagnostic));
}
explicit DiagnosedSilenceableFailure(SmallVector<Diagnostic> &&diagnostics)
: diagnostics(std::move(diagnostics)), result(failure()) {}
/// The diagnostic associated with this object. If present, the object is
/// considered to be in the silencable failure state regardless of the
/// The diagnostics associated with this object. If non-empty, the object is
/// considered to be in the silenceable failure state regardless of the
/// `result` field.
Optional<Diagnostic> diagnostic;
SmallVector<Diagnostic, 1> diagnostics;
/// The "definite" logical state, either success or failure. Ignored if the
/// diagnostic message is present.
/// The "definite" logical state, either success or failure.
/// Ignored if the diagnostics message is present.
LogicalResult result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Whther the associated diagnostic have been reported. Diagnostic reporting
/// consumes the diagnostic, so we need a mechanism to differentiate a
/// reported diagnostic from a state where it was never created.
/// Whether the associated diagnostics have been reported.
/// Diagnostics reporting consumes the diagnostics, so we need a mechanism to
/// differentiate reported diagnostics from a state where it was never
/// created.
bool reported = false;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
@ -579,24 +618,45 @@ public:
};
/// Trait implementing the TransformOpInterface for operations applying a
/// transformation to a single operation handle and producing one or multiple
/// operation handles.
/// The op must implement a method with one of the following signatures:
/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy, state)
/// - FailureOr<SmallVector<convertible-to-Operation*>>applyToOne(OpTy, state)
/// - LogicalResult applyToOne(OpTy, state)
/// transformation to a single operation handle and producing zero, one or
/// multiple operation handles.
/// The op must implement a method with the following signature:
/// - DiagnosedSilenceableFailure applyToOne(OpTy,
/// SmallVector<Operation*> &results, state)
/// to perform a transformation that is applied in turn to all payload IR
/// operations that correspond to the handle of the transform IR operation.
/// In the functions above, OpTy is either Operation * or a concrete payload IR
/// Op class that the transformation is applied to (NOT the class of the
/// transform IR op). The op is expected to have a single operand.
/// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
/// that the transformation is applied to (and NOT the class of the transform IR
/// op).
/// The `applyToOne` method takes an empty `results` vector that it fills with
/// zero, one or multiple operations depending on the number of resultd expected
/// by the transform op.
/// The number of results must match the number of results of the transform op.
/// `applyToOne` is allowed to fill the `results` with all null elements to
/// signify that the transformation did not apply to the payload IR operations.
/// Such null elements are filtered out from results before return.
///
/// The transform op having this trait is expected to have a single operand.
template <typename OpTy>
class TransformEachOpTrait
: public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
public:
/// Calls `applyToOne` for every payload operation associated with the operand
/// of this transform IR op. If `applyToOne` returns ops, associates them with
/// the result of this transform op.
/// of this transform IR op, the following case disjunction happens:
/// 1. If not target payload ops are associated to the operand then fill the
/// results vector with the expected number of null elements and return
/// success. This is the corner case handling that allows propagating
/// the "no-op" case gracefully to improve usability.
/// 2. If any `applyToOne` returns definiteFailure, the transformation is
/// immediately considered definitely failed and we return.
/// 3. All applications of `applyToOne` are checked to return a number of
/// results expected by the transform IR op. If not, this is a definite
/// failure and we return early.
/// 4. If `applyToOne` produces ops, associate them with the result of this
/// transform op.
/// 5. If any `applyToOne` return silenceableFailure, the transformation is
/// considered silenceable.
/// 6. Otherwise the transformation is considered successful.
DiagnosedSilenceableFailure apply(TransformResults &transformResults,
TransformState &state);
@ -714,88 +774,58 @@ public:
namespace mlir {
namespace transform {
namespace detail {
/// Appends `result` to the vector assuming it corresponds to the success state
/// in `FailureOr<convertible-to-Operation*>`. If `result` is just a
/// `LogicalResult`, appends an empy vector.
template <typename Ty>
std::enable_if_t<std::is_same<Ty, LogicalResult>::value, LogicalResult>
appendTransformResultToVector(
Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
results.push_back(SmallVector<Operation *>());
return result;
}
template <typename Ty>
std::enable_if_t<
llvm::conjunction<
llvm::negation<std::is_same<Ty, LogicalResult>>,
std::is_convertible<typename Ty::value_type, Operation *>>::value,
LogicalResult>
appendTransformResultToVector(
Ty result, SmallVectorImpl<SmallVector<Operation *>> &results) {
if (failed(result))
return failure();
results.push_back(SmallVector<Operation *>{*result});
return success();
}
template <typename ContainerTy>
std::enable_if_t<
llvm::conjunction<
llvm::negation<std::is_same<ContainerTy, LogicalResult>>,
llvm::negation<std::is_convertible<typename ContainerTy::value_type,
Operation *>>>::value,
LogicalResult>
appendTransformResultToVector(
ContainerTy resultContainer,
SmallVectorImpl<SmallVector<Operation *>> &results) {
if (failed(resultContainer))
return failure();
results.push_back(*resultContainer);
return success();
}
/// Applies a one-to-one or a one-to-many transform to each of the given
/// targets. Puts the results of transforms, if any, in `results` in the same
/// order. Fails if any of the application fails. Individual transforms must be
/// callable with one of the following signatures:
/// - FailureOr<convertible-to-Operation*>(OpTy)
/// - LogicalResult(OpTy)
/// - FailureOr<SmallVectorImpl<convertible-to-Operation*>>(
/// SmallVectorImpl<OpTy>)
/// - LogicalResult(SmallVectorImpl<OpTy>)
/// callable with the following signature:
/// - DiagnosedSilenceableFailure(OpTy,
/// SmallVector<Operation*> &results, state)
/// where OpTy is either
/// - Operation *, in which case the transform is always applied;
/// - a concrete Op class, in which case a check is performed whether
/// `targets` contains operations of the same class and a silencable failure
/// `targets` contains operations of the same class and a silenceable failure
/// is reported if it does not.
template <typename FnTy>
DiagnosedSilenceableFailure
applyTransformToEach(ArrayRef<Operation *> targets,
SmallVectorImpl<SmallVector<Operation *>> &results,
FnTy transform) {
DiagnosedSilenceableFailure applyTransformToEach(
Location loc, int expectedNumResults, ArrayRef<Operation *> targets,
SmallVectorImpl<SmallVector<Operation *>> &results, FnTy transform) {
SmallVector<Diagnostic> silenceableStack;
using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
static_assert(std::is_convertible<OpTy, Operation *>::value,
"expected transform function to take an operation");
using RetTy = typename llvm::function_traits<FnTy>::result_t;
static_assert(std::is_convertible<RetTy, LogicalResult>::value,
"expected transform function to return LogicalResult or "
"FailureOr<convertible-to-Operation*>");
for (Operation *target : targets) {
// Emplace back a placeholder for the returned new ops.
// This is filled with `expectedNumResults` if the op fails to apply.
results.push_back(SmallVector<Operation *>());
auto specificOp = dyn_cast<OpTy>(target);
if (!specificOp) {
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
diag << "attempted to apply transform to the wrong op kind";
return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
Diagnostic diag(loc, DiagnosticSeverity::Error);
diag << "transform applied to the wrong op kind";
diag.attachNote(target->getLoc()) << "when applied to this op";
// Producing `expectedNumResults` nullptr is a silenceableFailure mode.
// TODO: encode this implicit `expectedNumResults` nullptr ==
// silenceableFailure with a proper trait.
results.back().assign(expectedNumResults, nullptr);
silenceableStack.push_back(std::move(diag));
continue;
}
auto result = transform(specificOp);
if (failed(appendTransformResultToVector(result, results)))
return DiagnosedSilenceableFailure::definiteFailure();
DiagnosedSilenceableFailure result = transform(specificOp, results.back());
if (result.isDefiniteFailure())
return result;
if (result.isSilenceableFailure())
for (auto &&diag : result.takeDiagnostics())
silenceableStack.push_back(std::move(diag));
}
if (!silenceableStack.empty()) {
return DiagnosedSilenceableFailure::silenceableFailure(
std::move(silenceableStack));
}
return DiagnosedSilenceableFailure::success();
}
/// Helper function to transform M ops with N results into N results of M ops.
/// Helper function: transpose MxN into NxM; assumes that the input is a valid.
static inline SmallVector<SmallVector<Operation *, 1>>
transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
SmallVector<SmallVector<Operation *, 1>> res;
@ -824,74 +854,95 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
decltype(&OpTy::applyToOne)>::template arg_t<0>;
ArrayRef<Operation *> targets =
state.getPayloadOps(this->getOperation()->getOperand(0));
// Handle the corner case where no target is specified.
// Step 1. Handle the corner case where no target is specified.
// This is typically the case when the matcher fails to apply and we need to
// propagate gracefully.
// In this case, we fill all results with an empty vector.
if (targets.empty()) {
SmallVector<Operation *> emptyResult;
SmallVector<Operation *> empty;
for (auto r : this->getOperation()->getResults())
transformResults.set(r.template cast<OpResult>(), emptyResult);
transformResults.set(r.template cast<OpResult>(), empty);
return DiagnosedSilenceableFailure::success();
}
// Step 2. Call applyToOne on each target and record newly produced ops in its
// corresponding results entry.
int expectedNumResults = this->getOperation()->getNumResults();
SmallVector<SmallVector<Operation *>, 1> results;
// In the multi-result case, collect the number of results each transform
// produced.
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
targets, results, [&](TransformOpType specificOp) {
return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
this->getOperation()->getLoc(), expectedNumResults, targets, results,
[&](TransformOpType specificOp, SmallVector<Operation *> &partialResult) {
auto res = static_cast<OpTy *>(this)->applyToOne(specificOp,
partialResult, state);
if (res.isDefiniteFailure())
return res;
// TODO: encode this implicit must always produce `expectedNumResults`
// and nullptr is fine with a proper trait.
if (static_cast<int>(partialResult.size()) != expectedNumResults) {
auto loc = this->getOperation()->getLoc();
auto diag = mlir::emitError(loc, "applications of ")
<< OpTy::getOperationName() << " expected to produce "
<< expectedNumResults << " results (actually produced "
<< partialResult.size() << ").";
diag.attachNote(loc)
<< "If you need variadic results, consider a generic `apply` "
<< "instead of the specialized `applyToOne`.";
diag.attachNote(loc)
<< "Producing " << expectedNumResults << " null results is "
<< "allowed if the use case warrants it.";
diag.attachNote(specificOp->getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
// Check that all is null or none is null
// TODO: relax this behavior and encode with a proper trait.
if (llvm::any_of(partialResult, [](Operation *op) { return op; }) &&
llvm::any_of(partialResult, [](Operation *op) { return !op; })) {
auto loc = this->getOperation()->getLoc();
auto diag = mlir::emitError(loc, "unexpected application of ")
<< OpTy::getOperationName()
<< " produces both null and non null results.";
diag.attachNote(specificOp->getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
return res;
});
// Propagate the failure (definite or silencable) if any.
if (!result.succeeded())
// Step 3. Propagate the definite failure if any and bail out.
if (result.isDefiniteFailure())
return result;
// Legitimately no results, bail early.
if (results.empty() && OpTy::template hasTrait<OpTrait::ZeroResults>())
return DiagnosedSilenceableFailure::success();
// Step 4. If there are no results, return early.
if (OpTy::template hasTrait<OpTrait::ZeroResults>())
return result;
// Ensure all applications return the same number of results.
// Variadic cases are much trickier to handle in a generic fashion.
int64_t nRes = results.empty() ? 0 : results[0].size();
if (llvm::any_of(results, [&](const auto &r) {
return static_cast<int64_t>(r.size()) != nRes;
})) {
return static_cast<OpTy *>(this)->emitSilenceableError()
<< "expected all applications of " << OpTy::getOperationName()
<< " to produce " << nRes
<< " results.\n If you need variadic results, consider using a "
"generic `apply` instead of the specialized `applyToOne`";
}
// Ensure the number of results agrees with what the transform op expects.
// Unless we see empty results, in which case we just want to propagate the
// emptiness.
if (this->getOperation()->getNumResults() != nRes) {
InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
<< "unexpected number of results (got " << nRes
<< " expected "
<< this->getOperation()->getNumResults() << ")";
return DiagnosedSilenceableFailure::definiteFailure();
}
// Perform transposition of M applications producing N results each into N
// results for each of the M applications.
// Step 5. Perform transposition of M applications producing N results each
// into N results for each of the M applications.
SmallVector<SmallVector<Operation *, 1>> transposedResults =
detail::transposeResults(results);
// Single result applies to M ops produces one single M-result.
// Step 6. Single result applies to M ops produces one single M-result.
if (OpTy::template hasTrait<OpTrait::OneResult>()) {
assert(transposedResults.size() == 1 && "Expected single result");
transformResults.set(
this->getOperation()->getResult(0).template cast<OpResult>(),
transposedResults[0]);
return DiagnosedSilenceableFailure::success();
// ApplyToOne may have returned silenceableFailure, propagate it.
return result;
}
// M ops, N results each.
// Step 7. Filter out empty results and set the transformResults.
for (const auto &it :
llvm::zip(this->getOperation()->getResults(), transposedResults)) {
transformResults.set(std::get<0>(it).template cast<OpResult>(),
std::get<1>(it));
SmallVector<Operation *, 1> filtered;
llvm::copy_if(std::get<1>(it), std::back_inserter(filtered),
[](Operation *op) { return op; });
transformResults.set(std::get<0>(it).template cast<OpResult>(), filtered);
}
return DiagnosedSilenceableFailure::success();
// Step 8. ApplyToOne may have returned silenceableFailure, propagate it.
return result;
}
template <typename OpTy>

View File

@ -62,11 +62,21 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> {
return diag;
}
/// Creates the silencable failure object with a diagnostic located at the
/// Creates the silenceable failure object with a diagnostic located at the
/// current operation.
DiagnosedSilenceableFailure emitSilenceableError() {
Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
/// Creates the default silenceable failure for a transform op that failed
/// to properly apply to a target.
DiagnosedSilenceableFailure emitDefaultSilenceableFailure(
Operation *target) {
Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error);
diag << $_op->getName() << " failed to apply";
diag.attachNote(target->getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
}];
}

View File

@ -76,19 +76,24 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
// DecomposeOp
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target,
TransformState &state) {
DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
FailureOr<LinalgOp> windowed =
tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
if (succeeded(windowed))
return windowed;
if (succeeded(windowed)) {
results.push_back(*windowed);
return DiagnosedSilenceableFailure(success());
}
FailureOr<LinalgOp> depthwise =
tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
if (succeeded(depthwise))
return depthwise;
return reportUnknownTransformError(target);
if (succeeded(depthwise)) {
results.push_back(*depthwise);
return DiagnosedSilenceableFailure(success());
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
//===----------------------------------------------------------------------===//
@ -221,41 +226,46 @@ LogicalResult transform::FuseOp::verify() {
// GeneralizeOp
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target,
TransformState &state) {
DiagnosedSilenceableFailure
transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
// Exit early if no transformation is needed.
if (isa<GenericOp>(target))
return target;
if (isa<GenericOp>(target)) {
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
if (succeeded(generic))
return generic;
return reportUnknownTransformError(target);
if (succeeded(generic)) {
results.push_back(generic->getOperation());
return DiagnosedSilenceableFailure(success());
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp>
transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) {
DiagnosedSilenceableFailure
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
SmallVector<unsigned> interchangeVector =
extractUIntArray(getIteratorInterchange());
// Exit early if no transformation is needed.
if (interchangeVector.empty())
return target;
auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
if (!genericTarget) {
InFlightDiagnostic diag = emitOpError()
<< "applies to " << GenericOp::getOperationName()
<< " ops";
diag.attachNote(target.getLoc()) << "attempted to apply to this op";
return diag;
if (interchangeVector.empty()) {
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
SimpleRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
interchangeGenericOp(rewriter, target, interchangeVector);
if (failed(res))
return DiagnosedSilenceableFailure::definiteFailure();
results.push_back(res->getOperation());
return DiagnosedSilenceableFailure(success());
}
LogicalResult transform::InterchangeOp::verify() {
@ -275,8 +285,10 @@ LogicalResult transform::InterchangeOp::verify() {
// PadOp
//===---------------------------------------------------------------------===//
FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
TransformState &state) {
DiagnosedSilenceableFailure
transform::PadOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
// Convert the integer packing flags to booleans.
SmallVector<bool> packPaddings;
for (int64_t packPadding : extractI64Array(getPackPaddings()))
@ -293,21 +305,19 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
paddingValues.push_back(
parseAttribute(attr.cast<StringAttr>(), elementType));
if (!paddingValues.back()) {
InFlightDiagnostic diag = emitOpError()
<< "expects a padding value that parses to "
<< elementType << ", got " << std::get<0>(it);
auto diag = this->emitOpError("expects a padding that parses to ")
<< elementType << ", got " << std::get<0>(it);
diag.attachNote(target.getLoc()) << "when applied to this op";
return diag;
return DiagnosedSilenceableFailure::definiteFailure();
}
continue;
}
// Otherwise, add the attribute directly.
if (attr.getType() != elementType) {
InFlightDiagnostic diag = emitOpError()
<< "expects a padding value of type "
<< elementType << ", got " << attr;
auto diag = this->emitOpError("expects a padding value of type ")
<< elementType << ", got " << attr;
diag.attachNote(target.getLoc()) << "when applied to this op";
return diag;
return DiagnosedSilenceableFailure::definiteFailure();
}
paddingValues.push_back(attr);
}
@ -327,13 +337,13 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
FailureOr<LinalgOp> result =
tryApply<LinalgPaddingPattern>(target, paddingOptions);
if (succeeded(result))
return result;
if (succeeded(result)) {
results.push_back(result->getOperation());
return DiagnosedSilenceableFailure(success());
}
InFlightDiagnostic diag = emitError()
<< "failed to apply pattern to target op";
diag.attachNote(target.getLoc()) << "target op";
return diag;
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
LogicalResult transform::PadOp::verify() {
@ -381,8 +391,10 @@ LogicalResult transform::PadOp::verify() {
// ScalarizeOp
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
TransformState &state) {
DiagnosedSilenceableFailure
transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims();
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
@ -394,9 +406,10 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
FailureOr<TiledLinalgOp> result =
pattern.returningMatchAndRewrite(target, rewriter);
if (failed(result))
return failure();
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
return result->op;
results.push_back(result->op);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//
@ -558,9 +571,10 @@ LogicalResult SplitOp::verify() {
// SplitReductionOp
//===----------------------------------------------------------------------===//
FailureOr<SmallVector<Operation *>>
transform::SplitReductionOp::applyToOne(LinalgOp target,
TransformState &state) {
DiagnosedSilenceableFailure
transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return std::pair<int64_t, unsigned>(getSplitFactor(),
getInsertSplitDimension());
@ -572,10 +586,13 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
: splitReduction(rewriter, target, splitFn, getUseAlloc());
if (failed(splitResult))
return getOperation()->emitError("failed to apply");
return SmallVector<Operation *>{splitResult->fillOp,
splitResult->splitLinalgOp,
splitResult->resultCombiningLinalgOp};
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
results.push_back(splitResult->initOrAlloc);
results.push_back(splitResult->fillOp);
results.push_back(splitResult->splitLinalgOp);
results.push_back(splitResult->resultCombiningLinalgOp);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//
@ -618,13 +635,14 @@ void TileOp::print(OpAsmPrinter &p) {
// VectorizeOp
//===----------------------------------------------------------------------===//
FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
TransformState &state) {
DiagnosedSilenceableFailure
transform::VectorizeOp::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
InFlightDiagnostic diag = emitOpError()
<< "applies only to isolated-from-above targets";
auto diag = this->emitOpError("requires isolated-from-above targets");
diag.attachNote(target->getLoc()) << "non-isolated target";
return diag;
return DiagnosedSilenceableFailure::definiteFailure();
}
MLIRContext *ctx = getContext();
@ -642,8 +660,10 @@ FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
linalg::populatePadOpVectorizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
return reportUnknownTransformError(target);
return target;
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//

View File

@ -127,18 +127,21 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
// LoopPeelOp
//===----------------------------------------------------------------------===//
FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop,
TransformState &state) {
DiagnosedSilenceableFailure
transform::LoopPeelOp::applyToOne(scf::ForOp target,
SmallVector<Operation *> &results,
transform::TransformState &state) {
scf::ForOp result;
IRRewriter rewriter(loop->getContext());
IRRewriter rewriter(target->getContext());
// This helper returns failure when peeling does not occur (i.e. when the IR
// is not modified). This is not a failure for the op as the postcondition:
// "the loop trip count is divisible by the step"
// is valid.
LogicalResult status =
scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
if (failed(status)) {
if (getFailIfAlreadyDivisible())
return reportUnknownTransformError(loop);
return loop;
}
return result;
scf::peelAndCanonicalizeForLoop(rewriter, target, result);
// TODO: Return both the peeled loop and the remainder loop.
results.push_back(failed(status) ? target : result);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//
@ -181,8 +184,10 @@ loopScheduling(scf::ForOp forOp,
}
}
FailureOr<scf::ForOp>
transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
DiagnosedSilenceableFailure
transform::LoopPipelineOp::applyToOne(scf::ForOp target,
SmallVector<Operation *> &results,
transform::TransformState &state) {
scf::PipeliningOption options;
options.getScheduleFn =
[this](scf::ForOp forOp,
@ -190,26 +195,33 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
loopScheduling(forOp, schedule, getIterationInterval(),
getReadLatency());
};
scf::ForLoopPipeliningPattern pattern(options, loop->getContext());
scf::ForLoopPipeliningPattern pattern(options, target->getContext());
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(loop);
rewriter.setInsertionPoint(target);
FailureOr<scf::ForOp> patternResult =
pattern.returningMatchAndRewrite(loop, rewriter);
if (failed(patternResult))
return reportUnknownTransformError(loop);
return patternResult;
pattern.returningMatchAndRewrite(target, rewriter);
if (succeeded(patternResult)) {
results.push_back(*patternResult);
return DiagnosedSilenceableFailure(success());
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
//===----------------------------------------------------------------------===//
// LoopUnrollOp
//===----------------------------------------------------------------------===//
LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop,
TransformState &state) {
if (failed(loopUnrollByFactor(loop, getFactor())))
return reportUnknownTransformError(loop);
return success();
DiagnosedSilenceableFailure
transform::LoopUnrollOp::applyToOne(scf::ForOp target,
SmallVector<Operation *> &results,
transform::TransformState &state) {
if (failed(loopUnrollByFactor(target, getFactor()))) {
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
diag << "op failed to unroll";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//

View File

@ -37,7 +37,7 @@ transform.with_pdl_patterns {
// -----
func.func @interchange_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
// expected-note @below {{attempted to apply to this op}}
// expected-note @below {{when applied to this op}}
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@ -54,7 +54,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @match_generic in %arg1
// expected-error @below {{applies to linalg.generic ops}}
// expected-error @below {{transform applied to the wrong op kind}}
transform.structured.interchange %0 { iterator_interchange = [1, 0]}
}
}

View File

@ -99,7 +99,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
// expected-error @below {{expects a padding value that parses to 'f32', got "foo"}}
// expected-error @below {{expects a padding that parses to 'f32', got "foo"}}
%1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
}
}
@ -109,7 +109,7 @@ transform.with_pdl_patterns {
func.func @pad(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
// expected-note @below {{target op}}
// expected-note @below {{when applied to this op}}
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}
@ -127,7 +127,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
// expected-error @below {{failed to apply pattern to target op}}
// expected-error @below {{transform.structured.pad failed to apply}}
%1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
}
}

View File

@ -31,7 +31,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1:3 = transform.structured.split_reduction %0
%1:4 = transform.structured.split_reduction %0
{ split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
}
}

View File

@ -30,6 +30,6 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1:3 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
%1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
}
}

View File

@ -176,7 +176,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
// expected-error @below {{applies only to isolated-from-above targets}}
// expected-error @below {{op requires isolated-from-above targets}}
%2 = transform.structured.vectorize %0
}
}

View File

@ -380,16 +380,8 @@ transform.with_pdl_patterns {
}
// -----
transform.sequence {
^bb0(%arg0: !pdl.operation):
// expected-error @below {{unexpected number of results (got 0 expected 3)}}
transform.test_wrong_number_of_results %arg0
}
// -----
func.func @foo() {
"op" () : () -> ()
// expected-note @below {{when applied to this op}}
"op" () : () -> ()
return
}
@ -406,7 +398,37 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1
// expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}}
// expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}}
// expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
// expected-note @below {{Producing 3 null results is allowed if the use case warrants it.}}
transform.test_wrong_number_of_results %0
}
}
// -----
func.func @foo() {
"op" () : () -> ()
// expected-note @below {{when applied to this op}}
"op" () : () -> ()
return
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @some : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1
// expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}}
// expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}}
// expected-note @below {{Producing 1 null results is allowed if the use case warrants it.}}
transform.test_wrong_number_of_multi_results %0
}
}
@ -463,6 +485,31 @@ transform.with_pdl_patterns {
// -----
func.func @foo() {
// expected-note @below {{when applied to this op}}
"op" () : () -> ()
return
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @some : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1
// expected-error @below {{unexpected application of transform.test_mixed_null_and_non_null_results produces both null and non null results.}}
transform.test_mixed_null_and_non_null_results %0
}
}
// -----
// Expecting to match all operations by merging the handles that matched addi
// and subi separately.
func.func @foo(%arg0: index) {
@ -498,4 +545,3 @@ transform.with_pdl_patterns {
test_print_remark_at_operand %2, "matched"
}
}

View File

@ -226,28 +226,44 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
return DiagnosedSilenceableFailure::success();
}
FailureOr<SmallVector<Operation *>>
mlir::test::TestWrongNumberOfResultsOp::applyToOne(
Operation *, transform::TransformState &state) {
return SmallVector<Operation *>{};
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
FailureOr<SmallVector<Operation *>>
DiagnosedSilenceableFailure
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
Operation *op, transform::TransformState &state) {
Operation *target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
static int count = 0;
if (count++ > 0)
return SmallVector<Operation *>{};
OperationState opState(op->getLoc(), "foo");
return SmallVector<Operation *>{OpBuilder(op).create(opState)};
if (count++ == 0) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
}
return DiagnosedSilenceableFailure::success();
}
FailureOr<SmallVector<Operation *>>
DiagnosedSilenceableFailure
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
Operation *op, transform::TransformState &state) {
OperationState opState(op->getLoc(), "foo");
return SmallVector<Operation *>{OpBuilder(op).create(opState),
OpBuilder(op).create(opState)};
Operation *target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(nullptr);
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
namespace {

View File

@ -139,8 +139,10 @@ def TestWrongNumberOfResultsOp
let assemblyFormat = "$target attr-dict";
let cppNamespace = "::mlir::test";
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
::mlir::Operation *target, transform::TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -153,8 +155,10 @@ def TestWrongNumberOfMultiResultsOp
let assemblyFormat = "$target attr-dict";
let cppNamespace = "::mlir::test";
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
::mlir::Operation *target, transform::TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
@ -168,8 +172,27 @@ def TestCorrectNumberOfMultiResultsOp
let assemblyFormat = "$target attr-dict";
let cppNamespace = "::mlir::test";
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
::mlir::Operation *target, transform::TransformState &state);
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}
def TestMixedNullAndNonNullResultsOp
: Op<Transform_Dialect, "test_mixed_null_and_non_null_results",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
let arguments = (ins PDL_Operation:$target);
let results = (outs PDL_Operation:$null,
PDL_Operation:$non_null);
let assemblyFormat = "$target attr-dict";
let cppNamespace = "::mlir::test";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation * target,
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
}