[MLIR] AffineMap value type

This CL applies the same pattern as AffineExpr to AffineMap: a simple struct
that acts as the storage is allocated in the bump pointer. The AffineMap is
immutable and accessed everywhere by value.

PiperOrigin-RevId: 216445930
This commit is contained in:
Nicolas Vasilache 2018-10-09 16:39:24 -07:00 committed by jpienaar
parent 82e55750d2
commit 1d3e7e2616
28 changed files with 444 additions and 361 deletions

View File

@ -42,7 +42,7 @@ class HyperRectangularSet;
/// A mutable affine map. Its affine expressions are however unique.
struct MutableAffineMap {
public:
MutableAffineMap(AffineMap *map, MLIRContext *context);
MutableAffineMap(AffineMap map);
AffineExpr getResult(unsigned idx) const { return results[idx]; }
void setResult(unsigned idx, AffineExpr result) { results[idx] = result; }
@ -60,9 +60,9 @@ public:
//-simplify-affine-expr pass).
void simplify();
/// Get the AffineMap corresponding to this MutableAffineMap. Note that an
/// AffineMap * will be uniqued and stored in context, while a mutable one
/// AffineMap will be uniqued and stored in context, while a mutable one
/// isn't.
AffineMap *getAffineMap();
AffineMap getAffineMap();
private:
// Same meaning as AffineMap's fields.
@ -117,11 +117,10 @@ private:
// TODO(bondhugula): Some of these classes could go into separate files.
class AffineValueMap {
public:
AffineValueMap(const AffineApplyOp &op, MLIRContext *context);
AffineValueMap(const AffineBound &bound, MLIRContext *context);
AffineValueMap(AffineMap *map, MLIRContext *context);
AffineValueMap(AffineMap *map, ArrayRef<MLValue *> operands,
MLIRContext *context);
AffineValueMap(const AffineApplyOp &op);
AffineValueMap(const AffineBound &bound);
AffineValueMap(AffineMap map);
AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands);
~AffineValueMap();
@ -156,7 +155,7 @@ public:
unsigned getNumOperands() const;
SSAValue *getOperand(unsigned i) const;
ArrayRef<MLValue *> getOperands() const;
AffineMap *getAffineMap();
AffineMap getAffineMap();
private:
// A mutable affine map.

View File

@ -27,9 +27,16 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace mlir {
namespace detail {
class AffineMapStorage;
} // end namespace detail
class AffineExpr;
class Attribute;
class MLIRContext;
@ -41,71 +48,91 @@ class MLIRContext;
/// is unique to this affine map.
class AffineMap {
public:
static AffineMap *get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes);
typedef detail::AffineMapStorage ImplType;
explicit AffineMap(ImplType *map = nullptr) : map(map) {}
static AffineMap Invalid() { return AffineMap(nullptr); }
static AffineMap get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes);
/// Returns a single constant result affine map.
static AffineMap *getConstantMap(int64_t val, MLIRContext *context);
static AffineMap getConstantMap(int64_t val, MLIRContext *context);
explicit operator bool() { return map; }
bool operator==(const AffineMap &other) const { return other.map == map; }
/// Returns true if the co-domain (or more loosely speaking, range) of this
/// map is bounded. Bounded affine maps have a size (extent) for each of
/// their range dimensions (more accurately co-domain dimensions).
bool isBounded() { return !rangeSizes.empty(); }
bool isBounded() const;
/// Returns true if this affine map is an identity affine map.
/// An identity affine map corresponds to an identity affine function on the
/// dimensional identifiers.
bool isIdentity();
bool isIdentity() const;
/// Returns true if this affine map is a single result constant function.
bool isSingleConstant();
bool isSingleConstant() const;
/// Returns the constant result of this map. This methods asserts that the map
/// has a single constant result.
int64_t getSingleConstantResult();
int64_t getSingleConstantResult() const;
// Prints affine map to 'os'.
void print(raw_ostream &os);
void dump();
void print(raw_ostream &os) const;
void dump() const;
unsigned getNumDims() { return numDims; }
unsigned getNumSymbols() { return numSymbols; }
unsigned getNumResults() { return numResults; }
unsigned getNumInputs() { return numDims + numSymbols; }
unsigned getNumDims() const;
unsigned getNumSymbols() const;
unsigned getNumResults() const;
unsigned getNumInputs() const;
ArrayRef<AffineExpr> getResults() { return results; }
ArrayRef<AffineExpr> getResults() const;
AffineExpr getResult(unsigned idx) const;
AffineExpr getResult(unsigned idx);
ArrayRef<AffineExpr> getRangeSizes() { return rangeSizes; }
ArrayRef<AffineExpr> getRangeSizes() const;
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
bool constantFold(ArrayRef<Attribute *> operandConstants,
SmallVectorImpl<Attribute *> &results);
SmallVectorImpl<Attribute *> &results) const;
friend ::llvm::hash_code hash_value(AffineMap arg);
private:
AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
ArrayRef<AffineExpr> results, ArrayRef<AffineExpr> rangeSizes);
AffineMap(const AffineMap &) = delete;
void operator=(const AffineMap &) = delete;
unsigned numDims;
unsigned numSymbols;
unsigned numResults;
/// The affine expressions for this (multi-dimensional) map.
/// TODO: use trailing objects for this.
ArrayRef<AffineExpr> results;
/// The extents along each of the range dimensions if the map is bounded,
/// nullptr otherwise.
ArrayRef<AffineExpr> rangeSizes;
ImplType *map;
};
// Make AffineExpr hashable.
inline ::llvm::hash_code hash_value(AffineMap arg) {
return ::llvm::hash_value(arg.map);
}
} // end namespace mlir
namespace llvm {
// AffineExpr hash just like pointers
template <> struct DenseMapInfo<mlir::AffineMap> {
static mlir::AffineMap getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
}
static mlir::AffineMap getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::AffineMap val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::AffineMap LHS, mlir::AffineMap RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif // MLIR_IR_AFFINE_MAP_H

View File

@ -18,11 +18,11 @@
#ifndef MLIR_IR_ATTRIBUTES_H
#define MLIR_IR_ATTRIBUTES_H
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class AffineMap;
class Function;
class FunctionType;
class MLIRContext;
@ -182,22 +182,21 @@ private:
class AffineMapAttr : public Attribute {
public:
static AffineMapAttr *get(AffineMap *value, MLIRContext *context);
static AffineMapAttr *get(AffineMap value);
AffineMap *getValue() const {
return value;
}
AffineMap getValue() const { return value; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::AffineMap;
}
private:
AffineMapAttr(AffineMap *value)
AffineMapAttr(AffineMap value)
: Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
value(value) {}
~AffineMapAttr() = delete;
AffineMap *value;
AffineMap value;
};
class TypeAttr : public Attribute {

View File

@ -84,7 +84,7 @@ public:
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results);
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap *> affineMapComposition = {},
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
@ -96,7 +96,7 @@ public:
FloatAttr *getFloatAttr(double value);
StringAttr *getStringAttr(StringRef bytes);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
AffineMapAttr *getAffineMapAttr(AffineMap *value);
AffineMapAttr *getAffineMapAttr(AffineMap map);
TypeAttr *getTypeAttr(Type *type);
FunctionAttr *getFunctionAttr(const Function *value);
@ -105,30 +105,30 @@ public:
AffineExpr getAffineSymbolExpr(unsigned position);
AffineExpr getAffineConstantExpr(int64_t constant);
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes);
AffineMap getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes);
// Special cases of affine maps and integer sets
/// Returns a single constant result affine map with 0 dimensions and 0
/// symbols. One constant result: () -> (val).
AffineMap *getConstantAffineMap(int64_t val);
AffineMap getConstantAffineMap(int64_t val);
// One dimension id identity map: (i) -> (i).
AffineMap *getDimIdentityMap();
AffineMap getDimIdentityMap();
// Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
AffineMap *getMultiDimIdentityMap(unsigned rank);
AffineMap getMultiDimIdentityMap(unsigned rank);
// One symbol identity map: ()[s] -> (s).
AffineMap *getSymbolIdentityMap();
AffineMap getSymbolIdentityMap();
/// Returns a map that shifts its (single) input dimension by 'shift'.
/// (d0) -> (d0 + shift)
AffineMap *getSingleDimShiftAffineMap(int64_t shift);
AffineMap getSingleDimShiftAffineMap(int64_t shift);
/// Returns an affine map that is a translation (shift) of all result
/// expressions in 'map' by 'shift'.
/// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
/// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
AffineMap *getShiftedAffineMap(AffineMap *map, int64_t shift);
AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
// Integer set.
IntegerSet *getIntegerSet(unsigned dimCount, unsigned symbolCount,
@ -392,8 +392,8 @@ public:
// Creates a for statement. When step is not specified, it is set to 1.
ForStmt *createFor(Location *location, ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap *ubMap, int64_t step = 1);
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap ubMap, int64_t step = 1);
// Creates a for statement with known (constant) lower and upper bounds.
// Default step is 1.

