forked from OSchip/llvm-project
[mlir] Create a generic reduction detection utility
This patch introduces a generic reduction detection utility that works across different dialecs. It is mostly a generalization of the reduction detection algorithm in Affine. The reduction detection logic in Affine, Linalg and SCFToOpenMP have been replaced with this new generic utility. The utility takes some basic components of the potential reduction and returns: 1) the reduced value, and 2) a list with the combiner operations. The logic to match reductions involving multiple combiner operations disabled until we can properly test it. Reviewed By: ftynse, bondhugula, nicolasvasilache, pifon2a Differential Revision: https://reviews.llvm.org/D110303
This commit is contained in:
parent
a7cdcf25c1
commit
2a876a711d
|
@ -22,6 +22,7 @@ namespace mlir {
|
|||
class AffineExpr;
|
||||
class AffineForOp;
|
||||
class AffineMap;
|
||||
class BlockArgument;
|
||||
class MemRefType;
|
||||
class NestedPattern;
|
||||
class Operation;
|
||||
|
@ -83,6 +84,37 @@ bool isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
|
|||
// TODO: extend this to check for memory-based dependence violation when we have
|
||||
// the support.
|
||||
bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
|
||||
|
||||
/// Utility to match a generic reduction given a list of iteration-carried
|
||||
/// arguments, `iterCarriedArgs` and the position of the potential reduction
|
||||
/// argument within the list, `redPos`. If a reduction is matched, returns the
|
||||
/// reduced value and the topologically-sorted list of combiner operations
|
||||
/// involved in the reduction. Otherwise, returns a null value.
|
||||
///
|
||||
/// The matching algorithm relies on the following invariants, which are subject
|
||||
/// to change:
|
||||
/// 1. The first combiner operation must be a binary operation with the
|
||||
/// iteration-carried value and the reduced value as operands.
|
||||
/// 2. The iteration-carried value and combiner operations must be side
|
||||
/// effect-free, have single result and a single use.
|
||||
/// 3. Combiner operations must be immediately nested in the region op
|
||||
/// performing the reduction.
|
||||
/// 4. Reduction def-use chain must end in a terminator op that yields the
|
||||
/// next iteration/output values in the same order as the iteration-carried
|
||||
/// values in `iterCarriedArgs`.
|
||||
/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
|
||||
/// of the region op performing the reduction.
|
||||
///
|
||||
/// This utility is generic enough to detect reductions involving multiple
|
||||
/// combiner operations (disabled for now) across multiple dialects, including
|
||||
/// Linalg, Affine and SCF. For the sake of genericity, it does not return
|
||||
/// specific enum values for the combiner operations since its goal is also
|
||||
/// matching reductions without pre-defined semantics in core MLIR. It's up to
|
||||
/// each client to make sense out of the list of combiner operations. It's also
|
||||
/// up to each client to check for additional invariants on the expected
|
||||
/// reductions not covered by this generic matching.
|
||||
Value matchReduction(ArrayRef<BlockArgument> iterCarriedArgs, unsigned redPos,
|
||||
SmallVectorImpl<Operation *> &combinerOps);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
||||
|
|
|
@ -595,6 +595,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
return 0;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output block arguments of the region.
|
||||
}],
|
||||
/*retTy=*/"Block::BlockArgListType",
|
||||
/*methodName=*/"getRegionOutputArgs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
Block &entryBlock = this->getOperation()->getRegion(0).front();
|
||||
return entryBlock.getArguments().take_back(this->getNumOutputs());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `opOperand` shape or an empty vector for scalars.
|
||||
|
|
|
@ -30,6 +30,7 @@ class MLIRContext;
|
|||
class Operation;
|
||||
class OperationName;
|
||||
class Type;
|
||||
class Value;
|
||||
|
||||
namespace detail {
|
||||
struct DiagnosticEngineImpl;
|
||||
|
@ -218,6 +219,9 @@ public:
|
|||
return *this << *val;
|
||||
}
|
||||
|
||||
/// Stream in a Value.
|
||||
Diagnostic &operator<<(Value val);
|
||||
|
||||
/// Stream in a range.
|
||||
template <typename T, typename ValueT = llvm::detail::ValueOfRange<T>>
|
||||
std::enable_if_t<!std::is_constructible<DiagnosticArgument, T>::value,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
|
@ -22,7 +23,6 @@
|
|||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -33,29 +33,6 @@ using namespace mlir;
|
|||
|
||||
using llvm::dbgs;
|
||||
|
||||
/// Returns true if `value` (transitively) depends on iteration arguments of the
|
||||
/// given `forOp`.
|
||||
static bool dependsOnIterArgs(Value value, AffineForOp forOp) {
|
||||
// Compute the backward slice of the value.
|
||||
SetVector<Operation *> slice;
|
||||
getBackwardSlice(value, &slice,
|
||||
[&](Operation *op) { return !forOp->isAncestor(op); });
|
||||
|
||||
// Check that none of the operands of the operations in the backward slice are
|
||||
// loop iteration arguments, and neither is the value itself.
|
||||
auto argRange = forOp.getRegionIterArgs();
|
||||
llvm::SmallPtrSet<Value, 8> iterArgs(argRange.begin(), argRange.end());
|
||||
if (iterArgs.contains(value))
|
||||
return true;
|
||||
|
||||
for (Operation *op : slice)
|
||||
for (Value operand : op->getOperands())
|
||||
if (iterArgs.contains(operand))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Get the value that is being reduced by `pos`-th reduction in the loop if
|
||||
/// such a reduction can be performed by affine parallel loops. This assumes
|
||||
/// floating-point operations are commutative. On success, `kind` will be the
|
||||
|
@ -63,18 +40,19 @@ static bool dependsOnIterArgs(Value value, AffineForOp forOp) {
|
|||
/// reduction is not supported, returns null.
|
||||
static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
|
||||
AtomicRMWKind &kind) {
|
||||
auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->back());
|
||||
Value yielded = yieldOp.operands()[pos];
|
||||
Operation *definition = yielded.getDefiningOp();
|
||||
if (!definition)
|
||||
return nullptr;
|
||||
if (!forOp.getRegionIterArgs()[pos].hasOneUse())
|
||||
return nullptr;
|
||||
if (!yielded.hasOneUse())
|
||||
SmallVector<Operation *> combinerOps;
|
||||
Value reducedVal =
|
||||
matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
|
||||
if (!reducedVal)
|
||||
return nullptr;
|
||||
|
||||
// Expected only one combiner operation.
|
||||
if (combinerOps.size() > 1)
|
||||
return nullptr;
|
||||
|
||||
Operation *combinerOp = combinerOps.back();
|
||||
Optional<AtomicRMWKind> maybeKind =
|
||||
TypeSwitch<Operation *, Optional<AtomicRMWKind>>(definition)
|
||||
TypeSwitch<Operation *, Optional<AtomicRMWKind>>(combinerOp)
|
||||
.Case<AddFOp>([](Operation *) { return AtomicRMWKind::addf; })
|
||||
.Case<MulFOp>([](Operation *) { return AtomicRMWKind::mulf; })
|
||||
.Case<AddIOp>([](Operation *) { return AtomicRMWKind::addi; })
|
||||
|
@ -88,14 +66,7 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
|
|||
return nullptr;
|
||||
|
||||
kind = *maybeKind;
|
||||
if (definition->getOperand(0) == forOp.getRegionIterArgs()[pos] &&
|
||||
!dependsOnIterArgs(definition->getOperand(1), forOp))
|
||||
return definition->getOperand(1);
|
||||
if (definition->getOperand(1) == forOp.getRegionIterArgs()[pos] &&
|
||||
!dependsOnIterArgs(definition->getOperand(0), forOp))
|
||||
return definition->getOperand(0);
|
||||
|
||||
return nullptr;
|
||||
return reducedVal;
|
||||
}
|
||||
|
||||
/// Returns true if `forOp' is a parallel loop. If `parallelReductions` is
|
||||
|
|
|
@ -15,11 +15,13 @@
|
|||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/Analysis/NestedMatcher.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include <type_traits>
|
||||
|
||||
|
@ -392,3 +394,105 @@ bool mlir::isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns true if `value` (transitively) depends on iteration-carried values
|
||||
/// of the given `ancestorOp`.
|
||||
static bool dependsOnCarriedVals(Value value,
|
||||
ArrayRef<BlockArgument> iterCarriedArgs,
|
||||
Operation *ancestorOp) {
|
||||
// Compute the backward slice of the value.
|
||||
SetVector<Operation *> slice;
|
||||
getBackwardSlice(value, &slice,
|
||||
[&](Operation *op) { return !ancestorOp->isAncestor(op); });
|
||||
|
||||
// Check that none of the operands of the operations in the backward slice are
|
||||
// loop iteration arguments, and neither is the value itself.
|
||||
SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
|
||||
iterCarriedArgs.end());
|
||||
if (iterCarriedValSet.contains(value))
|
||||
return true;
|
||||
|
||||
for (Operation *op : slice)
|
||||
for (Value operand : op->getOperands())
|
||||
if (iterCarriedValSet.contains(operand))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Utility to match a generic reduction given a list of iteration-carried
|
||||
/// arguments, `iterCarriedArgs` and the position of the potential reduction
|
||||
/// argument within the list, `redPos`. If a reduction is matched, returns the
|
||||
/// reduced value and the topologically-sorted list of combiner operations
|
||||
/// involved in the reduction. Otherwise, returns a null value.
|
||||
///
|
||||
/// The matching algorithm relies on the following invariants, which are subject
|
||||
/// to change:
|
||||
/// 1. The first combiner operation must be a binary operation with the
|
||||
/// iteration-carried value and the reduced value as operands.
|
||||
/// 2. The iteration-carried value and combiner operations must be side
|
||||
/// effect-free, have single result and a single use.
|
||||
/// 3. Combiner operations must be immediately nested in the region op
|
||||
/// performing the reduction.
|
||||
/// 4. Reduction def-use chain must end in a terminator op that yields the
|
||||
/// next iteration/output values in the same order as the iteration-carried
|
||||
/// values in `iterCarriedArgs`.
|
||||
/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values
|
||||
/// of the region op performing the reduction.
|
||||
///
|
||||
/// This utility is generic enough to detect reductions involving multiple
|
||||
/// combiner operations (disabled for now) across multiple dialects, including
|
||||
/// Linalg, Affine and SCF. For the sake of genericity, it does not return
|
||||
/// specific enum values for the combiner operations since its goal is also
|
||||
/// matching reductions without pre-defined semantics in core MLIR. It's up to
|
||||
/// each client to make sense out of the list of combiner operations. It's also
|
||||
/// up to each client to check for additional invariants on the expected
|
||||
/// reductions not covered by this generic matching.
|
||||
Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
|
||||
unsigned redPos,
|
||||
SmallVectorImpl<Operation *> &combinerOps) {
|
||||
assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
|
||||
|
||||
BlockArgument redCarriedVal = iterCarriedArgs[redPos];
|
||||
if (!redCarriedVal.hasOneUse())
|
||||
return nullptr;
|
||||
|
||||
// For now, the first combiner op must be a binary op.
|
||||
Operation *combinerOp = *redCarriedVal.getUsers().begin();
|
||||
if (combinerOp->getNumOperands() != 2)
|
||||
return nullptr;
|
||||
Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
|
||||
? combinerOp->getOperand(1)
|
||||
: combinerOp->getOperand(0);
|
||||
|
||||
Operation *redRegionOp =
|
||||
iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
|
||||
if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
|
||||
return nullptr;
|
||||
|
||||
// Traverse the def-use chain starting from the first combiner op until a
|
||||
// terminator is found. Gather all the combiner ops along the way in
|
||||
// topological order.
|
||||
while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
|
||||
if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) ||
|
||||
combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() ||
|
||||
combinerOp->getParentOp() != redRegionOp)
|
||||
return nullptr;
|
||||
|
||||
combinerOps.push_back(combinerOp);
|
||||
combinerOp = *combinerOp->getUsers().begin();
|
||||
}
|
||||
|
||||
// Limit matching to single combiner op until we can properly test reductions
|
||||
// involving multiple combiners.
|
||||
if (combinerOps.size() != 1)
|
||||
return nullptr;
|
||||
|
||||
// Check that the yielded value is in the same position as in
|
||||
// `iterCarriedArgs`.
|
||||
Operation *terminatorOp = combinerOp;
|
||||
if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
|
||||
return nullptr;
|
||||
|
||||
return reducedVal;
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRSCFToOpenMP
|
|||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRLLVMIR
|
||||
MLIROpenMP
|
||||
MLIRSCF
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -34,10 +35,21 @@ static bool matchSimpleReduction(Block &block) {
|
|||
if (block.empty() || llvm::hasSingleElement(block) ||
|
||||
std::next(block.begin(), 2) != block.end())
|
||||
return false;
|
||||
return isa<OpTy...>(block.front()) &&
|
||||
|
||||
if (block.getNumArguments() != 2)
|
||||
return false;
|
||||
|
||||
SmallVector<Operation *, 4> combinerOps;
|
||||
Value reducedVal = matchReduction({block.getArguments()[1]},
|
||||
/*redPos=*/0, combinerOps);
|
||||
|
||||
if (!reducedVal || !reducedVal.isa<BlockArgument>() ||
|
||||
combinerOps.size() != 1)
|
||||
return false;
|
||||
|
||||
return isa<OpTy...>(combinerOps[0]) &&
|
||||
isa<scf::ReduceReturnOp>(block.back()) &&
|
||||
block.front().getOperands() == block.getArguments() &&
|
||||
block.back().getOperand(0) == block.front().getResult(0);
|
||||
block.front().getOperands() == block.getArguments();
|
||||
}
|
||||
|
||||
/// Matches a block containing a select-based min/max reduction. The types of
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
|
@ -110,46 +111,24 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
|
|||
return VectorType::get(st.getShape(), st.getElementType());
|
||||
}
|
||||
|
||||
/// Given an `outputOperand` of a LinalgOp, compute the intersection of the
|
||||
/// forward slice starting from `outputOperand` and the backward slice
|
||||
/// starting from the corresponding linalg.yield operand.
|
||||
/// This intersection is assumed to have a single binary operation that is
|
||||
/// the reduction operation. Multiple reduction operations would impose an
|
||||
/// Check whether `outputOperand` is a reduction with a single combiner
|
||||
/// operation. Return the combiner operation of the reduction, which is assumed
|
||||
/// to be a binary operation. Multiple reduction operations would impose an
|
||||
/// ordering between reduction dimensions and is currently unsupported in
|
||||
/// Linalg. This limitation is motivated by the fact that e.g.
|
||||
/// min(max(X)) != max(min(X))
|
||||
/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
|
||||
/// max(min(X))
|
||||
// TODO: use in LinalgOp verification, there is a circular dependency atm.
|
||||
static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
|
||||
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
||||
auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
|
||||
unsigned yieldNum =
|
||||
unsigned outputPos =
|
||||
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
|
||||
llvm::SetVector<Operation *> backwardSlice, forwardSlice;
|
||||
BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
|
||||
outputOperand->getOperandNumber());
|
||||
Value yieldVal = yieldOp->getOperand(yieldNum);
|
||||
getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
|
||||
return op->getParentOp() == linalgOp;
|
||||
});
|
||||
backwardSlice.insert(yieldVal.getDefiningOp());
|
||||
getForwardSlice(bbArg, &forwardSlice,
|
||||
[&](Operation *op) { return op->getParentOp() == linalgOp; });
|
||||
// Search for the (assumed unique) elementwiseMappable op at the intersection
|
||||
// of forward and backward slices.
|
||||
Operation *reductionOp = nullptr;
|
||||
for (Operation *op : llvm::reverse(backwardSlice)) {
|
||||
if (!forwardSlice.contains(op))
|
||||
continue;
|
||||
if (OpTrait::hasElementwiseMappableTraits(op)) {
|
||||
if (reductionOp) {
|
||||
// Reduction detection fails: found more than 1 elementwise-mappable op.
|
||||
return nullptr;
|
||||
}
|
||||
reductionOp = op;
|
||||
}
|
||||
}
|
||||
SmallVector<Operation *, 4> combinerOps;
|
||||
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
|
||||
combinerOps.size() != 1)
|
||||
return nullptr;
|
||||
|
||||
// TODO: also assert no other subsequent ops break the reduction.
|
||||
return reductionOp;
|
||||
return combinerOps[0];
|
||||
}
|
||||
|
||||
/// If `value` of assumed VectorType has a shape different than `shape`, try to
|
||||
|
|
|
@ -131,6 +131,14 @@ Diagnostic &Diagnostic::operator<<(Operation &val) {
|
|||
return *this << os.str();
|
||||
}
|
||||
|
||||
/// Stream in a Value.
|
||||
Diagnostic &Diagnostic::operator<<(Value val) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream os(str);
|
||||
val.print(os);
|
||||
return *this << os.str();
|
||||
}
|
||||
|
||||
/// Outputs this diagnostic to a stream.
|
||||
void Diagnostic::print(raw_ostream &os) const {
|
||||
for (auto &arg : getArguments())
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
// RUN: mlir-opt %s -test-match-reduction -verify-diagnostics -split-input-file
|
||||
|
||||
// Verify that the generic reduction detection utility works on different
|
||||
// dialects.
|
||||
|
||||
// expected-remark@below {{Testing function}}
|
||||
func @linalg_red_add(%in0t : tensor<?xf32>, %out0t : tensor<1xf32>) {
|
||||
// expected-remark@below {{Reduction found in output #0!}}
|
||||
// expected-remark@below {{Reduced Value: <block argument> of type 'f32' at index: 0}}
|
||||
// expected-remark@below {{Combiner Op: %1 = addf %arg2, %arg3 : f32}}
|
||||
%red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (0)>],
|
||||
iterator_types = ["reduction"]}
|
||||
ins(%in0t : tensor<?xf32>)
|
||||
outs(%out0t : tensor<1xf32>) {
|
||||
^bb0(%in0: f32, %out0: f32):
|
||||
%add = addf %in0, %out0 : f32
|
||||
linalg.yield %add : f32
|
||||
} -> tensor<1xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark@below {{Testing function}}
|
||||
func @affine_red_add(%in: memref<256x512xf32>, %out: memref<256xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
affine.for %i = 0 to 256 {
|
||||
// expected-remark@below {{Reduction found in output #0!}}
|
||||
// expected-remark@below {{Reduced Value: %1 = affine.load %arg0[%arg2, %arg3] : memref<256x512xf32>}}
|
||||
// expected-remark@below {{Combiner Op: %2 = addf %arg4, %1 : f32}}
|
||||
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
|
||||
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
|
||||
%add = addf %red_iter, %ld : f32
|
||||
affine.yield %add : f32
|
||||
}
|
||||
affine.store %final_red, %out[%i] : memref<256xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// TODO: Iteration-carried values with multiple uses are not supported yet.
|
||||
// expected-remark@below {{Testing function}}
|
||||
func @linalg_red_max(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) {
|
||||
// expected-remark@below {{Reduction NOT found in output #0!}}
|
||||
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%in0t : tensor<4x4xf32>)
|
||||
outs(%out0t : tensor<4xf32>) {
|
||||
^bb0(%in0: f32, %out0: f32):
|
||||
%cmp = cmpf ogt, %in0, %out0 : f32
|
||||
%sel = select %cmp, %in0, %out0 : f32
|
||||
linalg.yield %sel : f32
|
||||
} -> tensor<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark@below {{Testing function}}
|
||||
func @linalg_fused_red_add(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) {
|
||||
// expected-remark@below {{Reduction found in output #0!}}
|
||||
// expected-remark@below {{Reduced Value: %2 = subf %1, %arg2 : f32}}
|
||||
// expected-remark@below {{Combiner Op: %3 = addf %2, %arg3 : f32}}
|
||||
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>],
|
||||
iterator_types = ["parallel", "reduction"]}
|
||||
ins(%in0t : tensor<4x4xf32>)
|
||||
outs(%out0t : tensor<4xf32>) {
|
||||
^bb0(%in0: f32, %out0: f32):
|
||||
%mul = mulf %in0, %in0 : f32
|
||||
%sub = subf %mul, %in0 : f32
|
||||
%add = addf %sub, %out0 : f32
|
||||
linalg.yield %add : f32
|
||||
} -> tensor<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark@below {{Testing function}}
|
||||
func @affine_no_red_rec(%in: memref<512xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
// %rec is the value loaded in the previous iteration.
|
||||
// expected-remark@below {{Reduction NOT found in output #0!}}
|
||||
%final_val = affine.for %j = 0 to 512 iter_args(%rec = %cst) -> (f32) {
|
||||
%ld = affine.load %in[%j] : memref<512xf32>
|
||||
%add = addf %ld, %rec : f32
|
||||
affine.yield %ld : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark@below {{Testing function}}
|
||||
func @affine_output_dep(%in: memref<512xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
// Reduction %red is not supported because it depends on another
|
||||
// loop-carried dependence.
|
||||
// expected-remark@below {{Reduction NOT found in output #0!}}
|
||||
// expected-remark@below {{Reduction NOT found in output #1!}}
|
||||
%final_red, %final_dep = affine.for %j = 0 to 512
|
||||
iter_args(%red = %cst, %dep = %cst) -> (f32, f32) {
|
||||
%ld = affine.load %in[%j] : memref<512xf32>
|
||||
%add = addf %dep, %red : f32
|
||||
affine.yield %add, %ld : f32, f32
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ add_mlir_library(MLIRTestAnalysis
|
|||
TestAliasAnalysis.cpp
|
||||
TestCallGraph.cpp
|
||||
TestLiveness.cpp
|
||||
TestMatchReduction.cpp
|
||||
TestMemRefBoundCheck.cpp
|
||||
TestMemRefDependenceCheck.cpp
|
||||
TestMemRefStrideCalculation.cpp
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
//===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains a test pass for the match reduction utility.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
void printReductionResult(Operation *redRegionOp, unsigned numOutput,
|
||||
Value reducedValue,
|
||||
ArrayRef<Operation *> combinerOps) {
|
||||
if (reducedValue) {
|
||||
redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
|
||||
redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
|
||||
for (Operation *combOp : combinerOps)
|
||||
redRegionOp->emitRemark("Combiner Op: ") << *combOp;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
redRegionOp->emitRemark("Reduction NOT found in output #")
|
||||
<< numOutput << "!";
|
||||
}
|
||||
|
||||
struct TestMatchReductionPass
|
||||
: public PassWrapper<TestMatchReductionPass, FunctionPass> {
|
||||
StringRef getArgument() const final { return "test-match-reduction"; }
|
||||
StringRef getDescription() const final {
|
||||
return "Test the match reduction utility.";
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
func->emitRemark("Testing function");
|
||||
|
||||
func.walk<WalkOrder::PreOrder>([](Operation *op) {
|
||||
if (isa<FuncOp>(op))
|
||||
return;
|
||||
|
||||
// Limit testing to ops with only one region.
|
||||
if (op->getNumRegions() != 1)
|
||||
return;
|
||||
|
||||
Region ®ion = op->getRegion(0);
|
||||
if (!region.hasOneBlock())
|
||||
return;
|
||||
|
||||
// We expect all the tested region ops to have 1 input by default. The
|
||||
// remaining arguments are assumed to be outputs/reductions and there must
|
||||
// be at least one.
|
||||
// TODO: Extend it to support more generic cases.
|
||||
Block ®ionEntry = region.front();
|
||||
auto args = regionEntry.getArguments();
|
||||
if (args.size() < 2)
|
||||
return;
|
||||
|
||||
auto outputs = args.drop_front();
|
||||
for (int i = 0, size = outputs.size(); i < size; ++i) {
|
||||
SmallVector<Operation *, 4> combinerOps;
|
||||
Value reducedValue = matchReduction(outputs, i, combinerOps);
|
||||
printReductionResult(op, i, reducedValue, combinerOps);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestMatchReductionPass() {
|
||||
PassRegistration<TestMatchReductionPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
|
@ -94,6 +94,7 @@ void registerTestLivenessPass();
|
|||
void registerTestLoopFusion();
|
||||
void registerTestLoopMappingPass();
|
||||
void registerTestLoopUnrollingPass();
|
||||
void registerTestMatchReductionPass();
|
||||
void registerTestMathAlgebraicSimplificationPass();
|
||||
void registerTestMathPolynomialApproximationPass();
|
||||
void registerTestMemRefDependenceCheck();
|
||||
|
@ -183,6 +184,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestLoopFusion();
|
||||
mlir::test::registerTestLoopMappingPass();
|
||||
mlir::test::registerTestLoopUnrollingPass();
|
||||
mlir::test::registerTestMatchReductionPass();
|
||||
mlir::test::registerTestMathAlgebraicSimplificationPass();
|
||||
mlir::test::registerTestMathPolynomialApproximationPass();
|
||||
mlir::test::registerTestMemRefDependenceCheck();
|
||||
|
|
|
@ -4336,6 +4336,7 @@ cc_library(
|
|||
hdrs = ["include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":Analysis",
|
||||
":ConversionPassIncGen",
|
||||
":IR",
|
||||
":LLVMDialect",
|
||||
|
|
Loading…
Reference in New Issue