llvm-project/mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

3484 lines
138 KiB
C++
Raw Normal View History

//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "affine-analysis"
/// A utility function to check if a value is defined at the top level of
/// `region` or is an argument of `region`. A value of index type defined at the
/// top level of a `AffineScope` region is always a valid symbol for all
/// uses in that region.
static bool isTopLevelValue(Value value, Region *region) {
if (auto arg = value.dyn_cast<BlockArgument>())
return arg.getParentRegion() == region;
return value.getDefiningOp()->getParentRegion() == region;
}
/// Checks if `value` known to be a legal affine dimension or symbol in `src`
/// region remains legal if the operation that uses it is inlined into `dest`
/// with the given value mapping. `legalityCheck` is either `isValidDim` or
/// `isValidSymbol`, depending on the value being required to remain a valid
/// dimension or symbol.
static bool
remainsLegalAfterInline(Value value, Region *src, Region *dest,
const BlockAndValueMapping &mapping,
function_ref<bool(Value, Region *)> legalityCheck) {
// If the value is a valid dimension for any other reason than being
// a top-level value, it will remain valid: constants get inlined
// with the function, transitive affine applies also get inlined and
// will be checked themselves, etc.
if (!isTopLevelValue(value, src))
return true;
// If it's a top-level value because it's a block operand, i.e. a
// function argument, check whether the value replacing it after
// inlining is a valid dimension in the new region.
if (value.isa<BlockArgument>())
return legalityCheck(mapping.lookup(value), dest);
// If it's a top-level value beacuse it's defined in the region,
// it can only be inlined if the defining op is a constant or a
// `dim`, which can appear anywhere and be valid, since the defining
// op won't be top-level anymore after inlining.
Attribute operandCst;
return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
value.getDefiningOp<memref::DimOp>();
}
/// Checks if all values known to be legal affine dimensions or symbols in `src`
/// remain so if their respective users are inlined into `dest`.
static bool
remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
const BlockAndValueMapping &mapping,
function_ref<bool(Value, Region *)> legalityCheck) {
return llvm::all_of(values, [&](Value v) {
return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
});
}
/// Checks if an affine read or write operation remains legal after inlining
/// from `src` to `dest`.
template <typename OpTy>
static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
const BlockAndValueMapping &mapping) {
static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
AffineWriteOpInterface>::value,
"only ops with affine read/write interface are supported");
AffineMap map = op.getAffineMap();
ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
ValueRange symbolOperands =
op.getMapOperands().take_back(map.getNumSymbols());
if (!remainsLegalAfterInline(
dimOperands, src, dest, mapping,
static_cast<bool (*)(Value, Region *)>(isValidDim)))
return false;
if (!remainsLegalAfterInline(
symbolOperands, src, dest, mapping,
static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
return false;
return true;
}
/// Checks if an affine apply operation remains legal after inlining from `src`
/// to `dest`.
template <>
bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
const BlockAndValueMapping &mapping) {
// If it's a valid dimension, we need to check that it remains so.
if (isValidDim(op.getResult(), src))
return remainsLegalAfterInline(
op.getMapOperands(), src, dest, mapping,
static_cast<bool (*)(Value, Region *)>(isValidDim));
// Otherwise it must be a valid symbol, check that it remains so.
return remainsLegalAfterInline(
op.getMapOperands(), src, dest, mapping,
static_cast<bool (*)(Value, Region *)>(isValidSymbol));
}
//===----------------------------------------------------------------------===//
// AffineDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
/// This class defines the interface for handling inlining with affine
/// operations.
struct AffineInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// Returns true if the given region 'src' can be inlined into the region
/// 'dest' that is attached to an operation registered to the current dialect.
/// 'wouldBeCloned' is set if the region is cloned into its new location
/// rather than moved, indicating there may be other users.
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
// We can inline into affine loops and conditionals if this doesn't break
// affine value categorization rules.
Operation *destOp = dest->getParentOp();
if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
return false;
// Multi-block regions cannot be inlined into affine constructs, all of
// which require single-block regions.
if (!llvm::hasSingleElement(*src))
return false;
// Side-effecting operations that the affine dialect cannot understand
// should not be inlined.
Block &srcBlock = src->front();
for (Operation &op : srcBlock) {
// Ops with no side effects are fine,
if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
if (iface.hasNoEffect())
continue;
}
// Assuming the inlined region is valid, we only need to check if the
// inlining would change it.
bool remainsValid =
llvm::TypeSwitch<Operation *, bool>(&op)
.Case<AffineApplyOp, AffineReadOpInterface,
AffineWriteOpInterface>([&](auto op) {
return remainsLegalAfterInline(op, src, dest, valueMapping);
})
.Default([](Operation *) {
// Conservatively disallow inlining ops we cannot reason about.
return false;
});
if (!remainsValid)
return false;
}
return true;
}
/// Returns true if the given operation 'op', that is registered to this
/// dialect, can be inlined into the given region, false otherwise.
bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
// Always allow inlining affine operations into a region that is marked as
// affine scope, or into affine loops and conditionals. There are some edge
// cases when inlining *into* affine structures, but that is handled in the
// other 'isLegalToInline' hook above.
Operation *parentOp = region->getParentOp();
return parentOp->hasTrait<OpTrait::AffineScope>() ||
isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
}
/// Affine regions should be analyzed recursively.
bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// AffineDialect
//===----------------------------------------------------------------------===//
void AffineDialect::initialize() {
addOperations<AffineDmaStartOp, AffineDmaWaitOp,
#define GET_OP_LIST
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
>();
addInterfaces<AffineInlinerInterface>();
}
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<ConstantOp>(loc, type, value);
}
/// A utility function to check if a value is defined at the top level of an
/// op with trait `AffineScope`. If the value is defined in an unlinked region,
/// conservatively assume it is not top-level. A value of index type defined at
/// the top level is always a valid symbol.
bool mlir::isTopLevelValue(Value value) {
if (auto arg = value.dyn_cast<BlockArgument>()) {
// The block owning the argument may be unlinked, e.g. when the surrounding
// region has not yet been attached to an Op, at which point the parent Op
// is null.
Operation *parentOp = arg.getOwner()->getParentOp();
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}
// The defining Op may live in an unlinked block so its parent Op may be null.
Operation *parentOp = value.getDefiningOp()->getParentOp();
return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}
/// Returns the closest region enclosing `op` that is held by an operation with
/// trait `AffineScope`; `nullptr` if there is no such region.
// TODO: getAffineScope should be publicly exposed for affine passes/utilities.
static Region *getAffineScope(Operation *op) {
auto *curOp = op;
while (auto *parentOp = curOp->getParentOp()) {
if (parentOp->hasTrait<OpTrait::AffineScope>())
return curOp->getParentRegion();
curOp = parentOp;
}
return nullptr;
}
// A Value can be used as a dimension id iff it meets one of the following
// conditions:
// *) It is valid as a symbol.
// *) It is an induction variable.
// *) It is the result of affine apply operation with dimension id arguments.
bool mlir::isValidDim(Value value) {
// The value must be an index type.
if (!value.getType().isIndex())
return false;
if (auto *defOp = value.getDefiningOp())
return isValidDim(value, getAffineScope(defOp));
// This value has to be a block argument for an op that has the
// `AffineScope` trait or for an affine.for or affine.parallel.
auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
isa<AffineForOp, AffineParallelOp>(parentOp));
}
// Value can be used as a dimension id iff it meets one of the following
// conditions:
// *) It is valid as a symbol.
// *) It is an induction variable.
// *) It is the result of an affine apply operation with dimension id operands.
bool mlir::isValidDim(Value value, Region *region) {
// The value must be an index type.
if (!value.getType().isIndex())
return false;
// All valid symbols are okay.
if (isValidSymbol(value, region))
return true;
auto *op = value.getDefiningOp();
if (!op) {
// This value has to be a block argument for an affine.for or an
// affine.parallel.
auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
return isa<AffineForOp, AffineParallelOp>(parentOp);
}
// Affine apply operation is ok if all of its operands are ok.
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
return applyOp.isValidDim(region);
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
if (auto dimOp = dyn_cast<memref::DimOp>(op))
return isTopLevelValue(dimOp.memrefOrTensor());
return false;
}
/// Returns true if the 'index' dimension of the `memref` defined by
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
/// for `region`.
template <typename AnyMemRefDefOp>
2020-05-20 04:16:15 +08:00
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
Region *region) {
auto memRefType = memrefDefOp.getType();
// Statically shaped.
if (!memRefType.isDynamicDim(index))
return true;
// Get the position of the dimension among dynamic dimensions;
unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
region);
}
/// Returns true if the result of the dim op is a valid symbol for `region`.
static bool isDimOpValidSymbol(memref::DimOp dimOp, Region *region) {
// The dim op is okay if its operand memref is defined at the top level.
if (isTopLevelValue(dimOp.memrefOrTensor()))
return true;
// Conservatively handle remaining BlockArguments as non-valid symbols.
// E.g. scf.for iterArgs.
if (dimOp.memrefOrTensor().isa<BlockArgument>())
return false;
// The dim op is also okay if its operand memref is a view/subview whose
// corresponding size is a valid symbol.
Optional<int64_t> index = dimOp.getConstantIndex();
assert(index.hasValue() &&
"expect only `dim` operations with a constant index");
int64_t i = index.getValue();
return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
[&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
.Default([](Operation *) { return false; });
}
// A value can be used as a symbol (at all its use sites) iff it meets one of
// the following conditions:
// *) It is a constant.
// *) Its defining op or block arg appearance is immediately enclosed by an op
// with `AffineScope` trait.
// *) It is the result of an affine.apply operation with symbol operands.
// *) It is a result of the dim op on a memref whose corresponding size is a
// valid symbol.
bool mlir::isValidSymbol(Value value) {
// The value must be an index type.
if (!value.getType().isIndex())
return false;
// Check that the value is a top level value.
if (isTopLevelValue(value))
return true;
if (auto *defOp = value.getDefiningOp())
return isValidSymbol(value, getAffineScope(defOp));
return false;
}
/// A value can be used as a symbol for `region` iff it meets onf of the the
/// following conditions:
/// *) It is a constant.
/// *) It is the result of an affine apply operation with symbol arguments.
/// *) It is a result of the dim op on a memref whose corresponding size is
/// a valid symbol.
/// *) It is defined at the top level of 'region' or is its argument.
/// *) It dominates `region`'s parent op.
/// If `region` is null, conservatively assume the symbol definition scope does
/// not exist and only accept the values that would be symbols regardless of
/// the surrounding region structure, i.e. the first three cases above.
bool mlir::isValidSymbol(Value value, Region *region) {
// The value must be an index type.
if (!value.getType().isIndex())
return false;
// A top-level value is a valid symbol.
if (region && ::isTopLevelValue(value, region))
return true;
auto *defOp = value.getDefiningOp();
if (!defOp) {
// A block argument that is not a top-level value is a valid symbol if it
// dominates region's parent op.
Operation *regionOp = region ? region->getParentOp() : nullptr;
if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
return isValidSymbol(value, parentOpRegion);
return false;
}
// Constant operation is ok.
Attribute operandCst;
if (matchPattern(defOp, m_Constant(&operandCst)))
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
return applyOp.isValidSymbol(region);
// Dim op results could be valid symbols at any level.
if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
return isDimOpValidSymbol(dimOp, region);
// Check for values dominating `region`'s parent op.
Operation *regionOp = region ? region->getParentOp() : nullptr;
if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
if (auto *parentRegion = region->getParentOp()->getParentRegion())
return isValidSymbol(value, parentRegion);
return false;
}
// Returns true if 'value' is a valid index to an affine operation (e.g.
// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
// `region` provides the polyhedral symbol scope. Returns false otherwise.
static bool isValidAffineIndexOperand(Value value, Region *region) {
return isValidDim(value, region) || isValidSymbol(value, region);
}
/// Prints dimension and symbol list.
static void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end,
unsigned numDims, OpAsmPrinter &printer) {
OperandRange operands(begin, end);
printer << '(' << operands.take_front(numDims) << ')';
if (operands.size() > numDims)
printer << '[' << operands.drop_front(numDims) << ']';
}
/// Parses dimension and symbol list and returns true if parsing failed.
ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
SmallVectorImpl<Value> &operands,
unsigned &numDims) {
SmallVector<OpAsmParser::OperandType, 8> opInfos;
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
return failure();
// Store number of dimensions for validation by caller.
numDims = opInfos.size();
// Parse the optional symbol operands.
auto indexTy = parser.getBuilder().getIndexType();
return failure(parser.parseOperandList(
opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
parser.resolveOperands(opInfos, indexTy, operands));
}
/// Utility function to verify that a set of operands are valid dimension and
/// symbol identifiers. The operands should be laid out such that the dimension
/// operands are before the symbol operands. This function returns failure if
/// there was an invalid operand. An operation is provided to emit any necessary
/// errors.
template <typename OpTy>
static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
unsigned numDims) {
unsigned opIt = 0;
for (auto operand : operands) {
if (opIt++ < numDims) {
if (!isValidDim(operand, getAffineScope(op)))
return op.emitOpError("operand cannot be used as a dimension id");
} else if (!isValidSymbol(operand, getAffineScope(op))) {
return op.emitOpError("operand cannot be used as a symbol");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//
AffineValueMap AffineApplyOp::getAffineValueMap() {
return AffineValueMap(getAffineMap(), getOperands(), getResult());
}
static ParseResult parseAffineApplyOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();
AffineMapAttr mapAttr;
unsigned numDims;
if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
parseDimAndSymbolList(parser, result.operands, numDims) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
auto map = mapAttr.getValue();
if (map.getNumDims() != numDims ||
numDims + map.getNumSymbols() != result.operands.size()) {
return parser.emitError(parser.getNameLoc(),
"dimension or symbol index mismatch");
}
result.types.append(map.getNumResults(), indexTy);
return success();
}
static void print(OpAsmPrinter &p, AffineApplyOp op) {
p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
printDimAndSymbolList(op.operand_begin(), op.operand_end(),
op.getAffineMap().getNumDims(), p);
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"});
}
static LogicalResult verify(AffineApplyOp op) {
// Check input and output dimensions match.
auto map = op.map();
// Verify that operand count matches affine map dimension and symbol count.
if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
return op.emitOpError(
"operand count and affine map dimension and symbol count must match");
// Verify that the map only produces one result.
if (map.getNumResults() != 1)
return op.emitOpError("mapping must produce one value");
return success();
}
// The result of the affine apply operation can be used as a dimension id if all
// its operands are valid dimension ids.
bool AffineApplyOp::isValidDim() {
return llvm::all_of(getOperands(),
[](Value op) { return mlir::isValidDim(op); });
}
// The result of the affine apply operation can be used as a dimension id if all
// its operands are valid dimension ids with the parent operation of `region`
// defining the polyhedral scope for symbols.
bool AffineApplyOp::isValidDim(Region *region) {
return llvm::all_of(getOperands(),
[&](Value op) { return ::isValidDim(op, region); });
}
// The result of the affine apply operation can be used as a symbol if all its
// operands are symbols.
bool AffineApplyOp::isValidSymbol() {
return llvm::all_of(getOperands(),
[](Value op) { return mlir::isValidSymbol(op); });
}
// The result of the affine apply operation can be used as a symbol in `region`
// if all its operands are symbols in `region`.
bool AffineApplyOp::isValidSymbol(Region *region) {
return llvm::all_of(getOperands(), [&](Value operand) {
return mlir::isValidSymbol(operand, region);
});
}
OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
auto map = getAffineMap();
// Fold dims and symbols to existing values.
auto expr = map.getResult(0);
if (auto dim = expr.dyn_cast<AffineDimExpr>())
return getOperand(dim.getPosition());
if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
return getOperand(map.getNumDims() + sym.getPosition());
// Otherwise, default to folding the map.
SmallVector<Attribute, 1> result;
if (failed(map.constantFold(operands, result)))
return {};
return result[0];
}
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
/// When `dimOrSymbolPosition >= dims.size()`,
/// AffineSymbolExpr@[pos - dims.size()] is replaced.
/// Mutate `map`,`dims` and `syms` in place as follows:
/// 1. `dims` and `syms` are only appended to.
/// 2. `map` dim and symbols are gradually shifted to higer positions.
/// 3. Old `dim` and `sym` entries are replaced by nullptr
/// This avoids the need for any bookkeeping.
static LogicalResult replaceDimOrSym(AffineMap *map,
unsigned dimOrSymbolPosition,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &syms) {
bool isDimReplacement = (dimOrSymbolPosition < dims.size());
unsigned pos = isDimReplacement ? dimOrSymbolPosition
: dimOrSymbolPosition - dims.size();
Value &v = isDimReplacement ? dims[pos] : syms[pos];
if (!v)
return failure();
auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
// At this point we will perform a replacement of `v`, set the entry in `dim`
// or `sym` to nullptr immediately.
v = nullptr;
// Compute the map, dims and symbols coming from the AffineApplyOp.
AffineMap composeMap = affineApply.getAffineMap();
assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
AffineExpr composeExpr =
composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
ValueRange composeDims =
affineApply.getMapOperands().take_front(composeMap.getNumDims());
ValueRange composeSyms =
affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
// Perform the replacement and append the dims and symbols where relevant.
MLIRContext *ctx = map->getContext();
AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
: getAffineSymbolExpr(pos, ctx);
*map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
dims.append(composeDims.begin(), composeDims.end());
syms.append(composeSyms.begin(), composeSyms.end());
return success();
}
/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
/// iteratively. Perform canonicalization of map and operands as well as
/// AffineMap simplification. `map` and `operands` are mutated in place.
static void composeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands) {
if (map->getNumResults() == 0) {
canonicalizeMapAndOperands(map, operands);
*map = simplifyAffineMap(*map);
return;
}
MLIRContext *ctx = map->getContext();
SmallVector<Value, 4> dims(operands->begin(),
operands->begin() + map->getNumDims());
SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
operands->end());
// Iterate over dims and symbols coming from AffineApplyOp and replace until
// exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
// and `syms` can only increase by construction.
// The implementation uses a `while` loop to support the case of symbols
// that may be constructed from dims ;this may be overkill.
while (true) {
bool changed = false;
for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
break;
if (!changed)
break;
}
// Clear operands so we can fill them anew.
operands->clear();
// At this point we may have introduced null operands, prune them out before
// canonicalizing map and operands.
unsigned nDims = 0, nSyms = 0;
SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
dimReplacements.reserve(dims.size());
symReplacements.reserve(syms.size());
for (auto *container : {&dims, &syms}) {
bool isDim = (container == &dims);
auto &repls = isDim ? dimReplacements : symReplacements;
for (auto en : llvm::enumerate(*container)) {
Value v = en.value();
if (!v) {
assert(isDim ? !map->isFunctionOfDim(en.index())
: !map->isFunctionOfSymbol(en.index()) &&
"map is function of unexpected expr@pos");
repls.push_back(getAffineConstantExpr(0, ctx));
continue;
}
repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
: getAffineSymbolExpr(nSyms++, ctx));
operands->push_back(v);
}
}
*map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
nSyms);
// Canonicalize and simplify before returning.
canonicalizeMapAndOperands(map, operands);
*map = simplifyAffineMap(*map);
}
void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands) {
while (llvm::any_of(*operands, [](Value v) {
return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
})) {
composeAffineMapAndOperands(map, operands);
}
}
AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
AffineMap map,
ValueRange operands) {
AffineMap normalizedMap = map;
SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
assert(normalizedMap);
return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
}
AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
AffineExpr e, ValueRange values) {
return makeComposedAffineApply(
b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(),
values);
}
// A symbol may appear as a dim in affine.apply operations. This function
// canonicalizes dims that are valid symbols into actual symbols.
template <class MapOrSet>
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
SmallVectorImpl<Value> *operands) {
if (!mapOrSet || operands->empty())
return;
assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
auto *context = mapOrSet->getContext();
SmallVector<Value, 8> resultOperands;
resultOperands.reserve(operands->size());
SmallVector<Value, 8> remappedSymbols;
remappedSymbols.reserve(operands->size());
unsigned nextDim = 0;
unsigned nextSym = 0;
unsigned oldNumSyms = mapOrSet->getNumSymbols();
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
if (i < mapOrSet->getNumDims()) {
if (isValidSymbol((*operands)[i])) {
// This is a valid symbol that appears as a dim, canonicalize it.
dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
remappedSymbols.push_back((*operands)[i]);
} else {
dimRemapping[i] = getAffineDimExpr(nextDim++, context);
resultOperands.push_back((*operands)[i]);
}
} else {
resultOperands.push_back((*operands)[i]);
}
}
resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
*operands = resultOperands;
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
oldNumSyms + nextSym);
assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
}
// Works for either an affine map or an integer set.
template <class MapOrSet>
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
SmallVectorImpl<Value> *operands) {
static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
"Argument must be either of AffineMap or IntegerSet type");
if (!mapOrSet || operands->empty())
return;
assert(mapOrSet->getNumInputs() == operands->size() &&
"map/set inputs must match number of operands");
canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
// Check to see what dims are used.
llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
mapOrSet->walkExprs([&](AffineExpr expr) {
if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
usedDims[dimExpr.getPosition()] = true;
else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
usedSyms[symExpr.getPosition()] = true;
});
auto *context = mapOrSet->getContext();
SmallVector<Value, 8> resultOperands;
resultOperands.reserve(operands->size());
llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
unsigned nextDim = 0;
for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
if (usedDims[i]) {
// Remap dim positions for duplicate operands.
auto it = seenDims.find((*operands)[i]);
if (it == seenDims.end()) {
dimRemapping[i] = getAffineDimExpr(nextDim++, context);
resultOperands.push_back((*operands)[i]);
seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
} else {
dimRemapping[i] = it->second;
}
}
}
llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
unsigned nextSym = 0;
for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
if (!usedSyms[i])
continue;
// Handle constant operands (only needed for symbolic operands since
// constant operands in dimensional positions would have already been
// promoted to symbolic positions above).
IntegerAttr operandCst;
if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
m_Constant(&operandCst))) {
symRemapping[i] =
getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
continue;
}
// Remap symbol positions for duplicate operands.
auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
if (it == seenSymbols.end()) {
symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
symRemapping[i]));
} else {
symRemapping[i] = it->second;
}
}
*mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
nextDim, nextSym);
*operands = resultOperands;
}
void mlir::canonicalizeMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands) {
canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
}
void mlir::canonicalizeSetAndOperands(IntegerSet *set,
SmallVectorImpl<Value> *operands) {
canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
}
namespace {
/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
/// maps that supply results into them.
///
template <typename AffineOpTy>
struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
using OpRewritePattern<AffineOpTy>::OpRewritePattern;
/// Replace the affine op with another instance of it with the supplied
/// map and mapOperands.
void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
AffineMap map, ArrayRef<Value> mapOperands) const;
LogicalResult matchAndRewrite(AffineOpTy affineOp,
PatternRewriter &rewriter) const override {
static_assert(
llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
AffineVectorStoreOp, AffineVectorLoadOp>::value,
"affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
"expected");
auto map = affineOp.getAffineMap();
AffineMap oldMap = map;
auto oldOperands = affineOp.getMapOperands();
SmallVector<Value, 8> resultOperands(oldOperands);
composeAffineMapAndOperands(&map, &resultOperands);
canonicalizeMapAndOperands(&map, &resultOperands);
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
resultOperands.begin()))
return failure();
replaceAffineOp(rewriter, affineOp, map, resultOperands);
return success();
}
};
// Specialize the template to account for the different build signatures for
// affine load, store, and apply ops.
template <>
void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
mapOperands);
}
template <>
void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
prefetch.isWrite(), prefetch.isDataCache());
}
template <>
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
}
template <>
void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
mapOperands);
}
template <>
void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
mapOperands);
}
// Generic version for ops that don't have extra operands.
template <typename AffineOpTy>
void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
ArrayRef<Value> mapOperands) const {
rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
}
} // end anonymous namespace.
void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyAffineOp<AffineApplyOp>>(context);
}
//===----------------------------------------------------------------------===//
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<memref::CastOp>();
if (cast && operand.get() != ignore &&
!cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
}
return success(folded);
}
//===----------------------------------------------------------------------===//
// AffineDmaStartOp
//===----------------------------------------------------------------------===//
// TODO: Check that map operands are loop IVs or symbols.
void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
Value srcMemRef, AffineMap srcMap,
ValueRange srcIndices, Value destMemRef,
AffineMap dstMap, ValueRange destIndices,
Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices, Value numElements,
Value stride, Value elementsPerStride) {
result.addOperands(srcMemRef);
result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
result.addOperands(srcIndices);
result.addOperands(destMemRef);
result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
result.addOperands(destIndices);
result.addOperands(tagMemRef);
result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
result.addOperands(tagIndices);
result.addOperands(numElements);
if (stride) {
result.addOperands({stride, elementsPerStride});
}
}
void AffineDmaStartOp::print(OpAsmPrinter &p) {
p << "affine.dma_start " << getSrcMemRef() << '[';
p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
p << "], " << getDstMemRef() << '[';
p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
p << "], " << getTagMemRef() << '[';
p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
p << "], " << getNumElements();
if (isStrided()) {
p << ", " << getStride();
p << ", " << getNumElementsPerStride();
}
p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
<< getTagMemRefType();
}
// Parse AffineDmaStartOp.
// Ex:
// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
// %stride, %num_elt_per_stride
// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
//
ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType srcMemRefInfo;
AffineMapAttr srcMapAttr;
SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
OpAsmParser::OperandType dstMemRefInfo;
AffineMapAttr dstMapAttr;
SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
OpAsmParser::OperandType tagMemRefInfo;
AffineMapAttr tagMapAttr;
SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
OpAsmParser::OperandType numElementsInfo;
SmallVector<OpAsmParser::OperandType, 2> strideInfo;
SmallVector<Type, 3> types;
auto indexType = parser.getBuilder().getIndexType();
// Parse and resolve the following list of operands:
// *) dst memref followed by its affine maps operands (in square brackets).
// *) src memref followed by its affine map operands (in square brackets).
// *) tag memref followed by its affine map operands (in square brackets).
// *) number of elements transferred by DMA operation.
if (parser.parseOperand(srcMemRefInfo) ||
parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
getSrcMapAttrName(), result.attributes) ||
parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
getDstMapAttrName(), result.attributes) ||
parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
getTagMapAttrName(), result.attributes) ||
parser.parseComma() || parser.parseOperand(numElementsInfo))
return failure();
// Parse optional stride and elements per stride.
if (parser.parseTrailingOperandList(strideInfo)) {
return failure();
}
if (!strideInfo.empty() && strideInfo.size() != 2) {
return parser.emitError(parser.getNameLoc(),
"expected two stride related operands");
}
bool isStrided = strideInfo.size() == 2;
if (parser.parseColonTypeList(types))
return failure();
if (types.size() != 3)
return parser.emitError(parser.getNameLoc(), "expected three types");
if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
parser.resolveOperand(numElementsInfo, indexType, result.operands))
return failure();
if (isStrided) {
if (parser.resolveOperands(strideInfo, indexType, result.operands))
return failure();
}
// Check that src/dst/tag operand counts match their map.numInputs.
if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
return parser.emitError(parser.getNameLoc(),
"memref operand count not equal to map.numInputs");
return success();
}
LogicalResult AffineDmaStartOp::verify() {
if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA source to be of memref type");
if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA destination to be of memref type");
if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
getDstMap().getNumInputs() +
getTagMap().getNumInputs();
if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
return emitOpError("incorrect number of operands");
}
Region *scope = getAffineScope(*this);
for (auto idx : getSrcIndices()) {
if (!idx.getType().isIndex())
return emitOpError("src index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
return emitOpError("src index must be a dimension or symbol identifier");
}
for (auto idx : getDstIndices()) {
if (!idx.getType().isIndex())
return emitOpError("dst index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
return emitOpError("dst index must be a dimension or symbol identifier");
}
for (auto idx : getTagIndices()) {
if (!idx.getType().isIndex())
return emitOpError("tag index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
return emitOpError("tag index must be a dimension or symbol identifier");
}
return success();
}
LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
// AffineDmaWaitOp
//===----------------------------------------------------------------------===//
// TODO: Check that map operands are loop IVs or symbols.
void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices, Value numElements) {
result.addOperands(tagMemRef);
result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
result.addOperands(tagIndices);
result.addOperands(numElements);
}
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
p << "affine.dma_wait " << getTagMemRef() << '[';
SmallVector<Value, 2> operands(getTagIndices());
p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
p << "], ";
p.printOperand(getNumElements());
p << " : " << getTagMemRef().getType();
}
// Parse AffineDmaWaitOp.
// Eg:
// affine.dma_wait %tag[%index], %num_elements
// : memref<1 x i32, (d0) -> (d0), 4>
//
ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType tagMemRefInfo;
AffineMapAttr tagMapAttr;
SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
Type type;
auto indexType = parser.getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its map operands, and dma size.
if (parser.parseOperand(tagMemRefInfo) ||
parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
getTagMapAttrName(), result.attributes) ||
parser.parseComma() || parser.parseOperand(numElementsInfo) ||
parser.parseColonType(type) ||
parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
parser.resolveOperand(numElementsInfo, indexType, result.operands))
return failure();
if (!type.isa<MemRefType>())
return parser.emitError(parser.getNameLoc(),
"expected tag to be of memref type");
if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
return parser.emitError(parser.getNameLoc(),
"tag memref operand count != to map.numInputs");
return success();
}
LogicalResult AffineDmaWaitOp::verify() {
if (!getOperand(0).getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
Region *scope = getAffineScope(*this);
for (auto idx : getTagIndices()) {
if (!idx.getType().isIndex())
return emitOpError("index to dma_wait must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
return emitOpError("index must be a dimension or symbol identifier");
}
return success();
}
LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
// AffineForOp
//===----------------------------------------------------------------------===//
/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
/// bodyBuilder are empty/null, we include default terminator op.
void AffineForOp::build(OpBuilder &builder, OperationState &result,
ValueRange lbOperands, AffineMap lbMap,
ValueRange ubOperands, AffineMap ubMap, int64_t step,
ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
assert(((!lbMap && lbOperands.empty()) ||
lbOperands.size() == lbMap.getNumInputs()) &&
"lower bound operand count does not match the affine map");
assert(((!ubMap && ubOperands.empty()) ||
ubOperands.size() == ubMap.getNumInputs()) &&
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
for (Value val : iterArgs)
result.addTypes(val.getType());
// Add an attribute for the step.
result.addAttribute(getStepAttrName(),
builder.getIntegerAttr(builder.getIndexType(), step));
// Add the lower bound.
result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
result.addOperands(lbOperands);
// Add the upper bound.
result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
result.addOperands(ubOperands);
result.addOperands(iterArgs);
// Create a region and a block for the body. The argument of the region is
// the loop induction variable.
Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
for (Value val : iterArgs)
bodyBlock.addArgument(val.getType());
// Create the default terminator if the builder is not provided and if the
// iteration arguments are not provided. Otherwise, leave this to the caller
// because we don't know which values to return from the loop.
if (iterArgs.empty() && !bodyBuilder) {
ensureTerminator(*bodyRegion, builder, result.location);
} else if (bodyBuilder) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&bodyBlock);
bodyBuilder(builder, result.location, inductionVar,
bodyBlock.getArguments().drop_front());
}
}
void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
int64_t ub, int64_t step, ValueRange iterArgs,
BodyBuilderFn bodyBuilder) {
auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
bodyBuilder);
}
static LogicalResult verify(AffineForOp op) {
// Check that the body defines as single block argument for the induction
// variable.
auto *body = op.getBody();
if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
return op.emitOpError(
"expected body to have a single index argument for the "
"induction variable");
// Verify that the bound operands are valid dimension/symbols.
/// Lower bound.
if (op.getLowerBoundMap().getNumInputs() > 0)
if (failed(
verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
op.getLowerBoundMap().getNumDims())))
return failure();
/// Upper bound.
if (op.getUpperBoundMap().getNumInputs() > 0)
if (failed(
verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
op.getUpperBoundMap().getNumDims())))
return failure();
unsigned opNumResults = op.getNumResults();
if (opNumResults == 0)
return success();
// If ForOp defines values, check that the number and types of the defined
// values match ForOp initial iter operands and backedge basic block
// arguments.
if (op.getNumIterOperands() != opNumResults)
return op.emitOpError(
"mismatch between the number of loop-carried values and results");
if (op.getNumRegionIterArgs() != opNumResults)
return op.emitOpError(
"mismatch between the number of basic block args and results");
return success();
}
/// Parse a for operation loop bounds.
static ParseResult parseBound(bool isLower, OperationState &result,
OpAsmParser &p) {
// 'min' / 'max' prefixes are generally syntactic sugar, but are required if
// the map has multiple results.
bool failedToParsedMinMax =
failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
auto &builder = p.getBuilder();
auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
: AffineForOp::getUpperBoundAttrName();
// Parse ssa-id as identity map.
SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
if (p.parseOperandList(boundOpInfos))
return failure();
if (!boundOpInfos.empty()) {
// Check that only one operand was parsed.
if (boundOpInfos.size() > 1)
return p.emitError(p.getNameLoc(),
"expected only one loop bound operand");
// TODO: improve error message when SSA value is not of index type.
// Currently it is 'use of value ... expects different type than prior uses'
if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
result.operands))
return failure();
// Create an identity map using symbol id. This representation is optimized
// for storage. Analysis passes may expand it into a multi-dimensional map
// if desired.
AffineMap map = builder.getSymbolIdentityMap();
result.addAttribute(boundAttrName, AffineMapAttr::get(map));
return success();
}
// Get the attribute location.
llvm::SMLoc attrLoc = p.getCurrentLocation();
Attribute boundAttr;
if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
result.attributes))
return failure();
// Parse full form - affine map followed by dim and symbol list.
if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
unsigned currentNumOperands = result.operands.size();
unsigned numDims;
if (parseDimAndSymbolList(p, result.operands, numDims))
return failure();
auto map = affineMapAttr.getValue();
if (map.getNumDims() != numDims)
return p.emitError(
p.getNameLoc(),
"dim operand count and affine map dim count must match");
unsigned numDimAndSymbolOperands =
result.operands.size() - currentNumOperands;
if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
return p.emitError(
p.getNameLoc(),
"symbol operand count and affine map symbol count must match");
// If the map has multiple results, make sure that we parsed the min/max
// prefix.
if (map.getNumResults() > 1 && failedToParsedMinMax) {
if (isLower) {
return p.emitError(attrLoc, "lower loop bound affine map with "
"multiple results requires 'max' prefix");
}
return p.emitError(attrLoc, "upper loop bound affine map with multiple "
"results requires 'min' prefix");
}
return success();
}
// Parse custom assembly form.
if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
result.attributes.pop_back();
result.addAttribute(
boundAttrName,
AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
return success();
}
return p.emitError(
p.getNameLoc(),
"expected valid affine map representation for loop bounds");
}
static ParseResult parseAffineForOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::OperandType inductionVariable;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
return failure();
// Parse loop bounds.
if (parseBound(/*isLower=*/true, result, parser) ||
parser.parseKeyword("to", " between bounds") ||
parseBound(/*isLower=*/false, result, parser))
return failure();
// Parse the optional loop step, we default to 1 if one is not present.
if (parser.parseOptionalKeyword("step")) {
result.addAttribute(
AffineForOp::getStepAttrName(),
builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
} else {
llvm::SMLoc stepLoc = parser.getCurrentLocation();
IntegerAttr stepAttr;
if (parser.parseAttribute(stepAttr, builder.getIndexType(),
AffineForOp::getStepAttrName().data(),
result.attributes))
return failure();
if (stepAttr.getValue().getSExtValue() < 0)
return parser.emitError(
stepLoc,
"expected step to be representable as a positive signed integer");
}
// Parse the optional initial iteration arguments.
SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
SmallVector<Type, 4> argTypes;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
// Parse assignment list and results type list.
if (parser.parseAssignmentList(regionArgs, operands) ||
parser.parseArrowTypeList(result.types))
return failure();
// Resolve input operands.
for (auto operandType : llvm::zip(operands, result.types))
if (parser.resolveOperand(std::get<0>(operandType),
std::get<1>(operandType), result.operands))
return failure();
}
// Induction variable.
Type indexType = builder.getIndexType();
argTypes.push_back(indexType);
// Loop carried variables.
argTypes.append(result.types.begin(), result.types.end());
// Parse the body region.
Region *body = result.addRegion();
if (regionArgs.size() != argTypes.size())
return parser.emitError(
parser.getNameLoc(),
"mismatch between the number of loop-carried values and results");
if (parser.parseRegion(*body, regionArgs, argTypes))
return failure();
AffineForOp::ensureTerminator(*body, builder, result.location);
// Parse the optional attribute list.
return parser.parseOptionalAttrDict(result.attributes);
}
static void printBound(AffineMapAttr boundMap,
Operation::operand_range boundOperands,
const char *prefix, OpAsmPrinter &p) {
AffineMap map = boundMap.getValue();
// Check if this bound should be printed using custom assembly form.
// The decision to restrict printing custom assembly form to trivial cases
// comes from the will to roundtrip MLIR binary -> text -> binary in a
// lossless way.
// Therefore, custom assembly form parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map.getNumResults() == 1) {
AffineExpr expr = map.getResult(0);
// Print constant bound.
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
p << constExpr.getValue();
return;
}
}
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
p.printOperand(*boundOperands.begin());
return;
}
}
} else {
// Map has multiple results. Print 'min' or 'max' prefix.
p << prefix << ' ';
}
// Print the map and its operands.
p << boundMap;
printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
map.getNumDims(), p);
}
unsigned AffineForOp::getNumIterOperands() {
AffineMap lbMap = getLowerBoundMapAttr().getValue();
AffineMap ubMap = getUpperBoundMapAttr().getValue();
return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
}
static void print(OpAsmPrinter &p, AffineForOp op) {
p << op.getOperationName() << ' ';
p.printOperand(op.getBody()->getArgument(0));
p << " = ";
printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
p << " to ";
printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
if (op.getStep() != 1)
p << " step " << op.getStep();
bool printBlockTerminators = false;
if (op.getNumIterOperands() > 0) {
p << " iter_args(";
auto regionArgs = op.getRegionIterArgs();
auto operands = op.getIterOperands();
llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
});
p << ") -> (" << op.getResultTypes() << ")";
printBlockTerminators = true;
}
p.printRegion(op.region(),
/*printEntryBlockArgs=*/false, printBlockTerminators);
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{op.getLowerBoundAttrName(),
op.getUpperBoundAttrName(),
op.getStepAttrName()});
}
/// Fold the constant bounds of a loop.
static LogicalResult foldLoopBounds(AffineForOp forOp) {
auto foldLowerOrUpperBound = [&forOp](bool lower) {
// Check to see if each of the operands is the result of a constant. If
// so, get the value. If not, ignore it.
SmallVector<Attribute, 8> operandConstants;
auto boundOperands =
lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
for (auto operand : boundOperands) {
Attribute operandCst;
matchPattern(operand, m_Constant(&operandCst));
operandConstants.push_back(operandCst);
}
AffineMap boundMap =
lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute, 4> foldedResults;
if (failed(boundMap.constantFold(operandConstants, foldedResults)))
return failure();
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result");
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
: forOp.setConstantUpperBound(maxOrMin.getSExtValue());
return success();
};
// Try to fold the lower bound.
bool folded = false;
if (!forOp.hasConstantLowerBound())
folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
// Try to fold the upper bound.
if (!forOp.hasConstantUpperBound())
folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
return success(folded);
}
/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
auto lbMap = forOp.getLowerBoundMap();
auto ubMap = forOp.getUpperBoundMap();
auto prevLbMap = lbMap;
auto prevUbMap = ubMap;
composeAffineMapAndOperands(&lbMap, &lbOperands);
canonicalizeMapAndOperands(&lbMap, &lbOperands);
lbMap = removeDuplicateExprs(lbMap);
composeAffineMapAndOperands(&ubMap, &ubOperands);
canonicalizeMapAndOperands(&ubMap, &ubOperands);
ubMap = removeDuplicateExprs(ubMap);
// Any canonicalization change always leads to updated map(s).
if (lbMap == prevLbMap && ubMap == prevUbMap)
return failure();
if (lbMap != prevLbMap)
forOp.setLowerBound(lbOperands, lbMap);
if (ubMap != prevUbMap)
forOp.setUpperBound(ubOperands, ubMap);
return success();
}
namespace {
/// This is a pattern to fold trivially empty loops.
struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
using OpRewritePattern<AffineForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
// Check that the body only contains a yield.
if (!llvm::hasSingleElement(*forOp.getBody()))
return failure();
rewriter.eraseOp(forOp);
return success();
}
};
} // end anonymous namespace
void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<AffineForEmptyLoopFolder>(context);
}
LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
bool folded = succeeded(foldLoopBounds(*this));
folded |= succeeded(canonicalizeLoopBounds(*this));
return success(folded);
}
AffineBound AffineForOp::getLowerBound() {
auto lbMap = getLowerBoundMap();
return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
}
AffineBound AffineForOp::getUpperBound() {
auto lbMap = getLowerBoundMap();
auto ubMap = getUpperBoundMap();
return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
}
void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
auto ubOperands = getUpperBoundOperands();
newOperands.append(ubOperands.begin(), ubOperands.end());
auto iterOperands = getIterOperands();
newOperands.append(iterOperands.begin(), iterOperands.end());
(*this)->setOperands(newOperands);
(*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
}
void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<Value, 4> newOperands(getLowerBoundOperands());
newOperands.append(ubOperands.begin(), ubOperands.end());
auto iterOperands = getIterOperands();
newOperands.append(iterOperands.begin(), iterOperands.end());
(*this)->setOperands(newOperands);
(*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
}
void AffineForOp::setLowerBoundMap(AffineMap map) {
auto lbMap = getLowerBoundMap();
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
(void)lbMap;
(*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
}
void AffineForOp::setUpperBoundMap(AffineMap map) {
auto ubMap = getUpperBoundMap();
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
(void)ubMap;
(*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
}
bool AffineForOp::hasConstantLowerBound() {
return getLowerBoundMap().isSingleConstant();
}
bool AffineForOp::hasConstantUpperBound() {
return getUpperBoundMap().isSingleConstant();
}
int64_t AffineForOp::getConstantLowerBound() {
return getLowerBoundMap().getSingleConstantResult();
}
int64_t AffineForOp::getConstantUpperBound() {
return getUpperBoundMap().getSingleConstantResult();
}
void AffineForOp::setConstantLowerBound(int64_t value) {
setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
}
void AffineForOp::setConstantUpperBound(int64_t value) {
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap().getNumInputs(),
operand_begin() + getLowerBoundMap().getNumInputs() +
getUpperBoundMap().getNumInputs()};
}
bool AffineForOp::matchingBoundOperandList() {
auto lbMap = getLowerBoundMap();
auto ubMap = getUpperBoundMap();
if (lbMap.getNumDims() != ubMap.getNumDims() ||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare Value 's.
if (getOperand(i) != getOperand(numOperands + i))
return false;
}
return true;
}
Region &AffineForOp::getLoopBody() { return region(); }
bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
return !region().isAncestor(value.getParentRegion());
}
LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
for (auto *op : ops)
op->moveBefore(*this);
return success();
}
/// Returns true if the provided value is the induction variable of a
/// AffineForOp.
bool mlir::isForInductionVar(Value val) {
return getForInductionVarOwner(val) != AffineForOp();
}
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
AffineForOp mlir::getForInductionVarOwner(Value val) {
auto ivArg = val.dyn_cast<BlockArgument>();
if (!ivArg || !ivArg.getOwner())
return AffineForOp();
auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
return dyn_cast<AffineForOp>(containingInst);
}
/// Extracts the induction variables from a list of AffineForOps and returns
/// them.
void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
SmallVectorImpl<Value> *ivs) {
ivs->reserve(forInsts.size());
for (auto forInst : forInsts)
ivs->push_back(forInst.getInductionVar());
}
/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
/// operations.
template <typename BoundListTy, typename LoopCreatorTy>
2020-07-09 19:48:56 +08:00
static void buildAffineLoopNestImpl(
OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
ArrayRef<int64_t> steps,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
LoopCreatorTy &&loopCreatorFn) {
assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
// If there are no loops to be constructed, construct the body anyway.
OpBuilder::InsertionGuard guard(builder);
if (lbs.empty()) {
if (bodyBuilderFn)
bodyBuilderFn(builder, loc, ValueRange());
return;
}
// Create the loops iteratively and store the induction variables.
SmallVector<Value, 4> ivs;
ivs.reserve(lbs.size());
for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
// Callback for creating the loop body, always creates the terminator.
auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
ValueRange iterArgs) {
ivs.push_back(iv);
// In the innermost loop, call the body builder.
if (i == e - 1 && bodyBuilderFn) {
OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
}
nestedBuilder.create<AffineYieldOp>(nestedLoc);
};
// Delegate actual loop creation to the callback in order to dispatch
// between constant- and variable-bound loops.
auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
builder.setInsertionPointToStart(loop.getBody());
}
}
/// Creates an affine loop from the bounds known to be constants.
static AffineForOp
buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
int64_t ub, int64_t step,
AffineForOp::BodyBuilderFn bodyBuilderFn) {
return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
bodyBuilderFn);
}
/// Creates an affine loop from the bounds that may or may not be constants.
static AffineForOp
buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
int64_t step,
AffineForOp::BodyBuilderFn bodyBuilderFn) {
auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
if (lbConst && ubConst)
return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
ubConst.getValue(), step,
bodyBuilderFn);
return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
builder.getDimIdentityMap(), step,
/*iterArgs=*/llvm::None, bodyBuilderFn);
}
void mlir::buildAffineLoopNest(
OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
buildAffineLoopFromConstants);
}
void mlir::buildAffineLoopNest(
OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
ArrayRef<int64_t> steps,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
buildAffineLoopFromValues);
}
//===----------------------------------------------------------------------===//
// AffineIfOp
//===----------------------------------------------------------------------===//
namespace {
/// Remove else blocks that have nothing other than a zero value yield.
struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineIfOp ifOp,
PatternRewriter &rewriter) const override {
if (ifOp.elseRegion().empty() ||
!llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
return failure();
rewriter.startRootUpdate(ifOp);
rewriter.eraseBlock(ifOp.getElseBlock());
rewriter.finalizeRootUpdate(ifOp);
return success();
}
};
/// Removes Affine.If cond if the condition is always true or false in certain
/// trivial cases. Promotes the then/else block in the parent operation block.
struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
using OpRewritePattern<AffineIfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineIfOp op,
PatternRewriter &rewriter) const override {
// If affine.if is returning results then don't remove it.
// TODO: Similar simplication can be done when affine.if return results.
if (op.getNumResults() > 0)
return failure();
IntegerSet conditionSet = op.getIntegerSet();
Block *blockToMove;
if (conditionSet.isEmptyIntegerSet()) {
// If the else region is not there, simply remove the Affine.if
// operation.
if (!op.hasElse()) {
rewriter.eraseOp(op);
return success();
}
blockToMove = op.getElseBlock();
} else if (conditionSet.getNumEqualities() == 1 &&
conditionSet.getNumInequalities() == 0 &&
conditionSet.getConstraint(0) == 0) {
// Condition to check for trivially true condition (0==0).
blockToMove = op.getThenBlock();
} else {
return failure();
}
// Remove the terminator from the block as it already exists in parent
// block.
Operation *blockTerminator = blockToMove->getTerminator();
rewriter.eraseOp(blockTerminator);
rewriter.mergeBlockBefore(blockToMove, op);
rewriter.eraseOp(op);
return success();
}
};
} // end anonymous namespace.
static LogicalResult verify(AffineIfOp op) {
// Verify that we have a condition attribute.
auto conditionAttr =
op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
if (!conditionAttr)
return op.emitOpError(
"requires an integer set attribute named 'condition'");
// Verify that there are enough operands for the condition.
IntegerSet condition = conditionAttr.getValue();
if (op.getNumOperands() != condition.getNumInputs())
return op.emitOpError(
"operand count and condition integer set dimension and "
"symbol count must match");
// Verify that the operands are valid dimension/symbols.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
condition.getNumDims())))
return failure();
return success();
}
static ParseResult parseAffineIfOp(OpAsmParser &parser,
OperationState &result) {
// Parse the condition attribute set.
IntegerSetAttr conditionAttr;
unsigned numDims;
if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
result.attributes) ||
parseDimAndSymbolList(parser, result.operands, numDims))
return failure();
// Verify the condition operands.
auto set = conditionAttr.getValue();
if (set.getNumDims() != numDims)
return parser.emitError(
parser.getNameLoc(),
"dim operand count and integer set dim count must match");
if (numDims + set.getNumSymbols() != result.operands.size())
return parser.emitError(
parser.getNameLoc(),
"symbol operand count and integer set symbol count must match");
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
Allow creating standalone Regions Currently, regions can only be constructed by passing in a `Function` or an `Instruction` pointer referencing the parent object, unlike `Function`s or `Instruction`s themselves that can be created without a parent. It leads to a rather complex flow in operation construction where one has to create the operation first before being able to work with its regions. It may be necessary to work with the regions before the operation is created. In particular, in `build` and `parse` functions that are executed _before_ the operation is created in cases where boilerplate region manipulation is required (for example, inserting the hypothetical default terminator in affine regions). Allow creating standalone regions. Such regions are meant to own a list of blocks and transfer them to other regions on demand. Each instruction stores a fixed number of regions as trailing objects and has ownership of them. This decreases the size of the Instruction object for the common case of instructions without regions. Keep this behavior intact. To allow some flexibility in construction, make OperationState store an owning vector of regions. When the Builder creates an Instruction from OperationState, the bodies of the regions are transferred into the instruction-owned regions to minimize copying. Thus, it becomes possible to fill standalone regions with blocks and move them to an operation when it is constructed, or move blocks from a region to an operation region, e.g., for inlining. PiperOrigin-RevId: 240368183
2019-03-27 00:55:06 +08:00
// Create the regions for 'then' and 'else'. The latter must be created even
// if it remains empty for the validity of the operation.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
Allow creating standalone Regions Currently, regions can only be constructed by passing in a `Function` or an `Instruction` pointer referencing the parent object, unlike `Function`s or `Instruction`s themselves that can be created without a parent. It leads to a rather complex flow in operation construction where one has to create the operation first before being able to work with its regions. It may be necessary to work with the regions before the operation is created. In particular, in `build` and `parse` functions that are executed _before_ the operation is created in cases where boilerplate region manipulation is required (for example, inserting the hypothetical default terminator in affine regions). Allow creating standalone regions. Such regions are meant to own a list of blocks and transfer them to other regions on demand. Each instruction stores a fixed number of regions as trailing objects and has ownership of them. This decreases the size of the Instruction object for the common case of instructions without regions. Keep this behavior intact. To allow some flexibility in construction, make OperationState store an owning vector of regions. When the Builder creates an Instruction from OperationState, the bodies of the regions are transferred into the instruction-owned regions to minimize copying. Thus, it becomes possible to fill standalone regions with blocks and move them to an operation when it is constructed, or move blocks from a region to an operation region, e.g., for inlining. PiperOrigin-RevId: 240368183
2019-03-27 00:55:06 +08:00
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, {}, {}))
return failure();
AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
result.location);
// If we find an 'else' keyword then parse the 'else' region.
if (!parser.parseOptionalKeyword("else")) {
if (parser.parseRegion(*elseRegion, {}, {}))
return failure();
AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
result.location);
}
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return success();
}
static void print(OpAsmPrinter &p, AffineIfOp op) {
auto conditionAttr =
op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
p << "affine.if " << conditionAttr;
printDimAndSymbolList(op.operand_begin(), op.operand_end(),
conditionAttr.getValue().getNumDims(), p);
p.printOptionalArrowTypeList(op.getResultTypes());
p.printRegion(op.thenRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/op.getNumResults());
// Print the 'else' regions if it has any blocks.
auto &elseRegion = op.elseRegion();
if (!elseRegion.empty()) {
p << " else";
p.printRegion(elseRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/op.getNumResults());
}
// Print the attribute list.
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/op.getConditionAttrName());
}
IntegerSet AffineIfOp::getIntegerSet() {
return (*this)
->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
.getValue();
}
void AffineIfOp::setIntegerSet(IntegerSet newSet) {
(*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
}
void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
setIntegerSet(set);
(*this)->setOperands(operands);
}
void AffineIfOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes, IntegerSet set, ValueRange args,
bool withElseRegion) {
assert(resultTypes.empty() || withElseRegion);
result.addTypes(resultTypes);
result.addOperands(args);
result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
Region *thenRegion = result.addRegion();
thenRegion->push_back(new Block());
if (resultTypes.empty())
AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
Region *elseRegion = result.addRegion();
if (withElseRegion) {
elseRegion->push_back(new Block());
if (resultTypes.empty())
AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
}
}
void AffineIfOp::build(OpBuilder &builder, OperationState &result,
IntegerSet set, ValueRange args, bool withElseRegion) {
AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
withElseRegion);
}
/// Canonicalize an affine if op's conditional (integer set + operands).
LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
SmallVector<Value, 4> operands(getOperands());
canonicalizeSetAndOperands(&set, &operands);
// Any canonicalization change always leads to either a reduction in the
// number of operands or a change in the number of symbolic operands
// (promotion of dims to symbols).
if (operands.size() < getIntegerSet().getNumInputs() ||
set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
setConditional(set, operands);
return success();
}
return failure();
}
void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
}
//===----------------------------------------------------------------------===//
// AffineLoadOp
//===----------------------------------------------------------------------===//
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
AffineMap map, ValueRange operands) {
assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
result.addOperands(operands);
if (map)
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
auto memrefType = operands[0].getType().cast<MemRefType>();
result.types.push_back(memrefType.getElementType());
}
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
Value memref, AffineMap map, ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(memref);
result.addOperands(mapOperands);
auto memrefType = memref.getType().cast<MemRefType>();
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
result.types.push_back(memrefType.getElementType());
}
void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange indices) {
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map =
rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
build(builder, result, memref, map, indices);
}
2020-05-20 04:16:15 +08:00
static ParseResult parseAffineLoadOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();
MemRefType type;
OpAsmParser::OperandType memrefInfo;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
return failure(
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
AffineLoadOp::getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types));
}
2020-05-20 04:16:15 +08:00
static void print(OpAsmPrinter &p, AffineLoadOp op) {
p << "affine.load " << op.getMemRef() << '[';
if (AffineMapAttr mapAttr =
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
p << ']';
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{op.getMapAttrName()});
p << " : " << op.getMemRefType();
}
/// Verify common indexing invariants of affine.load, affine.store,
/// affine.vector_load and affine.vector_store.
static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
if (mapAttr) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != memrefType.getRank())
return op->emitOpError("affine map num results must equal memref rank");
if (map.getNumInputs() != numIndexOperands)
return op->emitOpError("expects as many subscripts as affine map inputs");
} else {
if (memrefType.getRank() != numIndexOperands)
return op->emitOpError(
"expects the number of subscripts to be equal to memref rank");
}
Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {
if (!idx.getType().isIndex())
return op->emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
return op->emitOpError("index must be a dimension or symbol identifier");
}
return success();
}
LogicalResult verify(AffineLoadOp op) {
auto memrefType = op.getMemRefType();
if (op.getType() != memrefType.getElementType())
return op.emitOpError("result type must match element type of memref");
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 1)))
return failure();
return success();
}
void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyAffineOp<AffineLoadOp>>(context);
}
OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
/// load(memrefcast) -> load
if (succeeded(foldMemRefCast(*this)))
return getResult();
return OpFoldResult();
}
//===----------------------------------------------------------------------===//
// AffineStoreOp
//===----------------------------------------------------------------------===//
void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref, AffineMap map,
ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(valueToStore);
result.addOperands(memref);
result.addOperands(mapOperands);
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
}
// Use identity map.
void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref,
ValueRange indices) {
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map =
rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
build(builder, result, valueToStore, memref, map, indices);
}
2020-05-20 04:16:15 +08:00
static ParseResult parseAffineStoreOp(OpAsmParser &parser,
OperationState &result) {
auto indexTy = parser.getBuilder().getIndexType();
MemRefType type;
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
AffineStoreOp::getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(storeValueInfo, type.getElementType(),
result.operands) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands));
}
2020-05-20 04:16:15 +08:00
static void print(OpAsmPrinter &p, AffineStoreOp op) {
p << "affine.store " << op.getValueToStore();
p << ", " << op.getMemRef() << '[';
if (AffineMapAttr mapAttr =
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
p << ']';
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{op.getMapAttrName()});
p << " : " << op.getMemRefType();
}
LogicalResult verify(AffineStoreOp op) {
// First operand must have same type as memref element type.
auto memrefType = op.getMemRefType();
if (op.getValueToStore().getType() != memrefType.getElementType())
return op.emitOpError(
"first operand must have same type memref element type");
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 2)))
return failure();
return success();
}
void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyAffineOp<AffineStoreOp>>(context);
}
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
return foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
// AffineMinMaxOpBase
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyAffineMinMaxOp(T op) {
// Verify that operand count matches affine map dimension and symbol count.
if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
return op.emitOpError(
"operand count and affine map dimension and symbol count must match");
return success();
}
template <typename T>
static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
auto operands = op.getOperands();
unsigned numDims = op.map().getNumDims();
p << '(' << operands.take_front(numDims) << ')';
if (operands.size() != numDims)
p << '[' << operands.drop_front(numDims) << ']';
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{T::getMapAttrName()});
}
template <typename T>
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
SmallVector<OpAsmParser::OperandType, 8> dimInfos;
SmallVector<OpAsmParser::OperandType, 8> symInfos;
AffineMapAttr mapAttr;
return failure(
parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
parser.parseOperandList(symInfos,
OpAsmParser::Delimiter::OptionalSquare) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.resolveOperands(dimInfos, indexType, result.operands) ||
parser.resolveOperands(symInfos, indexType, result.operands) ||
parser.addTypeToList(indexType, result.types));
}
/// Fold an affine min or max operation with the given operands. The operand
/// list may contain nulls, which are interpreted as the operand not being a
/// constant.
template <typename T>
2020-05-20 04:16:15 +08:00
static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
"expected affine min or max op");
// Fold the affine map.
// TODO: Fold more cases:
// min(some_affine, some_affine + constant, ...), etc.
SmallVector<int64_t, 2> results;
auto foldedMap = op.map().partialConstantFold(operands, &results);
// If some of the map results are not constant, try changing the map in-place.
if (results.empty()) {
// If the map is the same, report that folding did not happen.
if (foldedMap == op.map())
return {};
op->setAttr("map", AffineMapAttr::get(foldedMap));
return op.getResult();
}
// Otherwise, completely fold the op into a constant.
auto resultIt = std::is_same<T, AffineMinOp>::value
? std::min_element(results.begin(), results.end())
: std::max_element(results.begin(), results.end());
if (resultIt == results.end())
return {};
return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
}
/// Remove duplicated expressions in affine min/max ops.
template <typename T>
struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T affineOp,
PatternRewriter &rewriter) const override {
AffineMap oldMap = affineOp.getAffineMap();
SmallVector<AffineExpr, 4> newExprs;
for (AffineExpr expr : oldMap.getResults()) {
// This is a linear scan over newExprs, but it should be fine given that
// we typically just have a few expressions per op.
if (!llvm::is_contained(newExprs, expr))
newExprs.push_back(expr);
}
if (newExprs.size() == oldMap.getNumResults())
return failure();
auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
newExprs, rewriter.getContext());
rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
return success();
}
};
/// Merge an affine min/max op to its consumers if its consumer is also an
/// affine min/max op.
///
/// This pattern requires the producer affine min/max op is bound to a
/// dimension/symbol that is used as a standalone expression in the consumer
/// affine op's map.
///
/// For example, a pattern like the following:
///
/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
///
/// Can be turned into:
///
/// %1 = affine.min affine_map<
/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
template <typename T>
struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T affineOp,
PatternRewriter &rewriter) const override {
AffineMap oldMap = affineOp.getAffineMap();
ValueRange dimOperands =
affineOp.getMapOperands().take_front(oldMap.getNumDims());
ValueRange symOperands =
affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
auto newDimOperands = llvm::to_vector<8>(dimOperands);
auto newSymOperands = llvm::to_vector<8>(symOperands);
SmallVector<AffineExpr, 4> newExprs;
SmallVector<T, 4> producerOps;
// Go over each expression to see whether it's a single dimension/symbol
// with the corresponding operand which is the result of another affine
// min/max op. If So it can be merged into this affine op.
for (AffineExpr expr : oldMap.getResults()) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
Value symValue = symOperands[symExpr.getPosition()];
if (auto producerOp = symValue.getDefiningOp<T>()) {
producerOps.push_back(producerOp);
continue;
}
} else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
Value dimValue = dimOperands[dimExpr.getPosition()];
if (auto producerOp = dimValue.getDefiningOp<T>()) {
producerOps.push_back(producerOp);
continue;
}
}
// For the above cases we will remove the expression by merging the
// producer affine min/max's affine expressions. Otherwise we need to
// keep the existing expression.
newExprs.push_back(expr);
}
if (producerOps.empty())
return failure();
unsigned numUsedDims = oldMap.getNumDims();
unsigned numUsedSyms = oldMap.getNumSymbols();
// Now go over all producer affine ops and merge their expressions.
for (T producerOp : producerOps) {
AffineMap producerMap = producerOp.getAffineMap();
unsigned numProducerDims = producerMap.getNumDims();
unsigned numProducerSyms = producerMap.getNumSymbols();
// Collect all dimension/symbol values.
ValueRange dimValues =
producerOp.getMapOperands().take_front(numProducerDims);
ValueRange symValues =
producerOp.getMapOperands().take_back(numProducerSyms);
newDimOperands.append(dimValues.begin(), dimValues.end());
newSymOperands.append(symValues.begin(), symValues.end());
// For expressions we need to shift to avoid overlap.
for (AffineExpr expr : producerMap.getResults()) {
newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
.shiftSymbols(numProducerSyms, numUsedSyms));
}
numUsedDims += numProducerDims;
numUsedSyms += numProducerSyms;
}
auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
rewriter.getContext());
auto newOperands =
llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
return success();
}
};
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
//
// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
//
OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
return foldMinMaxOp(*this, operands);
}
void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<DeduplicateAffineMinMaxExpressions<AffineMinOp>,
MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>(
context);
}
//===----------------------------------------------------------------------===//
// AffineMaxOp
//===----------------------------------------------------------------------===//
//
// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
//
OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
return foldMinMaxOp(*this, operands);
}
void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>(
context);
}
//===----------------------------------------------------------------------===//
// AffinePrefetchOp
//===----------------------------------------------------------------------===//
//
// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
//
static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();
MemRefType type;
OpAsmParser::OperandType memrefInfo;
IntegerAttr hintInfo;
auto i32Type = parser.getBuilder().getIntegerType(32);
StringRef readOrWrite, cacheType;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
if (parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
AffinePrefetchOp::getMapAttrName(),
result.attributes) ||
parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
parser.parseComma() || parser.parseKeyword("locality") ||
parser.parseLess() ||
parser.parseAttribute(hintInfo, i32Type,
AffinePrefetchOp::getLocalityHintAttrName(),
result.attributes) ||
parser.parseGreater() || parser.parseComma() ||
parser.parseKeyword(&cacheType) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands))
return failure();
if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
return parser.emitError(parser.getNameLoc(),
"rw specifier has to be 'read' or 'write'");
result.addAttribute(
AffinePrefetchOp::getIsWriteAttrName(),
parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
if (!cacheType.equals("data") && !cacheType.equals("instr"))
return parser.emitError(parser.getNameLoc(),
"cache type has to be 'data' or 'instr'");
result.addAttribute(
AffinePrefetchOp::getIsDataCacheAttrName(),
parser.getBuilder().getBoolAttr(cacheType.equals("data")));
return success();
}
static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
if (mapAttr) {
SmallVector<Value, 2> operands(op.getMapOperands());
p.printAffineMapOfSSAIds(mapAttr, operands);
}
p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
<< "locality<" << op.localityHint() << ">, "
<< (op.isDataCache() ? "data" : "instr");
p.printOptionalAttrDict(
op->getAttrs(),
/*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
p << " : " << op.getMemRefType();
}
static LogicalResult verify(AffinePrefetchOp op) {
auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
if (mapAttr) {
AffineMap map = mapAttr.getValue();
if (map.getNumResults() != op.getMemRefType().getRank())
return op.emitOpError("affine.prefetch affine map num results must equal"
" memref rank");
if (map.getNumInputs() + 1 != op.getNumOperands())
return op.emitOpError("too few operands");
} else {
if (op.getNumOperands() != 1)
return op.emitOpError("too few operands");
}
Region *scope = getAffineScope(op);
for (auto idx : op.getMapOperands()) {
if (!isValidAffineIndexOperand(idx, scope))
return op.emitOpError("index must be a dimension or symbol identifier");
}
return success();
}
void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// prefetch(memrefcast) -> prefetch
results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
}
LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// prefetch(memrefcast) -> prefetch
return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
// AffineParallelOp
//===----------------------------------------------------------------------===//
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes,
ArrayRef<AtomicRMWKind> reductions,
ArrayRef<int64_t> ranges) {
SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
return builder.getConstantAffineMap(value);
}));
SmallVector<int64_t> steps(ranges.size(), 1);
build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
/*ubArgs=*/{}, steps);
}
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes,
ArrayRef<AtomicRMWKind> reductions,
ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
ArrayRef<int64_t> steps) {
assert(llvm::all_of(lbMaps,
[lbMaps](AffineMap m) {
return m.getNumDims() == lbMaps[0].getNumDims() &&
m.getNumSymbols() == lbMaps[0].getNumSymbols();
}) &&
"expected all lower bounds maps to have the same number of dimensions "
"and symbols");
assert(llvm::all_of(ubMaps,
[ubMaps](AffineMap m) {
return m.getNumDims() == ubMaps[0].getNumDims() &&
m.getNumSymbols() == ubMaps[0].getNumSymbols();
}) &&
"expected all upper bounds maps to have the same number of dimensions "
"and symbols");
assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
"expected lower bound maps to have as many inputs as lower bound "
"operands");
assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
"expected upper bound maps to have as many inputs as upper bound "
"operands");
result.addTypes(resultTypes);
// Convert the reductions to integer attributes.
SmallVector<Attribute, 4> reductionAttrs;
for (AtomicRMWKind reduction : reductions)
reductionAttrs.push_back(
builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
result.addAttribute(getReductionsAttrName(),
builder.getArrayAttr(reductionAttrs));
// Concatenates maps defined in the same input space (same dimensions and
// symbols), assumes there is at least one map.
auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
SmallVectorImpl<int32_t> &groups) {
if (maps.empty())
return AffineMap::get(builder.getContext());
SmallVector<AffineExpr> exprs;
groups.reserve(groups.size() + maps.size());
exprs.reserve(maps.size());
for (AffineMap m : maps) {
llvm::append_range(exprs, m.getResults());
groups.push_back(m.getNumResults());
}
return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
maps[0].getContext());
};
// Set up the bounds.
SmallVector<int32_t> lbGroups, ubGroups;
AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
result.addAttribute(getLowerBoundsGroupsAttrName(),
builder.getI32TensorAttr(lbGroups));
result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
result.addAttribute(getUpperBoundsGroupsAttrName(),
builder.getI32TensorAttr(ubGroups));
result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
result.addOperands(lbArgs);
result.addOperands(ubArgs);
// Create a region and a block for the body.
auto *bodyRegion = result.addRegion();
auto *body = new Block();
// Add all the block arguments.
for (unsigned i = 0, e = steps.size(); i < e; ++i)
body->addArgument(IndexType::get(builder.getContext()));
bodyRegion->push_back(body);
if (resultTypes.empty())
ensureTerminator(*bodyRegion, builder, result.location);
}
Region &AffineParallelOp::getLoopBody() { return region(); }
bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
return !region().isAncestor(value.getParentRegion());
}
LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
for (Operation *op : ops)
op->moveBefore(*this);
return success();
}
unsigned AffineParallelOp::getNumDims() { return steps().size(); }
AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
return getOperands().take_front(lowerBoundsMap().getNumInputs());
}
AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
return getOperands().drop_front(lowerBoundsMap().getNumInputs());
}
AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
unsigned start = 0;
for (unsigned i = 0; i < pos; ++i)
start += lowerBoundsGroups().getValue<int32_t>(i);
return lowerBoundsMap().getSliceMap(
start, lowerBoundsGroups().getValue<int32_t>(pos));
}
AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
unsigned start = 0;
for (unsigned i = 0; i < pos; ++i)
start += upperBoundsGroups().getValue<int32_t>(i);
return upperBoundsMap().getSliceMap(
start, upperBoundsGroups().getValue<int32_t>(pos));
}
AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
}
AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
}
Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
if (hasMinMaxBounds())
return llvm::None;
// Try to convert all the ranges to constant expressions.
SmallVector<int64_t, 8> out;
AffineValueMap rangesValueMap;
AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
&rangesValueMap);
out.reserve(rangesValueMap.getNumResults());
for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
auto expr = rangesValueMap.getResult(i);
auto cst = expr.dyn_cast<AffineConstantExpr>();
if (!cst)
return llvm::None;
out.push_back(cst.getValue());
}
return out;
}
Block *AffineParallelOp::getBody() { return &region().front(); }
OpBuilder AffineParallelOp::getBodyBuilder() {
return OpBuilder(getBody(), std::prev(getBody()->end()));
}
void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs() &&
"operands to map must match number of inputs");
auto ubOperands = getUpperBoundsOperands();
SmallVector<Value, 4> newOperands(lbOperands);
newOperands.append(ubOperands.begin(), ubOperands.end());
(*this)->setOperands(newOperands);
lowerBoundsMapAttr(AffineMapAttr::get(map));
}
void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs() &&
"operands to map must match number of inputs");
SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
newOperands.append(ubOperands.begin(), ubOperands.end());
(*this)->setOperands(newOperands);
upperBoundsMapAttr(AffineMapAttr::get(map));
}
void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
AffineMap lbMap = lowerBoundsMap();
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
(void)lbMap;
lowerBoundsMapAttr(AffineMapAttr::get(map));
}
void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
AffineMap ubMap = upperBoundsMap();
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
(void)ubMap;
upperBoundsMapAttr(AffineMapAttr::get(map));
}
SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
SmallVector<int64_t, 8> result;
for (Attribute attr : steps()) {
result.push_back(attr.cast<IntegerAttr>().getInt());
}
return result;
}
void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
}
static LogicalResult verify(AffineParallelOp op) {
auto numDims = op.getNumDims();
if (op.lowerBoundsGroups().getNumElements() != numDims ||
op.upperBoundsGroups().getNumElements() != numDims ||
op.steps().size() != numDims ||
op.getBody()->getNumArguments() != numDims) {
return op.emitOpError()
<< "the number of region arguments ("
<< op.getBody()->getNumArguments()
<< ") and the number of map groups for lower ("
<< op.lowerBoundsGroups().getNumElements() << ") and upper bound ("
<< op.upperBoundsGroups().getNumElements()
<< "), and the number of steps (" << op.steps().size()
<< ") must all match";
}
unsigned expectedNumLBResults = 0;
for (APInt v : op.lowerBoundsGroups())
expectedNumLBResults += v.getZExtValue();
if (expectedNumLBResults != op.lowerBoundsMap().getNumResults())
return op.emitOpError() << "expected lower bounds map to have "
<< expectedNumLBResults << " results";
unsigned expectedNumUBResults = 0;
for (APInt v : op.upperBoundsGroups())
expectedNumUBResults += v.getZExtValue();
if (expectedNumUBResults != op.upperBoundsMap().getNumResults())
return op.emitOpError() << "expected upper bounds map to have "
<< expectedNumUBResults << " results";
if (op.reductions().size() != op.getNumResults())
return op.emitOpError("a reduction must be specified for each output");
// Verify reduction ops are all valid
for (Attribute attr : op.reductions()) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
return op.emitOpError("invalid reduction attribute");
}
// Verify that the bound operands are valid dimension/symbols.
/// Lower bounds.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
op.lowerBoundsMap().getNumDims())))
return failure();
/// Upper bounds.
if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
op.upperBoundsMap().getNumDims())))
return failure();
return success();
}
LogicalResult AffineValueMap::canonicalize() {
SmallVector<Value, 4> newOperands{operands};
auto newMap = getAffineMap();
composeAffineMapAndOperands(&newMap, &newOperands);
if (newMap == getAffineMap() && newOperands == operands)
return failure();
reset(newMap, newOperands);
return success();
}
/// Canonicalize the bounds of the given loop.
static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
AffineValueMap lb = op.getLowerBoundsValueMap();
bool lbCanonicalized = succeeded(lb.canonicalize());
AffineValueMap ub = op.getUpperBoundsValueMap();
bool ubCanonicalized = succeeded(ub.canonicalize());
// Any canonicalization change always leads to updated map(s).
if (!lbCanonicalized && !ubCanonicalized)
return failure();
if (lbCanonicalized)
op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
if (ubCanonicalized)
op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
return success();
}
LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return canonicalizeLoopBounds(*this);
}
/// Prints a lower(upper) bound of an affine parallel loop with max(min)
/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
/// identifies which of the those expressions form max/min groups. `operands`
/// are the SSA values of dimensions and symbols and `keyword` is either "min"
/// or "max".
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
DenseIntElementsAttr group, ValueRange operands,
StringRef keyword) {
AffineMap map = mapAttr.getValue();
unsigned numDims = map.getNumDims();
ValueRange dimOperands = operands.take_front(numDims);
ValueRange symOperands = operands.drop_front(numDims);
unsigned start = 0;
for (llvm::APInt groupSize : group) {
if (start != 0)
p << ", ";
unsigned size = groupSize.getZExtValue();
if (size == 1) {
p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
++start;
} else {
p << keyword << '(';
AffineMap submap = map.getSliceMap(start, size);
p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
p << ')';
start += size;
}
}
}
static void print(OpAsmPrinter &p, AffineParallelOp op) {
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(),
op.getLowerBoundsOperands(), "max");
p << ") to (";
printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(),
op.getUpperBoundsOperands(), "min");
p << ')';
SmallVector<int64_t, 8> steps = op.getSteps();
bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
if (!elideSteps) {
p << " step (";
llvm::interleaveComma(steps, p);
p << ')';
}
if (op.getNumResults()) {
p << " reduce (";
llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
AtomicRMWKind sym =
*symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
});
p << ") -> (" << op.getResultTypes() << ")";
}
p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/op.getNumResults());
p.printOptionalAttrDict(
op->getAttrs(),
/*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
AffineParallelOp::getLowerBoundsMapAttrName(),
AffineParallelOp::getLowerBoundsGroupsAttrName(),
AffineParallelOp::getUpperBoundsMapAttrName(),
AffineParallelOp::getUpperBoundsGroupsAttrName(),
AffineParallelOp::getStepsAttrName()});
}
/// Given a list of lists of parsed operands, populates `uniqueOperands` with
/// unique operands. Also populates `replacements with affine expressions of
/// `kind` that can be used to update affine maps previously accepting a
/// `operands` to accept `uniqueOperands` instead.
static void deduplicateAndResolveOperands(
OpAsmParser &parser,
ArrayRef<SmallVector<OpAsmParser::OperandType>> operands,
SmallVectorImpl<Value> &uniqueOperands,
SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
"expected operands to be dim or symbol expression");
Type indexType = parser.getBuilder().getIndexType();
for (const auto &list : operands) {
SmallVector<Value> valueOperands;
parser.resolveOperands(list, indexType, valueOperands);
for (Value operand : valueOperands) {
unsigned pos = std::distance(uniqueOperands.begin(),
llvm::find(uniqueOperands, operand));
if (pos == uniqueOperands.size())
uniqueOperands.push_back(operand);
replacements.push_back(
kind == AffineExprKind::DimId
? getAffineDimExpr(pos, parser.getBuilder().getContext())
: getAffineSymbolExpr(pos, parser.getBuilder().getContext()));
}
}
}
namespace {
enum class MinMaxKind { Min, Max };
} // namespace
/// Parses an affine map that can contain a min/max for groups of its results,
/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
/// `result` attributes with the map (flat list of expressions) and the grouping
/// (list of integers that specify how many expressions to put into each
/// min/max) attributes. Deduplicates repeated operands.
///
/// parallel-bound ::= `(` parallel-group-list `)`
/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
/// parallel-group ::= simple-group | min-max-group
/// simple-group ::= expr-of-ssa-ids
/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
///
/// Examples:
/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
/// (%0, max(%1 - 2 * %2))
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
OperationState &result,
MinMaxKind kind) {
constexpr llvm::StringLiteral tmpAttrName = "__pseudo_bound_map";
StringRef mapName = kind == MinMaxKind::Min
? AffineParallelOp::getUpperBoundsMapAttrName()
: AffineParallelOp::getLowerBoundsMapAttrName();
StringRef groupsName = kind == MinMaxKind::Min
? AffineParallelOp::getUpperBoundsGroupsAttrName()
: AffineParallelOp::getLowerBoundsGroupsAttrName();
if (failed(parser.parseLParen()))
return failure();
if (succeeded(parser.parseOptionalRParen())) {
result.addAttribute(
mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
return success();
}
SmallVector<AffineExpr> flatExprs;
SmallVector<SmallVector<OpAsmParser::OperandType>> flatDimOperands;
SmallVector<SmallVector<OpAsmParser::OperandType>> flatSymOperands;
SmallVector<int32_t> numMapsPerGroup;
SmallVector<OpAsmParser::OperandType> mapOperands;
do {
if (succeeded(parser.parseOptionalKeyword(
kind == MinMaxKind::Min ? "min" : "max"))) {
mapOperands.clear();
AffineMapAttr map;
if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrName,
result.attributes,
OpAsmParser::Delimiter::Paren)))
return failure();
result.attributes.erase(tmpAttrName);
llvm::append_range(flatExprs, map.getValue().getResults());
auto operandsRef = llvm::makeArrayRef(mapOperands);
auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
SmallVector<OpAsmParser::OperandType> dims(dimsRef.begin(),
dimsRef.end());
auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
SmallVector<OpAsmParser::OperandType> syms(symsRef.begin(),
symsRef.end());
flatDimOperands.append(map.getValue().getNumResults(), dims);
flatSymOperands.append(map.getValue().getNumResults(), syms);
numMapsPerGroup.push_back(map.getValue().getNumResults());
} else {
if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
flatSymOperands.emplace_back(),
flatExprs.emplace_back())))
return failure();
numMapsPerGroup.push_back(1);
}
} while (succeeded(parser.parseOptionalComma()));
if (failed(parser.parseRParen()))
return failure();
unsigned totalNumDims = 0;
unsigned totalNumSyms = 0;
for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
unsigned numDims = flatDimOperands[i].size();
unsigned numSyms = flatSymOperands[i].size();
flatExprs[i] = flatExprs[i]
.shiftDims(numDims, totalNumDims)
.shiftSymbols(numSyms, totalNumSyms);
totalNumDims += numDims;
totalNumSyms += numSyms;
}
// Deduplicate map operands.
SmallVector<Value> dimOperands, symOperands;
SmallVector<AffineExpr> dimRplacements, symRepacements;
deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
dimRplacements, AffineExprKind::DimId);
deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
symRepacements, AffineExprKind::SymbolId);
result.operands.append(dimOperands.begin(), dimOperands.end());
result.operands.append(symOperands.begin(), symOperands.end());
Builder &builder = parser.getBuilder();
auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
parser.getBuilder().getContext());
flatMap = flatMap.replaceDimsAndSymbols(
dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
result.addAttribute(mapName, AffineMapAttr::get(flatMap));
result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
return success();
}
//
// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
// `to` parallel-bound steps? region attr-dict?
// steps ::= `steps` `(` integer-literals `)`
//
static ParseResult parseAffineParallelOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
SmallVector<OpAsmParser::OperandType, 4> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser.parseEqual() ||
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
parser.parseKeyword("to") ||
parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
return failure();
AffineMapAttr stepsMapAttr;
NamedAttrList stepsAttrs;
SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
if (failed(parser.parseOptionalKeyword("step"))) {
SmallVector<int64_t, 4> steps(ivs.size(), 1);
result.addAttribute(AffineParallelOp::getStepsAttrName(),
builder.getI64ArrayAttr(steps));
} else {
if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
AffineParallelOp::getStepsAttrName(),
stepsAttrs,
OpAsmParser::Delimiter::Paren))
return failure();
// Convert steps from an AffineMap into an I64ArrayAttr.
SmallVector<int64_t, 4> steps;
auto stepsMap = stepsMapAttr.getValue();
for (const auto &result : stepsMap.getResults()) {
auto constExpr = result.dyn_cast<AffineConstantExpr>();
if (!constExpr)
return parser.emitError(parser.getNameLoc(),
"steps must be constant integers");
steps.push_back(constExpr.getValue());
}
result.addAttribute(AffineParallelOp::getStepsAttrName(),
builder.getI64ArrayAttr(steps));
}
// Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
// quoted strings are a member of the enum AtomicRMWKind.
SmallVector<Attribute, 4> reductions;
if (succeeded(parser.parseOptionalKeyword("reduce"))) {
if (parser.parseLParen())
return failure();
do {
// Parse a single quoted string via the attribute parsing, and then
// verify it is a member of the enum and convert to it's integer
// representation.
StringAttr attrVal;
NamedAttrList attrStorage;
auto loc = parser.getCurrentLocation();
if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
attrStorage))
return failure();
llvm::Optional<AtomicRMWKind> reduction =
symbolizeAtomicRMWKind(attrVal.getValue());
if (!reduction)
return parser.emitError(loc, "invalid reduction value: ") << attrVal;
reductions.push_back(builder.getI64IntegerAttr(
static_cast<int64_t>(reduction.getValue())));
// While we keep getting commas, keep parsing.
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen())
return failure();
}
result.addAttribute(AffineParallelOp::getReductionsAttrName(),
builder.getArrayAttr(reductions));
// Parse return types of reductions (if any)
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
// Now parse the body.
Region *body = result.addRegion();
SmallVector<Type, 4> types(ivs.size(), indexType);
if (parser.parseRegion(*body, ivs, types) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
// Add a terminator if none was parsed.
AffineParallelOp::ensureTerminator(*body, builder, result.location);
return success();
}
//===----------------------------------------------------------------------===//
// AffineYieldOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AffineYieldOp op) {
auto *parentOp = op->getParentOp();
auto results = parentOp->getResults();
auto operands = op.getOperands();
if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
return op.emitOpError() << "only terminates affine.if/for/parallel regions";
if (parentOp->getNumResults() != op.getNumOperands())
return op.emitOpError() << "parent of yield must have same number of "
"results as the yield operands";
for (auto it : llvm::zip(results, operands)) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
return op.emitOpError()
<< "types mismatch between yield op and its parent";
}
return success();
}
//===----------------------------------------------------------------------===//
// AffineVectorLoadOp
//===----------------------------------------------------------------------===//
void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
VectorType resultType, AffineMap map,
ValueRange operands) {
assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
result.addOperands(operands);
if (map)
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
result.types.push_back(resultType);
}
void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
VectorType resultType, Value memref,
AffineMap map, ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(memref);
result.addOperands(mapOperands);
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
result.types.push_back(resultType);
}
void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
VectorType resultType, Value memref,
ValueRange indices) {
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map =
rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
build(builder, result, resultType, memref, map, indices);
}
void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
}
2020-05-20 04:16:15 +08:00
static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();
MemRefType memrefType;
VectorType resultType;
OpAsmParser::OperandType memrefInfo;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
return failure(
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
AffineVectorLoadOp::getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(memrefType) || parser.parseComma() ||
parser.parseType(resultType) ||
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
2020-05-20 04:16:15 +08:00
static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
p << "affine.vector_load " << op.getMemRef() << '[';
if (AffineMapAttr mapAttr =
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
p << ']';
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{op.getMapAttrName()});
p << " : " << op.getMemRefType() << ", " << op.getType();
}
/// Verify common invariants of affine.vector_load and affine.vector_store.
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
VectorType vectorType) {
// Check that memref and vector element types match.
if (memrefType.getElementType() != vectorType.getElementType())
return op->emitOpError(
"requires memref and vector types of the same elemental type");
return success();
}
static LogicalResult verify(AffineVectorLoadOp op) {
MemRefType memrefType = op.getMemRefType();
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 1)))
return failure();
if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
op.getVectorType())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// AffineVectorStoreOp
//===----------------------------------------------------------------------===//
void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref, AffineMap map,
ValueRange mapOperands) {
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
result.addOperands(valueToStore);
result.addOperands(memref);
result.addOperands(mapOperands);
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
}
// Use identity map.
void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
Value valueToStore, Value memref,
ValueRange indices) {
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
// for zero-dimensional memrefs.
auto map =
rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
build(builder, result, valueToStore, memref, map, indices);
}
void AffineVectorStoreOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
}
2020-05-20 04:16:15 +08:00
static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
OperationState &result) {
auto indexTy = parser.getBuilder().getIndexType();
MemRefType memrefType;
VectorType resultType;
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
return failure(
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
AffineVectorStoreOp::getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(memrefType) || parser.parseComma() ||
parser.parseType(resultType) ||
parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands));
}
2020-05-20 04:16:15 +08:00
static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
p << "affine.vector_store " << op.getValueToStore();
p << ", " << op.getMemRef() << '[';
if (AffineMapAttr mapAttr =
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
p << ']';
p.printOptionalAttrDict(op->getAttrs(),
/*elidedAttrs=*/{op.getMapAttrName()});
p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
}
static LogicalResult verify(AffineVectorStoreOp op) {
MemRefType memrefType = op.getMemRefType();
if (failed(verifyMemoryOpIndexing(
op.getOperation(),
op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
op.getMapOperands(), memrefType,
/*numIndexOperands=*/op.getNumOperands() - 2)))
return failure();
if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
op.getVectorType())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"