View File

@ -22,6 +22,7 @@
#ifndef MLIR_IR_OPIMPLEMENTATION_H
#define MLIR_IR_OPIMPLEMENTATION_H
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
@ -69,7 +70,7 @@ public:
virtual void printType(const Type *type) = 0;
virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(const Attribute *attr) = 0;
virtual void printAffineMap(AffineMap *map) = 0;
virtual void printAffineMap(AffineMap map) = 0;
virtual void printAffineExpr(AffineExpr expr) = 0;
/// If the specified operation has attributes, print out an attribute
@ -104,8 +105,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const AffineMap &map) {
p.printAffineMap(&const_cast<AffineMap &>(map));
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, AffineMap map) {
p.printAffineMap(map);
return p;
}

View File

@ -86,11 +86,11 @@ class AffineApplyOp : public Op<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
/// Builds an affine apply op with the specified map and operands.
static void build(Builder *builder, OperationState *result, AffineMap *map,
static void build(Builder *builder, OperationState *result, AffineMap map,
ArrayRef<SSAValue *> operands);
/// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
AffineMap getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
}

View File

@ -22,6 +22,7 @@
#ifndef MLIR_IR_STATEMENTS_H
#define MLIR_IR_STATEMENTS_H
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StmtBlock.h"
@ -29,7 +30,6 @@
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
class AffineMap;
class AffineBound;
class IntegerSet;
class AffineCondition;
@ -199,8 +199,8 @@ private:
class ForStmt : public Statement, public MLValue, public StmtBlock {
public:
static ForStmt *create(Location *location, ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap *ubMap, int64_t step, MLIRContext *context);
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap ubMap, int64_t step);
~ForStmt() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
@ -235,20 +235,20 @@ public:
int64_t getStep() const { return step; }
/// Returns affine map for the lower bound.
AffineMap *getLowerBoundMap() const { return lbMap; }
AffineMap getLowerBoundMap() const { return lbMap; }
/// Returns affine map for the upper bound.
AffineMap *getUpperBoundMap() const { return ubMap; }
AffineMap getUpperBoundMap() const { return ubMap; }
/// Set lower bound.
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap *map);
void setLowerBound(ArrayRef<MLValue *> operands, AffineMap map);
/// Set upper bound.
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap *map);
void setUpperBound(ArrayRef<MLValue *> operands, AffineMap map);
/// Set the lower bound map without changing operands.
void setLowerBoundMap(AffineMap *map);
void setLowerBoundMap(AffineMap map);
/// Set the upper bound map without changing operands.
void setUpperBoundMap(AffineMap *map);
void setUpperBoundMap(AffineMap map);
/// Set loop step.
void setStep(int64_t step) {
@ -353,9 +353,9 @@ public:
private:
// Affine map for the lower bound.
AffineMap *lbMap;
AffineMap lbMap;
// Affine map for the upper bound.
AffineMap *ubMap;
AffineMap ubMap;
// Positive constant step. Since index is stored as an int64_t, we restrict
// step to the set of positive integers that int64_t can represent.
int64_t step;
@ -364,8 +364,8 @@ private:
// bound.
std::vector<StmtOperand> operands;
explicit ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
AffineMap *ubMap, int64_t step, MLIRContext *context);
explicit ForStmt(Location *location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step);
};
/// AffineBound represents a lower or upper bound in the for statement.
@ -375,7 +375,7 @@ private:
class AffineBound {
public:
const ForStmt *getForStmt() const { return &stmt; }
AffineMap *getMap() const { return map; }
AffineMap getMap() const { return map; }
unsigned getNumOperands() const { return opEnd - opStart; }
const MLValue *getOperand(unsigned idx) const {
@ -411,12 +411,11 @@ private:
// the containing 'for' statement operands.
unsigned opStart, opEnd;
// Affine map for this bound.
AffineMap *map;
AffineMap map;
AffineBound(const ForStmt &stmt, const unsigned opStart, const unsigned opEnd,
const AffineMap *map)
: stmt(stmt), opStart(opStart), opEnd(opEnd),
map(const_cast<AffineMap *>(map)) {}
AffineBound(const ForStmt &stmt, unsigned opStart, unsigned opEnd,
AffineMap map)
: stmt(stmt), opStart(opStart), opEnd(opEnd), map(map) {}
friend class ForStmt;
};

View File

@ -408,7 +408,7 @@ public:
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space.
static MemRefType *get(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap*> affineMapComposition,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace);
unsigned getRank() const { return getShape().size(); }
@ -426,7 +426,7 @@ public:
/// Returns an array of affine map pointers representing the memref affine
/// map composition.
ArrayRef<AffineMap*> getAffineMaps() const;
ArrayRef<AffineMap> getAffineMaps() const;
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const { return memorySpace; }
@ -446,12 +446,12 @@ private:
/// The number of affine maps in the 'affineMapList' array.
const unsigned numAffineMaps;
/// List of affine maps in the memref's layout/index map composition.
AffineMap *const *const affineMapList;
AffineMap const *affineMapList;
/// Memory space in which data referenced by memref resides.
const unsigned memorySpace;
MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap*> affineMapList, unsigned memorySpace,
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
MLIRContext *context);
~MemRefType() = delete;
};

View File

@ -71,15 +71,15 @@ void promoteSingleIterationLoops(MLFunction *f);
/// Returns the lower bound of the cleanup loop when unrolling a loop
/// with the specified unroll factor.
AffineMap *getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
/// Returns the upper bound of an unrolled loop when unrolling with
/// the specified trip count, stride, and unroll factor.
AffineMap *getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder);
/// Skew the statements in the body of a 'for' statement with the specified
/// statement-wise delays.

View File

@ -25,11 +25,11 @@
#ifndef MLIR_TRANSFORMS_UTILS_H
#define MLIR_TRANSFORMS_UTILS_H
#include "mlir/IR/AffineMap.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class AffineMap;
class MLValue;
class SSAValue;
@ -43,7 +43,7 @@ class SSAValue;
// extended to add additional indices at any position.
bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
llvm::ArrayRef<SSAValue *> extraIndices,
AffineMap *indexRemap = nullptr);
AffineMap indexRemap = AffineMap::Invalid());
} // end namespace mlir
#endif // MLIR_TRANSFORMS_UTILS_H

View File

@ -166,12 +166,13 @@ forwardSubstituteMutableAffineMap(const AffineMapCompositionUpdate &mapUpdate,
map->setNumSymbols(mapUpdate.outputNumSymbols);
}
MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
: numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
context(context) {
for (auto result : map->getResults())
MutableAffineMap::MutableAffineMap(AffineMap map)
: numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
// A map always has at leat 1 result by construction
context(map.getResult(0).getContext()) {
for (auto result : map.getResults())
results.push_back(result);
for (auto rangeSize : map->getRangeSizes())
for (auto rangeSize : map.getRangeSizes())
results.push_back(rangeSize);
}
@ -194,7 +195,7 @@ void MutableAffineMap::simplify() {
}
}
AffineMap *MutableAffineMap::getAffineMap() {
AffineMap MutableAffineMap::getAffineMap() {
return AffineMap::get(numDims, numSymbols, results, rangeSizes);
}
@ -209,17 +210,16 @@ MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
MLIRContext *context)
: numDims(numDims), numSymbols(numSymbols), context(context) {}
AffineValueMap::AffineValueMap(const AffineApplyOp &op, MLIRContext *context)
: map(op.getAffineMap(), context) {
AffineValueMap::AffineValueMap(const AffineApplyOp &op)
: map(op.getAffineMap()) {
for (auto *operand : op.getOperands())
operands.push_back(cast<MLValue>(const_cast<SSAValue *>(operand)));
for (unsigned i = 0, e = op.getNumResults(); i < e; i++)
results.push_back(cast<MLValue>(const_cast<SSAValue *>(op.getResult(i))));
}
AffineValueMap::AffineValueMap(AffineMap *map, ArrayRef<MLValue *> operands,
MLIRContext *context)
: map(map, context) {
AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<MLValue *> operands)
: map(map) {
for (MLValue *operand : operands) {
this->operands.push_back(operand);
}
@ -303,20 +303,20 @@ void AffineValueMap::forwardSubstitute(const AffineApplyOp &inputOp) {
// Gather dim and symbol positions from 'inputOp' on which
// 'inputResultsUsed' depend.
AffineMap *inputMap = inputOp.getAffineMap();
unsigned inputNumDims = inputMap->getNumDims();
AffineMap inputMap = inputOp.getAffineMap();
unsigned inputNumDims = inputMap.getNumDims();
DenseSet<unsigned> inputPositionsUsed;
AffineExprPositionGatherer gatherer(inputNumDims, &inputPositionsUsed);
for (unsigned i = 0; i < inputNumResults; ++i) {
if (inputResultsUsed.count(i) == 0)
continue;
gatherer.walkPostOrder(inputMap->getResult(i));
gatherer.walkPostOrder(inputMap.getResult(i));
}
// Build new output operands list and map update.
SmallVector<MLValue *, 4> outputOperands;
unsigned outputOperandPosition = 0;
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap()->getResults());
AffineMapCompositionUpdate mapUpdate(inputOp.getAffineMap().getResults());
// Add dim operands from current map.
for (unsigned i = 0; i < currNumDims; ++i) {
@ -405,7 +405,7 @@ ArrayRef<MLValue *> AffineValueMap::getOperands() const {
return ArrayRef<MLValue *>(operands);
}
AffineMap *AffineValueMap::getAffineMap() { return map.getAffineMap(); }
AffineMap AffineValueMap::getAffineMap() { return map.getAffineMap(); }
AffineValueMap::~AffineValueMap() {}

View File

@ -44,10 +44,10 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
int64_t ub = forStmt.getConstantUpperBound();
loopSpan = ub - lb + 1;
} else {
auto *lbMap = forStmt.getLowerBoundMap();
auto *ubMap = forStmt.getUpperBoundMap();
auto lbMap = forStmt.getLowerBoundMap();
auto ubMap = forStmt.getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return nullptr;
// TODO(bondhugula): handle bounds with different operands.
@ -56,11 +56,11 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
return nullptr;
// ub_expr - lb_expr + 1
AffineExpr lbExpr(lbMap->getResult(0));
AffineExpr ubExpr(ubMap->getResult(0));
AffineExpr lbExpr(lbMap.getResult(0));
AffineExpr ubExpr(ubMap.getResult(0));
auto loopSpanExpr = simplifyAffineExpr(
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
ubExpr - lbExpr + 1, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
if (!cExpr)
return loopSpanExpr.ceilDiv(step);

View File

@ -16,6 +16,7 @@
// =============================================================================
#include "mlir/IR/AffineMap.h"
#include "AffineMapDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Support/MathExtras.h"
@ -87,13 +88,15 @@ private:
} // end anonymous namespace
AffineMap::AffineMap(unsigned numDims, unsigned numSymbols, unsigned numResults,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes)
: numDims(numDims), numSymbols(numSymbols), numResults(numResults),
results(results), rangeSizes(rangeSizes) {}
/// Returns a single constant result affine map.
AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
return get(/*dimCount=*/0, /*symbolCount=*/0,
{getAffineConstantExpr(val, context)}, {});
}
bool AffineMap::isIdentity() {
bool AffineMap::isBounded() const { return !map->rangeSizes.empty(); }
bool AffineMap::isIdentity() const {
if (getNumDims() != getNumResults())
return false;
ArrayRef<AffineExpr> results = getResults();
@ -105,28 +108,35 @@ bool AffineMap::isIdentity() {
return true;
}
/// Returns a single constant result affine map.
AffineMap *AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
return get(/*dimCount=*/0, /*symbolCount=*/0,
{getAffineConstantExpr(val, context)}, {});
}
bool AffineMap::isSingleConstant() {
bool AffineMap::isSingleConstant() const {
return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
}
int64_t AffineMap::getSingleConstantResult() {
int64_t AffineMap::getSingleConstantResult() const {
assert(isSingleConstant() && "map must have a single constant result");
return getResult(0).cast<AffineConstantExpr>().getValue();
}
AffineExpr AffineMap::getResult(unsigned idx) { return results[idx]; }
unsigned AffineMap::getNumDims() const { return map->numDims; }
unsigned AffineMap::getNumSymbols() const { return map->numSymbols; }
unsigned AffineMap::getNumResults() const { return map->numResults; }
unsigned AffineMap::getNumInputs() const {
return map->numDims + map->numSymbols;
}
ArrayRef<AffineExpr> AffineMap::getResults() const { return map->results; }
AffineExpr AffineMap::getResult(unsigned idx) const {
return map->results[idx];
}
ArrayRef<AffineExpr> AffineMap::getRangeSizes() const {
return map->rangeSizes;
}
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
bool AffineMap::constantFold(ArrayRef<Attribute *> operandConstants,
SmallVectorImpl<Attribute *> &results) {
SmallVectorImpl<Attribute *> &results) const {
assert(getNumInputs() == operandConstants.size());
// Fold each of the result expressions.

View File

@ -0,0 +1,49 @@
//===- AffineMapDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This holds implementation details of AffineMap.
//
//===----------------------------------------------------------------------===//
#ifndef AFFINEMAPDETAIL_H_
#define AFFINEMAPDETAIL_H_
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
namespace detail {
struct AffineMapStorage {
unsigned numDims;
unsigned numSymbols;
unsigned numResults;
/// The affine expressions for this (multi-dimensional) map.
/// TODO: use trailing objects for this.
ArrayRef<AffineExpr> results;
/// The extents along each of the range dimensions if the map is bounded,
/// nullptr otherwise.
ArrayRef<AffineExpr> rangeSizes;
};
} // end namespace detail
} // end namespace mlir
#endif // AFFINEMAPDETAIL_H_

View File

@ -64,7 +64,7 @@ public:
// Initializes module state, populating affine map state.
void initialize(const Module *module);
int getAffineMapId(AffineMap *affineMap) const {
int getAffineMapId(AffineMap affineMap) const {
auto it = affineMapIds.find(affineMap);
if (it == affineMapIds.end()) {
return -1;
@ -72,7 +72,7 @@ public:
return it->second;
}
ArrayRef<AffineMap *> getAffineMapIds() const { return affineMapsById; }
ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
int getIntegerSetId(IntegerSet *integerSet) const {
auto it = integerSetIds.find(integerSet);
@ -85,7 +85,7 @@ public:
ArrayRef<IntegerSet *> getIntegerSetIds() const { return integerSetsById; }
private:
void recordAffineMapReference(AffineMap *affineMap) {
void recordAffineMapReference(AffineMap affineMap) {
if (affineMapIds.count(affineMap) == 0) {
affineMapIds[affineMap] = affineMapsById.size();
affineMapsById.push_back(affineMap);
@ -100,15 +100,15 @@ private:
}
// Return true if this map could be printed using the shorthand form.
static bool hasShorthandForm(AffineMap *boundMap) {
if (boundMap->isSingleConstant())
static bool hasShorthandForm(AffineMap boundMap) {
if (boundMap.isSingleConstant())
return true;
// Check if the affine map is single dim id or single symbol identity -
// (i)->(i) or ()[s]->(i)
return boundMap->getNumInputs() == 1 && boundMap->getNumResults() == 1 &&
(boundMap->getResult(0).isa<AffineDimExpr>() ||
boundMap->getResult(0).isa<AffineSymbolExpr>());
return boundMap.getNumInputs() == 1 && boundMap.getNumResults() == 1 &&
(boundMap.getResult(0).isa<AffineDimExpr>() ||
boundMap.getResult(0).isa<AffineSymbolExpr>());
}
// Visit functions.
@ -124,8 +124,8 @@ private:
void visitAttribute(const Attribute *attr);
void visitOperation(const Operation *op);
DenseMap<AffineMap *, int> affineMapIds;
std::vector<AffineMap *> affineMapsById;
DenseMap<AffineMap, int> affineMapIds;
std::vector<AffineMap> affineMapsById;
DenseMap<IntegerSet *, int> integerSetIds;
std::vector<IntegerSet *> integerSetsById;
@ -142,7 +142,7 @@ void ModuleState::visitType(const Type *type) {
visitType(result);
} else if (auto *memref = dyn_cast<MemRefType>(type)) {
// Visit affine maps in memref type.
for (auto *map : memref->getAffineMaps()) {
for (auto map : memref->getAffineMaps()) {
recordAffineMapReference(map);
}
}
@ -193,11 +193,11 @@ void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
}
void ModuleState::visitForStmt(const ForStmt *forStmt) {
AffineMap *lbMap = forStmt->getLowerBoundMap();
AffineMap lbMap = forStmt->getLowerBoundMap();
if (!hasShorthandForm(lbMap))
recordAffineMapReference(lbMap);
AffineMap *ubMap = forStmt->getUpperBoundMap();
AffineMap ubMap = forStmt->getUpperBoundMap();
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
@ -273,7 +273,7 @@ public:
void print(const CFGFunction *fn);
void print(const MLFunction *fn);
void printAffineMap(AffineMap *map);
void printAffineMap(AffineMap map);
void printAffineExpr(AffineExpr expr);
void printAffineConstraint(AffineExpr expr, bool isEq);
void printIntegerSet(IntegerSet *set);
@ -288,7 +288,7 @@ protected:
ArrayRef<const char *> elidedAttrs = {});
void printFunctionResultType(const FunctionType *type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap *affineMap);
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(IntegerSet *integerSet);
@ -321,14 +321,14 @@ void ModulePrinter::printAffineMapId(int affineMapId) const {
os << "#map" << affineMapId;
}
void ModulePrinter::printAffineMapReference(AffineMap *affineMap) {
void ModulePrinter::printAffineMapReference(AffineMap affineMap) {
int mapId = state.getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
printAffineMapId(mapId);
} else {
// Map not in module state so print inline.
affineMap->print(os);
affineMap.print(os);
}
}
@ -352,7 +352,7 @@ void ModulePrinter::print(const Module *module) {
for (const auto &map : state.getAffineMapIds()) {
printAffineMapId(state.getAffineMapId(map));
os << " = ";
map->print(os);
map.print(os);
os << '\n';
}
for (const auto &set : state.getIntegerSetIds()) {
@ -678,40 +678,40 @@ void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
isEq ? os << " == 0" : os << " >= 0";
}
void ModulePrinter::printAffineMap(AffineMap *map) {
void ModulePrinter::printAffineMap(AffineMap map) {
// Dimension identifiers.
os << '(';
for (int i = 0; i < (int)map->getNumDims() - 1; ++i)
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
os << 'd' << i << ", ";
if (map->getNumDims() >= 1)
os << 'd' << map->getNumDims() - 1;
if (map.getNumDims() >= 1)
os << 'd' << map.getNumDims() - 1;
os << ')';
// Symbolic identifiers.
if (map->getNumSymbols() != 0) {
if (map.getNumSymbols() != 0) {
os << '[';
for (unsigned i = 0; i < map->getNumSymbols() - 1; ++i)
for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
os << 's' << i << ", ";
if (map->getNumSymbols() >= 1)
os << 's' << map->getNumSymbols() - 1;
if (map.getNumSymbols() >= 1)
os << 's' << map.getNumSymbols() - 1;
os << ']';
}
// AffineMap should have at least one result.
assert(!map->getResults().empty());
assert(!map.getResults().empty());
// Result affine expressions.
os << " -> (";
interleaveComma(map->getResults(),
interleaveComma(map.getResults(),
[&](AffineExpr expr) { printAffineExpr(expr); });
os << ')';
if (!map->isBounded()) {
if (!map.isBounded()) {
return;
}
// Print range sizes for bounded affine maps.
os << " size (";
interleaveComma(map->getRangeSizes(),
interleaveComma(map.getRangeSizes(),
[&](AffineExpr expr) { printAffineExpr(expr); });
os << ')';
}
@ -851,7 +851,7 @@ public:
void printAttribute(const Attribute *attr) {
ModulePrinter::printAttribute(attr);
}
void printAffineMap(AffineMap *map) {
void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map);
}
void printIntegerSet(IntegerSet *set) {
@ -1422,7 +1422,7 @@ void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<StmtOperand> ops,
}
void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
AffineMap *map = bound.getMap();
AffineMap map = bound.getMap();
// Check if this bound should be printed using short-hand notation.
// The decision to restrict printing short-hand notation to trivial cases
@ -1430,11 +1430,11 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
// lossless way.
// Therefore, short-hand parsing and printing is only supported for
// zero-operand constant maps and single symbol operand identity maps.
if (map->getNumResults() == 1) {
AffineExpr expr = map->getResult(0);
if (map.getNumResults() == 1) {
AffineExpr expr = map.getResult(0);
// Print constant bound.
if (map->getNumDims() == 0 && map->getNumSymbols() == 0) {
if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
os << constExpr.getValue();
return;
@ -1443,7 +1443,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
// Print bound that consists of a single SSA symbol if the map is over a
// single symbol.
if (map->getNumDims() == 0 && map->getNumSymbols() == 1) {
if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
printOperand(bound.getOperand(0));
return;
@ -1456,7 +1456,7 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
// Print the map and its operands.
printAffineMapReference(map);
printDimAndSymbolList(bound.getStmtOperands(), map->getNumDims());
printDimAndSymbolList(bound.getStmtOperands(), map.getNumDims());
}
void MLFunctionPrinter::print(const IfStmt *stmt) {
@ -1496,7 +1496,7 @@ void Type::print(raw_ostream &os) const {
void Type::dump() const { print(llvm::errs()); }
void AffineMap::dump() {
void AffineMap::dump() const {
print(llvm::errs());
llvm::errs() << "\n";
}
@ -1516,9 +1516,9 @@ void AffineExpr::dump() const {
llvm::errs() << "\n";
}
void AffineMap::print(raw_ostream &os) {
void AffineMap::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAffineMap(this);
ModulePrinter(os, state).printAffineMap(*this);
}
void IntegerSet::print(raw_ostream &os) {

View File

@ -90,7 +90,7 @@ FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
}
MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap *> affineMapComposition,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
}
@ -133,8 +133,8 @@ ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
return ArrayAttr::get(value, context);
}
AffineMapAttr *Builder::getAffineMapAttr(AffineMap *map) {
return AffineMapAttr::get(map, context);
AffineMapAttr *Builder::getAffineMapAttr(AffineMap map) {
return AffineMapAttr::get(map);
}
TypeAttr *Builder::getTypeAttr(Type *type) {
@ -149,9 +149,9 @@ FunctionAttr *Builder::getFunctionAttr(const Function *value) {
// Affine Expressions, Affine Maps, and Integet Sets.
//===----------------------------------------------------------------------===//
AffineMap *Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes) {
AffineMap Builder::getAffineMap(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes) {
return AffineMap::get(dimCount, symbolCount, results, rangeSizes);
}
@ -173,17 +173,17 @@ IntegerSet *Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
}
AffineMap *Builder::getConstantAffineMap(int64_t val) {
AffineMap Builder::getConstantAffineMap(int64_t val) {
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
{getAffineConstantExpr(val)}, {});
}
AffineMap *Builder::getDimIdentityMap() {
AffineMap Builder::getDimIdentityMap() {
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
{getAffineDimExpr(0)}, {});
}
AffineMap *Builder::getMultiDimIdentityMap(unsigned rank) {
AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
@ -191,25 +191,25 @@ AffineMap *Builder::getMultiDimIdentityMap(unsigned rank) {
return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, {});
}
AffineMap *Builder::getSymbolIdentityMap() {
AffineMap Builder::getSymbolIdentityMap() {
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
{getAffineSymbolExpr(0)}, {});
}
AffineMap *Builder::getSingleDimShiftAffineMap(int64_t shift) {
AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
// expr = d0 + shift.
auto expr = getAffineDimExpr(0) + shift;
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr}, {});
}
AffineMap *Builder::getShiftedAffineMap(AffineMap *map, int64_t shift) {
AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
SmallVector<AffineExpr, 4> shiftedResults;
shiftedResults.reserve(map->getNumResults());
for (auto resultExpr : map->getResults()) {
shiftedResults.reserve(map.getNumResults());
for (auto resultExpr : map.getResults()) {
shiftedResults.push_back(resultExpr + shift);
}
return AffineMap::get(map->getNumDims(), map->getNumSymbols(), shiftedResults,
map->getRangeSizes());
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
map.getRangeSizes());
}
//===----------------------------------------------------------------------===//
@ -278,19 +278,19 @@ OperationStmt *MLFuncBuilder::createOperation(Location *location,
ForStmt *MLFuncBuilder::createFor(Location *location,
ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap,
AffineMap lbMap,
ArrayRef<MLValue *> ubOperands,
AffineMap *ubMap, int64_t step) {
auto *stmt = ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap,
step, context);
AffineMap ubMap, int64_t step) {
auto *stmt =
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getStatements().insert(insertPoint, stmt);
return stmt;
}
ForStmt *MLFuncBuilder::createFor(Location *location, int64_t lb, int64_t ub,
int64_t step) {
auto *lbMap = AffineMap::getConstantMap(lb, context);
auto *ubMap = AffineMap::getConstantMap(ub, context);
auto lbMap = AffineMap::getConstantMap(lb, context);
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}

View File

@ -17,6 +17,7 @@
#include "mlir/IR/MLIRContext.h"
#include "AffineExprDetail.h"
#include "AffineMapDetail.h"
#include "AttributeListStorage.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@ -59,13 +60,13 @@ struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
}
};
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
// Affine maps are uniqued based on their dim/symbol counts and affine
// expressions.
using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>,
ArrayRef<AffineExpr>>;
using DenseMapInfo<AffineMap *>::getHashValue;
using DenseMapInfo<AffineMap *>::isEqual;
using DenseMapInfo<AffineMap>::getHashValue;
using DenseMapInfo<AffineMap>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
@ -74,11 +75,11 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
}
static bool isEqual(const KeyTy &lhs, AffineMap *rhs) {
static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
rhs->getResults(), rhs->getRangeSizes());
return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
rhs.getResults(), rhs.getRangeSizes());
}
};
@ -124,7 +125,7 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
// MemRefs are uniqued based on their element type, shape, affine map
// composition, and memory space.
using KeyTy =
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap *>, unsigned>;
std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
using DenseMapInfo<MemRefType *>::getHashValue;
using DenseMapInfo<MemRefType *>::isEqual;
@ -222,7 +223,7 @@ public:
nullptr};
// Affine map uniquing.
using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
AffineMapSet affineMaps;
// Affine binary op expression uniquing. Figure out uniquing of dimensional
@ -267,7 +268,7 @@ public:
StringMap<StringAttr *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
DenseMap<AffineMap *, AffineMapAttr *> affineMapAttrs;
DenseMap<AffineMap, AffineMapAttr *> affineMapAttrs;
DenseMap<Type *, TypeAttr *> typeAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
@ -558,7 +559,7 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
}
MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap *> affineMapComposition,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
auto *context = elementType->getContext();
auto &impl = context->getImpl();
@ -581,7 +582,7 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
// Copy the affine map composition into the bump pointer.
// TODO(andydavis) Assert that the structure of the composition is valid.
affineMapComposition =
impl.copyInto(ArrayRef<AffineMap *>(affineMapComposition));
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
// Initialize the memory using placement new.
new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
@ -675,8 +676,9 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
return *existing.first = result;
}
AffineMapAttr *AffineMapAttr::get(AffineMap *value, MLIRContext *context) {
auto *&result = context->getImpl().affineMapAttrs[value];
AffineMapAttr *AffineMapAttr::get(AffineMap value) {
auto *context = value.getResult(0).getContext();
auto &result = context->getImpl().affineMapAttrs[value];
if (result)
return result;
@ -802,9 +804,9 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
// AffineMap and AffineExpr uniquing
//===----------------------------------------------------------------------===//
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes) {
AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results,
ArrayRef<AffineExpr> rangeSizes) {
// The number of results can't be zero.
assert(!results.empty());
@ -814,25 +816,26 @@ AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
// Check if we already have this affine map.
auto key = std::make_tuple(dimCount, symbolCount, results, rangeSizes);
auto existing = impl.affineMaps.insert_as(nullptr, key);
auto existing = impl.affineMaps.insert_as(AffineMap(nullptr), key);
// If we already have it, return that value.
if (!existing.second)
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *res = impl.allocator.Allocate<AffineMap>();
auto *res = impl.allocator.Allocate<detail::AffineMapStorage>();
// Copy the results and range sizes into the bump pointer.
results = impl.copyInto(results);
rangeSizes = impl.copyInto(rangeSizes);
// Initialize the memory using placement new.
new (res)
AffineMap(dimCount, symbolCount, results.size(), results, rangeSizes);
new (res) detail::AffineMapStorage{dimCount, symbolCount,
static_cast<unsigned>(results.size()),
results, rangeSizes};
// Cache and return it.
return *existing.first = res;
return *existing.first = AffineMap(res);
}
/// Simplify add expression. Return nullptr if it can't be simplified.

View File

@ -101,9 +101,9 @@ Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
//===----------------------------------------------------------------------===//
void AffineApplyOp::build(Builder *builder, OperationState *result,
AffineMap *map, ArrayRef<SSAValue *> operands) {
AffineMap map, ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->types.append(map->getNumResults(), builder->getIndexType());
result->types.append(map.getNumResults(), builder->getIndexType());
result->addAttribute("map", builder->getAffineMapAttr(map));
}
@ -117,22 +117,22 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
auto *map = mapAttr->getValue();
auto map = mapAttr->getValue();
if (map->getNumDims() != numDims ||
numDims + map->getNumSymbols() != result->operands.size()) {
if (map.getNumDims() != numDims ||
numDims + map.getNumSymbols() != result->operands.size()) {
return parser->emitError(parser->getNameLoc(),
"dimension or symbol index mismatch");
}
result->types.append(map->getNumResults(), affineIntTy);
result->types.append(map.getNumResults(), affineIntTy);
return false;
}
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
auto map = getAffineMap();
*p << "affine_apply " << map;
printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
}
@ -143,15 +143,15 @@ bool AffineApplyOp::verify() const {
return emitOpError("requires an affine map");
// Check input and output dimensions match.
auto *map = affineMapAttr->getValue();
auto map = affineMapAttr->getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
return emitOpError(
"operand count and affine map dimension and symbol count must match");
// Verify that result count matches affine map result count.
if (getNumResults() != map->getNumResults())
if (getNumResults() != map.getNumResults())
return emitOpError("result count and affine map result count must match");
return false;
@ -183,8 +183,8 @@ bool AffineApplyOp::isValidSymbol() const {
bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
SmallVectorImpl<Attribute *> &results,
MLIRContext *context) const {
auto *map = getAffineMap();
if (map->constantFold(operandConstants, results))
auto map = getAffineMap();
if (map.constantFold(operandConstants, results))
return true;
// Return false on success.
return false;
@ -243,11 +243,11 @@ bool AllocOp::verify() const {
unsigned numSymbols = 0;
if (!memRefType->getAffineMaps().empty()) {
AffineMap *affineMap = memRefType->getAffineMaps()[0];
AffineMap affineMap = memRefType->getAffineMaps()[0];
// Store number of symbols used in affine map (used in subsequent check).
numSymbols = affineMap->getNumSymbols();
numSymbols = affineMap.getNumSymbols();
// Verify that the layout affine map matches the rank of the memref.
if (affineMap->getNumDims() != memRefType->getRank())
if (affineMap.getNumDims() != memRefType->getRank())
return emitOpError("affine map dimension count must equal memref rank");
}
unsigned numDynamicDims = memRefType->getNumDynamicDims();

View File

@ -262,17 +262,16 @@ bool OperationStmt::isReturn() const { return is<ReturnOp>(); }
//===----------------------------------------------------------------------===//
ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
AffineMap *lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap *ubMap, int64_t step, MLIRContext *context) {
assert(lbOperands.size() == lbMap->getNumInputs() &&
AffineMap lbMap, ArrayRef<MLValue *> ubOperands,
AffineMap ubMap, int64_t step) {
assert(lbOperands.size() == lbMap.getNumInputs() &&
"lower bound operand count does not match the affine map");
assert(ubOperands.size() == ubMap->getNumInputs() &&
assert(ubOperands.size() == ubMap.getNumInputs() &&
"upper bound operand count does not match the affine map");
assert(step > 0 && "step has to be a positive integer constant");
unsigned numOperands = lbOperands.size() + ubOperands.size();
ForStmt *stmt =
new ForStmt(location, numOperands, lbMap, ubMap, step, context);
ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step);
unsigned i = 0;
for (unsigned e = lbOperands.size(); i != e; ++i)
@ -284,30 +283,31 @@ ForStmt *ForStmt::create(Location *location, ArrayRef<MLValue *> lbOperands,
return stmt;
}
ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap *lbMap,
AffineMap *ubMap, int64_t step, MLIRContext *context)
ForStmt::ForStmt(Location *location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
: Statement(Kind::For, location),
MLValue(MLValueKind::ForStmt, Type::getIndex(context)),
MLValue(MLValueKind::ForStmt,
Type::getIndex(lbMap.getResult(0).getContext())),
StmtBlock(StmtBlockKind::For), lbMap(lbMap), ubMap(ubMap), step(step) {
operands.reserve(numOperands);
}
const AffineBound ForStmt::getLowerBound() const {
return AffineBound(*this, 0, lbMap->getNumInputs(), lbMap);
return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
}
const AffineBound ForStmt::getUpperBound() const {
return AffineBound(*this, lbMap->getNumInputs(), getNumOperands(), ubMap);
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
}
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) {
assert(lbOperands.size() == map->getNumInputs());
assert(map->getNumResults() >= 1 && "bound map has at least one result");
void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> ubOperands(getUpperBoundOperands());
operands.clear();
operands.reserve(lbOperands.size() + ubMap->getNumInputs());
operands.reserve(lbOperands.size() + ubMap.getNumInputs());
for (auto *operand : lbOperands) {
operands.emplace_back(StmtOperand(this, operand));
}
@ -317,9 +317,9 @@ void ForStmt::setLowerBound(ArrayRef<MLValue *> lbOperands, AffineMap *map) {
this->lbMap = map;
}
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) {
assert(ubOperands.size() == map->getNumInputs());
assert(map->getNumResults() >= 1 && "bound map has at least one result");
void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
SmallVector<MLValue *, 4> lbOperands(getLowerBoundOperands());
@ -334,34 +334,30 @@ void ForStmt::setUpperBound(ArrayRef<MLValue *> ubOperands, AffineMap *map) {
this->ubMap = map;
}
void ForStmt::setLowerBoundMap(AffineMap *map) {
assert(lbMap->getNumDims() == map->getNumDims() &&
lbMap->getNumSymbols() == map->getNumSymbols());
assert(map->getNumResults() >= 1 && "bound map has at least one result");
void ForStmt::setLowerBoundMap(AffineMap map) {
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->lbMap = map;
}
void ForStmt::setUpperBoundMap(AffineMap *map) {
assert(ubMap->getNumDims() == map->getNumDims() &&
ubMap->getNumSymbols() == map->getNumSymbols());
assert(map->getNumResults() >= 1 && "bound map has at least one result");
void ForStmt::setUpperBoundMap(AffineMap map) {
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->ubMap = map;
}
bool ForStmt::hasConstantLowerBound() const {
return lbMap->isSingleConstant();
}
bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
bool ForStmt::hasConstantUpperBound() const {
return ubMap->isSingleConstant();
}
bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
int64_t ForStmt::getConstantLowerBound() const {
return lbMap->getSingleConstantResult();
return lbMap.getSingleConstantResult();
}
int64_t ForStmt::getConstantUpperBound() const {
return ubMap->getSingleConstantResult();
return ubMap.getSingleConstantResult();
}
void ForStmt::setConstantLowerBound(int64_t value) {
@ -373,21 +369,20 @@ void ForStmt::setConstantUpperBound(int64_t value) {
}
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
return {operand_begin(),
operand_begin() + getLowerBoundMap()->getNumInputs()};
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap()->getNumInputs(), operand_end()};
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
bool ForStmt::matchingBoundOperandList() const {
if (lbMap->getNumDims() != ubMap->getNumDims() ||
lbMap->getNumSymbols() != ubMap->getNumSymbols())
if (lbMap.getNumDims() != ubMap.getNumDims() ||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
unsigned numOperands = lbMap->getNumInputs();
for (unsigned i = 0, e = lbMap->getNumInputs(); i < e; i++) {
unsigned numOperands = lbMap.getNumInputs();
for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
// Compare MLValue *'s.
if (getOperand(i) != getOperand(numOperands + i))
return false;
@ -419,11 +414,11 @@ bool ForStmt::constantFoldBound(bool lower) {
operandConstants.push_back(operandCst);
}
AffineMap *boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
assert(boundMap->getNumResults() >= 1 &&
AffineMap boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute *, 4> foldedResults;
if (boundMap->constantFold(operandConstants, foldedResults))
if (boundMap.constantFold(operandConstants, foldedResults))
return true;
// Compute the max or min as applicable over the results.
@ -523,14 +518,14 @@ Statement *Statement::clone(DenseMap<const MLValue *, MLValue *> &operandMap,
}
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
auto *lbMap = forStmt->getLowerBoundMap();
auto *ubMap = forStmt->getUpperBoundMap();
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();
auto *newFor = ForStmt::create(
getLoc(),
ArrayRef<MLValue *>(operands).take_front(lbMap->getNumInputs()), lbMap,
ArrayRef<MLValue *>(operands).take_back(ubMap->getNumInputs()), ubMap,
forStmt->getStep(), context);
ArrayRef<MLValue *>(operands).take_front(lbMap.getNumInputs()), lbMap,
ArrayRef<MLValue *>(operands).take_back(ubMap.getNumInputs()), ubMap,
forStmt->getStep());
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;

View File

@ -77,14 +77,14 @@ UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
}
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap *> affineMapList,
unsigned memorySpace, MLIRContext *context)
ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
MLIRContext *context)
: Type(Kind::MemRef, context, shape.size()), elementType(elementType),
shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
ArrayRef<AffineMap*> MemRefType::getAffineMaps() const {
return ArrayRef<AffineMap*>(affineMapList, numAffineMaps);
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
}
unsigned MemRefType::getNumDynamicDims() const {

View File

@ -69,7 +69,7 @@ public:
}
// A map from affine map identifier to AffineMap.
llvm::StringMap<AffineMap *> affineMapDefinitions;
llvm::StringMap<AffineMap> affineMapDefinitions;
// A map from integer set identifier to IntegerSet.
llvm::StringMap<IntegerSet *> integerSetDefinitions;
@ -200,8 +200,8 @@ public:
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
// Polyhedral structures.
AffineMap *parseAffineMapInline();
AffineMap *parseAffineMapReference();
AffineMap parseAffineMapInline();
AffineMap parseAffineMapReference();
IntegerSet *parseIntegerSetInline();
IntegerSet *parseIntegerSetReference();
@ -521,7 +521,7 @@ Type *Parser::parseMemRefType() {
return (emitError(typeLoc, "invalid memref element type"), nullptr);
// Parse semi-affine-map-composition.
SmallVector<AffineMap *, 2> affineMapComposition;
SmallVector<AffineMap, 2> affineMapComposition;
unsigned memorySpace = 0;
bool parsedMemorySpace = false;
@ -540,8 +540,8 @@ Type *Parser::parseMemRefType() {
// Parse affine map.
if (parsedMemorySpace)
return emitError("affine map after memory space in memref type");
auto *affineMap = parseAffineMapReference();
if (affineMap == nullptr)
auto affineMap = parseAffineMapReference();
if (!affineMap)
return ParseFailure;
affineMapComposition.push_back(affineMap);
}
@ -728,7 +728,7 @@ Attribute *Parser::parseAttribute() {
case Token::hash_identifier:
case Token::l_paren: {
// Try to parse affine map reference.
if (auto *affineMap = parseAffineMapReference())
if (auto affineMap = parseAffineMapReference())
return builder.getAffineMapAttr(affineMap);
return (emitError("expected constant attribute value"), nullptr);
}
@ -827,7 +827,7 @@ class AffineParser : public Parser {
public:
explicit AffineParser(ParserState &state) : Parser(state) {}
AffineMap *parseAffineMapInline();
AffineMap parseAffineMapInline();
IntegerSet *parseIntegerSetInline();
private:
@ -1223,22 +1223,22 @@ ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)`
///
/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
AffineMap *AffineParser::parseAffineMapInline() {
AffineMap AffineParser::parseAffineMapInline() {
unsigned numDims = 0, numSymbols = 0;
// List of dimensional identifiers.
if (parseDimIdList(numDims))
return nullptr;
return AffineMap::Invalid();
// Symbols are optional.
if (getToken().is(Token::l_square)) {
if (parseSymbolIdList(numSymbols))
return nullptr;
return AffineMap::Invalid();
}
if (parseToken(Token::arrow, "expected '->' or '['") ||
parseToken(Token::l_paren, "expected '(' at start of affine map range"))
return nullptr;
return AffineMap::Invalid();
SmallVector<AffineExpr, 4> exprs;
auto parseElt = [&]() -> ParseResult {
@ -1252,7 +1252,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
// affine expressions); the list cannot be empty.
// Grammar: multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
return nullptr;
return AffineMap::Invalid();
// Parse optional range sizes.
// range-sizes ::= (`size` `(` dim-size (`,` dim-size)* `)`)?
@ -1264,7 +1264,7 @@ AffineMap *AffineParser::parseAffineMapInline() {
// Location of the l_paren token (if it exists) for error reporting later.
auto loc = getToken().getLoc();
if (parseToken(Token::l_paren, "expected '(' at start of affine map range"))
return nullptr;
return AffineMap::Invalid();
auto parseRangeSize = [&]() -> ParseResult {
auto loc = getToken().getLoc();
@ -1281,30 +1281,30 @@ AffineMap *AffineParser::parseAffineMapInline() {
};
if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false))
return nullptr;
return AffineMap::Invalid();
if (exprs.size() > rangeSizes.size())
return (emitError(loc, "fewer range sizes than range expressions"),
nullptr);
AffineMap::Invalid());
if (exprs.size() < rangeSizes.size())
return (emitError(loc, "more range sizes than range expressions"),
nullptr);
AffineMap::Invalid());
}
// Parsed a valid affine map.
return builder.getAffineMap(numDims, numSymbols, exprs, rangeSizes);
}
AffineMap *Parser::parseAffineMapInline() {
AffineMap Parser::parseAffineMapInline() {
return AffineParser(state).parseAffineMapInline();
}
AffineMap *Parser::parseAffineMapReference() {
AffineMap Parser::parseAffineMapReference() {
if (getToken().is(Token::hash_identifier)) {
// Parse affine map identifier and verify that it exists.
StringRef affineMapId = getTokenSpelling().drop_front();
if (getState().affineMapDefinitions.count(affineMapId) == 0)
return (emitError("undefined affine map id '" + affineMapId + "'"),
nullptr);
AffineMap::Invalid());
consumeToken(Token::hash_identifier);
return getState().affineMapDefinitions[affineMapId];
}
@ -2221,7 +2221,7 @@ private:
ParseResult parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap *&map,
ParseResult parseBound(SmallVectorImpl<MLValue *> &operands, AffineMap &map,
bool isLower);
ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
@ -2261,7 +2261,7 @@ ParseResult MLFunctionParser::parseForStmt() {
// Parse lower bound.
SmallVector<MLValue *, 4> lbOperands;
AffineMap *lbMap = nullptr;
AffineMap lbMap;
if (parseBound(lbOperands, lbMap, /*isLower*/ true))
return ParseFailure;
@ -2270,7 +2270,7 @@ ParseResult MLFunctionParser::parseForStmt() {
// Parse upper bound.
SmallVector<MLValue *, 4> ubOperands;
AffineMap *ubMap = nullptr;
AffineMap ubMap;
if (parseBound(ubOperands, ubMap, /*isLower*/ false))
return ParseFailure;
@ -2388,7 +2388,7 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl<MLValue *> &operands,
/// shorthand-bound ::= ssa-id | `-`? integer-literal
///
ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
AffineMap *&map, bool isLower) {
AffineMap &map, bool isLower) {
// 'min' / 'max' prefixes are syntactic sugar. Ignore them.
if (isLower)
consumeIf(Token::kw_max);
@ -2401,7 +2401,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<MLValue *> &operands,
if (!map)
return ParseFailure;
if (parseDimAndSymbolList(operands, map->getNumDims(), map->getNumInputs(),
if (parseDimAndSymbolList(operands, map.getNumDims(), map.getNumInputs(),
"affine map"))
return ParseFailure;
return ParseSuccess;
@ -2691,7 +2691,7 @@ ParseResult ModuleParser::parseAffineMapDef() {
StringRef affineMapId = getTokenSpelling().drop_front();
// Check for redefinitions.
auto *&entry = getState().affineMapDefinitions[affineMapId];
auto &entry = getState().affineMapDefinitions[affineMapId];
if (entry)
return emitError("redefinition of affine map id '" + affineMapId + "'");

View File

@ -103,9 +103,9 @@ static void createComposedAffineApplyOp(
unsigned rank = memrefType->getRank();
assert(indices.size() == rank);
// Create identity map with same number of dimensions as 'memrefType'.
auto *map = builder->getMultiDimIdentityMap(rank);
auto map = builder->getMultiDimIdentityMap(rank);
// Initialize AffineValueMap with identity map.
AffineValueMap valueMap(map, indices, builder->getContext());
AffineValueMap valueMap(map, indices);
for (auto *opStmt : affineApplyOps) {
assert(opStmt->is<AffineApplyOp>());

View File

@ -201,14 +201,14 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
if (unrollFactor == 1 || forStmt->getStatements().empty())
return false;
auto *lbMap = forStmt->getLowerBoundMap();
auto *ubMap = forStmt->getUpperBoundMap();
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as an MLFunction in the general case). However, the right way to
// do such unrolling for an MLFunction would be to specialize the loop for the
// 'hotspot' case and unroll that hotspot.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return false;
// Same operand list for lower and upper bound for now.
@ -229,7 +229,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
DenseMap<const MLValue *, MLValue *> operandMap;
MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
auto *clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
"always be determined");
@ -238,7 +238,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
promoteIfSingleIteration(cleanupForStmt);
// Adjust upper bound.
auto *unrolledUbMap =
auto unrolledUbMap =
getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
assert(unrolledUbMap &&
"upper bound map can alwayys be determined for an unrolled loop "
@ -267,7 +267,7 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
if (!forStmt->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto *bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
->getResult(0);

View File

@ -156,14 +156,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
return false;
auto *lbMap = forStmt->getLowerBoundMap();
auto *ubMap = forStmt->getUpperBoundMap();
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as an MLFunction in the general case). However, the right way to
// do such unrolling for an MLFunction would be to specialize the loop for the
// 'hotspot' case and unroll that hotspot.
if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return false;
// Same operand list for lower and upper bound for now.
@ -221,7 +221,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
if (!forStmt->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto *bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
->getResult(0);

View File

@ -35,50 +35,50 @@ using namespace mlir;
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor. Returns nullptr when
/// the trip count can't be expressed as an affine expression.
AffineMap *mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
auto *lbMap = forStmt.getLowerBoundMap();
AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
auto lbMap = forStmt.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap->getNumResults() != 1)
return nullptr;
if (lbMap.getNumResults() != 1)
return AffineMap::Invalid();
// Sometimes, the trip count cannot be expressed as an affine expression.
auto tripCount = getTripCountExpr(forStmt);
if (!tripCount)
return nullptr;
return AffineMap::Invalid();
AffineExpr lb(lbMap->getResult(0));
AffineExpr lb(lbMap.getResult(0));
unsigned step = forStmt.getStep();
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
{newUb}, {});
}
/// Returns the lower bound of the cleanup loop when unrolling a loop with lower
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
/// Returns nullptr when the trip count can't be expressed as an affine
/// expression.
AffineMap *mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
auto *lbMap = forStmt.getLowerBoundMap();
/// Returns an AffinMap with nullptr storage (that evaluates to false)
/// when the trip count can't be expressed as an affine expression.
AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
unsigned unrollFactor,
MLFuncBuilder *builder) {
auto lbMap = forStmt.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap->getNumResults() != 1)
return nullptr;
if (lbMap.getNumResults() != 1)
return AffineMap::Invalid();
// Sometimes the trip count cannot be expressed as an affine expression.
AffineExpr tripCount(getTripCountExpr(forStmt));
if (!tripCount)
return nullptr;
return AffineMap::Invalid();
AffineExpr lb(lbMap->getResult(0));
AffineExpr lb(lbMap.getResult(0));
unsigned step = forStmt.getStep();
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
return builder->getAffineMap(lbMap->getNumDims(), lbMap->getNumSymbols(),
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
{newLb}, {});
}
@ -91,7 +91,7 @@ bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
return false;
// TODO(mlir-team): there is no builder for a max.
if (forStmt->getLowerBoundMap()->getNumResults() != 1)
if (forStmt->getLowerBoundMap().getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
@ -140,7 +140,7 @@ void mlir::promoteSingleIterationLoops(MLFunction *f) {
/// the pair specifies the delay applied to that group of statements. Returns
/// nullptr if the generated loop simplifies to a single iteration one.
static ForStmt *
generateLoop(AffineMap *lb, AffineMap *ub,
generateLoop(AffineMap lb, AffineMap ub,
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
&stmtGroupQueue,
unsigned offset, ForStmt *srcForStmt, MLFuncBuilder *b) {
@ -296,7 +296,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
// of statements is paired with its delay.
std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
auto *origLbMap = forStmt->getLowerBoundMap();
auto origLbMap = forStmt->getLowerBoundMap();
uint64_t lbDelay = 0;
MLFuncBuilder b(forStmt);
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {

View File

@ -100,7 +100,8 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
->getResult());
auto d0 = bInner.getAffineDimExpr(0);
auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
auto modTwoMap =
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))

View File

@ -59,10 +59,10 @@ PassResult SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
void visitOperationStmt(OperationStmt *opStmt) {
for (auto attr : opStmt->getAttrs()) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
MutableAffineMap mMap(mapAttr->getValue(), context);
MutableAffineMap mMap(mapAttr->getValue());
mMap.simplify();
auto *map = mMap.getAffineMap();
opStmt->setAttr(attr.first, AffineMapAttr::get(map, context));
auto map = mMap.getAffineMap();
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
}
}
}

View File

@ -48,14 +48,14 @@ static bool isMemRefDereferencingOp(const Operation &op) {
// extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
ArrayRef<SSAValue *> extraIndices,
AffineMap *indexRemap) {
AffineMap indexRemap) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
(void)newMemRefRank;
if (indexRemap) {
assert(indexRemap->getNumInputs() == oldMemRefRank);
assert(indexRemap->getNumResults() + extraIndices.size() == newMemRefRank);
assert(indexRemap.getNumInputs() == oldMemRefRank);
assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
} else {
assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
}